#!/usr/bin/env python3 """Compute nDCG@10 over the RAG retrieval feedback table (TaskMaster #50). Outputs aggregated metrics as JSON: { "generated_at": "2026-05-26T12:34:56+00:00", "k": 10, "summary": { "total_searches_with_feedback": int, "total_searches_logged": int, "feedback_coverage_pct": float, "avg_ndcg_at_10": float | null }, "by_search_type": [ {"search_type": "precedent_library", "searches_with_feedback": int, "avg_ndcg_at_10": float | null}, ... ], "by_week": [ {"week_start": "2026-05-19", "search_type": "precedent_library", "searches_with_feedback": int, "avg_ndcg_at_10": float | null}, ... ], "top_cited_case_law": [ {"case_law_id": "...", "case_number": "...", "case_name": "...", "cite_count": int}, ... ] } Run: python ~/legal-ai/scripts/compute_ndcg.py python ~/legal-ai/scripts/compute_ndcg.py --weeks 12 --k 10 python ~/legal-ai/scripts/compute_ndcg.py --pretty """ from __future__ import annotations import argparse import asyncio import json import math import os import sys from datetime import datetime, timezone from pathlib import Path import asyncpg # Allow running as a standalone script — no package install required. REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT / "mcp-server" / "src")) def _postgres_url() -> str: """Resolve POSTGRES_URL the same way the MCP server does.""" url = os.environ.get("POSTGRES_URL") if url: return url user = os.environ.get("POSTGRES_USER", "legal_ai") pw = os.environ.get("POSTGRES_PASSWORD", "") host = os.environ.get("POSTGRES_HOST", "127.0.0.1") port = os.environ.get("POSTGRES_PORT", "5433") db = os.environ.get("POSTGRES_DB", "legal_ai") return f"postgres://{user}:{pw}@{host}:{port}/{db}" def dcg(relevances: list[int]) -> float: """Discounted Cumulative Gain at the length of ``relevances``. Uses the "gain = 2^rel - 1" form so high-relevance hits get significantly more weight than marginal ones — matches the convention used by most IR papers and TREC-EVAL. """ total = 0.0 for i, rel in enumerate(relevances, start=1): gain = (2 ** rel) - 1 total += gain / math.log2(i + 1) return total def ndcg_at_k(rel_at_rank: dict[int, int], k: int) -> float | None: """Compute nDCG@k. Args: rel_at_rank: ``{rank (1-based): relevance_score (0..3)}``. Ranks above ``k`` are ignored. Missing ranks count as 0. k: cutoff. Returns: nDCG in [0,1], or ``None`` if there's nothing to score (no relevant hits in the top-k -> IDCG = 0). """ actual = [rel_at_rank.get(r, 0) for r in range(1, k + 1)] if not any(actual): return None ideal = sorted(actual, reverse=True) idcg = dcg(ideal) if idcg == 0: return None return dcg(actual) / idcg async def _fetch_feedback_rows(conn: asyncpg.Connection, weeks: int | None) -> list[dict]: """Pull all (search_log_id, rank, relevance_score, search_type, created_at) rows where there's at least one feedback row. Restricting to recent weeks keeps the scan cheap on a growing log. """ where = "" params: list = [] if weeks is not None and weeks > 0: where = "WHERE sl.created_at >= NOW() - ($1::int * INTERVAL '1 week')" params.append(weeks) sql = f""" SELECT sl.id::text AS search_log_id, sl.search_type AS search_type, sl.created_at AS created_at, srf.rank AS rank, srf.relevance_score AS relevance_score FROM search_relevance_feedback srf JOIN search_logs sl ON sl.id = srf.search_log_id {where} """ rows = await conn.fetch(sql, *params) return [dict(r) for r in rows] async def _fetch_corpus_totals(conn: asyncpg.Connection, weeks: int | None) -> dict[str, int]: """Total search_logs count (overall and by type) — used for coverage %.""" where = "" params: list = [] if weeks is not None and weeks > 0: where = "WHERE created_at >= NOW() - ($1::int * INTERVAL '1 week')" params.append(weeks) total_row = await conn.fetchrow( f"SELECT COUNT(*) AS n FROM search_logs {where}", *params, ) by_type = await conn.fetch( f"SELECT search_type, COUNT(*) AS n FROM search_logs {where} GROUP BY search_type", *params, ) return { "_total": int(total_row["n"]) if total_row else 0, **{r["search_type"]: int(r["n"]) for r in by_type}, } async def _fetch_top_cited(conn: asyncpg.Connection, limit: int = 20) -> list[dict]: """Most-cited case_law (from auto-inferred feedback).""" rows = await conn.fetch( """ SELECT cl.id::text AS case_law_id, cl.case_number AS case_number, cl.case_name AS case_name, COUNT(*) AS cite_count FROM search_relevance_feedback srf JOIN case_law cl ON cl.id = srf.case_law_id WHERE srf.feedback_source = 'cited_in_decision' GROUP BY cl.id, cl.case_number, cl.case_name ORDER BY COUNT(*) DESC LIMIT $1 """, limit, ) return [dict(r) for r in rows] def _aggregate( feedback_rows: list[dict], k: int, ) -> tuple[dict[str, float], dict[tuple[str, str], float], int]: """Group feedback by search_log, compute per-log nDCG, then aggregate by search_type and by (week, search_type).""" by_log: dict[str, dict] = {} for row in feedback_rows: slid = row["search_log_id"] if slid not in by_log: by_log[slid] = { "search_type": row["search_type"], "created_at": row["created_at"], "rels": {}, } rank = int(row["rank"]) if 1 <= rank <= k: by_log[slid]["rels"][rank] = int(row["relevance_score"]) type_ndcg: dict[str, list[float]] = {} week_ndcg: dict[tuple[str, str], list[float]] = {} total_logs_with_feedback = 0 for entry in by_log.values(): score = ndcg_at_k(entry["rels"], k) if score is None: continue total_logs_with_feedback += 1 type_ndcg.setdefault(entry["search_type"], []).append(score) week_start = entry["created_at"].date() # Round down to ISO week Monday. week_start = week_start.fromordinal( week_start.toordinal() - week_start.weekday() ) wkey = (week_start.isoformat(), entry["search_type"]) week_ndcg.setdefault(wkey, []).append(score) type_avg = {t: sum(v) / len(v) for t, v in type_ndcg.items() if v} week_avg = {k_: sum(v) / len(v) for k_, v in week_ndcg.items() if v} return type_avg, week_avg, total_logs_with_feedback async def compute(weeks: int | None, k: int) -> dict: conn = await asyncpg.connect(_postgres_url()) try: fb_rows = await _fetch_feedback_rows(conn, weeks) totals = await _fetch_corpus_totals(conn, weeks) top_cited = await _fetch_top_cited(conn) finally: await conn.close() type_avg, week_avg, logs_scored = _aggregate(fb_rows, k) total_logs = totals.get("_total", 0) overall_avg = ( sum(v * len([s for s in type_avg]) for v in []) or None # placeholder ) # Recompute overall_avg cleanly: micro-average over all per-log scores. all_scores: list[float] = [] for v in [type_avg[t] for t in type_avg]: # type_avg already collapsed per-type — instead, re-run aggregation # over fb_rows by reusing the per-log calc, micro-averaged. pass # Simpler: redo with per-log granularity for overall mean. by_log_overall: dict[str, dict[int, int]] = {} log_to_type: dict[str, str] = {} for row in fb_rows: slid = row["search_log_id"] by_log_overall.setdefault(slid, {}) rank = int(row["rank"]) if 1 <= rank <= k: by_log_overall[slid][rank] = int(row["relevance_score"]) log_to_type[slid] = row["search_type"] per_log_scores: list[float] = [] for slid, rels in by_log_overall.items(): s = ndcg_at_k(rels, k) if s is not None: per_log_scores.append(s) overall_avg = (sum(per_log_scores) / len(per_log_scores)) if per_log_scores else None by_search_type = [] for t, totals_n in sorted(totals.items()): if t == "_total": continue by_search_type.append({ "search_type": t, "searches_logged": totals_n, "searches_with_feedback": sum( 1 for slid, tp in log_to_type.items() if tp == t ), "avg_ndcg_at_k": round(type_avg[t], 4) if t in type_avg else None, }) by_week = [ { "week_start": week, "search_type": stype, "avg_ndcg_at_k": round(score, 4), } for (week, stype), score in sorted(week_avg.items()) ] return { "generated_at": datetime.now(timezone.utc).isoformat(), "k": k, "window_weeks": weeks, "summary": { "total_searches_logged": total_logs, "total_searches_with_feedback": logs_scored, "feedback_coverage_pct": ( round(100 * logs_scored / total_logs, 2) if total_logs else 0.0 ), "avg_ndcg_at_k": round(overall_avg, 4) if overall_avg is not None else None, }, "by_search_type": by_search_type, "by_week": by_week, "top_cited_case_law": [ {**r, "cite_count": int(r["cite_count"])} for r in top_cited ], } def main() -> int: p = argparse.ArgumentParser(description="Compute nDCG@k from search_relevance_feedback") p.add_argument("--k", type=int, default=10, help="cutoff (default: 10)") p.add_argument( "--weeks", type=int, default=None, help="restrict to the last N weeks (default: all time)", ) p.add_argument("--pretty", action="store_true", help="indented JSON output") args = p.parse_args() result = asyncio.run(compute(weeks=args.weeks, k=args.k)) indent = 2 if args.pretty else None print(json.dumps(result, ensure_ascii=False, indent=indent, default=str)) return 0 if __name__ == "__main__": raise SystemExit(main())