feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s

Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-03 18:43:41 +00:00
parent 688ba37d9c
commit 26c3fddf41
13 changed files with 1578 additions and 100 deletions

View File

@@ -47,6 +47,17 @@ VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", "")
VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-law-2")
VOYAGE_DIMENSIONS = 1024
# Rerank — cross-encoder second-stage. Off by default; flip with env to
# enable across all semantic search tools (search_decisions,
# search_case_documents, find_similar_cases, search_precedent_library).
VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2")
VOYAGE_RERANK_ENABLED = (
os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true"
)
# How many candidates to fetch from bi-encoder before reranking.
# 50 was the depth used in the POC; balances recall vs rerank cost.
VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50"))
# Google Cloud Vision (OCR for scanned PDFs)
GOOGLE_CLOUD_VISION_API_KEY = os.environ.get("GOOGLE_CLOUD_VISION_API_KEY", "")

View File

@@ -53,3 +53,26 @@ async def embed_query(query: str) -> list[float]:
"""Embed a single search query."""
results = await embed_texts([query], input_type="query")
return results[0]
async def voyage_rerank(
query: str, documents: list[str], top_k: int | None = None,
) -> list[tuple[int, float]]:
"""Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...]
sorted by relevance. Each tuple's index refers to the position in the
*input* documents list (not a DB row id) — caller maps it back.
Used as a second stage after bi-encoder retrieval: fetch top-N
candidates with cosine, then rerank to get top-K with cross-encoder
attention over (query, doc).
"""
if not documents:
return []
client = _get_client()
result = client.rerank(
query=query,
documents=documents,
model=config.VOYAGE_RERANK_MODEL,
top_k=top_k,
)
return [(r.index, float(r.relevance_score)) for r in result.results]

View File

