"""POC #5 — full precedent_library corpus benchmark. Tests R1 (voyage-3) vs R2 (voyage-3 + rerank-2) on the *real* corpus that search_precedent_library queries against: precedent_chunks — 385 rows from 3 precedent cases halachot — 400 rule statements with reasoning summaries Total: 785 documents. The MCP tool merges results from both tables so the benchmark mirrors production retrieval. R3 (context-3) is dropped — it would require windowed re-embedding of 3 cases which we already proved doesn't help (POC #2). The question now is: does rerank-2's +9% on a single case generalize to a heterogeneous corpus? Also measures end-to-end latency: pure voyage-3 vs voyage-3 + rerank. Usage: /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ /home/chaim/legal-ai/scripts/voyage_rerank_corpus_poc.py """ from __future__ import annotations import asyncio import json import math import os import re import subprocess import sys import time from collections import defaultdict ENV_PATH = os.path.expanduser("~/.env") if os.path.isfile(ENV_PATH): with open(ENV_PATH) as f: for line in f: line = line.strip() if line and not line.startswith("#") and "=" in line: k, v = line.split("=", 1) os.environ.setdefault(k, v) import asyncpg # noqa: E402 import voyageai # noqa: E402 TEXT_MODEL = "voyage-3" RERANK_MODEL = "rerank-2" JUDGE_MODEL = "claude-haiku-4-5-20251001" TOP_VEC = 50 # voyage-3 retrieve depth TOP_K = 10 # final returned to "agent" JUDGE_K = 5 # how many top results to actually judge per retriever # 12 queries spanning typical use cases by Daphna's agents: # precedent search for citing in decision blocks י-יא. QUERIES = [ # K — keyword ("K1", "פיצויים לפי סעיף 197"), ("K2", "תמ\"א 38 והשבחה"), ("K3", "כלל הנטרול בשמאות"), # C — conceptual ("C1", "תכלית היטל ההשבחה"), ("C2", "מה מקנה לבעלים זכות לפיצוי"), ("C3", "ההבחנה בין השבחה לפיצויים"), # N — narrative / context-aware ("N1", "מה נקבע לגבי תמ\"א 38 בפסיקה"), ("N2", "ההלכה לעניין נטרול ציפיות"), ("N3", "תכנית פוגעת ושומה"), # P — practical (drafting needs — what an agent typically asks) ("P1", "פסיקה שדנה בתכנית מתאר ארצית"), ("P2", "מתי מותר לוועדה לדחות פיצויים"), ("P3", "שיקול דעת הוועדה המקומית"), ] def cosine(a, b): dot = sum(x * y for x, y in zip(a, b)) na = math.sqrt(sum(x * x for x in a)) nb = math.sqrt(sum(y * y for y in b)) return dot / (na * nb) if na and nb else 0.0 def parse_pgvector(s): return [float(x) for x in s.strip("[]").split(",")] BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי. לפניך שאילתה ומספר פסקאות מפסקי דין/הלכות. דרג כל פסקה 1-5 לפי רלוונטיות. 5 — תשובה ישירה למה שנשאל 4 — מאד רלוונטי, מכיל מידע ליבה 3 — רלוונטי חלקית, נוגע בעקיפין 2 — מעט קשור, רעש סביב הנושא 1 — לא רלוונטי בכלל השאילתה: {query} הפסקאות: {chunks_block} החזר JSON בלבד: {{"scores": {{"": <1-5>, ...}}}} ללא טקסט נוסף, ללא ```.""" def batch_judge(query: str, items: list[tuple[str, str]]) -> dict[str, int]: """Judge (id, text) pairs via claude CLI. Returns {id: score}.""" blocks = [] for cid, content in items: snippet = content.replace("\n", " ").strip()[:1500] blocks.append(f"\n{snippet}\n") prompt = BATCH_JUDGE_PROMPT.format( query=query, chunks_block="\n\n".join(blocks)) proc = subprocess.run( ["claude", "-p", "--model", JUDGE_MODEL], input=prompt, capture_output=True, text=True, timeout=180, ) out = proc.stdout.strip() out = re.sub(r"^```(?:json)?\s*", "", out) out = re.sub(r"\s*```$", "", out) try: data = json.loads(out) raw = data.get("scores", {}) return {str(k): int(v) for k, v in raw.items() if str(v).isdigit() and 1 <= int(v) <= 5} except (json.JSONDecodeError, ValueError, TypeError) as e: print(f" [judge parse fail: {e}; out={out[:200]!r}]") return {} async def main(): voyage_key = os.environ["VOYAGE_API_KEY"] pg_pw = os.environ["POSTGRES_PASSWORD"] try: subprocess.run(["claude", "--version"], capture_output=True, text=True, timeout=10, check=True) except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError): sys.exit("claude CLI not found") voyage = voyageai.Client(api_key=voyage_key) pool = await asyncpg.create_pool( host="127.0.0.1", port=5433, user="legal_ai", password=pg_pw, database="legal_ai", min_size=1, max_size=2, ) # Load full corpus: precedent_chunks + halachot pc_rows = await pool.fetch(""" SELECT 'pc:' || id::text AS doc_id, content, embedding::text AS emb_text FROM precedent_chunks WHERE content IS NOT NULL AND embedding IS NOT NULL """) h_rows = await pool.fetch(""" SELECT 'h:' || id::text AS doc_id, TRIM(BOTH ' —' FROM rule_statement || ' — ' || COALESCE(reasoning_summary, '')) AS content, embedding::text AS emb_text FROM halachot WHERE rule_statement IS NOT NULL AND embedding IS NOT NULL """) all_rows = list(pc_rows) + list(h_rows) print(f"[load] corpus: {len(pc_rows)} precedent_chunks + " f"{len(h_rows)} halachot = {len(all_rows)} total") doc_ids = [r["doc_id"] for r in all_rows] contents = [r["content"] for r in all_rows] embs = [parse_pgvector(r["emb_text"]) for r in all_rows] # Latency measurement: 5 queries, time the two pipelines print("\n[latency] measuring 5 sample queries…") sample = QUERIES[:5] r1_lat = [] r2_lat = [] for _, query in sample: # R1: voyage-3 embed + cosine top-10 t0 = time.time() q_emb = voyage.embed([query], model=TEXT_MODEL, input_type="query").embeddings[0] scores = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)], reverse=True)[:TOP_K] r1_lat.append(time.time() - t0) # R2: voyage-3 embed + cosine top-50 + rerank-2 → top-10 t0 = time.time() q_emb = voyage.embed([query], model=TEXT_MODEL, input_type="query").embeddings[0] cands = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)], reverse=True)[:TOP_VEC] cand_texts = [contents[i] for _, i in cands] rr = voyage.rerank(query=query, documents=cand_texts, model=RERANK_MODEL, top_k=TOP_K) r2_lat.append(time.time() - t0) print(f" R1 (voyage-3 only) avg={sum(r1_lat)/5*1000:.0f}ms" f" min={min(r1_lat)*1000:.0f} max={max(r1_lat)*1000:.0f}") print(f" R2 (voyage-3 + rerank-2) avg={sum(r2_lat)/5*1000:.0f}ms" f" min={min(r2_lat)*1000:.0f} max={max(r2_lat)*1000:.0f}") print(f" Δ (rerank overhead) avg={(sum(r2_lat)-sum(r1_lat))/5*1000:.0f}ms") # Retrieval functions def r1_baseline(query: str, k: int = TOP_K) -> list[int]: q = voyage.embed([query], model=TEXT_MODEL, input_type="query").embeddings[0] scores = sorted([(cosine(q, e), i) for i, e in enumerate(embs)], reverse=True) return [i for _, i in scores[:k]] def r2_rerank(query: str, k: int = TOP_K) -> list[int]: cands = r1_baseline(query, k=TOP_VEC) cand_texts = [contents[i] for i in cands] rr = voyage.rerank(query=query, documents=cand_texts, model=RERANK_MODEL, top_k=k) return [cands[r.index] for r in rr.results] retrievers = [("R1-voyage3", r1_baseline), ("R2-rerank2", r2_rerank)] print(f"\n[judge] running {len(QUERIES)} queries × 2 retrievers, " f"top-{JUDGE_K} judged…") all_results = [] for qid, query in QUERIES: print(f"\n[{qid}] {query}") retr_results = {} for r_name, r_fn in retrievers: try: retr_results[r_name] = r_fn(query, k=JUDGE_K) except Exception as e: print(f" {r_name}: FAILED — {e}") retr_results[r_name] = [] union = sorted({i for top in retr_results.values() for i in top}) items = [(doc_ids[i], contents[i]) for i in union] print(f" judging {len(items)} unique docs…") scores_map = batch_judge(query, items) for r_name, top in retr_results.items(): scores = [scores_map.get(doc_ids[i], 0) for i in top] mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0 mean5 = sum(scores) / len(scores) if scores else 0 mrr = 0.0 for r, s in enumerate(scores): if s >= 4: mrr = 1.0 / (r + 1) break print(f" {r_name}: doc_ids={[doc_ids[i][:14] for i in top]} " f"scores={scores} m@3={mean3:.2f} m@5={mean5:.2f} " f"MRR={mrr:.3f}") all_results.append({ "qid": qid, "category": qid[0], "query": query, "retriever": r_name, "doc_ids": [doc_ids[i] for i in top], "scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr, }) # Aggregate print("\n" + "=" * 100) print("AGGREGATED RESULTS — full precedent_library corpus (785 docs)") print("=" * 100) by_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) by_cat_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) for r in all_results: by_r[r["retriever"]]["mean3"].append(r["mean3"]) by_r[r["retriever"]]["mean5"].append(r["mean5"]) by_r[r["retriever"]]["mrr"].append(r["mrr"]) ck = (r["category"], r["retriever"]) by_cat_r[ck]["mean3"].append(r["mean3"]) by_cat_r[ck]["mean5"].append(r["mean5"]) by_cat_r[ck]["mrr"].append(r["mrr"]) print(f"\nOverall ({len(QUERIES)} queries):") print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") avg = lambda xs: sum(xs) / len(xs) if xs else 0 for r_name, _ in retrievers: m = by_r[r_name] print(f"{r_name:<14} {avg(m['mean3']):>8.3f} " f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") # Improvement r1m = avg(by_r["R1-voyage3"]["mean3"]) r2m = avg(by_r["R2-rerank2"]["mean3"]) if r1m > 0: print(f"\nR2 vs R1 improvement: " f"mean@3 {(r2m - r1m) / r1m * 100:+.1f}%") print(f"\nBy category:") print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} " f"{'MRR':>8}") for cat in ["K", "C", "N", "P"]: for r_name, _ in retrievers: m = by_cat_r[(cat, r_name)] if not m["mean3"]: continue print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} " f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") print(f"\nPer-query winner (highest mean@3):") print(f"{'qid':<4} {'query':<40} {'winner':<14} {'scores'}") by_q = defaultdict(list) for r in all_results: by_q[r["qid"]].append(r) for qid, results in sorted(by_q.items()): max_s = max(r["mean3"] for r in results) winners = [r["retriever"] for r in results if r["mean3"] == max_s] scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}" for r in results) q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:38] print(f"{qid:<4} {q_str:<40} {','.join(w[:8] for w in winners):<14} " f"{scores}") out_path = "/tmp/voyage_rerank_corpus_results.json" with open(out_path, "w") as f: json.dump(all_results, f, ensure_ascii=False, indent=2) print(f"\nSaved to {out_path}") await pool.close() if __name__ == "__main__": asyncio.run(main())