"""Optional cross-encoder reranking layer for semantic search. Wraps a base search function with two-stage retrieval: 1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine) 2. pass them to voyage rerank-2, return top-``limit`` When the feature flag is off (or ``force_rerank=False``) the helper just calls the base function with ``limit`` and returns its results unchanged — so callers can wrap unconditionally and let env control behaviour. The helper extracts the rerank text from each row using the first non-empty field among ``content``, ``rule_statement``, ``reasoning_summary`` (matches the schema used by ``search_similar`` and ``search_precedent_library_semantic``). Decision validated by POC #5 (785-doc precedent corpus, 12 queries): - mean@3: 4.306 → 4.500 (+4.5%) - practical-category queries: 3.78 → 4.22 (+11.6%) - latency: +702ms per query """ from __future__ import annotations import logging from collections.abc import Awaitable, Callable from typing import Any from legal_mcp import config from legal_mcp.services import embeddings logger = logging.getLogger(__name__) SearchFn = Callable[..., Awaitable[list[dict]]] def _rerank_text(row: dict) -> str: """First non-empty text field that voyage rerank should see.""" for key in ("content", "rule_statement", "reasoning_summary", "supporting_quote"): v = row.get(key) if v: return str(v) return "" async def maybe_rerank( query: str, base_search: SearchFn, limit: int, *, force_rerank: bool | None = None, fetch_k: int | None = None, **base_kwargs: Any, ) -> list[dict]: """Two-stage retrieval helper. Args: query: original query string (needed for the rerank API). base_search: any async function that takes ``limit=…`` and the other ``base_kwargs`` and returns ``list[dict]``. limit: final number of results to return. force_rerank: override the env flag. ``None`` → use config. fetch_k: override the bi-encoder fetch depth. **base_kwargs: forwarded to ``base_search``. Returns: List of dict rows. When rerank is active, each row's ``score`` is replaced with the rerank-2 relevance score (0..1). """ enabled = (config.VOYAGE_RERANK_ENABLED if force_rerank is None else force_rerank) if not enabled: return await base_search(limit=limit, **base_kwargs) depth = fetch_k or config.VOYAGE_RERANK_FETCH_K candidates = await base_search(limit=depth, **base_kwargs) if not candidates: return [] texts = [_rerank_text(c) for c in candidates] # Drop candidates with empty rerank text (shouldn't happen but be safe) keep = [(i, t) for i, t in enumerate(texts) if t] if not keep: logger.warning("rerank: all candidates empty, falling back to base") return candidates[:limit] keep_idx = [i for i, _ in keep] keep_texts = [t for _, t in keep] try: ranked = await embeddings.voyage_rerank( query, keep_texts, top_k=limit, ) except Exception as e: # Fail open — if Voyage rerank is down, return bi-encoder ordering logger.warning("rerank failed, falling back to base: %s", e) return candidates[:limit] out: list[dict] = [] for keep_pos, score in ranked: orig_idx = keep_idx[keep_pos] row = dict(candidates[orig_idx]) row["score"] = float(score) out.append(row) return out