@@ -22,7 +22,7 @@ from typing import Awaitable, Callable
from uuid import UUID, uuid4
from legal_mcp import config
from legal_mcp.services import chunker, db, embeddings, extractor
from legal_mcp.services import chunker, db, embeddings, extractor, rerank
# Note: halacha_extractor and precedent_metadata_extractor are NOT imported
# at module load. They are imported lazily inside the dedicated re-extract
@@ -403,18 +403,29 @@ async def search_library(
Only ``approved`` / ``published`` halachot are returned, per chair-review
policy. Chunks are returned regardless of halacha review status.
When ``VOYAGE_RERANK_ENABLED`` is set, results are passed through
voyage rerank-2 (cross-encoder). The +0.05 halacha boost from
``search_precedent_library_semantic`` is preserved before rerank
but the rerank scores ultimately decide the order.
"""
if not query.strip():
return []
query_vec = await embeddings.embed_query(query)
return await db.search_precedent_library_semantic(
query_embedding=query_vec,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
subject_tag=subject_tag,
limit=limit,
include_halachot=include_halachot,
async def _base(limit: int) -> list[dict]:
return await db.search_precedent_library_semantic(
query_embedding=query_vec,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
subject_tag=subject_tag,
limit=limit,
include_halachot=include_halachot,
)
return await rerank.maybe_rerank(
query=query, base_search=_base, limit=limit,
)

View File

@@ -0,0 +1,103 @@
"""Optional cross-encoder reranking layer for semantic search.
Wraps a base search function with two-stage retrieval:
1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine)
2. pass them to voyage rerank-2, return top-``limit``
When the feature flag is off (or ``force_rerank=False``) the helper just
calls the base function with ``limit`` and returns its results unchanged
— so callers can wrap unconditionally and let env control behaviour.
The helper extracts the rerank text from each row using the first
non-empty field among ``content``, ``rule_statement``,
``reasoning_summary`` (matches the schema used by ``search_similar``
and ``search_precedent_library_semantic``).
Decision validated by POC #5 (785-doc precedent corpus, 12 queries):
- mean@3: 4.306 → 4.500 (+4.5%)
- practical-category queries: 3.78 → 4.22 (+11.6%)
- latency: +702ms per query
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from legal_mcp import config
from legal_mcp.services import embeddings
logger = logging.getLogger(__name__)
SearchFn = Callable[..., Awaitable[list[dict]]]
def _rerank_text(row: dict) -> str:
"""First non-empty text field that voyage rerank should see."""
for key in ("content", "rule_statement", "reasoning_summary",
"supporting_quote"):
v = row.get(key)
if v:
return str(v)
return ""
async def maybe_rerank(
query: str,
base_search: SearchFn,
limit: int,
*,
force_rerank: bool | None = None,
fetch_k: int | None = None,
**base_kwargs: Any,
) -> list[dict]:
"""Two-stage retrieval helper.
Args:
query: original query string (needed for the rerank API).
base_search: any async function that takes ``limit=…`` and the
other ``base_kwargs`` and returns ``list[dict]``.
limit: final number of results to return.
force_rerank: override the env flag. ``None`` → use config.
fetch_k: override the bi-encoder fetch depth.
**base_kwargs: forwarded to ``base_search``.
Returns:
List of dict rows. When rerank is active, each row's ``score``
is replaced with the rerank-2 relevance score (0..1).
"""
enabled = (config.VOYAGE_RERANK_ENABLED
if force_rerank is None else force_rerank)
if not enabled:
return await base_search(limit=limit, **base_kwargs)
depth = fetch_k or config.VOYAGE_RERANK_FETCH_K
candidates = await base_search(limit=depth, **base_kwargs)
if not candidates:
return []
texts = [_rerank_text(c) for c in candidates]
# Drop candidates with empty rerank text (shouldn't happen but be safe)
keep = [(i, t) for i, t in enumerate(texts) if t]
if not keep:
logger.warning("rerank: all candidates empty, falling back to base")
return candidates[:limit]
keep_idx = [i for i, _ in keep]
keep_texts = [t for _, t in keep]
try:
ranked = await embeddings.voyage_rerank(
query, keep_texts, top_k=limit,
)
except Exception as e:
# Fail open — if Voyage rerank is down, return bi-encoder ordering
logger.warning("rerank failed, falling back to base: %s", e)
return candidates[:limit]
out: list[dict] = []
for keep_pos, score in ranked:
orig_idx = keep_idx[keep_pos]
row = dict(candidates[orig_idx])
row["score"] = float(score)
out.append(row)
return out

View File

@@ -6,7 +6,7 @@ import json
import logging
from uuid import UUID
from legal_mcp.services import db, embeddings
from legal_mcp.services import db, embeddings, rerank
logger = logging.getLogger(__name__)
@@ -43,8 +43,9 @@ async def search_decisions(
)
query_emb = await embeddings.embed_query(query)
results = await db.search_similar(
query_embedding=query_emb,
results = await rerank.maybe_rerank(
query=query,
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
limit=limit,
section_type=section_type or None,
practice_area=practice_area or None,
@@ -86,8 +87,9 @@ async def search_case_documents(
query_emb = await embeddings.embed_query(query)
# Restricted to case_id — practice_area filter would be redundant.
results = await db.search_similar(
query_embedding=query_emb,
results = await rerank.maybe_rerank(
query=query,
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
limit=limit,
case_id=UUID(case["id"]),
)
@@ -137,9 +139,13 @@ async def find_similar_cases(
)
query_emb = await embeddings.embed_query(description)
results = await db.search_similar(
query_embedding=query_emb,
limit=limit * 3, # Get more to deduplicate by case
# Use description as the query text for rerank too.
# Note: even with rerank we ask for ``limit*3`` so the dedup-by-case
# step downstream still has enough rows to pick the best per case.
results = await rerank.maybe_rerank(
query=description,
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
limit=limit * 3,
practice_area=practice_area or None,
appeal_subtype=appeal_subtype or None,
)