feat(eval): FU-5 — retrieval eval harness + halacha backlog visibility (#63)
Covers GAP-11 (INV-RET4/G8) and GAP-14 (INV-QA1/G10). Retrieval quality was
never measured (only telemetry observation) and the halacha review backlog was
invisible (the 10/19 gap was found by accident).
Unit B — backlog visibility (pure code, container):
- metrics.halacha_backlog(conn) → {pending_review, approved, rejected, published,
total, oldest_pending_at}; surfaced in metrics.get_dashboard() (get_metrics MCP
tool) and /api/system/diagnostics. Live count revealed 178 pending / 1552 total,
oldest from 2026-05-03 — previously invisible.
Unit A — retrieval eval harness (host-side scripts):
- scripts/eval_gold_bootstrap.py — seeds data/eval/gold-set.jsonl. Two sources:
citations (cited==relevant via search_relevance_feedback — empty until decisions
cite precedents) and known_item (query=case_name → relevant=self; a real
citation-free signal, the methodology #52 checked by hand). Idempotent; preserves
source='chair' rows.
- scripts/eval_retrieval.py — runs the production retrieval path (search_library /
search_internal) over the gold-set; computes precision@k, recall@k, MRR, nDCG@k
(k=5,10); aggregates overall + per-corpus + per-practice_area; writes a report and
a delta vs committed baseline.json (which records the retrieval_config it reflects).
--self-test unit-checks the metric math offline.
Gold-set strategy = hybrid (chair decision): bootstrap + chair review. The citation
source is empty today (0 cited precedents in decisions), so the seed is known-item
(77 queries: 54 internal_decisions + 23 precedent_library). The gold-set is
PROVISIONAL until Dafna reviews it (the domain chair-gate).
Baseline (production config: multimodal+rerank on): R@10=0.987, MRR=0.837,
nDCG@10=0.872. Finding: MULTIMODAL_ENABLED=true slightly lowers known-item recall
(image-page results displace exact name matches) — relevant to #15. precedent_library
weaker than internal (R@10 0.957 vs 1.0) — one external precedent unfindable by name.
"CI gate" realized as discipline (re-runnable harness + committed baseline + run
before/after any retrieval-layer change) — retrieval needs prod DB + Voyage, no CI
runner has that access.
Spec: docs/superpowers/specs/2026-05-31-fu5-eval-harness-design.md
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
294
scripts/eval_retrieval.py
Normal file
294
scripts/eval_retrieval.py
Normal file
@@ -0,0 +1,294 @@
|
||||
#!/usr/bin/env python3
|
||||
"""FU-5 (GAP-11, INV-RET4/G8) — retrieval eval harness: precision/recall/MRR/nDCG.
|
||||
|
||||
Runs the PRODUCTION retrieval path (the same service functions the MCP search tools
|
||||
call) over the labeled gold-set (data/eval/gold-set.jsonl, built by
|
||||
scripts/eval_gold_bootstrap.py) and reports retrieval quality. This is the empirical
|
||||
measurement INV-RET4 requires: no more tuning RRF weights / k / embedder "by feel".
|
||||
|
||||
Metrics per query (relevant = gold case_law_ids; ranked = retrieved case_law_ids):
|
||||
• precision@k = |top-k ∩ relevant| / k
|
||||
• recall@k = |top-k ∩ relevant| / |relevant|
|
||||
• MRR = 1 / rank-of-first-relevant (0 if none retrieved)
|
||||
• nDCG@k = DCG@k / IDCG@k (binary gains, log2 discount)
|
||||
Aggregated as the mean overall, per corpus, and per practice_area.
|
||||
|
||||
"CI gate" by discipline: run before AND after any retrieval-layer change (RRF weights,
|
||||
k, chunk threshold, embedder, rerank) and compare to the committed data/eval/baseline.json.
|
||||
Retrieval needs the prod DB + Voyage, so this is a re-runnable script, not automated CI.
|
||||
|
||||
Usage (mcp-server venv; needs POSTGRES + VOYAGE_API_KEY for live runs):
|
||||
PY=/home/chaim/legal-ai/mcp-server/.venv/bin/python
|
||||
$PY scripts/eval_retrieval.py --self-test # offline metric unit tests (no DB)
|
||||
POSTGRES_PASSWORD=… VOYAGE_API_KEY=… POSTGRES_HOST=127.0.0.1 POSTGRES_PORT=5433 \
|
||||
$PY scripts/eval_retrieval.py # run eval, write report + baseline delta
|
||||
… $PY scripts/eval_retrieval.py --update-baseline # adopt current run as the new baseline
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(REPO_ROOT / "mcp-server" / "src"))
|
||||
|
||||
if "POSTGRES_URL" not in os.environ:
|
||||
os.environ["POSTGRES_URL"] = (
|
||||
f"postgres://{os.environ.get('POSTGRES_USER','legal_ai')}:"
|
||||
f"{os.environ.get('POSTGRES_PASSWORD','')}@"
|
||||
f"{os.environ.get('POSTGRES_HOST','127.0.0.1')}:"
|
||||
f"{os.environ.get('POSTGRES_PORT','5433')}/"
|
||||
f"{os.environ.get('POSTGRES_DB','legal_ai')}"
|
||||
)
|
||||
|
||||
EVAL_DIR = REPO_ROOT / "data" / "eval"
|
||||
GOLD_PATH = EVAL_DIR / "gold-set.jsonl"
|
||||
BASELINE_PATH = EVAL_DIR / "baseline.json"
|
||||
K_VALUES = (5, 10)
|
||||
|
||||
|
||||
# ── metrics (pure, unit-tested offline) ──────────────────────────────────────
|
||||
def precision_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
|
||||
if k <= 0:
|
||||
return 0.0
|
||||
topk = ranked[:k]
|
||||
return sum(1 for r in topk if r in relevant) / k
|
||||
|
||||
|
||||
def recall_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
|
||||
if not relevant:
|
||||
return 0.0
|
||||
topk = ranked[:k]
|
||||
return sum(1 for r in topk if r in relevant) / len(relevant)
|
||||
|
||||
|
||||
def mrr(ranked: list[str], relevant: set[str]) -> float:
|
||||
for i, r in enumerate(ranked, start=1):
|
||||
if r in relevant:
|
||||
return 1.0 / i
|
||||
return 0.0
|
||||
|
||||
|
||||
def ndcg_at_k(ranked: list[str], relevant: set[str], k: int) -> float:
|
||||
if not relevant:
|
||||
return 0.0
|
||||
dcg = sum((1.0 / math.log2(i + 1)) for i, r in enumerate(ranked[:k], start=1) if r in relevant)
|
||||
ideal_hits = min(len(relevant), k)
|
||||
idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1))
|
||||
return dcg / idcg if idcg else 0.0
|
||||
|
||||
|
||||
def _self_test() -> int:
|
||||
# ranked positions: 1 2 3 4
|
||||
ranked = ["A", "B", "C", "D"]
|
||||
rel = {"B", "D"} # relevant at ranks 2 and 4
|
||||
ok = True
|
||||
|
||||
def chk(name, got, exp):
|
||||
nonlocal ok
|
||||
good = abs(got - exp) < 1e-9
|
||||
ok = ok and good
|
||||
print(f" {name:14} got={got:.6f} exp={exp:.6f} {'ok' if good else 'FAIL'}")
|
||||
|
||||
chk("P@2", precision_at_k(ranked, rel, 2), 1 / 2) # B hit → 1/2
|
||||
chk("P@4", precision_at_k(ranked, rel, 4), 2 / 4) # B,D → 2/4
|
||||
chk("R@2", recall_at_k(ranked, rel, 2), 1 / 2) # 1 of 2 found
|
||||
chk("R@4", recall_at_k(ranked, rel, 4), 2 / 2) # both found
|
||||
chk("MRR", mrr(ranked, rel), 1 / 2) # first rel at rank 2
|
||||
# nDCG@4: DCG = 1/log2(3) + 1/log2(5); IDCG = 1/log2(2)+1/log2(3)
|
||||
dcg = 1 / math.log2(3) + 1 / math.log2(5)
|
||||
idcg = 1 / math.log2(2) + 1 / math.log2(3)
|
||||
chk("nDCG@4", ndcg_at_k(ranked, rel, 4), dcg / idcg)
|
||||
chk("MRR-none", mrr(ranked, {"Z"}), 0.0)
|
||||
chk("R@k-empty", recall_at_k(ranked, set(), 4), 0.0)
|
||||
print("ALL PASS" if ok else "*** FAILURES ***")
|
||||
return 0 if ok else 1
|
||||
|
||||
|
||||
# ── retrieval (production path) ──────────────────────────────────────────────
|
||||
def _ranked_ids(results: list[dict]) -> list[str]:
|
||||
"""Ranked, de-duplicated case_law_ids from a result list (order = ranking)."""
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for r in results or []:
|
||||
if not isinstance(r, dict):
|
||||
continue
|
||||
cid = r.get("case_law_id")
|
||||
if cid is None:
|
||||
continue
|
||||
s = str(cid)
|
||||
if s not in seen:
|
||||
seen.add(s)
|
||||
out.append(s)
|
||||
return out
|
||||
|
||||
|
||||
async def _retrieve(corpus: str, query: str, practice_area: str, limit: int) -> list[str]:
|
||||
from legal_mcp.services import precedent_library, internal_decisions
|
||||
if corpus == "precedent_library":
|
||||
res = await precedent_library.search_library(query=query, practice_area=practice_area, limit=limit)
|
||||
elif corpus == "internal_decisions":
|
||||
res = await internal_decisions.search_internal(query=query, practice_area=practice_area, limit=limit)
|
||||
else:
|
||||
return []
|
||||
return _ranked_ids(res)
|
||||
|
||||
|
||||
def _retrieval_config() -> dict:
|
||||
"""Capture the retrieval knobs the run reflects — a baseline is only comparable
|
||||
to another run under the SAME config (multimodal/rerank/weights change results)."""
|
||||
from legal_mcp import config as cfg
|
||||
return {
|
||||
"MULTIMODAL_ENABLED": cfg.MULTIMODAL_ENABLED,
|
||||
"VOYAGE_RERANK_ENABLED": cfg.VOYAGE_RERANK_ENABLED,
|
||||
"VOYAGE_MODEL": cfg.VOYAGE_MODEL,
|
||||
"MULTIMODAL_TEXT_WEIGHT": cfg.MULTIMODAL_TEXT_WEIGHT,
|
||||
"MULTIMODAL_RRF_K": cfg.MULTIMODAL_RRF_K,
|
||||
"BM25_HYBRID_ENABLED": cfg.BM25_HYBRID_ENABLED,
|
||||
}
|
||||
|
||||
|
||||
def _load_gold() -> list[dict]:
|
||||
if not GOLD_PATH.exists():
|
||||
return []
|
||||
out = []
|
||||
for line in GOLD_PATH.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
|
||||
def _mean(vals: list[float]) -> float:
|
||||
return sum(vals) / len(vals) if vals else 0.0
|
||||
|
||||
|
||||
def _aggregate(per_query: list[dict]) -> dict:
|
||||
"""Mean of every metric across the given per-query records."""
|
||||
agg: dict[str, float] = {}
|
||||
if not per_query:
|
||||
return agg
|
||||
keys = [k for k in per_query[0]["metrics"]]
|
||||
for mk in keys:
|
||||
agg[mk] = round(_mean([q["metrics"][mk] for q in per_query]), 4)
|
||||
return agg
|
||||
|
||||
|
||||
async def _run() -> dict:
|
||||
gold = _load_gold()
|
||||
kmax = max(K_VALUES)
|
||||
per_query: list[dict] = []
|
||||
for g in gold:
|
||||
relevant = set(g.get("relevant_case_law_ids") or [])
|
||||
ranked = await _retrieve(g["corpus"], g["query"], g.get("practice_area", ""), kmax)
|
||||
m: dict[str, float] = {}
|
||||
for k in K_VALUES:
|
||||
m[f"P@{k}"] = precision_at_k(ranked, relevant, k)
|
||||
m[f"R@{k}"] = recall_at_k(ranked, relevant, k)
|
||||
m[f"nDCG@{k}"] = ndcg_at_k(ranked, relevant, k)
|
||||
m["MRR"] = mrr(ranked, relevant)
|
||||
per_query.append({
|
||||
"id": g["id"], "corpus": g["corpus"], "practice_area": g.get("practice_area", ""),
|
||||
"query": g["query"], "n_relevant": len(relevant), "n_retrieved": len(ranked),
|
||||
"first_rank": next((i for i, r in enumerate(ranked, 1) if r in relevant), None),
|
||||
"metrics": m,
|
||||
})
|
||||
|
||||
corpora = sorted({q["corpus"] for q in per_query})
|
||||
pas = sorted({q["practice_area"] for q in per_query if q["practice_area"]})
|
||||
return {
|
||||
"gold_size": len(gold),
|
||||
"retrieval_config": _retrieval_config(),
|
||||
"overall": _aggregate(per_query),
|
||||
"by_corpus": {c: _aggregate([q for q in per_query if q["corpus"] == c]) for c in corpora},
|
||||
"by_practice_area": {p: _aggregate([q for q in per_query if q["practice_area"] == p]) for p in pas},
|
||||
"per_query": per_query,
|
||||
}
|
||||
|
||||
|
||||
def _ts() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
|
||||
|
||||
def _delta_table(cur: dict, base: dict | None) -> str:
|
||||
lines = ["| metric | current | baseline | Δ |", "|---|---|---|---|"]
|
||||
base_overall = (base or {}).get("overall", {})
|
||||
for mk, cv in cur["overall"].items():
|
||||
bv = base_overall.get(mk)
|
||||
d = f"{cv - bv:+.4f}" if isinstance(bv, (int, float)) else "—"
|
||||
lines.append(f"| {mk} | {cv:.4f} | {bv if bv is not None else '—'} | {d} |")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _write_report(result: dict, base: dict | None, ts: str) -> tuple[Path, Path]:
|
||||
EVAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
jp = EVAL_DIR / f"eval-report-{ts}.json"
|
||||
mp = EVAL_DIR / f"eval-report-{ts}.md"
|
||||
jp.write_text(json.dumps(result, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
cfg = result.get("retrieval_config", {})
|
||||
cfg_line = " · ".join(f"{k}={v}" for k, v in cfg.items())
|
||||
base_cfg = (base or {}).get("retrieval_config")
|
||||
cfg_warn = ""
|
||||
if base_cfg and base_cfg != cfg:
|
||||
cfg_warn = "\n> ⚠ retrieval_config differs from baseline — deltas are NOT apples-to-apples.\n"
|
||||
lines = [f"# FU-5 — דוח הערכת-אחזור — {ts}\n",
|
||||
f"- gold queries: {result['gold_size']}",
|
||||
f"- retrieval_config: {cfg_line}",
|
||||
f"- baseline: {'data/eval/baseline.json' if base else '(none yet)'}",
|
||||
cfg_warn,
|
||||
"## Overall (mean) — delta vs baseline\n", _delta_table(result, base), "",
|
||||
"## Per corpus\n"]
|
||||
if result["by_corpus"]:
|
||||
metric_keys = list(next(iter(result["by_corpus"].values())).keys())
|
||||
lines.append("| corpus | " + " | ".join(metric_keys) + " |")
|
||||
lines.append("|" + "---|" * (len(metric_keys) + 1))
|
||||
for c, agg in result["by_corpus"].items():
|
||||
lines.append(f"| {c} | " + " | ".join(f"{agg[k]:.4f}" for k in metric_keys) + " |")
|
||||
else:
|
||||
lines.append("(none)")
|
||||
mp.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||
return jp, mp
|
||||
|
||||
|
||||
async def main() -> int:
|
||||
ap = argparse.ArgumentParser(description="FU-5 retrieval eval harness")
|
||||
ap.add_argument("--self-test", action="store_true", help="run offline metric unit tests and exit")
|
||||
ap.add_argument("--update-baseline", action="store_true", help="write current run as data/eval/baseline.json")
|
||||
args = ap.parse_args()
|
||||
|
||||
if args.self_test:
|
||||
return _self_test()
|
||||
|
||||
gold = _load_gold()
|
||||
if not gold:
|
||||
print(f"gold-set empty ({GOLD_PATH}). Run scripts/eval_gold_bootstrap.py first.", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
result = await _run()
|
||||
base = json.loads(BASELINE_PATH.read_text(encoding="utf-8")) if BASELINE_PATH.exists() else None
|
||||
ts = _ts()
|
||||
jp, mp = _write_report(result, base, ts)
|
||||
|
||||
print(f"EVAL: {result['gold_size']} queries")
|
||||
for mk, v in result["overall"].items():
|
||||
bv = (base or {}).get("overall", {}).get(mk)
|
||||
d = f" (Δ {v - bv:+.4f})" if isinstance(bv, (int, float)) else ""
|
||||
print(f" {mk:8} {v:.4f}{d}")
|
||||
print(f" report: {mp}")
|
||||
|
||||
if args.update_baseline:
|
||||
snapshot = {k: result[k] for k in ("gold_size", "retrieval_config", "overall", "by_corpus", "by_practice_area")}
|
||||
snapshot["generated_at"] = ts
|
||||
BASELINE_PATH.write_text(json.dumps(snapshot, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
print(f" baseline updated: {BASELINE_PATH}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
Reference in New Issue
Block a user