feat(retrieval): add voyage-multimodal-3 page-image embeddings (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m50s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m50s
Stage C: per-page image embeddings via voyage-multimodal-3 + hybrid text+image search. Off by default; enable with MULTIMODAL_ENABLED=true. - Schema V9: document_image_embeddings + precedent_image_embeddings (vector(1024), page_number, image_thumbnail_path) - extractor.render_pages_for_multimodal renders PDF pages at MULTIMODAL_DPI (144) for embedding + JPEG thumbnails at MULTIMODAL_THUMB_DPI (96) for UI preview, in one pass - embeddings.embed_images calls voyage-multimodal-3 in 50-page batches - services/hybrid_search.py orchestrator: rerank applied to text side first (rerank-2 is text-only); image side cosine; weighted merge with text_weight 0.65 (env-tunable); image-only pages surface as match_type='image' so dense scanned content still appears - processor.process_document and precedent_library.ingest_precedent gated by flag — non-fatal on multimodal failure - scripts/multimodal_backfill.py — idempotent per-case CLI to embed existing documents without re-extracting text Validated locally on a 5-page response brief: render 0.31s, embed 8.32s, hybrid merge surfaces image rows correctly. Production rollout starts with flag=false (no behavior change), then per-case A/B. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
202
mcp-server/src/legal_mcp/services/hybrid_search.py
Normal file
202
mcp-server/src/legal_mcp/services/hybrid_search.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Hybrid (text + image) search wrappers.
|
||||
|
||||
Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is
|
||||
true the result comes from a weighted merge of:
|
||||
|
||||
• text side: cosine on chunks → optional rerank-2 cross-encoder
|
||||
• image side: cosine on per-page voyage-multimodal-3 embeddings
|
||||
|
||||
rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed
|
||||
through it; they keep their cosine score and merge alongside the
|
||||
(possibly reranked) text rows. Image-only pages with no overlapping
|
||||
text chunk are surfaced as ``match_type='image'`` so scanned-only or
|
||||
visual-heavy content still appears in results.
|
||||
|
||||
When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain
|
||||
``rerank.maybe_rerank`` — callers can wrap unconditionally and let env
|
||||
control behaviour.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import db, embeddings, rerank
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def search_documents_hybrid(
|
||||
query: str,
|
||||
query_text_embedding: list[float],
|
||||
*,
|
||||
limit: int,
|
||||
case_id: UUID | None = None,
|
||||
section_type: str | None = None,
|
||||
practice_area: str | None = None,
|
||||
appeal_subtype: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Hybrid wrapper for document-chunk search (search_decisions /
|
||||
search_case_documents / find_similar_cases)."""
|
||||
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
|
||||
text_results = await rerank.maybe_rerank(
|
||||
query=query,
|
||||
base_search=lambda **kw: db.search_similar(
|
||||
query_embedding=query_text_embedding, **kw,
|
||||
),
|
||||
limit=fetch_k,
|
||||
case_id=case_id,
|
||||
section_type=section_type,
|
||||
practice_area=practice_area,
|
||||
appeal_subtype=appeal_subtype,
|
||||
)
|
||||
if not config.MULTIMODAL_ENABLED:
|
||||
return text_results[:limit]
|
||||
|
||||
try:
|
||||
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
||||
img_rows = await db.search_document_images_similar(
|
||||
query_img_emb,
|
||||
limit=fetch_k,
|
||||
case_id=case_id,
|
||||
practice_area=practice_area,
|
||||
appeal_subtype=appeal_subtype,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
||||
return text_results[:limit]
|
||||
|
||||
merged = _merge(
|
||||
text_results, img_rows,
|
||||
id_field="document_id",
|
||||
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
||||
)
|
||||
return merged[:limit]
|
||||
|
||||
|
||||
async def search_precedent_library_hybrid(
|
||||
query: str,
|
||||
query_text_embedding: list[float],
|
||||
*,
|
||||
limit: int,
|
||||
practice_area: str = "",
|
||||
court: str = "",
|
||||
precedent_level: str = "",
|
||||
appeal_subtype: str = "",
|
||||
is_binding: bool | None = None,
|
||||
subject_tag: str = "",
|
||||
include_halachot: bool = True,
|
||||
) -> list[dict]:
|
||||
"""Hybrid wrapper for precedent-library search."""
|
||||
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
|
||||
|
||||
async def _base(limit_inner: int) -> list[dict]:
|
||||
return await db.search_precedent_library_semantic(
|
||||
query_embedding=query_text_embedding,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
subject_tag=subject_tag,
|
||||
limit=limit_inner,
|
||||
include_halachot=include_halachot,
|
||||
)
|
||||
|
||||
text_results = await rerank.maybe_rerank(
|
||||
query=query, base_search=_base, limit=fetch_k,
|
||||
)
|
||||
if not config.MULTIMODAL_ENABLED:
|
||||
return text_results[:limit]
|
||||
|
||||
try:
|
||||
query_img_emb = await embeddings.embed_query_for_multimodal(query)
|
||||
img_rows = await db.search_precedent_images_similar(
|
||||
query_img_emb,
|
||||
limit=fetch_k,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Hybrid: image side failed, returning text only: %s", e)
|
||||
return text_results[:limit]
|
||||
|
||||
merged = _merge(
|
||||
text_results, img_rows,
|
||||
id_field="case_law_id",
|
||||
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
|
||||
)
|
||||
return merged[:limit]
|
||||
|
||||
|
||||
def _merge(
|
||||
text_rows: list[dict],
|
||||
img_rows: list[dict],
|
||||
id_field: str,
|
||||
text_weight: float,
|
||||
) -> list[dict]:
|
||||
"""Weighted merge of text + image rows.
|
||||
|
||||
Joins on ``(id_field, page_number)``. Halachot in precedent rows
|
||||
have no page_number; for those, image_score = max page score in
|
||||
the same case_law row (case-level boost).
|
||||
|
||||
Image-only rows (no matching text hit) appear with match_type='image'
|
||||
and empty content — UI shows the thumbnail instead of a snippet.
|
||||
"""
|
||||
img_weight = 1.0 - text_weight
|
||||
img_by_key: dict[tuple, dict] = {}
|
||||
img_max_by_id: dict[str, float] = {}
|
||||
for r in img_rows:
|
||||
rid = str(r[id_field])
|
||||
page = r.get("page_number")
|
||||
img_by_key[(rid, page)] = r
|
||||
score = float(r.get("score", 0.0))
|
||||
img_max_by_id[rid] = max(img_max_by_id.get(rid, 0.0), score)
|
||||
|
||||
seen: set = set()
|
||||
merged: list[dict] = []
|
||||
for r in text_rows:
|
||||
rid = str(r[id_field])
|
||||
page = r.get("page_number")
|
||||
key = (rid, page) if page is not None else None
|
||||
img_hit = img_by_key.get(key) if key else None
|
||||
text_score = float(r.get("score", 0.0))
|
||||
if img_hit:
|
||||
image_score = float(img_hit["score"])
|
||||
elif r.get("type") == "halacha":
|
||||
image_score = img_max_by_id.get(rid, 0.0)
|
||||
else:
|
||||
image_score = 0.0
|
||||
d = dict(r)
|
||||
d["text_score"] = text_score
|
||||
d["image_score"] = image_score
|
||||
d["score"] = text_score * text_weight + image_score * img_weight
|
||||
d["match_type"] = "text+image" if img_hit else "text"
|
||||
if img_hit:
|
||||
d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path")
|
||||
if key:
|
||||
seen.add(key)
|
||||
merged.append(d)
|
||||
|
||||
for r in img_rows:
|
||||
rid = str(r[id_field])
|
||||
key = (rid, r.get("page_number"))
|
||||
if key in seen:
|
||||
continue
|
||||
d = dict(r)
|
||||
d["text_score"] = 0.0
|
||||
d["image_score"] = float(r.get("score", 0.0))
|
||||
d["score"] = float(r.get("score", 0.0)) * img_weight
|
||||
d["match_type"] = "image"
|
||||
d["content"] = ""
|
||||
d["section_type"] = "image"
|
||||
merged.append(d)
|
||||
|
||||
merged.sort(key=lambda x: -float(x["score"]))
|
||||
return merged
|
||||
Reference in New Issue
Block a user