All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m23s
`rerank.maybe_rerank` calls `base_search(limit=…, **base_kwargs)` on both
the rerank-on and rerank-off paths. Commit 242f668 moved the closure into
hybrid_search.py and renamed its parameter to `limit_inner`, so every call
to `/api/precedent-library/search` raised TypeError 500 regardless of the
VOYAGE_RERANK_ENABLED flag. Sibling `search_documents_hybrid` was unaffected
because it uses `lambda **kw:` which absorbs the kwarg.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
216 lines
7.1 KiB
Python
216 lines
7.1 KiB
Python
"""Hybrid (text + image) search wrappers.
|
|
|
|
Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is
|
|
true the result comes from a weighted merge of:
|
|
|
|
• text side: cosine on chunks → optional rerank-2 cross-encoder
|
|
• image side: cosine on per-page voyage-multimodal-3 embeddings
|
|
|
|
rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed
|
|
through it; they keep their cosine score and merge alongside the
|
|
(possibly reranked) text rows. Image-only pages with no overlapping
|
|
text chunk are surfaced as ``match_type='image'`` so scanned-only or
|
|
visual-heavy content still appears in results.
|
|
|
|
When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain
|
|
``rerank.maybe_rerank`` — callers can wrap unconditionally and let env
|
|
control behaviour.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from legal_mcp import config
|
|
from legal_mcp.services import db, embeddings, rerank
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def search_documents_hybrid(
|
|
query: str,
|
|
query_text_embedding: list[float],
|
|
*,
|
|
limit: int,
|
|
case_id: UUID | None = None,
|
|
section_type: str | None = None,
|
|
practice_area: str | None = None,
|
|
appeal_subtype: str | None = None,
|
|
) -> list[dict]:
|
|
"""Hybrid wrapper for document-chunk search (search_decisions /
|
|
search_case_documents / find_similar_cases)."""
|
|
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
|
|
text_results = await rerank.maybe_rerank(
|
|
query=query,
|
|
base_search=lambda **kw: db.search_similar(
|
|
query_embedding=query_text_embedding, **kw,
|
|
),
|
|
limit=fetch_k,
|
|
case_id=case_id,
|
|
section_type=section_type,
|
|
practice_area=practice_area,
|
|
appeal_subtype=appeal_subtype,
|
|
)
|
|
if not config.MULTIMODAL_ENABLED:
|
|
return text_results[:limit]
|
|
|
|
try:
|
|
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
|
img_rows = await db.search_document_images_similar(
|
|
query_img_emb,
|
|
limit=fetch_k,
|
|
case_id=case_id,
|
|
practice_area=practice_area,
|
|
appeal_subtype=appeal_subtype,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
|
return text_results[:limit]
|
|
|
|
merged = _merge(
|
|
text_results, img_rows,
|
|
id_field="document_id",
|
|
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
|
)
|
|
return merged[:limit]
|
|
|
|
|
|
async def search_precedent_library_hybrid(
|
|
query: str,
|
|
query_text_embedding: list[float],
|
|
*,
|
|
limit: int,
|
|
practice_area: str = "",
|
|
court: str = "",
|
|
precedent_level: str = "",
|
|
appeal_subtype: str = "",
|
|
is_binding: bool | None = None,
|
|
subject_tag: str = "",
|
|
include_halachot: bool = True,
|
|
) -> list[dict]:
|
|
"""Hybrid wrapper for precedent-library search."""
|
|
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
|
|
|
|
async def _base(limit: int) -> list[dict]:
|
|
return await db.search_precedent_library_semantic(
|
|
query_embedding=query_text_embedding,
|
|
practice_area=practice_area,
|
|
court=court,
|
|
precedent_level=precedent_level,
|
|
appeal_subtype=appeal_subtype,
|
|
is_binding=is_binding,
|
|
subject_tag=subject_tag,
|
|
limit=limit,
|
|
include_halachot=include_halachot,
|
|
)
|
|
|
|
text_results = await rerank.maybe_rerank(
|
|
query=query, base_search=_base, limit=fetch_k,
|
|
)
|
|
if not config.MULTIMODAL_ENABLED:
|
|
return text_results[:limit]
|
|
|
|
try:
|
|
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
|
img_rows = await db.search_precedent_images_similar(
|
|
query_img_emb,
|
|
limit=fetch_k,
|
|
practice_area=practice_area,
|
|
court=court,
|
|
precedent_level=precedent_level,
|
|
appeal_subtype=appeal_subtype,
|
|
is_binding=is_binding,
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
|
return text_results[:limit]
|
|
|
|
merged = _merge(
|
|
text_results, img_rows,
|
|
id_field="case_law_id",
|
|
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
|
)
|
|
return merged[:limit]
|
|
|
|
|
|
def _merge(
|
|
text_rows: list[dict],
|
|
img_rows: list[dict],
|
|
id_field: str,
|
|
text_weight: float,
|
|
) -> list[dict]:
|
|
"""Reciprocal Rank Fusion of text + image rows.
|
|
|
|
Why RRF: voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3
|
|
scores (~0.2-0.25) live on different scales — a direct weighted
|
|
sum lets text always dominate. RRF combines by *rank* in each list,
|
|
making the merge robust to score-scale differences.
|
|
|
|
Per item::
|
|
|
|
rrf_score = text_weight / (k + text_rank)
|
|
+ image_weight / (k + image_rank)
|
|
|
|
A row that appears in only one list contributes that list's term
|
|
only. Rows joined at ``(id_field, page_number)`` get both terms —
|
|
surfaced as ``match_type='text+image'`` with the thumbnail attached.
|
|
|
|
Halachot in precedent rows have no page_number; they remain
|
|
text-only under RRF (the case-level image boost is dropped — RRF
|
|
works on rank, not raw scores).
|
|
"""
|
|
from legal_mcp import config as _cfg
|
|
img_weight = 1.0 - text_weight
|
|
k = _cfg.MULTIMODAL_RRF_K
|
|
|
|
# Index image rows by their join key for boost detection.
|
|
img_rank_by_key: dict[tuple, int] = {}
|
|
img_row_by_key: dict[tuple, dict] = {}
|
|
for rank, r in enumerate(img_rows, 1):
|
|
key = (str(r[id_field]), r.get("page_number"))
|
|
img_rank_by_key[key] = rank
|
|
img_row_by_key[key] = r
|
|
|
|
seen_image_keys: set = set()
|
|
merged: list[dict] = []
|
|
for rank, r in enumerate(text_rows, 1):
|
|
rid = str(r[id_field])
|
|
page = r.get("page_number")
|
|
key = (rid, page) if page is not None else None
|
|
img_rank = img_rank_by_key.get(key) if key else None
|
|
text_term = text_weight / (k + rank)
|
|
image_term = img_weight / (k + img_rank) if img_rank else 0.0
|
|
d = dict(r)
|
|
d["text_score"] = float(r.get("score", 0.0))
|
|
d["text_rank"] = rank
|
|
if img_rank:
|
|
img_hit = img_row_by_key[key]
|
|
d["image_score"] = float(img_hit.get("score", 0.0))
|
|
d["image_rank"] = img_rank
|
|
d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path")
|
|
d["match_type"] = "text+image"
|
|
seen_image_keys.add(key)
|
|
else:
|
|
d["image_score"] = 0.0
|
|
d["match_type"] = "text"
|
|
d["score"] = text_term + image_term
|
|
merged.append(d)
|
|
|
|
for rank, r in enumerate(img_rows, 1):
|
|
key = (str(r[id_field]), r.get("page_number"))
|
|
if key in seen_image_keys:
|
|
continue
|
|
d = dict(r)
|
|
d["text_score"] = 0.0
|
|
d["image_score"] = float(r.get("score", 0.0))
|
|
d["image_rank"] = rank
|
|
d["score"] = img_weight / (k + rank)
|
|
d["match_type"] = "image"
|
|
d["content"] = ""
|
|
d["section_type"] = "image"
|
|
merged.append(d)
|
|
|
|
merged.sort(key=lambda x: -float(x["score"]))
|
|
return merged
|