feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).
POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage
Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
unavailable.
- tools/search.py: search_decisions, search_case_documents,
find_similar_cases all wrapped
- services/precedent_library.search_library wrapped
Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.
POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -47,6 +47,17 @@ VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", "")
|
||||
VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-law-2")
|
||||
VOYAGE_DIMENSIONS = 1024
|
||||
|
||||
# Rerank — cross-encoder second-stage. Off by default; flip with env to
|
||||
# enable across all semantic search tools (search_decisions,
|
||||
# search_case_documents, find_similar_cases, search_precedent_library).
|
||||
VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2")
|
||||
VOYAGE_RERANK_ENABLED = (
|
||||
os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true"
|
||||
)
|
||||
# How many candidates to fetch from bi-encoder before reranking.
|
||||
# 50 was the depth used in the POC; balances recall vs rerank cost.
|
||||
VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50"))
|
||||
|
||||
# Google Cloud Vision (OCR for scanned PDFs)
|
||||
GOOGLE_CLOUD_VISION_API_KEY = os.environ.get("GOOGLE_CLOUD_VISION_API_KEY", "")
|
||||
|
||||
|
||||
@@ -53,3 +53,26 @@ async def embed_query(query: str) -> list[float]:
|
||||
"""Embed a single search query."""
|
||||
results = await embed_texts([query], input_type="query")
|
||||
return results[0]
|
||||
|
||||
|
||||
async def voyage_rerank(
|
||||
query: str, documents: list[str], top_k: int | None = None,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...]
|
||||
sorted by relevance. Each tuple's index refers to the position in the
|
||||
*input* documents list (not a DB row id) — caller maps it back.
|
||||
|
||||
Used as a second stage after bi-encoder retrieval: fetch top-N
|
||||
candidates with cosine, then rerank to get top-K with cross-encoder
|
||||
attention over (query, doc).
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
client = _get_client()
|
||||
result = client.rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=config.VOYAGE_RERANK_MODEL,
|
||||
top_k=top_k,
|
||||
)
|
||||
return [(r.index, float(r.relevance_score)) for r in result.results]
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Awaitable, Callable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import chunker, db, embeddings, extractor
|
||||
from legal_mcp.services import chunker, db, embeddings, extractor, rerank
|
||||
|
||||
# Note: halacha_extractor and precedent_metadata_extractor are NOT imported
|
||||
# at module load. They are imported lazily inside the dedicated re-extract
|
||||
@@ -403,18 +403,29 @@ async def search_library(
|
||||
|
||||
Only ``approved`` / ``published`` halachot are returned, per chair-review
|
||||
policy. Chunks are returned regardless of halacha review status.
|
||||
|
||||
When ``VOYAGE_RERANK_ENABLED`` is set, results are passed through
|
||||
voyage rerank-2 (cross-encoder). The +0.05 halacha boost from
|
||||
``search_precedent_library_semantic`` is preserved before rerank
|
||||
but the rerank scores ultimately decide the order.
|
||||
"""
|
||||
if not query.strip():
|
||||
return []
|
||||
query_vec = await embeddings.embed_query(query)
|
||||
return await db.search_precedent_library_semantic(
|
||||
query_embedding=query_vec,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
subject_tag=subject_tag,
|
||||
limit=limit,
|
||||
include_halachot=include_halachot,
|
||||
|
||||
async def _base(limit: int) -> list[dict]:
|
||||
return await db.search_precedent_library_semantic(
|
||||
query_embedding=query_vec,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
subject_tag=subject_tag,
|
||||
limit=limit,
|
||||
include_halachot=include_halachot,
|
||||
)
|
||||
|
||||
return await rerank.maybe_rerank(
|
||||
query=query, base_search=_base, limit=limit,
|
||||
)
|
||||
|
||||
103
mcp-server/src/legal_mcp/services/rerank.py
Normal file
103
mcp-server/src/legal_mcp/services/rerank.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Optional cross-encoder reranking layer for semantic search.
|
||||
|
||||
Wraps a base search function with two-stage retrieval:
|
||||
1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine)
|
||||
2. pass them to voyage rerank-2, return top-``limit``
|
||||
|
||||
When the feature flag is off (or ``force_rerank=False``) the helper just
|
||||
calls the base function with ``limit`` and returns its results unchanged
|
||||
— so callers can wrap unconditionally and let env control behaviour.
|
||||
|
||||
The helper extracts the rerank text from each row using the first
|
||||
non-empty field among ``content``, ``rule_statement``,
|
||||
``reasoning_summary`` (matches the schema used by ``search_similar``
|
||||
and ``search_precedent_library_semantic``).
|
||||
|
||||
Decision validated by POC #5 (785-doc precedent corpus, 12 queries):
|
||||
- mean@3: 4.306 → 4.500 (+4.5%)
|
||||
- practical-category queries: 3.78 → 4.22 (+11.6%)
|
||||
- latency: +702ms per query
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchFn = Callable[..., Awaitable[list[dict]]]
|
||||
|
||||
|
||||
def _rerank_text(row: dict) -> str:
|
||||
"""First non-empty text field that voyage rerank should see."""
|
||||
for key in ("content", "rule_statement", "reasoning_summary",
|
||||
"supporting_quote"):
|
||||
v = row.get(key)
|
||||
if v:
|
||||
return str(v)
|
||||
return ""
|
||||
|
||||
|
||||
async def maybe_rerank(
|
||||
query: str,
|
||||
base_search: SearchFn,
|
||||
limit: int,
|
||||
*,
|
||||
force_rerank: bool | None = None,
|
||||
fetch_k: int | None = None,
|
||||
**base_kwargs: Any,
|
||||
) -> list[dict]:
|
||||
"""Two-stage retrieval helper.
|
||||
|
||||
Args:
|
||||
query: original query string (needed for the rerank API).
|
||||
base_search: any async function that takes ``limit=…`` and the
|
||||
other ``base_kwargs`` and returns ``list[dict]``.
|
||||
limit: final number of results to return.
|
||||
force_rerank: override the env flag. ``None`` → use config.
|
||||
fetch_k: override the bi-encoder fetch depth.
|
||||
**base_kwargs: forwarded to ``base_search``.
|
||||
|
||||
Returns:
|
||||
List of dict rows. When rerank is active, each row's ``score``
|
||||
is replaced with the rerank-2 relevance score (0..1).
|
||||
"""
|
||||
enabled = (config.VOYAGE_RERANK_ENABLED
|
||||
if force_rerank is None else force_rerank)
|
||||
if not enabled:
|
||||
return await base_search(limit=limit, **base_kwargs)
|
||||
|
||||
depth = fetch_k or config.VOYAGE_RERANK_FETCH_K
|
||||
candidates = await base_search(limit=depth, **base_kwargs)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
texts = [_rerank_text(c) for c in candidates]
|
||||
# Drop candidates with empty rerank text (shouldn't happen but be safe)
|
||||
keep = [(i, t) for i, t in enumerate(texts) if t]
|
||||
if not keep:
|
||||
logger.warning("rerank: all candidates empty, falling back to base")
|
||||
return candidates[:limit]
|
||||
keep_idx = [i for i, _ in keep]
|
||||
keep_texts = [t for _, t in keep]
|
||||
|
||||
try:
|
||||
ranked = await embeddings.voyage_rerank(
|
||||
query, keep_texts, top_k=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail open — if Voyage rerank is down, return bi-encoder ordering
|
||||
logger.warning("rerank failed, falling back to base: %s", e)
|
||||
return candidates[:limit]
|
||||
|
||||
out: list[dict] = []
|
||||
for keep_pos, score in ranked:
|
||||
orig_idx = keep_idx[keep_pos]
|
||||
row = dict(candidates[orig_idx])
|
||||
row["score"] = float(score)
|
||||
out.append(row)
|
||||
return out
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from legal_mcp.services import db, embeddings
|
||||
from legal_mcp.services import db, embeddings, rerank
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,8 +43,9 @@ async def search_decisions(
|
||||
)
|
||||
|
||||
query_emb = await embeddings.embed_query(query)
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
results = await rerank.maybe_rerank(
|
||||
query=query,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit,
|
||||
section_type=section_type or None,
|
||||
practice_area=practice_area or None,
|
||||
@@ -86,8 +87,9 @@ async def search_case_documents(
|
||||
|
||||
query_emb = await embeddings.embed_query(query)
|
||||
# Restricted to case_id — practice_area filter would be redundant.
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
results = await rerank.maybe_rerank(
|
||||
query=query,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit,
|
||||
case_id=UUID(case["id"]),
|
||||
)
|
||||
@@ -137,9 +139,13 @@ async def find_similar_cases(
|
||||
)
|
||||
|
||||
query_emb = await embeddings.embed_query(description)
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
limit=limit * 3, # Get more to deduplicate by case
|
||||
# Use description as the query text for rerank too.
|
||||
# Note: even with rerank we ask for ``limit*3`` so the dedup-by-case
|
||||
# step downstream still has enough rows to pick the best per case.
|
||||
results = await rerank.maybe_rerank(
|
||||
query=description,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit * 3,
|
||||
practice_area=practice_area or None,
|
||||
appeal_subtype=appeal_subtype or None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user