feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s

Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-03 18:43:41 +00:00
parent 688ba37d9c
commit 26c3fddf41
13 changed files with 1578 additions and 100 deletions

View File

@@ -0,0 +1,318 @@
"""POC #5 — full precedent_library corpus benchmark.
Tests R1 (voyage-3) vs R2 (voyage-3 + rerank-2) on the *real* corpus that
search_precedent_library queries against:
precedent_chunks — 385 rows from 3 precedent cases
halachot — 400 rule statements with reasoning summaries
Total: 785 documents. The MCP tool merges results from both tables so the
benchmark mirrors production retrieval. R3 (context-3) is dropped — it
would require windowed re-embedding of 3 cases which we already proved
doesn't help (POC #2). The question now is: does rerank-2's +9% on a
single case generalize to a heterogeneous corpus?
Also measures end-to-end latency: pure voyage-3 vs voyage-3 + rerank.
Usage:
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
/home/chaim/legal-ai/scripts/voyage_rerank_corpus_poc.py
"""
from __future__ import annotations
import asyncio
import json
import math
import os
import re
import subprocess
import sys
import time
from collections import defaultdict
ENV_PATH = os.path.expanduser("~/.env")
if os.path.isfile(ENV_PATH):
with open(ENV_PATH) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, v = line.split("=", 1)
os.environ.setdefault(k, v)
import asyncpg # noqa: E402
import voyageai # noqa: E402
TEXT_MODEL = "voyage-3"
RERANK_MODEL = "rerank-2"
JUDGE_MODEL = "claude-haiku-4-5-20251001"
TOP_VEC = 50 # voyage-3 retrieve depth
TOP_K = 10 # final returned to "agent"
JUDGE_K = 5 # how many top results to actually judge per retriever
# 12 queries spanning typical use cases by Daphna's agents:
# precedent search for citing in decision blocks י-יא.
QUERIES = [
# K — keyword
("K1", "פיצויים לפי סעיף 197"),
("K2", "תמ\"א 38 והשבחה"),
("K3", "כלל הנטרול בשמאות"),
# C — conceptual
("C1", "תכלית היטל ההשבחה"),
("C2", "מה מקנה לבעלים זכות לפיצוי"),
("C3", "ההבחנה בין השבחה לפיצויים"),
# N — narrative / context-aware
("N1", "מה נקבע לגבי תמ\"א 38 בפסיקה"),
("N2", "ההלכה לעניין נטרול ציפיות"),
("N3", "תכנית פוגעת ושומה"),
# P — practical (drafting needs — what an agent typically asks)
("P1", "פסיקה שדנה בתכנית מתאר ארצית"),
("P2", "מתי מותר לוועדה לדחות פיצויים"),
("P3", "שיקול דעת הוועדה המקומית"),
]
def cosine(a, b):
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(y * y for y in b))
return dot / (na * nb) if na and nb else 0.0
def parse_pgvector(s):
return [float(x) for x in s.strip("[]").split(",")]
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
לפניך שאילתה ומספר פסקאות מפסקי דין/הלכות. דרג כל פסקה 1-5 לפי רלוונטיות.
5 — תשובה ישירה למה שנשאל
4 — מאד רלוונטי, מכיל מידע ליבה
3 — רלוונטי חלקית, נוגע בעקיפין
2 — מעט קשור, רעש סביב הנושא
1 — לא רלוונטי בכלל
השאילתה:
{query}
הפסקאות:
{chunks_block}
החזר JSON בלבד: {{"scores": {{"<id>": <1-5>, ...}}}}
ללא טקסט נוסף, ללא ```."""
def batch_judge(query: str, items: list[tuple[str, str]]) -> dict[str, int]:
"""Judge (id, text) pairs via claude CLI. Returns {id: score}."""
blocks = []
for cid, content in items:
snippet = content.replace("\n", " ").strip()[:1500]
blocks.append(f"<id={cid}>\n{snippet}\n</id>")
prompt = BATCH_JUDGE_PROMPT.format(
query=query, chunks_block="\n\n".join(blocks))
proc = subprocess.run(
["claude", "-p", "--model", JUDGE_MODEL],
input=prompt, capture_output=True, text=True, timeout=180,
)
out = proc.stdout.strip()
out = re.sub(r"^```(?:json)?\s*", "", out)
out = re.sub(r"\s*```$", "", out)
try:
data = json.loads(out)
raw = data.get("scores", {})
return {str(k): int(v) for k, v in raw.items()
if str(v).isdigit() and 1 <= int(v) <= 5}
except (json.JSONDecodeError, ValueError, TypeError) as e:
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
return {}
async def main():
voyage_key = os.environ["VOYAGE_API_KEY"]
pg_pw = os.environ["POSTGRES_PASSWORD"]
try:
subprocess.run(["claude", "--version"], capture_output=True,
text=True, timeout=10, check=True)
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
sys.exit("claude CLI not found")
voyage = voyageai.Client(api_key=voyage_key)
pool = await asyncpg.create_pool(
host="127.0.0.1", port=5433, user="legal_ai",
password=pg_pw, database="legal_ai",
min_size=1, max_size=2,
)
# Load full corpus: precedent_chunks + halachot
pc_rows = await pool.fetch("""
SELECT 'pc:' || id::text AS doc_id,
content,
embedding::text AS emb_text
FROM precedent_chunks
WHERE content IS NOT NULL AND embedding IS NOT NULL
""")
h_rows = await pool.fetch("""
SELECT 'h:' || id::text AS doc_id,
TRIM(BOTH '' FROM rule_statement || '' ||
COALESCE(reasoning_summary, '')) AS content,
embedding::text AS emb_text
FROM halachot
WHERE rule_statement IS NOT NULL AND embedding IS NOT NULL
""")
all_rows = list(pc_rows) + list(h_rows)
print(f"[load] corpus: {len(pc_rows)} precedent_chunks + "
f"{len(h_rows)} halachot = {len(all_rows)} total")
doc_ids = [r["doc_id"] for r in all_rows]
contents = [r["content"] for r in all_rows]
embs = [parse_pgvector(r["emb_text"]) for r in all_rows]
# Latency measurement: 5 queries, time the two pipelines
print("\n[latency] measuring 5 sample queries…")
sample = QUERIES[:5]
r1_lat = []
r2_lat = []
for _, query in sample:
# R1: voyage-3 embed + cosine top-10
t0 = time.time()
q_emb = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
scores = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
reverse=True)[:TOP_K]
r1_lat.append(time.time() - t0)
# R2: voyage-3 embed + cosine top-50 + rerank-2 → top-10
t0 = time.time()
q_emb = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
cands = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
reverse=True)[:TOP_VEC]
cand_texts = [contents[i] for _, i in cands]
rr = voyage.rerank(query=query, documents=cand_texts,
model=RERANK_MODEL, top_k=TOP_K)
r2_lat.append(time.time() - t0)
print(f" R1 (voyage-3 only) avg={sum(r1_lat)/5*1000:.0f}ms"
f" min={min(r1_lat)*1000:.0f} max={max(r1_lat)*1000:.0f}")
print(f" R2 (voyage-3 + rerank-2) avg={sum(r2_lat)/5*1000:.0f}ms"
f" min={min(r2_lat)*1000:.0f} max={max(r2_lat)*1000:.0f}")
print(f" Δ (rerank overhead) avg={(sum(r2_lat)-sum(r1_lat))/5*1000:.0f}ms")
# Retrieval functions
def r1_baseline(query: str, k: int = TOP_K) -> list[int]:
q = voyage.embed([query], model=TEXT_MODEL,
input_type="query").embeddings[0]
scores = sorted([(cosine(q, e), i) for i, e in enumerate(embs)],
reverse=True)
return [i for _, i in scores[:k]]
def r2_rerank(query: str, k: int = TOP_K) -> list[int]:
cands = r1_baseline(query, k=TOP_VEC)
cand_texts = [contents[i] for i in cands]
rr = voyage.rerank(query=query, documents=cand_texts,
model=RERANK_MODEL, top_k=k)
return [cands[r.index] for r in rr.results]
retrievers = [("R1-voyage3", r1_baseline),
("R2-rerank2", r2_rerank)]
print(f"\n[judge] running {len(QUERIES)} queries × 2 retrievers, "
f"top-{JUDGE_K} judged…")
all_results = []
for qid, query in QUERIES:
print(f"\n[{qid}] {query}")
retr_results = {}
for r_name, r_fn in retrievers:
try:
retr_results[r_name] = r_fn(query, k=JUDGE_K)
except Exception as e:
print(f" {r_name}: FAILED — {e}")
retr_results[r_name] = []
union = sorted({i for top in retr_results.values() for i in top})
items = [(doc_ids[i], contents[i]) for i in union]
print(f" judging {len(items)} unique docs…")
scores_map = batch_judge(query, items)
for r_name, top in retr_results.items():
scores = [scores_map.get(doc_ids[i], 0) for i in top]
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
mean5 = sum(scores) / len(scores) if scores else 0
mrr = 0.0
for r, s in enumerate(scores):
if s >= 4:
mrr = 1.0 / (r + 1)
break
print(f" {r_name}: doc_ids={[doc_ids[i][:14] for i in top]} "
f"scores={scores} m@3={mean3:.2f} m@5={mean5:.2f} "
f"MRR={mrr:.3f}")
all_results.append({
"qid": qid, "category": qid[0], "query": query,
"retriever": r_name,
"doc_ids": [doc_ids[i] for i in top],
"scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr,
})
# Aggregate
print("\n" + "=" * 100)
print("AGGREGATED RESULTS — full precedent_library corpus (785 docs)")
print("=" * 100)
by_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
by_cat_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
for r in all_results:
by_r[r["retriever"]]["mean3"].append(r["mean3"])
by_r[r["retriever"]]["mean5"].append(r["mean5"])
by_r[r["retriever"]]["mrr"].append(r["mrr"])
ck = (r["category"], r["retriever"])
by_cat_r[ck]["mean3"].append(r["mean3"])
by_cat_r[ck]["mean5"].append(r["mean5"])
by_cat_r[ck]["mrr"].append(r["mrr"])
print(f"\nOverall ({len(QUERIES)} queries):")
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
avg = lambda xs: sum(xs) / len(xs) if xs else 0
for r_name, _ in retrievers:
m = by_r[r_name]
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
# Improvement
r1m = avg(by_r["R1-voyage3"]["mean3"])
r2m = avg(by_r["R2-rerank2"]["mean3"])
if r1m > 0:
print(f"\nR2 vs R1 improvement: "
f"mean@3 {(r2m - r1m) / r1m * 100:+.1f}%")
print(f"\nBy category:")
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} "
f"{'MRR':>8}")
for cat in ["K", "C", "N", "P"]:
for r_name, _ in retrievers:
m = by_cat_r[(cat, r_name)]
if not m["mean3"]:
continue
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
print(f"\nPer-query winner (highest mean@3):")
print(f"{'qid':<4} {'query':<40} {'winner':<14} {'scores'}")
by_q = defaultdict(list)
for r in all_results:
by_q[r["qid"]].append(r)
for qid, results in sorted(by_q.items()):
max_s = max(r["mean3"] for r in results)
winners = [r["retriever"] for r in results if r["mean3"] == max_s]
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
for r in results)
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:38]
print(f"{qid:<4} {q_str:<40} {','.join(w[:8] for w in winners):<14} "
f"{scores}")
out_path = "/tmp/voyage_rerank_corpus_results.json"
with open(out_path, "w") as f:
json.dump(all_results, f, ensure_ascii=False, indent=2)
print(f"\nSaved to {out_path}")
await pool.close()
if __name__ == "__main__":
asyncio.run(main())