Files
legal-ai/scripts/eval_retrieval.py
Chaim 6ff2e36bf9 feat(eval): FU-5 — retrieval eval harness + halacha backlog visibility (#63)
Covers GAP-11 (INV-RET4/G8) and GAP-14 (INV-QA1/G10). Retrieval quality was
never measured (only telemetry observation) and the halacha review backlog was
invisible (the 10/19 gap was found by accident).

Unit B — backlog visibility (pure code, container):
- metrics.halacha_backlog(conn) → {pending_review, approved, rejected, published,
  total, oldest_pending_at}; surfaced in metrics.get_dashboard() (get_metrics MCP
  tool) and /api/system/diagnostics. Live count revealed 178 pending / 1552 total,
  oldest from 2026-05-03 — previously invisible.

Unit A — retrieval eval harness (host-side scripts):
- scripts/eval_gold_bootstrap.py — seeds data/eval/gold-set.jsonl. Two sources:
  citations (cited==relevant via search_relevance_feedback — empty until decisions
  cite precedents) and known_item (query=case_name → relevant=self; a real
  citation-free signal, the methodology #52 checked by hand). Idempotent; preserves
  source='chair' rows.
- scripts/eval_retrieval.py — runs the production retrieval path (search_library /
  search_internal) over the gold-set; computes precision@k, recall@k, MRR, nDCG@k
  (k=5,10); aggregates overall + per-corpus + per-practice_area; writes a report and
  a delta vs committed baseline.json (which records the retrieval_config it reflects).
  --self-test unit-checks the metric math offline.

Gold-set strategy = hybrid (chair decision): bootstrap + chair review. The citation
source is empty today (0 cited precedents in decisions), so the seed is known-item
(77 queries: 54 internal_decisions + 23 precedent_library). The gold-set is
PROVISIONAL until Dafna reviews it (the domain chair-gate).

Baseline (production config: multimodal+rerank on): R@10=0.987, MRR=0.837,
nDCG@10=0.872. Finding: MULTIMODAL_ENABLED=true slightly lowers known-item recall
(image-page results displace exact name matches) — relevant to #15. precedent_library
weaker than internal (R@10 0.957 vs 1.0) — one external precedent unfindable by name.

"CI gate" realized as discipline (re-runnable harness + committed baseline + run
before/after any retrieval-layer change) — retrieval needs prod DB + Voyage, no CI
runner has that access.

Spec: docs/superpowers/specs/2026-05-31-fu5-eval-harness-design.md

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-05-31 14:58:13 +00:00

295 lines
12 KiB
Python

#!/usr/bin/env python3
"""FU-5 (GAP-11, INV-RET4/G8) — retrieval eval harness: precision/recall/MRR/nDCG.
Runs the PRODUCTION retrieval path (the same service functions the MCP search tools
call) over the labeled gold-set (data/eval/gold-set.jsonl, built by
scripts/eval_gold_bootstrap.py) and reports retrieval quality. This is the empirical
measurement INV-RET4 requires: no more tuning RRF weights / k / embedder "by feel".
Metrics per query (relevant = gold case_law_ids; ranked = retrieved case_law_ids):
• precision@k = |top-k ∩ relevant| / k
• recall@k = |top-k ∩ relevant| / |relevant|
• MRR = 1 / rank-of-first-relevant (0 if none retrieved)
• nDCG@k = DCG@k / IDCG@k (binary gains, log2 discount)
Aggregated as the mean overall, per corpus, and per practice_area.
"CI gate" by discipline: run before AND after any retrieval-layer change (RRF weights,
k, chunk threshold, embedder, rerank) and compare to the committed data/eval/baseline.json.
Retrieval needs the prod DB + Voyage, so this is a re-runnable script, not automated CI.
Usage (mcp-server venv; needs POSTGRES + VOYAGE_API_KEY for live runs):
PY=/home/chaim/legal-ai/mcp-server/.venv/bin/python
$PY scripts/eval_retrieval.py --self-test # offline metric unit tests (no DB)
POSTGRES_PASSWORD=… VOYAGE_API_KEY=… POSTGRES_HOST=127.0.0.1 POSTGRES_PORT=5433 \
$PY scripts/eval_retrieval.py # run eval, write report + baseline delta
… $PY scripts/eval_retrieval.py --update-baseline # adopt current run as the new baseline
"""
from __future__ import annotations
import argparse
import asyncio
import json
import math
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT / "mcp-server" / "src"))
if "POSTGRES_URL" not in os.environ:
os.environ["POSTGRES_URL"] = (
f"postgres://{os.environ.get('POSTGRES_USER','legal_ai')}:"
f"{os.environ.get('POSTGRES_PASSWORD','')}@"
f"{os.environ.get('POSTGRES_HOST','127.0.0.1')}:"
f"{os.environ.get('POSTGRES_PORT','5433')}/"
f"{os.environ.get('POSTGRES_DB','legal_ai')}"
)
EVAL_DIR = REPO_ROOT / "data" / "eval"
GOLD_PATH = EVAL_DIR / "gold-set.jsonl"
BASELINE_PATH = EVAL_DIR / "baseline.json"
K_VALUES = (5, 10)
# ── metrics (pure, unit-tested offline) ──────────────────────────────────────
def precision_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
if k <= 0:
return 0.0
topk = ranked[:k]
return sum(1 for r in topk if r in relevant) / k
def recall_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
if not relevant:
return 0.0
topk = ranked[:k]
return sum(1 for r in topk if r in relevant) / len(relevant)
def mrr(ranked: list[str], relevant: set[str]) -> float:
for i, r in enumerate(ranked, start=1):
if r in relevant:
return 1.0 / i
return 0.0
def ndcg_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
if not relevant:
return 0.0
dcg = sum((1.0 / math.log2(i + 1)) for i, r in enumerate(ranked[:k], start=1) if r in relevant)
ideal_hits = min(len(relevant), k)
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1))
return dcg / idcg if idcg else 0.0
def _self_test() -> int:
# ranked positions: 1 2 3 4
ranked = ["A", "B", "C", "D"]
rel = {"B", "D"} # relevant at ranks 2 and 4
ok = True
def chk(name, got, exp):
nonlocal ok
good = abs(got - exp) < 1e-9
ok = ok and good
print(f" {name:14} got={got:.6f} exp={exp:.6f} {'ok' if good else 'FAIL'}")
chk("P@2", precision_at_k(ranked, rel, 2), 1 / 2) # B hit → 1/2
chk("P@4", precision_at_k(ranked, rel, 4), 2 / 4) # B,D → 2/4
chk("R@2", recall_at_k(ranked, rel, 2), 1 / 2) # 1 of 2 found
chk("R@4", recall_at_k(ranked, rel, 4), 2 / 2) # both found
chk("MRR", mrr(ranked, rel), 1 / 2) # first rel at rank 2
# nDCG@4: DCG = 1/log2(3) + 1/log2(5); IDCG = 1/log2(2)+1/log2(3)
dcg = 1 / math.log2(3) + 1 / math.log2(5)
idcg = 1 / math.log2(2) + 1 / math.log2(3)
chk("nDCG@4", ndcg_at_k(ranked, rel, 4), dcg / idcg)
chk("MRR-none", mrr(ranked, {"Z"}), 0.0)
chk("R@k-empty", recall_at_k(ranked, set(), 4), 0.0)
print("ALL PASS" if ok else "*** FAILURES ***")
return 0 if ok else 1
# ── retrieval (production path) ──────────────────────────────────────────────
def _ranked_ids(results: list[dict]) -> list[str]:
"""Ranked, de-duplicated case_law_ids from a result list (order = ranking)."""
out: list[str] = []
seen: set[str] = set()
for r in results or []:
if not isinstance(r, dict):
continue
cid = r.get("case_law_id")
if cid is None:
continue
s = str(cid)
if s not in seen:
seen.add(s)
out.append(s)
return out
async def _retrieve(corpus: str, query: str, practice_area: str, limit: int) -> list[str]:
from legal_mcp.services import precedent_library, internal_decisions
if corpus == "precedent_library":
res = await precedent_library.search_library(query=query, practice_area=practice_area, limit=limit)
elif corpus == "internal_decisions":
res = await internal_decisions.search_internal(query=query, practice_area=practice_area, limit=limit)
else:
return []
return _ranked_ids(res)
def _retrieval_config() -> dict:
"""Capture the retrieval knobs the run reflects — a baseline is only comparable
to another run under the SAME config (multimodal/rerank/weights change results)."""
from legal_mcp import config as cfg
return {
"MULTIMODAL_ENABLED": cfg.MULTIMODAL_ENABLED,
"VOYAGE_RERANK_ENABLED": cfg.VOYAGE_RERANK_ENABLED,
"VOYAGE_MODEL": cfg.VOYAGE_MODEL,
"MULTIMODAL_TEXT_WEIGHT": cfg.MULTIMODAL_TEXT_WEIGHT,
"MULTIMODAL_RRF_K": cfg.MULTIMODAL_RRF_K,
"BM25_HYBRID_ENABLED": cfg.BM25_HYBRID_ENABLED,
}
def _load_gold() -> list[dict]:
if not GOLD_PATH.exists():
return []
out = []
for line in GOLD_PATH.read_text(encoding="utf-8").splitlines():
line = line.strip()
if line:
out.append(json.loads(line))
return out
def _mean(vals: list[float]) -> float:
return sum(vals) / len(vals) if vals else 0.0
def _aggregate(per_query: list[dict]) -> dict:
"""Mean of every metric across the given per-query records."""
agg: dict[str, float] = {}
if not per_query:
return agg
keys = [k for k in per_query[0]["metrics"]]
for mk in keys:
agg[mk] = round(_mean([q["metrics"][mk] for q in per_query]), 4)
return agg
async def _run() -> dict:
gold = _load_gold()
kmax = max(K_VALUES)
per_query: list[dict] = []
for g in gold:
relevant = set(g.get("relevant_case_law_ids") or [])
ranked = await _retrieve(g["corpus"], g["query"], g.get("practice_area", ""), kmax)
m: dict[str, float] = {}
for k in K_VALUES:
m[f"P@{k}"] = precision_at_k(ranked, relevant, k)
m[f"R@{k}"] = recall_at_k(ranked, relevant, k)
m[f"nDCG@{k}"] = ndcg_at_k(ranked, relevant, k)
m["MRR"] = mrr(ranked, relevant)
per_query.append({
"id": g["id"], "corpus": g["corpus"], "practice_area": g.get("practice_area", ""),
"query": g["query"], "n_relevant": len(relevant), "n_retrieved": len(ranked),
"first_rank": next((i for i, r in enumerate(ranked, 1) if r in relevant), None),
"metrics": m,
})
corpora = sorted({q["corpus"] for q in per_query})
pas = sorted({q["practice_area"] for q in per_query if q["practice_area"]})
return {
"gold_size": len(gold),
"retrieval_config": _retrieval_config(),
"overall": _aggregate(per_query),
"by_corpus": {c: _aggregate([q for q in per_query if q["corpus"] == c]) for c in corpora},
"by_practice_area": {p: _aggregate([q for q in per_query if q["practice_area"] == p]) for p in pas},
"per_query": per_query,
}
def _ts() -> str:
return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
def _delta_table(cur: dict, base: dict | None) -> str:
lines = ["| metric | current | baseline | Δ |", "|---|---|---|---|"]
base_overall = (base or {}).get("overall", {})
for mk, cv in cur["overall"].items():
bv = base_overall.get(mk)
d = f"{cv - bv:+.4f}" if isinstance(bv, (int, float)) else ""
lines.append(f"| {mk} | {cv:.4f} | {bv if bv is not None else ''} | {d} |")
return "\n".join(lines)
def _write_report(result: dict, base: dict | None, ts: str) -> tuple[Path, Path]:
EVAL_DIR.mkdir(parents=True, exist_ok=True)
jp = EVAL_DIR / f"eval-report-{ts}.json"
mp = EVAL_DIR / f"eval-report-{ts}.md"
jp.write_text(json.dumps(result, ensure_ascii=False, indent=2), encoding="utf-8")
cfg = result.get("retrieval_config", {})
cfg_line = " · ".join(f"{k}={v}" for k, v in cfg.items())
base_cfg = (base or {}).get("retrieval_config")
cfg_warn = ""
if base_cfg and base_cfg != cfg:
cfg_warn = "\n> ⚠ retrieval_config differs from baseline — deltas are NOT apples-to-apples.\n"
lines = [f"# FU-5 — דוח הערכת-אחזור — {ts}\n",
f"- gold queries: {result['gold_size']}",
f"- retrieval_config: {cfg_line}",
f"- baseline: {'data/eval/baseline.json' if base else '(none yet)'}",
cfg_warn,
"## Overall (mean) — delta vs baseline\n", _delta_table(result, base), "",
"## Per corpus\n"]
if result["by_corpus"]:
metric_keys = list(next(iter(result["by_corpus"].values())).keys())
lines.append("| corpus | " + " | ".join(metric_keys) + " |")
lines.append("|" + "---|" * (len(metric_keys) + 1))
for c, agg in result["by_corpus"].items():
lines.append(f"| {c} | " + " | ".join(f"{agg[k]:.4f}" for k in metric_keys) + " |")
else:
lines.append("(none)")
mp.write_text("\n".join(lines) + "\n", encoding="utf-8")
return jp, mp
async def main() -> int:
ap = argparse.ArgumentParser(description="FU-5 retrieval eval harness")
ap.add_argument("--self-test", action="store_true", help="run offline metric unit tests and exit")
ap.add_argument("--update-baseline", action="store_true", help="write current run as data/eval/baseline.json")
args = ap.parse_args()
if args.self_test:
return _self_test()
gold = _load_gold()
if not gold:
print(f"gold-set empty ({GOLD_PATH}). Run scripts/eval_gold_bootstrap.py first.", file=sys.stderr)
return 2
result = await _run()
base = json.loads(BASELINE_PATH.read_text(encoding="utf-8")) if BASELINE_PATH.exists() else None
ts = _ts()
jp, mp = _write_report(result, base, ts)
print(f"EVAL: {result['gold_size']} queries")
for mk, v in result["overall"].items():
bv = (base or {}).get("overall", {}).get(mk)
d = f"{v - bv:+.4f})" if isinstance(bv, (int, float)) else ""
print(f" {mk:8} {v:.4f}{d}")
print(f" report: {mp}")
if args.update_baseline:
snapshot = {k: result[k] for k in ("gold_size", "retrieval_config", "overall", "by_corpus", "by_practice_area")}
snapshot["generated_at"] = ts
BASELINE_PATH.write_text(json.dumps(snapshot, ensure_ascii=False, indent=2), encoding="utf-8")
print(f" baseline updated: {BASELINE_PATH}")
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))