Files
legal-ai/mcp-server/src/legal_mcp/services/hybrid_search.py
Chaim af651d0135
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m35s
feat(rag): Stage B — RAG improvements (HNSW + BM25 hybrid + MMR + dynamic boost)
Five enhancements to the precedent retrieval stack:

* **#44 HNSW indexes** for precedent_chunks + halachot (replacing IVFFlat
  lists=50). Build time ~3s combined. Better recall@10 with pgvector 0.8.2.
* **#45 Halacha sweep** — 96 pending halachot at conf>=0.78 promoted to
  approved (1141 → 1237). Cluster at conf=0.78 spot-checked OK. Applied
  via psql only — env HALACHA_AUTO_APPROVE_THRESHOLD unchanged (0.80).
* **#43 MMR diversity** — search_precedent_library_hybrid now caps at
  ``max_per_case_law=2`` (default). Prevents one precedent dominating
  top-10 when many of its chunks/halachot rank high. New helper
  ``_diversify_by_case_law`` in hybrid_search.py.
* **#46 Dynamic halacha boost** — replaces the static ``score+=0.05``
  with ``score+=confidence*0.06``. Calibrated so avg-confidence (~0.85)
  stays at +0.05; high-conf halachot get a slight extra lift, low-conf
  ones get less. Behaviour preserved at the mean.
* **#41 BM25/tsvector hybrid + RRF**. Schema V12 adds STORED tsvector
  columns ``precedent_chunks.content_tsv`` and ``halachot.rule_tsv``
  (using simple config — Postgres has no Hebrew stemmer) + GIN indexes.
  New ``db.search_precedent_library_lexical`` mirrors the semantic
  function with ts_rank_cd over plainto_tsquery. ``hybrid_search``
  runs sem+lex in parallel and fuses via RRF before rerank. Toggle:
  env ``BM25_HYBRID_ENABLED`` (default true), graceful fallback to
  semantic-only on lexical failure.

#40 (VOYAGE_RERANK_ENABLED) was already true in Coolify env; no change.
#42 (Claude Haiku query expansion) deferred — latency + cost concerns
warrant a separate plan; the bm25 lexical leg already recovers most of
the exact-string recall #42 was meant to address.

Closes TaskMaster #41, #43-#46.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-26 08:08:02 +00:00

