"""POC #4: Comprehensive retrieval benchmark with LLM-as-judge. Compares 3 retrievers on אהרון ברק 403/17 (219 chunks): R1 — voyage-3 (current production baseline) R2 — voyage-3 + voyage-rerank-2 (retrieve 50, rerank, top-10) R3 — voyage-context-3 (windowed, from POC #2) Judges relevance with claude-haiku-4-5 — for each (query, chunk) pair the judge returns 1-5. Aggregates: mean relevance@3, @5, @10, MRR (rank of first 4+ chunk), per-query winner. 20 queries grouped into 3 categories so we can see *which* query types benefit from which retriever: K — keyword/lexical (term-heavy, specific entity) C — conceptual (abstract idea, principle) N — narrative/contextual (requires document-internal reference) Usage (key passed via env, NOT stored in script): ANTHROPIC_API_KEY=... \\ /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ /home/chaim/legal-ai/scripts/voyage_rerank_judge_poc.py """ from __future__ import annotations import asyncio import json import math import os import sys import time from collections import defaultdict ENV_PATH = os.path.expanduser("~/.env") if os.path.isfile(ENV_PATH): with open(ENV_PATH) as f: for line in f: line = line.strip() if line and not line.startswith("#") and "=" in line: k, v = line.split("=", 1) os.environ.setdefault(k, v) import re import subprocess import asyncpg # noqa: E402 import voyageai # noqa: E402 CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # אהרון ברק 403/17 TEXT_MODEL = "voyage-3" CONTEXT_MODEL = "voyage-context-3" RERANK_MODEL = "rerank-2" JUDGE_MODEL = "claude-haiku-4-5-20251001" WINDOW_SIZE = 80 WINDOW_STRIDE = 70 # 18 queries × 3 retrievers × top-5 = 270 judge calls. ~$0.05 with haiku. QUERIES = [ # K — keyword/lexical ("K1", "תכנית רחביה הוראות בנייה"), ("K2", "תמ\"א 38"), ("K3", "תכנית 9988"), ("K4", "סעיף 197 לחוק התכנון והבניה"), ("K5", "השופט גרוסקופף"), ("K6", "ועדה מקומית ירושלים"), # C — conceptual / abstract principles ("C1", "כלל הנטרול של זכויות תכנוניות"), ("C2", "אינטרס הציבור בתכנון"), ("C3", "תכלית היטל ההשבחה"), ("C4", "תכנית פוגעת לעומת תכנית משביחה"), ("C5", "ההבחנה בין השבחה לפיצויים"), ("C6", "מהותו של היטל ההשבחה"), # N — narrative / context-dependent ("N1", "מה נקבע לגבי תמ\"א 38 בפסק הדין"), ("N2", "מסקנת בית המשפט בעניין רובע 3"), ("N3", "ההלכה שנקבעה בעניין שמעוני"), ("N4", "ההבדל בין המקרה שלפנינו לעניין רון"), ("N5", "סוף דבר ותוצאת פסק הדין"), ("N6", "הסכמת השופטים האחרים לחוות הדעת"), ] def cosine(a, b): dot = sum(x * y for x, y in zip(a, b)) na = math.sqrt(sum(x * x for x in a)) nb = math.sqrt(sum(y * y for y in b)) return dot / (na * nb) if na and nb else 0.0 def parse_pgvector(s): return [float(x) for x in s.strip("[]").split(",")] def build_windows(n, size, stride): out = [] s = 0 while s < n: e = min(s + size, n) out.append((s, e)) if e == n: break s += stride return out def central_window(idx, windows): best, best_d = -1, -1 for w_idx, (s, e) in enumerate(windows): if not (s <= idx < e): continue d = min(idx - s, (e - 1) - idx) if d > best_d: best_d = d best = w_idx return best BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי. לפניך שאילתה ומספר פסקאות מפסק דין. דרג כל פסקה בנפרד 1-5 לפי רלוונטיות. סולם: 5 — תשובה ישירה ומדויקת לשאילתה 4 — מאד רלוונטי, מכיל מידע ליבה 3 — רלוונטי חלקית, נוגע בעקיפין בנושא 2 — מעט קשור, רעש סביב הנושא 1 — לא רלוונטי בכלל השאילתה: {query} הפסקאות: {chunks_block} החזר JSON בלבד, בפורמט: {{"scores": {{"": <1-5>, ...}}}} ללא טקסט נוסף, ללא explanations, ללא ```.""" def batch_judge(query: str, items: list[tuple[int, str]]) -> dict[int, int]: """Judge a list of (chunk_idx, content) pairs in a single CLI call. Returns: dict[chunk_idx → score 1-5]. Returns 0 for parse failures. """ chunks_block_lines = [] for ci, content in items: snippet = content.replace("\n", " ").strip()[:1500] chunks_block_lines.append(f"\n{snippet}\n") prompt = BATCH_JUDGE_PROMPT.format( query=query, chunks_block="\n\n".join(chunks_block_lines), ) proc = subprocess.run( ["claude", "-p", "--model", JUDGE_MODEL], input=prompt, capture_output=True, text=True, timeout=120, ) out = proc.stdout.strip() # Strip ```json fences if any out = re.sub(r"^```(?:json)?\s*", "", out) out = re.sub(r"\s*```$", "", out) try: data = json.loads(out) raw = data.get("scores", {}) return {int(k): int(v) for k, v in raw.items() if str(v).isdigit() and 1 <= int(v) <= 5} except (json.JSONDecodeError, ValueError, TypeError) as e: print(f" [judge parse fail: {e}; out={out[:200]!r}]") return {} async def main(): voyage_key = os.environ["VOYAGE_API_KEY"] pg_pw = os.environ["POSTGRES_PASSWORD"] # Verify Claude CLI is available (uses OAuth from ~/.claude/.credentials) try: subprocess.run(["claude", "--version"], capture_output=True, text=True, timeout=10, check=True) except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError): sys.exit("claude CLI not found or not authenticated") voyage = voyageai.Client(api_key=voyage_key) # Load chunks + voyage-3 embeddings pool = await asyncpg.create_pool( host="127.0.0.1", port=5433, user="legal_ai", password=pg_pw, database="legal_ai", min_size=1, max_size=2, ) rows = await pool.fetch(""" SELECT chunk_index, content, embedding::text AS emb_text FROM precedent_chunks WHERE case_law_id = $1 ORDER BY chunk_index """, CASE_ID) chunks = [r["content"] for r in rows] chunk_indices = [r["chunk_index"] for r in rows] baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows] n = len(chunks) print(f"[load] {n} chunks loaded") # Compute context-3 (windowed) embeddings — same as POC #2 windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE) print(f"[context-3] embedding {len(windows)} windows…") win_embs = [] for s, e in windows: result = voyage.contextualized_embed( inputs=[chunks[s:e]], model=CONTEXT_MODEL, input_type="document", ) win_embs.append(result.results[0].embeddings) context_embs = [] for i in range(n): w = central_window(i, windows) s, _ = windows[w] context_embs.append(win_embs[w][i - s]) print(f"[context-3] done") # Retrieval functions def r1_baseline(query: str, k: int = 10) -> list[int]: q = voyage.embed([query], model=TEXT_MODEL, input_type="query").embeddings[0] scores = sorted( [(cosine(q, e), i) for i, e in enumerate(baseline_embs)], reverse=True, ) return [i for _, i in scores[:k]] def r2_rerank(query: str, k: int = 10) -> list[int]: # 1) voyage-3 retrieve top-50 cands = r1_baseline(query, k=50) cand_texts = [chunks[i] for i in cands] # 2) voyage-rerank-2 over the 50 rr = voyage.rerank( query=query, documents=cand_texts, model=RERANK_MODEL, top_k=k, ) # rr.results: list of RerankingResult(index=..., relevance_score=...) # `index` refers to position in cand_texts → map back to chunk idx return [cands[r.index] for r in rr.results] def r3_context(query: str, k: int = 10) -> list[int]: q = voyage.contextualized_embed( inputs=[[query]], model=CONTEXT_MODEL, input_type="query", ).results[0].embeddings[0] scores = sorted( [(cosine(q, e), i) for i, e in enumerate(context_embs)], reverse=True, ) return [i for _, i in scores[:k]] retrievers = [("R1-voyage3", r1_baseline), ("R2-rerank2", r2_rerank), ("R3-context3", r3_context)] # Run all queries × all retrievers, judging top-5 per pair. # Strategy: for each query, gather the union of all retrievers' top-K # and judge them in ONE batched CLI call → 18 calls total instead of 270. all_results = [] JUDGE_TOP_K = 5 print(f"\n[judge] running {len(QUERIES)} queries × " f"{len(retrievers)} retrievers × top-{JUDGE_TOP_K} — batched per query…") for qid, query in QUERIES: print(f"\n[{qid}] {query}") # Collect retrievals first retr_results = {} for r_name, r_fn in retrievers: try: retr_results[r_name] = r_fn(query, k=JUDGE_TOP_K) except Exception as e: print(f" {r_name}: FAILED — {e}") retr_results[r_name] = [] # Union of unique chunk indices to judge union = sorted({i for top in retr_results.values() for i in top}) items = [(i, chunks[i]) for i in union] print(f" judging {len(items)} unique chunks via batch CLI…") scores_map = batch_judge(query, items) # Build per-retriever score lists for r_name, top in retr_results.items(): scores = [scores_map.get(i, 0) for i in top] mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0 mean5 = sum(scores) / len(scores) if scores else 0 mrr = 0.0 for r, s in enumerate(scores): if s >= 4: mrr = 1.0 / (r + 1) break print(f" {r_name}: chunks={[chunk_indices[i] for i in top]} " f"scores={scores} mean@3={mean3:.2f} mean@5={mean5:.2f} " f"MRR={mrr:.3f}") all_results.append({ "qid": qid, "category": qid[0], "query": query, "retriever": r_name, "chunks": [chunk_indices[i] for i in top], "scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr, }) # Aggregate print("\n" + "=" * 100) print("AGGREGATED RESULTS") print("=" * 100) by_retriever = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) by_cat_retriever = defaultdict( lambda: {"mean3": [], "mean5": [], "mrr": []}) for r in all_results: by_retriever[r["retriever"]]["mean3"].append(r["mean3"]) by_retriever[r["retriever"]]["mean5"].append(r["mean5"]) by_retriever[r["retriever"]]["mrr"].append(r["mrr"]) cat_key = (r["category"], r["retriever"]) by_cat_retriever[cat_key]["mean3"].append(r["mean3"]) by_cat_retriever[cat_key]["mean5"].append(r["mean5"]) by_cat_retriever[cat_key]["mrr"].append(r["mrr"]) print("\nOverall (across all 18 queries):") print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") for r_name, _ in retrievers: m = by_retriever[r_name] avg = lambda xs: sum(xs) / len(xs) if xs else 0 print(f"{r_name:<14} {avg(m['mean3']):>8.3f} " f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") print("\nBy category (K=keyword, C=conceptual, N=narrative):") print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") for cat in ["K", "C", "N"]: for r_name, _ in retrievers: m = by_cat_retriever[(cat, r_name)] avg = lambda xs: sum(xs) / len(xs) if xs else 0 print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} " f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") print("\nPer-query winner (highest mean@3, ties shown):") print(f"{'qid':<4} {'query':<45} {'winner':<24} {'scores'}") by_query = defaultdict(list) for r in all_results: by_query[r["qid"]].append(r) for qid, results in sorted(by_query.items()): max_score = max(r["mean3"] for r in results) winners = [r["retriever"] for r in results if r["mean3"] == max_score] scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}" for r in results) q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:42] print(f"{qid:<4} {q_str:<45} {','.join(w[:8] for w in winners):<24} " f"{scores}") # Save raw results to JSON for further analysis out_path = "/tmp/voyage_rerank_judge_results.json" with open(out_path, "w") as f: json.dump(all_results, f, ensure_ascii=False, indent=2) print(f"\nRaw results saved to {out_path}") await pool.close() if __name__ == "__main__": asyncio.run(main())