Files
legal-ai/scripts/voyage_rerank_corpus_poc.py
Chaim 26c3fddf41
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-03 18:43:41 +00:00

319 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""POC #5 — full precedent_library corpus benchmark.
Tests R1 (voyage-3) vs R2 (voyage-3 + rerank-2) on the *real* corpus that
search_precedent_library queries against:
precedent_chunks — 385 rows from 3 precedent cases
halachot — 400 rule statements with reasoning summaries
Total: 785 documents. The MCP tool merges results from both tables so the
benchmark mirrors production retrieval. R3 (context-3) is dropped — it
would require windowed re-embedding of 3 cases which we already proved
doesn't help (POC #2). The question now is: does rerank-2's +9% on a
single case generalize to a heterogeneous corpus?
Also measures end-to-end latency: pure voyage-3 vs voyage-3 + rerank.
Usage:
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
/home/chaim/legal-ai/scripts/voyage_rerank_corpus_poc.py
"""
from __future__ import annotations
import asyncio
import json
import math
import os
import re
import subprocess
import sys
import time
from collections import defaultdict
ENV_PATH = os.path.expanduser("~/.env")
if os.path.isfile(ENV_PATH):
with open(ENV_PATH) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, v = line.split("=", 1)
os.environ.setdefault(k, v)
import asyncpg # noqa: E402
import voyageai # noqa: E402
TEXT_MODEL = "voyage-3"
RERANK_MODEL = "rerank-2"
JUDGE_MODEL = "claude-haiku-4-5-20251001"
TOP_VEC = 50 # voyage-3 retrieve depth
TOP_K = 10 # final returned to "agent"
JUDGE_K = 5 # how many top results to actually judge per retriever
# 12 queries spanning typical use cases by Daphna's agents:
# precedent search for citing in decision blocks י-יא.
QUERIES = [
# K — keyword
("K1", "פיצויים לפי סעיף 197"),
("K2", "תמ\"א 38 והשבחה"),
("K3", "כלל הנטרול בשמאות"),
# C — conceptual
("C1", "תכלית היטל ההשבחה"),
("C2", "מה מקנה לבעלים זכות לפיצוי"),
("C3", "ההבחנה בין השבחה לפיצויים"),
# N — narrative / context-aware
("N1", "מה נקבע לגבי תמ\"א 38 בפסיקה"),
("N2", "ההלכה לעניין נטרול ציפיות"),
("N3", "תכנית פוגעת ושומה"),
# P — practical (drafting needs — what an agent typically asks)
("P1", "פסיקה שדנה בתכנית מתאר ארצית"),
("P2", "מתי מותר לוועדה לדחות פיצויים"),
("P3", "שיקול דעת הוועדה המקומית"),
]
def cosine(a, b):
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(y * y for y in b))
return dot / (na * nb) if na and nb else 0.0
def parse_pgvector(s):
return [float(x) for x in s.strip("[]").split(",")]
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
לפניך שאילתה ומספר פסקאות מפסקי דין/הלכות. דרג כל פסקה 1-5 לפי רלוונטיות.
5 — תשובה ישירה למה שנשאל
4 — מאד רלוונטי, מכיל מידע ליבה
3 — רלוונטי חלקית, נוגע בעקיפין
2 — מעט קשור, רעש סביב הנושא
1 — לא רלוונטי בכלל
השאילתה:
{query}
הפסקאות:
{chunks_block}
החזר JSON בלבד: {{"scores": {{"<id>": <1-5>, ...}}}}
ללא טקסט נוסף, ללא ```."""
def batch_judge(query: str, items: list[tuple[str, str]]) -> dict[str, int]:
"""Judge (id, text) pairs via claude CLI. Returns {id: score}."""
blocks = []
for cid, content in items:
snippet = content.replace("\n", " ").strip()[:1500]
blocks.append(f"<id={cid}>\n{snippet}\n</id>")
prompt = BATCH_JUDGE_PROMPT.format(
query=query, chunks_block="\n\n".join(blocks))
proc = subprocess.run(
["claude", "-p", "--model", JUDGE_MODEL],
input=prompt, capture_output=True, text=True, timeout=180,
)
out = proc.stdout.strip()
out = re.sub(r"^```(?:json)?\s*", "", out)
out = re.sub(r"\s*```$", "", out)
try:
data = json.loads(out)
raw = data.get("scores", {})
return {str(k): int(v) for k, v in raw.items()
if str(v).isdigit() and 1 <= int(v) <= 5}
except (json.JSONDecodeError, ValueError, TypeError) as e:
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
return {}
async def main():
voyage_key = os.environ["VOYAGE_API_KEY"]
pg_pw = os.environ["POSTGRES_PASSWORD"]
try:
subprocess.run(["claude", "--version"], capture_output=True,
text=True, timeout=10, check=True)
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
sys.exit("claude CLI not found")
voyage = voyageai.Client(api_key=voyage_key)
pool = await asyncpg.create_pool(
host="127.0.0.1", port=5433, user="legal_ai",
password=pg_pw, database="legal_ai",
min_size=1, max_size=2,
)
# Load full corpus: precedent_chunks + halachot
pc_rows = await pool.fetch("""
SELECT 'pc:' || id::text AS doc_id,
content,
embedding::text AS emb_text
FROM precedent_chunks
WHERE content IS NOT NULL AND embedding IS NOT NULL
""")
h_rows = await pool.fetch("""
SELECT 'h:' || id::text AS doc_id,
TRIM(BOTH '' FROM rule_statement || '' ||
COALESCE(reasoning_summary, '')) AS content,
embedding::text AS emb_text
FROM halachot
WHERE rule_statement IS NOT NULL AND embedding IS NOT NULL
""")
all_rows = list(pc_rows) + list(h_rows)
print(f"[load] corpus: {len(pc_rows)} precedent_chunks + "
f"{len(h_rows)} halachot = {len(all_rows)} total")
doc_ids = [r["doc_id"] for r in all_rows]
contents = [r["content"] for r in all_rows]
embs = [parse_pgvector(r["emb_text"]) for r in all_rows]
# Latency measurement: 5 queries, time the two pipelines
print("\n[latency] measuring 5 sample queries…")
sample = QUERIES[:5]
r1_lat = []
r2_lat = []
for _, query in sample:
# R1: voyage-3 embed + cosine top-10
t0 = time.time()
q_emb = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
scores = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
reverse=True)[:TOP_K]
r1_lat.append(time.time() - t0)
# R2: voyage-3 embed + cosine top-50 + rerank-2 → top-10
t0 = time.time()
q_emb = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
cands = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
reverse=True)[:TOP_VEC]
cand_texts = [contents[i] for _, i in cands]
rr = voyage.rerank(query=query, documents=cand_texts,
model=RERANK_MODEL, top_k=TOP_K)
r2_lat.append(time.time() - t0)
print(f" R1 (voyage-3 only) avg={sum(r1_lat)/5*1000:.0f}ms"
f" min={min(r1_lat)*1000:.0f} max={max(r1_lat)*1000:.0f}")
print(f" R2 (voyage-3 + rerank-2) avg={sum(r2_lat)/5*1000:.0f}ms"
f" min={min(r2_lat)*1000:.0f} max={max(r2_lat)*1000:.0f}")
print(f" Δ (rerank overhead) avg={(sum(r2_lat)-sum(r1_lat))/5*1000:.0f}ms")
# Retrieval functions
def r1_baseline(query: str, k: int = TOP_K) -> list[int]:
q = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
scores = sorted([(cosine(q, e), i) for i, e in enumerate(embs)],
reverse=True)
return [i for _, i in scores[:k]]
def r2_rerank(query: str, k: int = TOP_K) -> list[int]:
cands = r1_baseline(query, k=TOP_VEC)
cand_texts = [contents[i] for i in cands]
rr = voyage.rerank(query=query, documents=cand_texts,
model=RERANK_MODEL, top_k=k)
return [cands[r.index] for r in rr.results]
retrievers = [("R1-voyage3", r1_baseline),
("R2-rerank2", r2_rerank)]
print(f"\n[judge] running {len(QUERIES)} queries × 2 retrievers, "
f"top-{JUDGE_K} judged…")
all_results = []
for qid, query in QUERIES:
print(f"\n[{qid}] {query}")
retr_results = {}
for r_name, r_fn in retrievers:
try:
retr_results[r_name] = r_fn(query, k=JUDGE_K)
except Exception as e:
print(f" {r_name}: FAILED — {e}")
retr_results[r_name] = []
union = sorted({i for top in retr_results.values() for i in top})
items = [(doc_ids[i], contents[i]) for i in union]
print(f" judging {len(items)} unique docs…")
scores_map = batch_judge(query, items)
for r_name, top in retr_results.items():
scores = [scores_map.get(doc_ids[i], 0) for i in top]
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
mean5 = sum(scores) / len(scores) if scores else 0
mrr = 0.0
for r, s in enumerate(scores):
if s >= 4:
mrr = 1.0 / (r + 1)
break
print(f" {r_name}: doc_ids={[doc_ids[i][:14] for i in top]} "
f"scores={scores} m@3={mean3:.2f} m@5={mean5:.2f} "
f"MRR={mrr:.3f}")
all_results.append({
"qid": qid, "category": qid[0], "query": query,
"retriever": r_name,
"doc_ids": [doc_ids[i] for i in top],
"scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr,
})
# Aggregate
print("\n" + "=" * 100)
print("AGGREGATED RESULTS — full precedent_library corpus (785 docs)")
print("=" * 100)
by_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
by_cat_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
for r in all_results:
by_r[r["retriever"]]["mean3"].append(r["mean3"])
by_r[r["retriever"]]["mean5"].append(r["mean5"])
by_r[r["retriever"]]["mrr"].append(r["mrr"])
ck = (r["category"], r["retriever"])
by_cat_r[ck]["mean3"].append(r["mean3"])
by_cat_r[ck]["mean5"].append(r["mean5"])
by_cat_r[ck]["mrr"].append(r["mrr"])
print(f"\nOverall ({len(QUERIES)} queries):")
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
avg = lambda xs: sum(xs) / len(xs) if xs else 0
for r_name, _ in retrievers:
m = by_r[r_name]
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
# Improvement
r1m = avg(by_r["R1-voyage3"]["mean3"])
r2m = avg(by_r["R2-rerank2"]["mean3"])
if r1m > 0:
print(f"\nR2 vs R1 improvement: "
f"mean@3 {(r2m - r1m) / r1m * 100:+.1f}%")
print(f"\nBy category:")
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} "
f"{'MRR':>8}")
for cat in ["K", "C", "N", "P"]:
for r_name, _ in retrievers:
m = by_cat_r[(cat, r_name)]
if not m["mean3"]:
continue
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
print(f"\nPer-query winner (highest mean@3):")
print(f"{'qid':<4} {'query':<40} {'winner':<14} {'scores'}")
by_q = defaultdict(list)
for r in all_results:
by_q[r["qid"]].append(r)
for qid, results in sorted(by_q.items()):
max_s = max(r["mean3"] for r in results)
winners = [r["retriever"] for r in results if r["mean3"] == max_s]
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
for r in results)
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:38]
print(f"{qid:<4} {q_str:<40} {','.join(w[:8] for w in winners):<14} "
f"{scores}")
out_path = "/tmp/voyage_rerank_corpus_results.json"
with open(out_path, "w") as f:
json.dump(all_results, f, ensure_ascii=False, indent=2)
print(f"\nSaved to {out_path}")
await pool.close()
if __name__ == "__main__":
asyncio.run(main())