feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s

Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-03 18:43:41 +00:00
parent 688ba37d9c
commit 26c3fddf41
13 changed files with 1578 additions and 100 deletions

View File

@@ -0,0 +1,238 @@
"""POC #2: voyage-3 vs voyage-context-3 on a LONG case (אהרון ברק 403/17).
Case is 178K chars / 219 chunks / ~60K tokens — too big for a single
contextualized_embed call (32K token limit per inner list). We split the
chunks into overlapping sliding windows (~80 chunks each, ~22K tokens)
and merge: each chunk gets the embedding from the window where it sits
*most centrally* (max symmetric context on both sides).
The hypothesis: voyage-context-3 should shine here because the case is
full of internal references ("ראה לעיל סעיף 13", "להבדיל מעניין X",
"תוצאת הבחינה ב-בר"מ 1975/24 שנידונה לעיל"). voyage-3 embeds chunks
in isolation; context-3 sees ~80 surrounding chunks per embedding.
No DB writes. Output: side-by-side ranking comparison + summary.
Usage:
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
/home/chaim/legal-ai/scripts/voyage_context3_poc_long.py
"""
from __future__ import annotations
import asyncio
import math
import os
import sys
import time
ENV_PATH = os.path.expanduser("~/.env")
if os.path.isfile(ENV_PATH):
with open(ENV_PATH) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
k, v = line.split("=", 1)
os.environ.setdefault(k, v)
import asyncpg # noqa: E402
import voyageai # noqa: E402
CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # 403/17 אהרון ברק (178K chars)
CONTEXT_MODEL = "voyage-context-3"
BASELINE_MODEL = "voyage-3"
# Sliding-window split params. With 219 chunks and ~60K tokens total
# (~275 tokens/chunk average), 3 windows of 80 chunks each is ~22K tokens
# per call — comfortably under 32K.
WINDOW_SIZE = 80
WINDOW_STRIDE = 70 # overlap = WINDOW_SIZE - WINDOW_STRIDE = 10
# Mix of:
# (a) generic queries (also tested in POC #1)
# (b) queries that require *internal* document context
QUERIES = [
# generic
"תכנית רחביה הוראות בנייה",
"פיצויים לפי סעיף 197 ירידת ערך",
"השפעת תכנית על שווי מקרקעין",
"סמכות ועדת ערר לדון בפיצויים",
"תוספת זכויות בנייה כפיצוי",
# internal-context — should benefit context-3
"ההבחנה בין השבחה לפיצויים",
"מה נקבע לגבי תמ\"א 38 בפסק הדין",
"ההלכה שנקבעה בעניין רובע 3",
"כלל הנטרול של זכויות תכנוניות",
"הסכמת השופט אלרון לחוות הדעת",
]
def cosine(a: list[float], b: list[float]) -> float:
dot = sum(x * y for x, y in zip(a, b))
na = math.sqrt(sum(x * x for x in a))
nb = math.sqrt(sum(y * y for y in b))
return dot / (na * nb) if na and nb else 0.0
def parse_pgvector(s: str) -> list[float]:
return [float(x) for x in s.strip("[]").split(",")]
def build_windows(n: int, size: int, stride: int) -> list[tuple[int, int]]:
"""Return list of (start, end) ranges (end exclusive) covering 0..n.
Last window extends to n exactly. Overlap = size - stride.
"""
windows = []
start = 0
while start < n:
end = min(start + size, n)
windows.append((start, end))
if end == n:
break
start += stride
return windows
def assign_chunk_to_window(
chunk_idx: int, windows: list[tuple[int, int]],
) -> int:
"""Pick the window where chunk_idx sits most centrally (max symmetric
distance to either edge). Ties broken by larger window."""
best = -1
best_score = -1
for w_idx, (s, e) in enumerate(windows):
if not (s <= chunk_idx < e):
continue
# symmetric distance: min(distance to s, distance to e-1)
dist = min(chunk_idx - s, (e - 1) - chunk_idx)
if dist > best_score:
best_score = dist
best = w_idx
return best
async def main():
api_key = os.environ["VOYAGE_API_KEY"]
pg_pw = os.environ["POSTGRES_PASSWORD"]
voyage = voyageai.Client(api_key=api_key)
pool = await asyncpg.create_pool(
host="127.0.0.1", port=5433, user="legal_ai",
password=pg_pw, database="legal_ai",
min_size=1, max_size=2,
)
rows = await pool.fetch("""
SELECT chunk_index, content, embedding::text AS emb_text
FROM precedent_chunks
WHERE case_law_id = $1
ORDER BY chunk_index
""", CASE_ID)
n = len(rows)
print(f"[load] {n} chunks from אהרון ברק 403/17")
chunks = [r["content"] for r in rows]
indices = [r["chunk_index"] for r in rows]
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
# Build windows
windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE)
print(f"[windows] {len(windows)} windows: "
f"{', '.join(f'[{s}:{e})' for s, e in windows)}")
# Embed each window with context-3
window_embs: list[list[list[float]]] = [] # [window][chunk_in_window][dim]
total_call_tokens = 0
total_start = time.time()
for w_idx, (s, e) in enumerate(windows):
sub_chunks = chunks[s:e]
sub_chars = sum(len(c) for c in sub_chunks)
start = time.time()
result = voyage.contextualized_embed(
inputs=[sub_chunks],
model=CONTEXT_MODEL,
input_type="document",
)
elapsed = time.time() - start
toks = getattr(result, "total_tokens", 0)
total_call_tokens += toks
print(f" [window {w_idx}] [{s}:{e}) — {len(sub_chunks)} chunks, "
f"{sub_chars:,} chars, {toks} tokens — {elapsed:.1f}s")
window_embs.append(result.results[0].embeddings)
total_elapsed = time.time() - total_start
print(f"[context] all windows done in {total_elapsed:.1f}s, "
f"{total_call_tokens} total tokens")
# Merge: for each chunk, pick the embedding from its most-central window
context_embs: list[list[float]] = []
chunk_window_choice = []
for i in range(n):
w_idx = assign_chunk_to_window(i, windows)
chunk_window_choice.append(w_idx)
s, _ = windows[w_idx]
context_embs.append(window_embs[w_idx][i - s])
print(f"[merge] window distribution: "
f"{[chunk_window_choice.count(j) for j in range(len(windows))]}")
# Run queries
print("\n" + "=" * 100)
print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3 (windowed)':<48}")
print("=" * 100)
rank_overlaps = []
for q_idx, query in enumerate(QUERIES, 1):
q_baseline = voyage.embed(
[query], model=BASELINE_MODEL, input_type="query"
).embeddings[0]
q_context = voyage.contextualized_embed(
inputs=[[query]],
model=CONTEXT_MODEL,
input_type="query",
).results[0].embeddings[0]
scores_b = sorted(
[(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)],
reverse=True,
)
scores_c = sorted(
[(cosine(q_context, e), i) for i, e in enumerate(context_embs)],
reverse=True,
)
top10_b = [i for _, i in scores_b[:10]]
top10_c = [i for _, i in scores_c[:10]]
overlap = len(set(top10_b) & set(top10_c))
rank_overlaps.append(overlap)
print(f"\n[Q{q_idx}] {query}")
print(f" overlap top-10: {overlap}/10 | "
f"avg score top-3: baseline="
f"{sum(s for s, _ in scores_b[:3])/3:.3f} "
f"context-3={sum(s for s, _ in scores_c[:3])/3:.3f}")
for rank in range(5):
sb, ib = scores_b[rank]
sc, ic = scores_c[rank]
cb = chunks[ib].replace("\n", " ").strip()[:50]
cc = chunks[ic].replace("\n", " ").strip()[:50]
print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} "
f"| [{indices[ic]:3d}] {sc:.3f} {cc}")
print("\n" + "=" * 100)
print("SUMMARY")
print("=" * 100)
avg = sum(rank_overlaps) / len(rank_overlaps)
print(f"Avg overlap top-10: {avg:.1f}/10")
print(f"Per-query overlap: {rank_overlaps}")
print(f"Total context-3 tokens used: {total_call_tokens:,} "
f"(in {len(windows)} calls)")
print(f"\nNote: cosine across models not directly comparable. The")
print(f"meaningful test is *which chunks bubble to the top* — read")
print(f"the actual text above to judge relevance.")
await pool.close()
if __name__ == "__main__":
asyncio.run(main())