All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m35s
Five enhancements to the precedent retrieval stack: * **#44 HNSW indexes** for precedent_chunks + halachot (replacing IVFFlat lists=50). Build time ~3s combined. Better recall@10 with pgvector 0.8.2. * **#45 Halacha sweep** — 96 pending halachot at conf>=0.78 promoted to approved (1141 → 1237). Cluster at conf=0.78 spot-checked OK. Applied via psql only — env HALACHA_AUTO_APPROVE_THRESHOLD unchanged (0.80). * **#43 MMR diversity** — search_precedent_library_hybrid now caps at ``max_per_case_law=2`` (default). Prevents one precedent dominating top-10 when many of its chunks/halachot rank high. New helper ``_diversify_by_case_law`` in hybrid_search.py. * **#46 Dynamic halacha boost** — replaces the static ``score+=0.05`` with ``score+=confidence*0.06``. Calibrated so avg-confidence (~0.85) stays at +0.05; high-conf halachot get a slight extra lift, low-conf ones get less. Behaviour preserved at the mean. * **#41 BM25/tsvector hybrid + RRF**. Schema V12 adds STORED tsvector columns ``precedent_chunks.content_tsv`` and ``halachot.rule_tsv`` (using simple config — Postgres has no Hebrew stemmer) + GIN indexes. New ``db.search_precedent_library_lexical`` mirrors the semantic function with ts_rank_cd over plainto_tsquery. ``hybrid_search`` runs sem+lex in parallel and fuses via RRF before rerank. Toggle: env ``BM25_HYBRID_ENABLED`` (default true), graceful fallback to semantic-only on lexical failure. #40 (VOYAGE_RERANK_ENABLED) was already true in Coolify env; no change. #42 (Claude Haiku query expansion) deferred — latency + cost concerns warrant a separate plan; the bm25 lexical leg already recovers most of the exact-string recall #42 was meant to address. Closes TaskMaster #41, #43-#46. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
390 lines
14 KiB
Python
390 lines
14 KiB
Python
"""Hybrid (text + image) search wrappers.
|
||
|
||
Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is
|
||
true the result comes from a weighted merge of:
|
||
|
||
• text side: cosine on chunks → optional rerank-2 cross-encoder
|
||
(precedent search additionally fuses ``ts_rank_cd`` lexical results
|
||
via RRF before this step — see ``BM25_HYBRID_ENABLED``)
|
||
• image side: cosine on per-page voyage-multimodal-3 embeddings
|
||
|
||
rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed
|
||
through it; they keep their cosine score and merge alongside the
|
||
(possibly reranked) text rows. Image-only pages with no overlapping
|
||
text chunk are surfaced as ``match_type='image'`` so scanned-only or
|
||
visual-heavy content still appears in results.
|
||
|
||
When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain
|
||
``rerank.maybe_rerank`` — callers can wrap unconditionally and let env
|
||
control behaviour.
|
||
|
||
BM25/lexical leg (V12 + ``BM25_HYBRID_ENABLED``):
|
||
``search_precedent_library_hybrid`` runs ``search_precedent_library_lexical``
|
||
in parallel with the semantic side and fuses the two by rank via RRF.
|
||
This recovers exact-string recall (case-number citations like "1461/20",
|
||
rare planning terms) that voyage embeddings blur. The fused list is
|
||
then handed to rerank-2 (if enabled) and to the image RRF (if
|
||
multimodal is enabled) exactly as before.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Any
|
||
from uuid import UUID
|
||
|
||
from legal_mcp import config
|
||
from legal_mcp.services import db, embeddings, rerank
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def search_documents_hybrid(
|
||
query: str,
|
||
query_text_embedding: list[float],
|
||
*,
|
||
limit: int,
|
||
case_id: UUID | None = None,
|
||
section_type: str | None = None,
|
||
practice_area: str | None = None,
|
||
appeal_subtype: str | None = None,
|
||
) -> list[dict]:
|
||
"""Hybrid wrapper for document-chunk search (search_decisions /
|
||
search_case_documents / find_similar_cases)."""
|
||
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
|
||
text_results = await rerank.maybe_rerank(
|
||
query=query,
|
||
base_search=lambda **kw: db.search_similar(
|
||
query_embedding=query_text_embedding, **kw,
|
||
),
|
||
limit=fetch_k,
|
||
case_id=case_id,
|
||
section_type=section_type,
|
||
practice_area=practice_area,
|
||
appeal_subtype=appeal_subtype,
|
||
)
|
||
if not config.MULTIMODAL_ENABLED:
|
||
return text_results[:limit]
|
||
|
||
try:
|
||
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
||
img_rows = await db.search_document_images_similar(
|
||
query_img_emb,
|
||
limit=fetch_k,
|
||
case_id=case_id,
|
||
practice_area=practice_area,
|
||
appeal_subtype=appeal_subtype,
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
||
return text_results[:limit]
|
||
|
||
merged = _merge(
|
||
text_results, img_rows,
|
||
id_field="document_id",
|
||
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
||
)
|
||
return merged[:limit]
|
||
|
||
|
||
async def search_precedent_library_hybrid(
|
||
query: str,
|
||
query_text_embedding: list[float],
|
||
*,
|
||
limit: int,
|
||
practice_area: str = "",
|
||
court: str = "",
|
||
precedent_level: str = "",
|
||
appeal_subtype: str = "",
|
||
is_binding: bool | None = None,
|
||
subject_tag: str = "",
|
||
include_halachot: bool = True,
|
||
source_kind: str = "external_upload",
|
||
district: str = "",
|
||
chair_name: str = "",
|
||
max_per_case_law: int = 2,
|
||
) -> list[dict]:
|
||
"""Hybrid wrapper for precedent-library search.
|
||
|
||
source_kind='external_upload' → court rulings (default)
|
||
source_kind='internal_committee' → appeals-committee decisions
|
||
max_per_case_law: MMR-style diversity cap — at most N hits per
|
||
case_law_id in the final ranked list (default 2). Prevents a
|
||
single precedent from monopolizing the result list when many of
|
||
its chunks/halachot are individually relevant.
|
||
|
||
When ``config.BM25_HYBRID_ENABLED`` is true (default) ``_base`` fuses
|
||
semantic cosine + lexical ``ts_rank_cd`` via RRF before handing the
|
||
candidates to rerank-2 (if enabled) and the image merge (if
|
||
multimodal is enabled).
|
||
"""
|
||
# Fetch deeper so diversity dedup still leaves enough candidates.
|
||
fetch_k = max(limit * max(max_per_case_law, 1), config.VOYAGE_RERANK_FETCH_K) \
|
||
if config.MULTIMODAL_ENABLED else max(limit * max(max_per_case_law, 1), limit)
|
||
|
||
async def _base(limit: int) -> list[dict]:
|
||
sem_rows = await db.search_precedent_library_semantic(
|
||
query_embedding=query_text_embedding,
|
||
practice_area=practice_area,
|
||
court=court,
|
||
precedent_level=precedent_level,
|
||
appeal_subtype=appeal_subtype,
|
||
is_binding=is_binding,
|
||
subject_tag=subject_tag,
|
||
limit=limit,
|
||
include_halachot=include_halachot,
|
||
source_kind=source_kind,
|
||
district=district,
|
||
chair_name=chair_name,
|
||
)
|
||
if not config.BM25_HYBRID_ENABLED:
|
||
return sem_rows
|
||
# Fetch lexical with ≥ 2× depth so RRF has reserves at the tail.
|
||
lex_limit = max(limit * 2, limit)
|
||
try:
|
||
lex_rows = await db.search_precedent_library_lexical(
|
||
query=query,
|
||
practice_area=practice_area,
|
||
court=court,
|
||
precedent_level=precedent_level,
|
||
appeal_subtype=appeal_subtype,
|
||
is_binding=is_binding,
|
||
subject_tag=subject_tag,
|
||
source_kind=source_kind,
|
||
district=district,
|
||
chair_name=chair_name,
|
||
limit=lex_limit,
|
||
include_halachot=include_halachot,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"Hybrid precedent: lexical side failed, semantic only: %s", e,
|
||
)
|
||
return sem_rows
|
||
if not lex_rows:
|
||
return sem_rows
|
||
return _merge_sem_lex(sem_rows, lex_rows, limit=limit)
|
||
|
||
text_results = await rerank.maybe_rerank(
|
||
query=query, base_search=_base, limit=fetch_k,
|
||
)
|
||
if not config.MULTIMODAL_ENABLED:
|
||
return _diversify_by_case_law(text_results, limit, max_per_case_law)
|
||
|
||
try:
|
||
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
||
img_rows = await db.search_precedent_images_similar(
|
||
query_img_emb,
|
||
limit=fetch_k,
|
||
practice_area=practice_area,
|
||
court=court,
|
||
precedent_level=precedent_level,
|
||
appeal_subtype=appeal_subtype,
|
||
is_binding=is_binding,
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
||
return _diversify_by_case_law(text_results, limit, max_per_case_law)
|
||
|
||
merged = _merge(
|
||
text_results, img_rows,
|
||
id_field="case_law_id",
|
||
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
||
)
|
||
return _diversify_by_case_law(merged, limit, max_per_case_law)
|
||
|
||
|
||
def _diversify_by_case_law(
|
||
rows: list[dict],
|
||
limit: int,
|
||
max_per_case_law: int,
|
||
) -> list[dict]:
|
||
"""MMR-style diversity cap: at most ``max_per_case_law`` rows per
|
||
case_law_id in the final list. Preserves input order (which is the
|
||
relevance ranking) — for each row, include it only if we haven't
|
||
reached the cap for its case_law_id yet.
|
||
|
||
Set max_per_case_law<=0 to disable (returns rows[:limit] unchanged).
|
||
"""
|
||
if max_per_case_law <= 0 or not rows:
|
||
return rows[:limit]
|
||
counts: dict[str, int] = {}
|
||
out: list[dict] = []
|
||
for r in rows:
|
||
clid = str(r.get("case_law_id") or "")
|
||
if not clid:
|
||
out.append(r)
|
||
if len(out) >= limit:
|
||
break
|
||
continue
|
||
n = counts.get(clid, 0)
|
||
if n < max_per_case_law:
|
||
out.append(r)
|
||
counts[clid] = n + 1
|
||
if len(out) >= limit:
|
||
break
|
||
return out
|
||
|
||
|
||
def _row_key(r: dict) -> tuple[str, str]:
|
||
"""Stable identity for sem/lex RRF.
|
||
|
||
Halachot rows have ``halacha_id``; chunk rows have ``chunk_id``.
|
||
Returns ``(type, id)`` so a halacha and a chunk with the same UUID
|
||
(extremely unlikely, but distinct namespaces) don't collide.
|
||
"""
|
||
typ = str(r.get("type") or "")
|
||
rid = r.get("halacha_id") if typ == "halacha" else r.get("chunk_id")
|
||
return (typ, str(rid or ""))
|
||
|
||
|
||
def _merge_sem_lex(
|
||
sem_rows: list[dict],
|
||
lex_rows: list[dict],
|
||
*,
|
||
limit: int,
|
||
) -> list[dict]:
|
||
"""RRF fusion of semantic + lexical precedent results.
|
||
|
||
Why RRF (and not weighted score sum): cosine similarities (~0.4-0.7)
|
||
and ``ts_rank_cd`` values (often 0.001-0.5, query-length-dependent)
|
||
live on completely different scales — a weighted sum would let one
|
||
side dominate by accident. RRF combines by *rank*, so a row that
|
||
tops one list and is mid-pack in the other gets a robust boost.
|
||
|
||
Per row::
|
||
|
||
rrf_score = 1 / (k + sem_rank) + 1 / (k + lex_rank)
|
||
|
||
A row that appears in only one list contributes that list's term
|
||
only. Output is sorted by combined score, with extra debug fields
|
||
(``sem_score``, ``sem_rank``, ``lex_score``, ``lex_rank``) attached
|
||
so callers and tests can inspect why a row ranked where it did.
|
||
|
||
The row payload (``content``, ``rule_statement``, ``case_*`` joins,
|
||
etc.) is taken from the semantic-side row when available — the two
|
||
sources return identical column shapes, but semantic rows carry the
|
||
confidence-boosted ``score`` that the rest of the pipeline expects.
|
||
"""
|
||
k = config.MULTIMODAL_RRF_K
|
||
sem_rank_by_key: dict[tuple, int] = {}
|
||
sem_row_by_key: dict[tuple, dict] = {}
|
||
for rank, r in enumerate(sem_rows, 1):
|
||
key = _row_key(r)
|
||
if not key[1]:
|
||
continue
|
||
sem_rank_by_key[key] = rank
|
||
sem_row_by_key[key] = r
|
||
|
||
lex_rank_by_key: dict[tuple, int] = {}
|
||
lex_row_by_key: dict[tuple, dict] = {}
|
||
for rank, r in enumerate(lex_rows, 1):
|
||
key = _row_key(r)
|
||
if not key[1]:
|
||
continue
|
||
lex_rank_by_key[key] = rank
|
||
lex_row_by_key[key] = r
|
||
|
||
all_keys = set(sem_rank_by_key) | set(lex_rank_by_key)
|
||
merged: list[dict] = []
|
||
for key in all_keys:
|
||
sem_rank = sem_rank_by_key.get(key)
|
||
lex_rank = lex_rank_by_key.get(key)
|
||
base = sem_row_by_key.get(key) or lex_row_by_key.get(key)
|
||
if base is None:
|
||
continue
|
||
d = dict(base)
|
||
sem_term = 1.0 / (k + sem_rank) if sem_rank else 0.0
|
||
lex_term = 1.0 / (k + lex_rank) if lex_rank else 0.0
|
||
d["sem_score"] = float(sem_row_by_key[key]["score"]) \
|
||
if key in sem_row_by_key else 0.0
|
||
d["sem_rank"] = sem_rank or 0
|
||
d["lex_score"] = float(lex_row_by_key[key]["score"]) \
|
||
if key in lex_row_by_key else 0.0
|
||
d["lex_rank"] = lex_rank or 0
|
||
d["score"] = sem_term + lex_term
|
||
merged.append(d)
|
||
|
||
merged.sort(key=lambda x: -float(x["score"]))
|
||
return merged[:limit]
|
||
|
||
|
||
def _merge(
|
||
text_rows: list[dict],
|
||
img_rows: list[dict],
|
||
id_field: str,
|
||
text_weight: float,
|
||
) -> list[dict]:
|
||
"""Reciprocal Rank Fusion of text + image rows.
|
||
|
||
Why RRF: voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3
|
||
scores (~0.2-0.25) live on different scales — a direct weighted
|
||
sum lets text always dominate. RRF combines by *rank* in each list,
|
||
making the merge robust to score-scale differences.
|
||
|
||
Per item::
|
||
|
||
rrf_score = text_weight / (k + text_rank)
|
||
+ image_weight / (k + image_rank)
|
||
|
||
A row that appears in only one list contributes that list's term
|
||
only. Rows joined at ``(id_field, page_number)`` get both terms —
|
||
surfaced as ``match_type='text+image'`` with the thumbnail attached.
|
||
|
||
Halachot in precedent rows have no page_number; they remain
|
||
text-only under RRF (the case-level image boost is dropped — RRF
|
||
works on rank, not raw scores).
|
||
"""
|
||
from legal_mcp import config as _cfg
|
||
img_weight = 1.0 - text_weight
|
||
k = _cfg.MULTIMODAL_RRF_K
|
||
|
||
# Index image rows by their join key for boost detection.
|
||
img_rank_by_key: dict[tuple, int] = {}
|
||
img_row_by_key: dict[tuple, dict] = {}
|
||
for rank, r in enumerate(img_rows, 1):
|
||
key = (str(r[id_field]), r.get("page_number"))
|
||
img_rank_by_key[key] = rank
|
||
img_row_by_key[key] = r
|
||
|
||
seen_image_keys: set = set()
|
||
merged: list[dict] = []
|
||
for rank, r in enumerate(text_rows, 1):
|
||
rid = str(r[id_field])
|
||
page = r.get("page_number")
|
||
key = (rid, page) if page is not None else None
|
||
img_rank = img_rank_by_key.get(key) if key else None
|
||
text_term = text_weight / (k + rank)
|
||
image_term = img_weight / (k + img_rank) if img_rank else 0.0
|
||
d = dict(r)
|
||
d["text_score"] = float(r.get("score", 0.0))
|
||
d["text_rank"] = rank
|
||
if img_rank:
|
||
img_hit = img_row_by_key[key]
|
||
d["image_score"] = float(img_hit.get("score", 0.0))
|
||
d["image_rank"] = img_rank
|
||
d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path")
|
||
d["match_type"] = "text+image"
|
||
seen_image_keys.add(key)
|
||
else:
|
||
d["image_score"] = 0.0
|
||
d["match_type"] = "text"
|
||
d["score"] = text_term + image_term
|
||
merged.append(d)
|
||
|
||
for rank, r in enumerate(img_rows, 1):
|
||
key = (str(r[id_field]), r.get("page_number"))
|
||
if key in seen_image_keys:
|
||
continue
|
||
d = dict(r)
|
||
d["text_score"] = 0.0
|
||
d["image_score"] = float(r.get("score", 0.0))
|
||
d["image_rank"] = rank
|
||
d["score"] = img_weight / (k + rank)
|
||
d["match_type"] = "image"
|
||
d["content"] = ""
|
||
d["section_type"] = "image"
|
||
merged.append(d)
|
||
|
||
merged.sort(key=lambda x: -float(x["score"]))
|
||
return merged
|