"""Hybrid (text + image) search wrappers. Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is true the result comes from a weighted merge of: • text side: cosine on chunks → optional rerank-2 cross-encoder (precedent search additionally fuses ``ts_rank_cd`` lexical results via RRF before this step — see ``BM25_HYBRID_ENABLED``) • image side: cosine on per-page voyage-multimodal-3 embeddings rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed through it; they keep their cosine score and merge alongside the (possibly reranked) text rows. Image-only pages with no overlapping text chunk are surfaced as ``match_type='image'`` so scanned-only or visual-heavy content still appears in results. When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain ``rerank.maybe_rerank`` — callers can wrap unconditionally and let env control behaviour. BM25/lexical leg (V12 + ``BM25_HYBRID_ENABLED``): ``search_precedent_library_hybrid`` runs ``search_precedent_library_lexical`` in parallel with the semantic side and fuses the two by rank via RRF. This recovers exact-string recall (case-number citations like "1461/20", rare planning terms) that voyage embeddings blur. The fused list is then handed to rerank-2 (if enabled) and to the image RRF (if multimodal is enabled) exactly as before. """ from __future__ import annotations import logging from typing import Any from uuid import UUID from legal_mcp import config from legal_mcp.services import db, embeddings, rerank logger = logging.getLogger(__name__) async def search_documents_hybrid( query: str, query_text_embedding: list[float], *, limit: int, case_id: UUID | None = None, section_type: str | None = None, practice_area: str | None = None, appeal_subtype: str | None = None, ) -> list[dict]: """Hybrid wrapper for document-chunk search (search_decisions / search_case_documents / find_similar_cases).""" fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit text_results = await rerank.maybe_rerank( query=query, base_search=lambda **kw: db.search_similar( query_embedding=query_text_embedding, **kw, ), limit=fetch_k, case_id=case_id, section_type=section_type, practice_area=practice_area, appeal_subtype=appeal_subtype, ) if not config.MULTIMODAL_ENABLED: return text_results[:limit] try: query_img_emb = await embeddings.embed_query_for_multimodal(query) img_rows = await db.search_document_images_similar( query_img_emb, limit=fetch_k, case_id=case_id, practice_area=practice_area, appeal_subtype=appeal_subtype, ) except Exception as e: logger.warning("Hybrid: image side failed, returning text only: %s", e) return text_results[:limit] merged = _merge( text_results, img_rows, id_field="document_id", text_weight=config.MULTIMODAL_TEXT_WEIGHT, ) return merged[:limit] async def search_precedent_library_hybrid( query: str, query_text_embedding: list[float], *, limit: int, practice_area: str = "", court: str = "", precedent_level: str = "", appeal_subtype: str = "", is_binding: bool | None = None, subject_tag: str = "", include_halachot: bool = True, source_kind: str = "external_upload", district: str = "", chair_name: str = "", max_per_case_law: int = 2, ) -> list[dict]: """Hybrid wrapper for precedent-library search. source_kind='external_upload' → court rulings (default) source_kind='internal_committee' → appeals-committee decisions max_per_case_law: MMR-style diversity cap — at most N hits per case_law_id in the final ranked list (default 2). Prevents a single precedent from monopolizing the result list when many of its chunks/halachot are individually relevant. When ``config.BM25_HYBRID_ENABLED`` is true (default) ``_base`` fuses semantic cosine + lexical ``ts_rank_cd`` via RRF before handing the candidates to rerank-2 (if enabled) and the image merge (if multimodal is enabled). """ # Fetch deeper so diversity dedup still leaves enough candidates. fetch_k = max(limit * max(max_per_case_law, 1), config.VOYAGE_RERANK_FETCH_K) \ if config.MULTIMODAL_ENABLED else max(limit * max(max_per_case_law, 1), limit) async def _base(limit: int) -> list[dict]: sem_rows = await db.search_precedent_library_semantic( query_embedding=query_text_embedding, practice_area=practice_area, court=court, precedent_level=precedent_level, appeal_subtype=appeal_subtype, is_binding=is_binding, subject_tag=subject_tag, limit=limit, include_halachot=include_halachot, source_kind=source_kind, district=district, chair_name=chair_name, ) if not config.BM25_HYBRID_ENABLED: return sem_rows # Fetch lexical with ≥ 2× depth so RRF has reserves at the tail. lex_limit = max(limit * 2, limit) try: lex_rows = await db.search_precedent_library_lexical( query=query, practice_area=practice_area, court=court, precedent_level=precedent_level, appeal_subtype=appeal_subtype, is_binding=is_binding, subject_tag=subject_tag, source_kind=source_kind, district=district, chair_name=chair_name, limit=lex_limit, include_halachot=include_halachot, ) except Exception as e: logger.warning( "Hybrid precedent: lexical side failed, semantic only: %s", e, ) return sem_rows if not lex_rows: return sem_rows return _merge_sem_lex(sem_rows, lex_rows, limit=limit) text_results = await rerank.maybe_rerank( query=query, base_search=_base, limit=fetch_k, ) if not config.MULTIMODAL_ENABLED: return _diversify_by_case_law(text_results, limit, max_per_case_law) try: query_img_emb = await embeddings.embed_query_for_multimodal(query) img_rows = await db.search_precedent_images_similar( query_img_emb, limit=fetch_k, practice_area=practice_area, court=court, precedent_level=precedent_level, appeal_subtype=appeal_subtype, is_binding=is_binding, ) except Exception as e: logger.warning("Hybrid: image side failed, returning text only: %s", e) return _diversify_by_case_law(text_results, limit, max_per_case_law) merged = _merge( text_results, img_rows, id_field="case_law_id", text_weight=config.MULTIMODAL_TEXT_WEIGHT, ) return _diversify_by_case_law(merged, limit, max_per_case_law) def _diversify_by_case_law( rows: list[dict], limit: int, max_per_case_law: int, ) -> list[dict]: """MMR-style diversity cap: at most ``max_per_case_law`` rows per case_law_id in the final list. Preserves input order (which is the relevance ranking) — for each row, include it only if we haven't reached the cap for its case_law_id yet. Set max_per_case_law<=0 to disable (returns rows[:limit] unchanged). """ if max_per_case_law <= 0 or not rows: return rows[:limit] counts: dict[str, int] = {} out: list[dict] = [] for r in rows: clid = str(r.get("case_law_id") or "") if not clid: out.append(r) if len(out) >= limit: break continue n = counts.get(clid, 0) if n < max_per_case_law: out.append(r) counts[clid] = n + 1 if len(out) >= limit: break return out def _row_key(r: dict) -> tuple[str, str]: """Stable identity for sem/lex RRF. Halachot rows have ``halacha_id``; chunk rows have ``chunk_id``. Returns ``(type, id)`` so a halacha and a chunk with the same UUID (extremely unlikely, but distinct namespaces) don't collide. """ typ = str(r.get("type") or "") rid = r.get("halacha_id") if typ == "halacha" else r.get("chunk_id") return (typ, str(rid or "")) def _merge_sem_lex( sem_rows: list[dict], lex_rows: list[dict], *, limit: int, ) -> list[dict]: """RRF fusion of semantic + lexical precedent results. Why RRF (and not weighted score sum): cosine similarities (~0.4-0.7) and ``ts_rank_cd`` values (often 0.001-0.5, query-length-dependent) live on completely different scales — a weighted sum would let one side dominate by accident. RRF combines by *rank*, so a row that tops one list and is mid-pack in the other gets a robust boost. Per row:: rrf_score = 1 / (k + sem_rank) + 1 / (k + lex_rank) A row that appears in only one list contributes that list's term only. Output is sorted by combined score, with extra debug fields (``sem_score``, ``sem_rank``, ``lex_score``, ``lex_rank``) attached so callers and tests can inspect why a row ranked where it did. The row payload (``content``, ``rule_statement``, ``case_*`` joins, etc.) is taken from the semantic-side row when available — the two sources return identical column shapes, but semantic rows carry the confidence-boosted ``score`` that the rest of the pipeline expects. """ k = config.MULTIMODAL_RRF_K sem_rank_by_key: dict[tuple, int] = {} sem_row_by_key: dict[tuple, dict] = {} for rank, r in enumerate(sem_rows, 1): key = _row_key(r) if not key[1]: continue sem_rank_by_key[key] = rank sem_row_by_key[key] = r lex_rank_by_key: dict[tuple, int] = {} lex_row_by_key: dict[tuple, dict] = {} for rank, r in enumerate(lex_rows, 1): key = _row_key(r) if not key[1]: continue lex_rank_by_key[key] = rank lex_row_by_key[key] = r all_keys = set(sem_rank_by_key) | set(lex_rank_by_key) merged: list[dict] = [] for key in all_keys: sem_rank = sem_rank_by_key.get(key) lex_rank = lex_rank_by_key.get(key) base = sem_row_by_key.get(key) or lex_row_by_key.get(key) if base is None: continue d = dict(base) sem_term = 1.0 / (k + sem_rank) if sem_rank else 0.0 lex_term = 1.0 / (k + lex_rank) if lex_rank else 0.0 d["sem_score"] = float(sem_row_by_key[key]["score"]) \ if key in sem_row_by_key else 0.0 d["sem_rank"] = sem_rank or 0 d["lex_score"] = float(lex_row_by_key[key]["score"]) \ if key in lex_row_by_key else 0.0 d["lex_rank"] = lex_rank or 0 d["score"] = sem_term + lex_term merged.append(d) merged.sort(key=lambda x: -float(x["score"])) return merged[:limit] def _merge( text_rows: list[dict], img_rows: list[dict], id_field: str, text_weight: float, ) -> list[dict]: """Reciprocal Rank Fusion of text + image rows. Why RRF: voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3 scores (~0.2-0.25) live on different scales — a direct weighted sum lets text always dominate. RRF combines by *rank* in each list, making the merge robust to score-scale differences. Per item:: rrf_score = text_weight / (k + text_rank) + image_weight / (k + image_rank) A row that appears in only one list contributes that list's term only. Rows joined at ``(id_field, page_number)`` get both terms — surfaced as ``match_type='text+image'`` with the thumbnail attached. Halachot in precedent rows have no page_number; they remain text-only under RRF (the case-level image boost is dropped — RRF works on rank, not raw scores). """ from legal_mcp import config as _cfg img_weight = 1.0 - text_weight k = _cfg.MULTIMODAL_RRF_K # Index image rows by their join key for boost detection. img_rank_by_key: dict[tuple, int] = {} img_row_by_key: dict[tuple, dict] = {} for rank, r in enumerate(img_rows, 1): key = (str(r[id_field]), r.get("page_number")) img_rank_by_key[key] = rank img_row_by_key[key] = r seen_image_keys: set = set() merged: list[dict] = [] for rank, r in enumerate(text_rows, 1): rid = str(r[id_field]) page = r.get("page_number") key = (rid, page) if page is not None else None img_rank = img_rank_by_key.get(key) if key else None text_term = text_weight / (k + rank) image_term = img_weight / (k + img_rank) if img_rank else 0.0 d = dict(r) d["text_score"] = float(r.get("score", 0.0)) d["text_rank"] = rank if img_rank: img_hit = img_row_by_key[key] d["image_score"] = float(img_hit.get("score", 0.0)) d["image_rank"] = img_rank d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path") d["match_type"] = "text+image" seen_image_keys.add(key) else: d["image_score"] = 0.0 d["match_type"] = "text" d["score"] = text_term + image_term merged.append(d) for rank, r in enumerate(img_rows, 1): key = (str(r[id_field]), r.get("page_number")) if key in seen_image_keys: continue d = dict(r) d["text_score"] = 0.0 d["image_score"] = float(r.get("score", 0.0)) d["image_rank"] = rank d["score"] = img_weight / (k + rank) d["match_type"] = "image" d["content"] = "" d["section_type"] = "image" merged.append(d) merged.sort(key=lambda x: -float(x["score"])) return merged