Files
legal-ai/mcp-server/src/legal_mcp/services/rerank.py
Chaim 26c3fddf41
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).

POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage

Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
  via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
  unavailable.
- tools/search.py: search_decisions, search_case_documents,
  find_similar_cases all wrapped
- services/precedent_library.search_library wrapped

Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.

POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-03 18:43:41 +00:00

104 lines
3.5 KiB
Python

"""Optional cross-encoder reranking layer for semantic search.
Wraps a base search function with two-stage retrieval:
1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine)
2. pass them to voyage rerank-2, return top-``limit``
When the feature flag is off (or ``force_rerank=False``) the helper just
calls the base function with ``limit`` and returns its results unchanged
— so callers can wrap unconditionally and let env control behaviour.
The helper extracts the rerank text from each row using the first
non-empty field among ``content``, ``rule_statement``,
``reasoning_summary`` (matches the schema used by ``search_similar``
and ``search_precedent_library_semantic``).
Decision validated by POC #5 (785-doc precedent corpus, 12 queries):
- mean@3: 4.306 → 4.500 (+4.5%)
- practical-category queries: 3.78 → 4.22 (+11.6%)
- latency: +702ms per query
"""
from __future__ import annotations
import logging
from collections.abc import Awaitable, Callable
from typing import Any
from legal_mcp import config
from legal_mcp.services import embeddings
logger = logging.getLogger(__name__)
SearchFn = Callable[..., Awaitable[list[dict]]]
def _rerank_text(row: dict) -> str:
"""First non-empty text field that voyage rerank should see."""
for key in ("content", "rule_statement", "reasoning_summary",
"supporting_quote"):
v = row.get(key)
if v:
return str(v)
return ""
async def maybe_rerank(
query: str,
base_search: SearchFn,
limit: int,
*,
force_rerank: bool | None = None,
fetch_k: int | None = None,
**base_kwargs: Any,
) -> list[dict]:
"""Two-stage retrieval helper.
Args:
query: original query string (needed for the rerank API).
base_search: any async function that takes ``limit=…`` and the
other ``base_kwargs`` and returns ``list[dict]``.
limit: final number of results to return.
force_rerank: override the env flag. ``None`` → use config.
fetch_k: override the bi-encoder fetch depth.
**base_kwargs: forwarded to ``base_search``.
Returns:
List of dict rows. When rerank is active, each row's ``score``
is replaced with the rerank-2 relevance score (0..1).
"""
enabled = (config.VOYAGE_RERANK_ENABLED
if force_rerank is None else force_rerank)
if not enabled:
return await base_search(limit=limit, **base_kwargs)
depth = fetch_k or config.VOYAGE_RERANK_FETCH_K
candidates = await base_search(limit=depth, **base_kwargs)
if not candidates:
return []
texts = [_rerank_text(c) for c in candidates]
# Drop candidates with empty rerank text (shouldn't happen but be safe)
keep = [(i, t) for i, t in enumerate(texts) if t]
if not keep:
logger.warning("rerank: all candidates empty, falling back to base")
return candidates[:limit]
keep_idx = [i for i, _ in keep]
keep_texts = [t for _, t in keep]
try:
ranked = await embeddings.voyage_rerank(
query, keep_texts, top_k=limit,
)
except Exception as e:
# Fail open — if Voyage rerank is down, return bi-encoder ordering
logger.warning("rerank failed, falling back to base: %s", e)
return candidates[:limit]
out: list[dict] = []
for keep_pos, score in ranked:
orig_idx = keep_idx[keep_pos]
row = dict(candidates[orig_idx])
row["score"] = float(score)
out.append(row)
return out