feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).
POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage
Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
unavailable.
- tools/search.py: search_decisions, search_case_documents,
find_similar_cases all wrapped
- services/precedent_library.search_library wrapped
Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.
POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
361
scripts/voyage_rerank_judge_poc.py
Normal file
361
scripts/voyage_rerank_judge_poc.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""POC #4: Comprehensive retrieval benchmark with LLM-as-judge.
|
||||
|
||||
Compares 3 retrievers on אהרון ברק 403/17 (219 chunks):
|
||||
R1 — voyage-3 (current production baseline)
|
||||
R2 — voyage-3 + voyage-rerank-2 (retrieve 50, rerank, top-10)
|
||||
R3 — voyage-context-3 (windowed, from POC #2)
|
||||
|
||||
Judges relevance with claude-haiku-4-5 — for each (query, chunk) pair the
|
||||
judge returns 1-5. Aggregates: mean relevance@3, @5, @10, MRR (rank of
|
||||
first 4+ chunk), per-query winner.
|
||||
|
||||
20 queries grouped into 3 categories so we can see *which* query types
|
||||
benefit from which retriever:
|
||||
K — keyword/lexical (term-heavy, specific entity)
|
||||
C — conceptual (abstract idea, principle)
|
||||
N — narrative/contextual (requires document-internal reference)
|
||||
|
||||
Usage (key passed via env, NOT stored in script):
|
||||
ANTHROPIC_API_KEY=... \\
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_rerank_judge_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # אהרון ברק 403/17
|
||||
TEXT_MODEL = "voyage-3"
|
||||
CONTEXT_MODEL = "voyage-context-3"
|
||||
RERANK_MODEL = "rerank-2"
|
||||
JUDGE_MODEL = "claude-haiku-4-5-20251001"
|
||||
|
||||
WINDOW_SIZE = 80
|
||||
WINDOW_STRIDE = 70
|
||||
|
||||
# 18 queries × 3 retrievers × top-5 = 270 judge calls. ~$0.05 with haiku.
|
||||
QUERIES = [
|
||||
# K — keyword/lexical
|
||||
("K1", "תכנית רחביה הוראות בנייה"),
|
||||
("K2", "תמ\"א 38"),
|
||||
("K3", "תכנית 9988"),
|
||||
("K4", "סעיף 197 לחוק התכנון והבניה"),
|
||||
("K5", "השופט גרוסקופף"),
|
||||
("K6", "ועדה מקומית ירושלים"),
|
||||
# C — conceptual / abstract principles
|
||||
("C1", "כלל הנטרול של זכויות תכנוניות"),
|
||||
("C2", "אינטרס הציבור בתכנון"),
|
||||
("C3", "תכלית היטל ההשבחה"),
|
||||
("C4", "תכנית פוגעת לעומת תכנית משביחה"),
|
||||
("C5", "ההבחנה בין השבחה לפיצויים"),
|
||||
("C6", "מהותו של היטל ההשבחה"),
|
||||
# N — narrative / context-dependent
|
||||
("N1", "מה נקבע לגבי תמ\"א 38 בפסק הדין"),
|
||||
("N2", "מסקנת בית המשפט בעניין רובע 3"),
|
||||
("N3", "ההלכה שנקבעה בעניין שמעוני"),
|
||||
("N4", "ההבדל בין המקרה שלפנינו לעניין רון"),
|
||||
("N5", "סוף דבר ותוצאת פסק הדין"),
|
||||
("N6", "הסכמת השופטים האחרים לחוות הדעת"),
|
||||
]
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s):
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
def build_windows(n, size, stride):
|
||||
out = []
|
||||
s = 0
|
||||
while s < n:
|
||||
e = min(s + size, n)
|
||||
out.append((s, e))
|
||||
if e == n:
|
||||
break
|
||||
s += stride
|
||||
return out
|
||||
|
||||
|
||||
def central_window(idx, windows):
|
||||
best, best_d = -1, -1
|
||||
for w_idx, (s, e) in enumerate(windows):
|
||||
if not (s <= idx < e):
|
||||
continue
|
||||
d = min(idx - s, (e - 1) - idx)
|
||||
if d > best_d:
|
||||
best_d = d
|
||||
best = w_idx
|
||||
return best
|
||||
|
||||
|
||||
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
|
||||
לפניך שאילתה ומספר פסקאות מפסק דין. דרג כל פסקה בנפרד 1-5 לפי רלוונטיות.
|
||||
|
||||
סולם:
|
||||
5 — תשובה ישירה ומדויקת לשאילתה
|
||||
4 — מאד רלוונטי, מכיל מידע ליבה
|
||||
3 — רלוונטי חלקית, נוגע בעקיפין בנושא
|
||||
2 — מעט קשור, רעש סביב הנושא
|
||||
1 — לא רלוונטי בכלל
|
||||
|
||||
השאילתה:
|
||||
{query}
|
||||
|
||||
הפסקאות:
|
||||
{chunks_block}
|
||||
|
||||
החזר JSON בלבד, בפורמט: {{"scores": {{"<id>": <1-5>, ...}}}}
|
||||
ללא טקסט נוסף, ללא explanations, ללא ```."""
|
||||
|
||||
|
||||
def batch_judge(query: str,
|
||||
items: list[tuple[int, str]]) -> dict[int, int]:
|
||||
"""Judge a list of (chunk_idx, content) pairs in a single CLI call.
|
||||
|
||||
Returns: dict[chunk_idx → score 1-5]. Returns 0 for parse failures.
|
||||
"""
|
||||
chunks_block_lines = []
|
||||
for ci, content in items:
|
||||
snippet = content.replace("\n", " ").strip()[:1500]
|
||||
chunks_block_lines.append(f"<id={ci}>\n{snippet}\n</id>")
|
||||
prompt = BATCH_JUDGE_PROMPT.format(
|
||||
query=query,
|
||||
chunks_block="\n\n".join(chunks_block_lines),
|
||||
)
|
||||
proc = subprocess.run(
|
||||
["claude", "-p", "--model", JUDGE_MODEL],
|
||||
input=prompt, capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
out = proc.stdout.strip()
|
||||
# Strip ```json fences if any
|
||||
out = re.sub(r"^```(?:json)?\s*", "", out)
|
||||
out = re.sub(r"\s*```$", "", out)
|
||||
try:
|
||||
data = json.loads(out)
|
||||
raw = data.get("scores", {})
|
||||
return {int(k): int(v) for k, v in raw.items()
|
||||
if str(v).isdigit() and 1 <= int(v) <= 5}
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
|
||||
return {}
|
||||
|
||||
|
||||
async def main():
|
||||
voyage_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
# Verify Claude CLI is available (uses OAuth from ~/.claude/.credentials)
|
||||
try:
|
||||
subprocess.run(["claude", "--version"], capture_output=True,
|
||||
text=True, timeout=10, check=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
|
||||
sys.exit("claude CLI not found or not authenticated")
|
||||
|
||||
voyage = voyageai.Client(api_key=voyage_key)
|
||||
|
||||
# Load chunks + voyage-3 embeddings
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
rows = await pool.fetch("""
|
||||
SELECT chunk_index, content, embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE case_law_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", CASE_ID)
|
||||
chunks = [r["content"] for r in rows]
|
||||
chunk_indices = [r["chunk_index"] for r in rows]
|
||||
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
n = len(chunks)
|
||||
print(f"[load] {n} chunks loaded")
|
||||
|
||||
# Compute context-3 (windowed) embeddings — same as POC #2
|
||||
windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE)
|
||||
print(f"[context-3] embedding {len(windows)} windows…")
|
||||
win_embs = []
|
||||
for s, e in windows:
|
||||
result = voyage.contextualized_embed(
|
||||
inputs=[chunks[s:e]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
win_embs.append(result.results[0].embeddings)
|
||||
context_embs = []
|
||||
for i in range(n):
|
||||
w = central_window(i, windows)
|
||||
s, _ = windows[w]
|
||||
context_embs.append(win_embs[w][i - s])
|
||||
print(f"[context-3] done")
|
||||
|
||||
# Retrieval functions
|
||||
def r1_baseline(query: str, k: int = 10) -> list[int]:
|
||||
q = voyage.embed([query], model=TEXT_MODEL,
|
||||
input_type="query").embeddings[0]
|
||||
scores = sorted(
|
||||
[(cosine(q, e), i) for i, e in enumerate(baseline_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
return [i for _, i in scores[:k]]
|
||||
|
||||
def r2_rerank(query: str, k: int = 10) -> list[int]:
|
||||
# 1) voyage-3 retrieve top-50
|
||||
cands = r1_baseline(query, k=50)
|
||||
cand_texts = [chunks[i] for i in cands]
|
||||
# 2) voyage-rerank-2 over the 50
|
||||
rr = voyage.rerank(
|
||||
query=query, documents=cand_texts,
|
||||
model=RERANK_MODEL, top_k=k,
|
||||
)
|
||||
# rr.results: list of RerankingResult(index=..., relevance_score=...)
|
||||
# `index` refers to position in cand_texts → map back to chunk idx
|
||||
return [cands[r.index] for r in rr.results]
|
||||
|
||||
def r3_context(query: str, k: int = 10) -> list[int]:
|
||||
q = voyage.contextualized_embed(
|
||||
inputs=[[query]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="query",
|
||||
).results[0].embeddings[0]
|
||||
scores = sorted(
|
||||
[(cosine(q, e), i) for i, e in enumerate(context_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
return [i for _, i in scores[:k]]
|
||||
|
||||
retrievers = [("R1-voyage3", r1_baseline),
|
||||
("R2-rerank2", r2_rerank),
|
||||
("R3-context3", r3_context)]
|
||||
|
||||
# Run all queries × all retrievers, judging top-5 per pair.
|
||||
# Strategy: for each query, gather the union of all retrievers' top-K
|
||||
# and judge them in ONE batched CLI call → 18 calls total instead of 270.
|
||||
all_results = []
|
||||
JUDGE_TOP_K = 5
|
||||
print(f"\n[judge] running {len(QUERIES)} queries × "
|
||||
f"{len(retrievers)} retrievers × top-{JUDGE_TOP_K} — batched per query…")
|
||||
|
||||
for qid, query in QUERIES:
|
||||
print(f"\n[{qid}] {query}")
|
||||
# Collect retrievals first
|
||||
retr_results = {}
|
||||
for r_name, r_fn in retrievers:
|
||||
try:
|
||||
retr_results[r_name] = r_fn(query, k=JUDGE_TOP_K)
|
||||
except Exception as e:
|
||||
print(f" {r_name}: FAILED — {e}")
|
||||
retr_results[r_name] = []
|
||||
# Union of unique chunk indices to judge
|
||||
union = sorted({i for top in retr_results.values() for i in top})
|
||||
items = [(i, chunks[i]) for i in union]
|
||||
print(f" judging {len(items)} unique chunks via batch CLI…")
|
||||
scores_map = batch_judge(query, items)
|
||||
# Build per-retriever score lists
|
||||
for r_name, top in retr_results.items():
|
||||
scores = [scores_map.get(i, 0) for i in top]
|
||||
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
|
||||
mean5 = sum(scores) / len(scores) if scores else 0
|
||||
mrr = 0.0
|
||||
for r, s in enumerate(scores):
|
||||
if s >= 4:
|
||||
mrr = 1.0 / (r + 1)
|
||||
break
|
||||
print(f" {r_name}: chunks={[chunk_indices[i] for i in top]} "
|
||||
f"scores={scores} mean@3={mean3:.2f} mean@5={mean5:.2f} "
|
||||
f"MRR={mrr:.3f}")
|
||||
all_results.append({
|
||||
"qid": qid, "category": qid[0], "query": query,
|
||||
"retriever": r_name,
|
||||
"chunks": [chunk_indices[i] for i in top],
|
||||
"scores": scores,
|
||||
"mean3": mean3, "mean5": mean5, "mrr": mrr,
|
||||
})
|
||||
|
||||
# Aggregate
|
||||
print("\n" + "=" * 100)
|
||||
print("AGGREGATED RESULTS")
|
||||
print("=" * 100)
|
||||
|
||||
by_retriever = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
by_cat_retriever = defaultdict(
|
||||
lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
for r in all_results:
|
||||
by_retriever[r["retriever"]]["mean3"].append(r["mean3"])
|
||||
by_retriever[r["retriever"]]["mean5"].append(r["mean5"])
|
||||
by_retriever[r["retriever"]]["mrr"].append(r["mrr"])
|
||||
cat_key = (r["category"], r["retriever"])
|
||||
by_cat_retriever[cat_key]["mean3"].append(r["mean3"])
|
||||
by_cat_retriever[cat_key]["mean5"].append(r["mean5"])
|
||||
by_cat_retriever[cat_key]["mrr"].append(r["mrr"])
|
||||
|
||||
print("\nOverall (across all 18 queries):")
|
||||
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
|
||||
for r_name, _ in retrievers:
|
||||
m = by_retriever[r_name]
|
||||
avg = lambda xs: sum(xs) / len(xs) if xs else 0
|
||||
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
|
||||
print("\nBy category (K=keyword, C=conceptual, N=narrative):")
|
||||
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
|
||||
for cat in ["K", "C", "N"]:
|
||||
for r_name, _ in retrievers:
|
||||
m = by_cat_retriever[(cat, r_name)]
|
||||
avg = lambda xs: sum(xs) / len(xs) if xs else 0
|
||||
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
|
||||
print("\nPer-query winner (highest mean@3, ties shown):")
|
||||
print(f"{'qid':<4} {'query':<45} {'winner':<24} {'scores'}")
|
||||
by_query = defaultdict(list)
|
||||
for r in all_results:
|
||||
by_query[r["qid"]].append(r)
|
||||
for qid, results in sorted(by_query.items()):
|
||||
max_score = max(r["mean3"] for r in results)
|
||||
winners = [r["retriever"] for r in results if r["mean3"] == max_score]
|
||||
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
|
||||
for r in results)
|
||||
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:42]
|
||||
print(f"{qid:<4} {q_str:<45} {','.join(w[:8] for w in winners):<24} "
|
||||
f"{scores}")
|
||||
|
||||
# Save raw results to JSON for further analysis
|
||||
out_path = "/tmp/voyage_rerank_judge_results.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=2)
|
||||
print(f"\nRaw results saved to {out_path}")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user