feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).
POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage
Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
unavailable.
- tools/search.py: search_decisions, search_case_documents,
find_similar_cases all wrapped
- services/precedent_library.search_library wrapped
Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.
POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
182
scripts/voyage_context3_poc.py
Normal file
182
scripts/voyage_context3_poc.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""POC: Compare voyage-3 vs voyage-context-3 retrieval on case 403/17.
|
||||
|
||||
Pulls all chunks of "אהרון ברק - תכנית רחביה" (case_law_id=e151fc25-...),
|
||||
runs them through voyage-context-3 in a single contextualized_embed call,
|
||||
then runs benchmark queries and compares rankings against the existing
|
||||
voyage-3 embeddings (already in the DB).
|
||||
|
||||
No DB writes — all comparisons in memory. Output: ranking table for each
|
||||
query showing top-10 from both models side-by-side.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_context3_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Load ~/.env
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
# Using קלמנוביץ/לויתן (52K chars, 63 chunks, ~18K tokens)
|
||||
# — fits in single context-3 call (32K token limit per inner list).
|
||||
# אהרון ברק (60K tokens) requires splitting; we'll handle that after POC.
|
||||
CASE_ID = "436efd48-c8ab-49f0-b3a9-52bf15ea806d" # בר"מ 25226-04-25
|
||||
CONTEXT_MODEL = "voyage-context-3"
|
||||
BASELINE_MODEL = "voyage-3" # already in DB
|
||||
|
||||
QUERIES = [
|
||||
"סמכות ועדת ערר",
|
||||
"פיצויים לפי סעיף 197",
|
||||
"ירידת ערך מקרקעין",
|
||||
"תכנית פוגעת",
|
||||
"שיקול דעת ועדה מקומית",
|
||||
"חוות דעת שמאי מכריע",
|
||||
"מקרקעין גובלים",
|
||||
"תקופת התיישנות תביעה",
|
||||
"אינטרס ציבורי בתכנון",
|
||||
"דחיית תביעת פיצויים",
|
||||
]
|
||||
|
||||
|
||||
def cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s: str) -> list[float]:
|
||||
"""pgvector text format: '[0.1,0.2,...]'."""
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
voyage = voyageai.Client(api_key=api_key)
|
||||
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
|
||||
# 1. Pull all chunks + their existing voyage-3 embeddings
|
||||
rows = await pool.fetch("""
|
||||
SELECT chunk_index, content, embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE case_law_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", CASE_ID)
|
||||
print(f"[load] {len(rows)} chunks from case 403/17")
|
||||
|
||||
chunks = [r["content"] for r in rows]
|
||||
indices = [r["chunk_index"] for r in rows]
|
||||
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
|
||||
# 2. Embed all chunks with voyage-context-3 — single contextualized call
|
||||
total_chars = sum(len(c) for c in chunks)
|
||||
print(f"[context] embedding {len(chunks)} chunks, {total_chars:,} chars total")
|
||||
start = time.time()
|
||||
result = voyage.contextualized_embed(
|
||||
inputs=[chunks], # one document = one inner list
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
# ContextualizedEmbeddingsObject: result.results = list of per-document
|
||||
# embeddings. result.results[0].embeddings = list of chunk embeddings.
|
||||
context_embs = result.results[0].embeddings
|
||||
total_tokens = getattr(result, "total_tokens", "?")
|
||||
print(f"[context] done in {elapsed:.1f}s — total_tokens={total_tokens}")
|
||||
assert len(context_embs) == len(chunks), "embedding count mismatch"
|
||||
|
||||
# 3. For each query — embed twice and compare top-10
|
||||
print("\n" + "=" * 100)
|
||||
print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3':<48}")
|
||||
print("=" * 100)
|
||||
|
||||
rank_overlaps = []
|
||||
score_lifts = []
|
||||
|
||||
for q_idx, query in enumerate(QUERIES, 1):
|
||||
# Baseline query embedding (regular embed)
|
||||
q_baseline = voyage.embed(
|
||||
[query], model=BASELINE_MODEL, input_type="query"
|
||||
).embeddings[0]
|
||||
# Context query embedding — must use contextualized_embed even for
|
||||
# single-string queries (regular embed() rejects voyage-context-3).
|
||||
q_context = voyage.contextualized_embed(
|
||||
inputs=[[query]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="query",
|
||||
).results[0].embeddings[0]
|
||||
|
||||
# Score every chunk under both models
|
||||
scores_b = sorted(
|
||||
[(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
scores_c = sorted(
|
||||
[(cosine(q_context, e), i) for i, e in enumerate(context_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
top10_b = [i for _, i in scores_b[:10]]
|
||||
top10_c = [i for _, i in scores_c[:10]]
|
||||
|
||||
# Compute overlap and avg score in top-3
|
||||
overlap = len(set(top10_b) & set(top10_c))
|
||||
avg_b_top3 = sum(s for s, _ in scores_b[:3]) / 3
|
||||
avg_c_top3 = sum(s for s, _ in scores_c[:3]) / 3
|
||||
rank_overlaps.append(overlap)
|
||||
score_lifts.append(avg_c_top3 - avg_b_top3)
|
||||
|
||||
print(f"\n[Q{q_idx}] {query}")
|
||||
print(f" overlap top-10: {overlap}/10 | avg score top-3: "
|
||||
f"baseline={avg_b_top3:.3f} context-3={avg_c_top3:.3f} "
|
||||
f"Δ={avg_c_top3 - avg_b_top3:+.3f}")
|
||||
for rank in range(5):
|
||||
sb, ib = scores_b[rank]
|
||||
sc, ic = scores_c[rank]
|
||||
cb = chunks[ib].replace("\n", " ").strip()[:50]
|
||||
cc = chunks[ic].replace("\n", " ").strip()[:50]
|
||||
print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} "
|
||||
f"| [{indices[ic]:3d}] {sc:.3f} {cc}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 100)
|
||||
print("SUMMARY")
|
||||
print("=" * 100)
|
||||
avg_overlap = sum(rank_overlaps) / len(rank_overlaps)
|
||||
avg_lift = sum(score_lifts) / len(score_lifts)
|
||||
print(f"Avg overlap top-10: {avg_overlap:.1f}/10 "
|
||||
f"(higher = models agree more)")
|
||||
print(f"Avg score lift top-3 (context - baseline): {avg_lift:+.4f}")
|
||||
print(f"\nNote: cosine scores are not directly comparable across models.")
|
||||
print(f"What matters more is which CHUNKS bubble to the top —")
|
||||
print(f"reading the actual content above tells the real story.")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user