#!/usr/bin/env python3 """FU-5 (GAP-11, INV-RET4/G8) — retrieval eval harness: precision/recall/MRR/nDCG. Runs the PRODUCTION retrieval path (the same service functions the MCP search tools call) over the labeled gold-set (data/eval/gold-set.jsonl, built by scripts/eval_gold_bootstrap.py) and reports retrieval quality. This is the empirical measurement INV-RET4 requires: no more tuning RRF weights / k / embedder "by feel". Metrics per query (relevant = gold case_law_ids; ranked = retrieved case_law_ids): • precision@k = |top-k ∩ relevant| / k • recall@k = |top-k ∩ relevant| / |relevant| • MRR = 1 / rank-of-first-relevant (0 if none retrieved) • nDCG@k = DCG@k / IDCG@k (binary gains, log2 discount) Aggregated as the mean overall, per corpus, and per practice_area. "CI gate" by discipline: run before AND after any retrieval-layer change (RRF weights, k, chunk threshold, embedder, rerank) and compare to the committed data/eval/baseline.json. Retrieval needs the prod DB + Voyage, so this is a re-runnable script, not automated CI. Usage (mcp-server venv; needs POSTGRES + VOYAGE_API_KEY for live runs): PY=/home/chaim/legal-ai/mcp-server/.venv/bin/python $PY scripts/eval_retrieval.py --self-test # offline metric unit tests (no DB) POSTGRES_PASSWORD=… VOYAGE_API_KEY=… POSTGRES_HOST=127.0.0.1 POSTGRES_PORT=5433 \ $PY scripts/eval_retrieval.py # run eval, write report + baseline delta … $PY scripts/eval_retrieval.py --update-baseline # adopt current run as the new baseline """ from __future__ import annotations import argparse import asyncio import json import math import os import sys from datetime import datetime, timezone from pathlib import Path REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT / "mcp-server" / "src")) if "POSTGRES_URL" not in os.environ: os.environ["POSTGRES_URL"] = ( f"postgres://{os.environ.get('POSTGRES_USER','legal_ai')}:" f"{os.environ.get('POSTGRES_PASSWORD','')}@" f"{os.environ.get('POSTGRES_HOST','127.0.0.1')}:" f"{os.environ.get('POSTGRES_PORT','5433')}/" f"{os.environ.get('POSTGRES_DB','legal_ai')}" ) EVAL_DIR = REPO_ROOT / "data" / "eval" GOLD_PATH = EVAL_DIR / "gold-set.jsonl" BASELINE_PATH = EVAL_DIR / "baseline.json" K_VALUES = (5, 10) # ── metrics (pure, unit-tested offline) ────────────────────────────────────── def precision_at_k(ranked: list[str], relevant: set[str], k: int) -> float: if k <= 0: return 0.0 topk = ranked[:k] return sum(1 for r in topk if r in relevant) / k def recall_at_k(ranked: list[str], relevant: set[str], k: int) -> float: if not relevant: return 0.0 topk = ranked[:k] return sum(1 for r in topk if r in relevant) / len(relevant) def mrr(ranked: list[str], relevant: set[str]) -> float: for i, r in enumerate(ranked, start=1): if r in relevant: return 1.0 / i return 0.0 def ndcg_at_k(ranked: list[str], relevant: set[str], k: int) -> float: if not relevant: return 0.0 dcg = sum((1.0 / math.log2(i + 1)) for i, r in enumerate(ranked[:k], start=1) if r in relevant) ideal_hits = min(len(relevant), k) idcg = sum(1.0 / math.log2(i + 1) for i in range(1, ideal_hits + 1)) return dcg / idcg if idcg else 0.0 def _self_test() -> int: # ranked positions: 1 2 3 4 ranked = ["A", "B", "C", "D"] rel = {"B", "D"} # relevant at ranks 2 and 4 ok = True def chk(name, got, exp): nonlocal ok good = abs(got - exp) < 1e-9 ok = ok and good print(f" {name:14} got={got:.6f} exp={exp:.6f} {'ok' if good else 'FAIL'}") chk("P@2", precision_at_k(ranked, rel, 2), 1 / 2) # B hit → 1/2 chk("P@4", precision_at_k(ranked, rel, 4), 2 / 4) # B,D → 2/4 chk("R@2", recall_at_k(ranked, rel, 2), 1 / 2) # 1 of 2 found chk("R@4", recall_at_k(ranked, rel, 4), 2 / 2) # both found chk("MRR", mrr(ranked, rel), 1 / 2) # first rel at rank 2 # nDCG@4: DCG = 1/log2(3) + 1/log2(5); IDCG = 1/log2(2)+1/log2(3) dcg = 1 / math.log2(3) + 1 / math.log2(5) idcg = 1 / math.log2(2) + 1 / math.log2(3) chk("nDCG@4", ndcg_at_k(ranked, rel, 4), dcg / idcg) chk("MRR-none", mrr(ranked, {"Z"}), 0.0) chk("R@k-empty", recall_at_k(ranked, set(), 4), 0.0) print("ALL PASS" if ok else "*** FAILURES ***") return 0 if ok else 1 # ── retrieval (production path) ────────────────────────────────────────────── def _ranked_ids(results: list[dict]) -> list[str]: """Ranked, de-duplicated case_law_ids from a result list (order = ranking).""" out: list[str] = [] seen: set[str] = set() for r in results or []: if not isinstance(r, dict): continue cid = r.get("case_law_id") if cid is None: continue s = str(cid) if s not in seen: seen.add(s) out.append(s) return out async def _retrieve(corpus: str, query: str, practice_area: str, limit: int) -> list[str]: from legal_mcp.services import precedent_library, internal_decisions if corpus == "precedent_library": res = await precedent_library.search_library(query=query, practice_area=practice_area, limit=limit) elif corpus == "internal_decisions": res = await internal_decisions.search_internal(query=query, practice_area=practice_area, limit=limit) else: return [] return _ranked_ids(res) def _retrieval_config() -> dict: """Capture the retrieval knobs the run reflects — a baseline is only comparable to another run under the SAME config (multimodal/rerank/weights change results).""" from legal_mcp import config as cfg return { "MULTIMODAL_ENABLED": cfg.MULTIMODAL_ENABLED, "VOYAGE_RERANK_ENABLED": cfg.VOYAGE_RERANK_ENABLED, "VOYAGE_MODEL": cfg.VOYAGE_MODEL, "MULTIMODAL_TEXT_WEIGHT": cfg.MULTIMODAL_TEXT_WEIGHT, "MULTIMODAL_RRF_K": cfg.MULTIMODAL_RRF_K, "BM25_HYBRID_ENABLED": cfg.BM25_HYBRID_ENABLED, } def _load_gold() -> list[dict]: if not GOLD_PATH.exists(): return [] out = [] for line in GOLD_PATH.read_text(encoding="utf-8").splitlines(): line = line.strip() if line: out.append(json.loads(line)) return out def _mean(vals: list[float]) -> float: return sum(vals) / len(vals) if vals else 0.0 def _aggregate(per_query: list[dict]) -> dict: """Mean of every metric across the given per-query records.""" agg: dict[str, float] = {} if not per_query: return agg keys = [k for k in per_query[0]["metrics"]] for mk in keys: agg[mk] = round(_mean([q["metrics"][mk] for q in per_query]), 4) return agg async def _run() -> dict: gold = _load_gold() kmax = max(K_VALUES) per_query: list[dict] = [] for g in gold: relevant = set(g.get("relevant_case_law_ids") or []) ranked = await _retrieve(g["corpus"], g["query"], g.get("practice_area", ""), kmax) m: dict[str, float] = {} for k in K_VALUES: m[f"P@{k}"] = precision_at_k(ranked, relevant, k) m[f"R@{k}"] = recall_at_k(ranked, relevant, k) m[f"nDCG@{k}"] = ndcg_at_k(ranked, relevant, k) m["MRR"] = mrr(ranked, relevant) per_query.append({ "id": g["id"], "corpus": g["corpus"], "practice_area": g.get("practice_area", ""), "query": g["query"], "n_relevant": len(relevant), "n_retrieved": len(ranked), "first_rank": next((i for i, r in enumerate(ranked, 1) if r in relevant), None), "metrics": m, }) corpora = sorted({q["corpus"] for q in per_query}) pas = sorted({q["practice_area"] for q in per_query if q["practice_area"]}) return { "gold_size": len(gold), "retrieval_config": _retrieval_config(), "overall": _aggregate(per_query), "by_corpus": {c: _aggregate([q for q in per_query if q["corpus"] == c]) for c in corpora}, "by_practice_area": {p: _aggregate([q for q in per_query if q["practice_area"] == p]) for p in pas}, "per_query": per_query, } def _ts() -> str: return datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") def _delta_table(cur: dict, base: dict | None) -> str: lines = ["| metric | current | baseline | Δ |", "|---|---|---|---|"] base_overall = (base or {}).get("overall", {}) for mk, cv in cur["overall"].items(): bv = base_overall.get(mk) d = f"{cv - bv:+.4f}" if isinstance(bv, (int, float)) else "—" lines.append(f"| {mk} | {cv:.4f} | {bv if bv is not None else '—'} | {d} |") return "\n".join(lines) def _write_report(result: dict, base: dict | None, ts: str) -> tuple[Path, Path]: EVAL_DIR.mkdir(parents=True, exist_ok=True) jp = EVAL_DIR / f"eval-report-{ts}.json" mp = EVAL_DIR / f"eval-report-{ts}.md" jp.write_text(json.dumps(result, ensure_ascii=False, indent=2), encoding="utf-8") cfg = result.get("retrieval_config", {}) cfg_line = " · ".join(f"{k}={v}" for k, v in cfg.items()) base_cfg = (base or {}).get("retrieval_config") cfg_warn = "" if base_cfg and base_cfg != cfg: cfg_warn = "\n> ⚠ retrieval_config differs from baseline — deltas are NOT apples-to-apples.\n" lines = [f"# FU-5 — דוח הערכת-אחזור — {ts}\n", f"- gold queries: {result['gold_size']}", f"- retrieval_config: {cfg_line}", f"- baseline: {'data/eval/baseline.json' if base else '(none yet)'}", cfg_warn, "## Overall (mean) — delta vs baseline\n", _delta_table(result, base), "", "## Per corpus\n"] if result["by_corpus"]: metric_keys = list(next(iter(result["by_corpus"].values())).keys()) lines.append("| corpus | " + " | ".join(metric_keys) + " |") lines.append("|" + "---|" * (len(metric_keys) + 1)) for c, agg in result["by_corpus"].items(): lines.append(f"| {c} | " + " | ".join(f"{agg[k]:.4f}" for k in metric_keys) + " |") else: lines.append("(none)") mp.write_text("\n".join(lines) + "\n", encoding="utf-8") return jp, mp async def main() -> int: ap = argparse.ArgumentParser(description="FU-5 retrieval eval harness") ap.add_argument("--self-test", action="store_true", help="run offline metric unit tests and exit") ap.add_argument("--update-baseline", action="store_true", help="write current run as data/eval/baseline.json") args = ap.parse_args() if args.self_test: return _self_test() gold = _load_gold() if not gold: print(f"gold-set empty ({GOLD_PATH}). Run scripts/eval_gold_bootstrap.py first.", file=sys.stderr) return 2 result = await _run() base = json.loads(BASELINE_PATH.read_text(encoding="utf-8")) if BASELINE_PATH.exists() else None ts = _ts() jp, mp = _write_report(result, base, ts) print(f"EVAL: {result['gold_size']} queries") for mk, v in result["overall"].items(): bv = (base or {}).get("overall", {}).get(mk) d = f" (Δ {v - bv:+.4f})" if isinstance(bv, (int, float)) else "" print(f" {mk:8} {v:.4f}{d}") print(f" report: {mp}") if args.update_baseline: snapshot = {k: result[k] for k in ("gold_size", "retrieval_config", "overall", "by_corpus", "by_practice_area")} snapshot["generated_at"] = ts BASELINE_PATH.write_text(json.dumps(snapshot, ensure_ascii=False, indent=2), encoding="utf-8") print(f" baseline updated: {BASELINE_PATH}") return 0 if __name__ == "__main__": sys.exit(asyncio.run(main()))