Files
legal-ai/scripts/voyage_rerank_judge_poc.py
Chaim 26c3fddf41
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-03 18:43:41 +00:00

362 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""POC #4: Comprehensive retrieval benchmark with LLM-as-judge.
Compares 3 retrievers on אהרון ברק 403/17 (219 chunks):
R1 — voyage-3 (current production baseline)
R2 — voyage-3 + voyage-rerank-2 (retrieve 50, rerank, top-10)
R3 — voyage-context-3 (windowed, from POC #2)
Judges relevance with claude-haiku-4-5 — for each (query, chunk) pair the
judge returns 1-5. Aggregates: mean relevance@3, @5, @10, MRR (rank of
first 4+ chunk), per-query winner.
20 queries grouped into 3 categories so we can see *which* query types
benefit from which retriever:
K — keyword/lexical (term-heavy, specific entity)
C — conceptual (abstract idea, principle)
N — narrative/contextual (requires document-internal reference)
Usage (key passed via env, NOT stored in script):
ANTHROPIC_API_KEY=... \\
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
/home/chaim/legal-ai/scripts/voyage_rerank_judge_poc.py
"""
from __future__ import annotations
import asyncio
import json
import math
import os
import sys
import time
from collections import defaultdict
ENV_PATH = os.path.expanduser("~/.env")
if os.path.isfile(ENV_PATH):
with open(ENV_PATH) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, v = line.split("=", 1)
os.environ.setdefault(k, v)
import re
import subprocess
import asyncpg # noqa: E402
import voyageai # noqa: E402
CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # אהרון ברק 403/17
TEXT_MODEL = "voyage-3"
CONTEXT_MODEL = "voyage-context-3"
RERANK_MODEL = "rerank-2"
JUDGE_MODEL = "claude-haiku-4-5-20251001"
WINDOW_SIZE = 80
WINDOW_STRIDE = 70
# 18 queries × 3 retrievers × top-5 = 270 judge calls. ~$0.05 with haiku.
QUERIES = [
# K — keyword/lexical
("K1", "תכנית רחביה הוראות בנייה"),
("K2", "תמ\"א 38"),
("K3", "תכנית 9988"),
("K4", "סעיף 197 לחוק התכנון והבניה"),
("K5", "השופט גרוסקופף"),
("K6", "ועדה מקומית ירושלים"),
# C — conceptual / abstract principles
("C1", "כלל הנטרול של זכויות תכנוניות"),
("C2", "אינטרס הציבור בתכנון"),
("C3", "תכלית היטל ההשבחה"),
("C4", "תכנית פוגעת לעומת תכנית משביחה"),
("C5", "ההבחנה בין השבחה לפיצויים"),
("C6", "מהותו של היטל ההשבחה"),
# N — narrative / context-dependent
("N1", "מה נקבע לגבי תמ\"א 38 בפסק הדין"),
("N2", "מסקנת בית המשפט בעניין רובע 3"),
("N3", "ההלכה שנקבעה בעניין שמעוני"),
("N4", "ההבדל בין המקרה שלפנינו לעניין רון"),
("N5", "סוף דבר ותוצאת פסק הדין"),
("N6", "הסכמת השופטים האחרים לחוות הדעת"),
]
def cosine(a, b):
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(y * y for y in b))
return dot / (na * nb) if na and nb else 0.0
def parse_pgvector(s):
return [float(x) for x in s.strip("[]").split(",")]
def build_windows(n, size, stride):
out = []
s = 0
while s < n:
e = min(s + size, n)
out.append((s, e))
if e == n:
break
s += stride
return out
def central_window(idx, windows):
best, best_d = -1, -1
for w_idx, (s, e) in enumerate(windows):
if not (s <= idx < e):
continue
d = min(idx - s, (e - 1) - idx)
if d > best_d:
best_d = d
best = w_idx
return best
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
לפניך שאילתה ומספר פסקאות מפסק דין. דרג כל פסקה בנפרד 1-5 לפי רלוונטיות.
סולם:
5 — תשובה ישירה ומדויקת לשאילתה
4 — מאד רלוונטי, מכיל מידע ליבה
3 — רלוונטי חלקית, נוגע בעקיפין בנושא
2 — מעט קשור, רעש סביב הנושא
1 — לא רלוונטי בכלל
השאילתה:
{query}
הפסקאות:
{chunks_block}
החזר JSON בלבד, בפורמט: {{"scores": {{"<id>": <1-5>, ...}}}}
ללא טקסט נוסף, ללא explanations, ללא ```."""
def batch_judge(query: str,
items: list[tuple[int, str]]) -> dict[int, int]:
"""Judge a list of (chunk_idx, content) pairs in a single CLI call.
Returns: dict[chunk_idx → score 1-5]. Returns 0 for parse failures.
"""
chunks_block_lines = []
for ci, content in items:
snippet = content.replace("\n", " ").strip()[:1500]
chunks_block_lines.append(f"<id={ci}>\n{snippet}\n</id>")
prompt = BATCH_JUDGE_PROMPT.format(
query=query,
chunks_block="\n\n".join(chunks_block_lines),
)
proc = subprocess.run(
["claude", "-p", "--model", JUDGE_MODEL],
input=prompt, capture_output=True, text=True, timeout=120,
)
out = proc.stdout.strip()
# Strip ```json fences if any
out = re.sub(r"^```(?:json)?\s*", "", out)
out = re.sub(r"\s*```$", "", out)
try:
data = json.loads(out)
raw = data.get("scores", {})
return {int(k): int(v) for k, v in raw.items()
if str(v).isdigit() and 1 <= int(v) <= 5}
except (json.JSONDecodeError, ValueError, TypeError) as e:
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
return {}
async def main():
voyage_key = os.environ["VOYAGE_API_KEY"]
pg_pw = os.environ["POSTGRES_PASSWORD"]
# Verify Claude CLI is available (uses OAuth from ~/.claude/.credentials)
try:
subprocess.run(["claude", "--version"], capture_output=True,
text=True, timeout=10, check=True)
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
sys.exit("claude CLI not found or not authenticated")
voyage = voyageai.Client(api_key=voyage_key)
# Load chunks + voyage-3 embeddings
pool = await asyncpg.create_pool(
host="127.0.0.1", port=5433, user="legal_ai",
password=pg_pw, database="legal_ai",
min_size=1, max_size=2,
)
rows = await pool.fetch("""
SELECT chunk_index, content, embedding::text AS emb_text
FROM precedent_chunks
WHERE case_law_id = $1
ORDER BY chunk_index
""", CASE_ID)
chunks = [r["content"] for r in rows]
chunk_indices = [r["chunk_index"] for r in rows]
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
n = len(chunks)
print(f"[load] {n} chunks loaded")
# Compute context-3 (windowed) embeddings — same as POC #2
windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE)
print(f"[context-3] embedding {len(windows)} windows…")
win_embs = []
for s, e in windows:
result = voyage.contextualized_embed(
inputs=[chunks[s:e]],
model=CONTEXT_MODEL,
input_type="document",
)
win_embs.append(result.results[0].embeddings)
context_embs = []
for i in range(n):
w = central_window(i, windows)
s, _ = windows[w]
context_embs.append(win_embs[w][i - s])
print(f"[context-3] done")
# Retrieval functions
def r1_baseline(query: str, k: int = 10) -> list[int]:
q = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
scores = sorted(
[(cosine(q, e), i) for i, e in enumerate(baseline_embs)],
reverse=True,
)
return [i for _, i in scores[:k]]
def r2_rerank(query: str, k: int = 10) -> list[int]:
# 1) voyage-3 retrieve top-50
cands = r1_baseline(query, k=50)
cand_texts = [chunks[i] for i in cands]
# 2) voyage-rerank-2 over the 50
rr = voyage.rerank(
query=query, documents=cand_texts,
model=RERANK_MODEL, top_k=k,
)
# rr.results: list of RerankingResult(index=..., relevance_score=...)
# `index` refers to position in cand_texts → map back to chunk idx
return [cands[r.index] for r in rr.results]
def r3_context(query: str, k: int = 10) -> list[int]:
q = voyage.contextualized_embed(
inputs=[[query]],
model=CONTEXT_MODEL,
input_type="query",
).results[0].embeddings[0]
scores = sorted(
[(cosine(q, e), i) for i, e in enumerate(context_embs)],
reverse=True,
)
return [i for _, i in scores[:k]]
retrievers = [("R1-voyage3", r1_baseline),
("R2-rerank2", r2_rerank),
("R3-context3", r3_context)]
# Run all queries × all retrievers, judging top-5 per pair.
# Strategy: for each query, gather the union of all retrievers' top-K
# and judge them in ONE batched CLI call → 18 calls total instead of 270.
all_results = []
JUDGE_TOP_K = 5
print(f"\n[judge] running {len(QUERIES)} queries × "
f"{len(retrievers)} retrievers × top-{JUDGE_TOP_K} — batched per query…")
for qid, query in QUERIES:
print(f"\n[{qid}] {query}")
# Collect retrievals first
retr_results = {}
for r_name, r_fn in retrievers:
try:
retr_results[r_name] = r_fn(query, k=JUDGE_TOP_K)
except Exception as e:
print(f" {r_name}: FAILED — {e}")
retr_results[r_name] = []
# Union of unique chunk indices to judge
union = sorted({i for top in retr_results.values() for i in top})
items = [(i, chunks[i]) for i in union]
print(f" judging {len(items)} unique chunks via batch CLI…")
scores_map = batch_judge(query, items)
# Build per-retriever score lists
for r_name, top in retr_results.items():
scores = [scores_map.get(i, 0) for i in top]
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
mean5 = sum(scores) / len(scores) if scores else 0
mrr = 0.0
for r, s in enumerate(scores):
if s >= 4:
mrr = 1.0 / (r + 1)
break
print(f" {r_name}: chunks={[chunk_indices[i] for i in top]} "
f"scores={scores} mean@3={mean3:.2f} mean@5={mean5:.2f} "
f"MRR={mrr:.3f}")
all_results.append({
"qid": qid, "category": qid[0], "query": query,
"retriever": r_name,
"chunks": [chunk_indices[i] for i in top],
"scores": scores,
"mean3": mean3, "mean5": mean5, "mrr": mrr,
})
# Aggregate
print("\n" + "=" * 100)
print("AGGREGATED RESULTS")
print("=" * 100)
by_retriever = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
by_cat_retriever = defaultdict(
lambda: {"mean3": [], "mean5": [], "mrr": []})
for r in all_results:
by_retriever[r["retriever"]]["mean3"].append(r["mean3"])
by_retriever[r["retriever"]]["mean5"].append(r["mean5"])
by_retriever[r["retriever"]]["mrr"].append(r["mrr"])
cat_key = (r["category"], r["retriever"])
by_cat_retriever[cat_key]["mean3"].append(r["mean3"])
by_cat_retriever[cat_key]["mean5"].append(r["mean5"])
by_cat_retriever[cat_key]["mrr"].append(r["mrr"])
print("\nOverall (across all 18 queries):")
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
for r_name, _ in retrievers:
m = by_retriever[r_name]
avg = lambda xs: sum(xs) / len(xs) if xs else 0
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
print("\nBy category (K=keyword, C=conceptual, N=narrative):")
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
for cat in ["K", "C", "N"]:
for r_name, _ in retrievers:
m = by_cat_retriever[(cat, r_name)]
avg = lambda xs: sum(xs) / len(xs) if xs else 0
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
print("\nPer-query winner (highest mean@3, ties shown):")
print(f"{'qid':<4} {'query':<45} {'winner':<24} {'scores'}")
by_query = defaultdict(list)
for r in all_results:
by_query[r["qid"]].append(r)
for qid, results in sorted(by_query.items()):
max_score = max(r["mean3"] for r in results)
winners = [r["retriever"] for r in results if r["mean3"] == max_score]
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
for r in results)
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:42]
print(f"{qid:<4} {q_str:<45} {','.join(w[:8] for w in winners):<24} "
f"{scores}")
# Save raw results to JSON for further analysis
out_path = "/tmp/voyage_rerank_judge_results.json"
with open(out_path, "w") as f:
json.dump(all_results, f, ensure_ascii=False, indent=2)
print(f"\nRaw results saved to {out_path}")
await pool.close()
if __name__ == "__main__":
asyncio.run(main())