390 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Hybrid (text + image) search wrappers.
Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is
true the result comes from a weighted merge of:
• text side: cosine on chunks → optional rerank-2 cross-encoder
(precedent search additionally fuses ``ts_rank_cd`` lexical results
via RRF before this step — see ``BM25_HYBRID_ENABLED``)
• image side: cosine on per-page voyage-multimodal-3 embeddings
rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed
through it; they keep their cosine score and merge alongside the
(possibly reranked) text rows. Image-only pages with no overlapping
text chunk are surfaced as ``match_type='image'`` so scanned-only or
visual-heavy content still appears in results.
When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain
``rerank.maybe_rerank`` — callers can wrap unconditionally and let env
control behaviour.
BM25/lexical leg (V12 + ``BM25_HYBRID_ENABLED``):
``search_precedent_library_hybrid`` runs ``search_precedent_library_lexical``
in parallel with the semantic side and fuses the two by rank via RRF.
This recovers exact-string recall (case-number citations like "1461/20",
rare planning terms) that voyage embeddings blur. The fused list is
then handed to rerank-2 (if enabled) and to the image RRF (if
multimodal is enabled) exactly as before.
"""
from __future__ import annotations
import logging
from typing import Any
from uuid import UUID
from legal_mcp import config
from legal_mcp.services import db, embeddings, rerank
logger = logging.getLogger(__name__)
async def search_documents_hybrid(
query: str,
query_text_embedding: list[float],
*,
limit: int,
case_id: UUID | None = None,
section_type: str | None = None,
practice_area: str | None = None,
appeal_subtype: str | None = None,
) -> list[dict]:
"""Hybrid wrapper for document-chunk search (search_decisions /
search_case_documents / find_similar_cases)."""
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
text_results = await rerank.maybe_rerank(
query=query,
base_search=lambda **kw: db.search_similar(
query_embedding=query_text_embedding, **kw,
),
limit=fetch_k,
case_id=case_id,
section_type=section_type,
practice_area=practice_area,
appeal_subtype=appeal_subtype,
)
if not config.MULTIMODAL_ENABLED:
return text_results[:limit]
try:
query_img_emb = await embeddings.embed_query_for_multimodal(query)
img_rows = await db.search_document_images_similar(
query_img_emb,
limit=fetch_k,
case_id=case_id,
practice_area=practice_area,
appeal_subtype=appeal_subtype,
)
except Exception as e:
logger.warning("Hybrid: image side failed, returning text only: %s", e)
return text_results[:limit]
merged = _merge(
text_results, img_rows,
id_field="document_id",
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
)
return merged[:limit]
async def search_precedent_library_hybrid(
query: str,
query_text_embedding: list[float],
*,
limit: int,
practice_area: str = "",
court: str = "",
precedent_level: str = "",
appeal_subtype: str = "",
is_binding: bool | None = None,
subject_tag: str = "",
include_halachot: bool = True,
source_kind: str = "external_upload",
district: str = "",
chair_name: str = "",
max_per_case_law: int = 2,
) -> list[dict]:
"""Hybrid wrapper for precedent-library search.
source_kind='external_upload' → court rulings (default)
source_kind='internal_committee' → appeals-committee decisions
max_per_case_law: MMR-style diversity cap — at most N hits per
case_law_id in the final ranked list (default 2). Prevents a
single precedent from monopolizing the result list when many of
its chunks/halachot are individually relevant.
When ``config.BM25_HYBRID_ENABLED`` is true (default) ``_base`` fuses
semantic cosine + lexical ``ts_rank_cd`` via RRF before handing the
candidates to rerank-2 (if enabled) and the image merge (if
multimodal is enabled).
"""
# Fetch deeper so diversity dedup still leaves enough candidates.
fetch_k = max(limit * max(max_per_case_law, 1), config.VOYAGE_RERANK_FETCH_K) \
if config.MULTIMODAL_ENABLED else max(limit * max(max_per_case_law, 1), limit)
async def _base(limit: int) -> list[dict]:
sem_rows = await db.search_precedent_library_semantic(
query_embedding=query_text_embedding,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
subject_tag=subject_tag,
limit=limit,
include_halachot=include_halachot,
source_kind=source_kind,
district=district,
chair_name=chair_name,
)
if not config.BM25_HYBRID_ENABLED:
return sem_rows
# Fetch lexical with ≥ 2× depth so RRF has reserves at the tail.
lex_limit = max(limit * 2, limit)
try:
lex_rows = await db.search_precedent_library_lexical(
query=query,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
subject_tag=subject_tag,
source_kind=source_kind,
district=district,
chair_name=chair_name,
limit=lex_limit,
include_halachot=include_halachot,
)
except Exception as e:
logger.warning(
"Hybrid precedent: lexical side failed, semantic only: %s", e,
)
return sem_rows
if not lex_rows:
return sem_rows
return _merge_sem_lex(sem_rows, lex_rows, limit=limit)
text_results = await rerank.maybe_rerank(
query=query, base_search=_base, limit=fetch_k,
)
if not config.MULTIMODAL_ENABLED:
return _diversify_by_case_law(text_results, limit, max_per_case_law)
try:
query_img_emb = await embeddings.embed_query_for_multimodal(query)
img_rows = await db.search_precedent_images_similar(
query_img_emb,
limit=fetch_k,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
)
except Exception as e:
logger.warning("Hybrid: image side failed, returning text only: %s", e)
return _diversify_by_case_law(text_results, limit, max_per_case_law)
merged = _merge(
text_results, img_rows,
id_field="case_law_id",
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
)
return _diversify_by_case_law(merged, limit, max_per_case_law)
def _diversify_by_case_law(
rows: list[dict],
limit: int,
max_per_case_law: int,
) -> list[dict]:
"""MMR-style diversity cap: at most ``max_per_case_law`` rows per
case_law_id in the final list. Preserves input order (which is the
relevance ranking) — for each row, include it only if we haven't
reached the cap for its case_law_id yet.
Set max_per_case_law<=0 to disable (returns rows[:limit] unchanged).
"""
if max_per_case_law <= 0 or not rows:
return rows[:limit]
counts: dict[str, int] = {}
out: list[dict] = []
for r in rows:
clid = str(r.get("case_law_id") or "")
if not clid:
out.append(r)
if len(out) >= limit:
break
continue
n = counts.get(clid, 0)
if n < max_per_case_law:
out.append(r)
counts[clid] = n + 1
if len(out) >= limit:
break
return out
def _row_key(r: dict) -> tuple[str, str]:
"""Stable identity for sem/lex RRF.
Halachot rows have ``halacha_id``; chunk rows have ``chunk_id``.
Returns ``(type, id)`` so a halacha and a chunk with the same UUID
(extremely unlikely, but distinct namespaces) don't collide.
"""
typ = str(r.get("type") or "")
rid = r.get("halacha_id") if typ == "halacha" else r.get("chunk_id")
return (typ, str(rid or ""))
def _merge_sem_lex(
sem_rows: list[dict],
lex_rows: list[dict],
*,
limit: int,
) -> list[dict]:
"""RRF fusion of semantic + lexical precedent results.
Why RRF (and not weighted score sum): cosine similarities (~0.4-0.7)
and ``ts_rank_cd`` values (often 0.001-0.5, query-length-dependent)
live on completely different scales — a weighted sum would let one
side dominate by accident. RRF combines by *rank*, so a row that
tops one list and is mid-pack in the other gets a robust boost.
Per row::
rrf_score = 1 / (k + sem_rank) + 1 / (k + lex_rank)
A row that appears in only one list contributes that list's term
only. Output is sorted by combined score, with extra debug fields
(``sem_score``, ``sem_rank``, ``lex_score``, ``lex_rank``) attached
so callers and tests can inspect why a row ranked where it did.
The row payload (``content``, ``rule_statement``, ``case_*`` joins,
etc.) is taken from the semantic-side row when available — the two
sources return identical column shapes, but semantic rows carry the
confidence-boosted ``score`` that the rest of the pipeline expects.
"""
k = config.MULTIMODAL_RRF_K
sem_rank_by_key: dict[tuple, int] = {}
sem_row_by_key: dict[tuple, dict] = {}
for rank, r in enumerate(sem_rows, 1):
key = _row_key(r)
if not key[1]:
continue
sem_rank_by_key[key] = rank
sem_row_by_key[key] = r
lex_rank_by_key: dict[tuple, int] = {}
lex_row_by_key: dict[tuple, dict] = {}
for rank, r in enumerate(lex_rows, 1):
key = _row_key(r)
if not key[1]:
continue
lex_rank_by_key[key] = rank
lex_row_by_key[key] = r
all_keys = set(sem_rank_by_key) | set(lex_rank_by_key)
merged: list[dict] = []
for key in all_keys:
sem_rank = sem_rank_by_key.get(key)
lex_rank = lex_rank_by_key.get(key)
base = sem_row_by_key.get(key) or lex_row_by_key.get(key)
if base is None:
continue
d = dict(base)
sem_term = 1.0 / (k + sem_rank) if sem_rank else 0.0
lex_term = 1.0 / (k + lex_rank) if lex_rank else 0.0
d["sem_score"] = float(sem_row_by_key[key]["score"]) \
if key in sem_row_by_key else 0.0
d["sem_rank"] = sem_rank or 0
d["lex_score"] = float(lex_row_by_key[key]["score"]) \
if key in lex_row_by_key else 0.0
d["lex_rank"] = lex_rank or 0
d["score"] = sem_term + lex_term
merged.append(d)
merged.sort(key=lambda x: -float(x["score"]))
return merged[:limit]
def _merge(
text_rows: list[dict],
img_rows: list[dict],
id_field: str,
text_weight: float,
) -> list[dict]:
"""Reciprocal Rank Fusion of text + image rows.
Why RRF: voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3
scores (~0.2-0.25) live on different scales — a direct weighted
sum lets text always dominate. RRF combines by *rank* in each list,
making the merge robust to score-scale differences.
Per item::
rrf_score = text_weight / (k + text_rank)
+ image_weight / (k + image_rank)
A row that appears in only one list contributes that list's term
only. Rows joined at ``(id_field, page_number)`` get both terms —
surfaced as ``match_type='text+image'`` with the thumbnail attached.
Halachot in precedent rows have no page_number; they remain
text-only under RRF (the case-level image boost is dropped — RRF
works on rank, not raw scores).
"""
from legal_mcp import config as _cfg
img_weight = 1.0 - text_weight
k = _cfg.MULTIMODAL_RRF_K
# Index image rows by their join key for boost detection.
img_rank_by_key: dict[tuple, int] = {}
img_row_by_key: dict[tuple, dict] = {}
for rank, r in enumerate(img_rows, 1):
key = (str(r[id_field]), r.get("page_number"))
img_rank_by_key[key] = rank
img_row_by_key[key] = r
seen_image_keys: set = set()
merged: list[dict] = []
for rank, r in enumerate(text_rows, 1):
rid = str(r[id_field])
page = r.get("page_number")
key = (rid, page) if page is not None else None
img_rank = img_rank_by_key.get(key) if key else None
text_term = text_weight / (k + rank)
image_term = img_weight / (k + img_rank) if img_rank else 0.0
d = dict(r)
d["text_score"] = float(r.get("score", 0.0))
d["text_rank"] = rank
if img_rank:
img_hit = img_row_by_key[key]
d["image_score"] = float(img_hit.get("score", 0.0))
d["image_rank"] = img_rank
d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path")
d["match_type"] = "text+image"
seen_image_keys.add(key)
else:
d["image_score"] = 0.0
d["match_type"] = "text"
d["score"] = text_term + image_term
merged.append(d)
for rank, r in enumerate(img_rows, 1):
key = (str(r[id_field]), r.get("page_number"))
if key in seen_image_keys:
continue
d = dict(r)
d["text_score"] = 0.0
d["image_score"] = float(r.get("score", 0.0))
d["image_rank"] = rank
d["score"] = img_weight / (k + rank)
d["match_type"] = "image"
d["content"] = ""
d["section_type"] = "image"
merged.append(d)
merged.sort(key=lambda x: -float(x["score"]))
return merged