diff --git a/mcp-server/src/legal_mcp/config.py b/mcp-server/src/legal_mcp/config.py index c55bd2f..96c3ca4 100644 --- a/mcp-server/src/legal_mcp/config.py +++ b/mcp-server/src/legal_mcp/config.py @@ -73,13 +73,19 @@ MULTIMODAL_DPI = int(os.environ.get("MULTIMODAL_DPI", "144")) # Separate, lower DPI for the JPEG thumbnail saved to disk for UI # preview. ~96dpi → ~20KB/page; ingestion-time, no re-render at view. MULTIMODAL_THUMB_DPI = int(os.environ.get("MULTIMODAL_THUMB_DPI", "96")) -# Hybrid merge weight for the *text* side. The image side gets -# (1 - this). POC found text dominates most queries; image wins only -# on table/visual queries — slight text bias starting point, tunable -# per env without redeploy. +# Hybrid merge: Reciprocal Rank Fusion (RRF) bias for the *text* side. +# voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3 scores +# (~0.20-0.25) live on different scales; a direct weighted sum lets +# text always dominate. RRF is rank-based and robust to that. The +# weight here biases the contribution of each side: 0.5 = balanced +# (vanilla RRF), >0.5 favours text, <0.5 favours image. Tunable per +# env without redeploy. MULTIMODAL_TEXT_WEIGHT = float( - os.environ.get("MULTIMODAL_TEXT_WEIGHT", "0.65") + os.environ.get("MULTIMODAL_TEXT_WEIGHT", "0.5") ) +# RRF damping constant. Standard literature value is 60: lower values +# concentrate weight at top ranks; higher values flatten the curve. +MULTIMODAL_RRF_K = int(os.environ.get("MULTIMODAL_RRF_K", "60")) # Halacha extraction — auto-approve threshold. Halachot with extractor # confidence >= this value are inserted with review_status='approved' diff --git a/mcp-server/src/legal_mcp/services/hybrid_search.py b/mcp-server/src/legal_mcp/services/hybrid_search.py index 30f4df4..8007266 100644 --- a/mcp-server/src/legal_mcp/services/hybrid_search.py +++ b/mcp-server/src/legal_mcp/services/hybrid_search.py @@ -140,59 +140,72 @@ def _merge( id_field: str, text_weight: float, ) -> list[dict]: - """Weighted merge of text + image rows. + """Reciprocal Rank Fusion of text + image rows. - Joins on ``(id_field, page_number)``. Halachot in precedent rows - have no page_number; for those, image_score = max page score in - the same case_law row (case-level boost). + Why RRF: voyage-3 cosine scores (~0.4-0.5) and voyage-multimodal-3 + scores (~0.2-0.25) live on different scales — a direct weighted + sum lets text always dominate. RRF combines by *rank* in each list, + making the merge robust to score-scale differences. - Image-only rows (no matching text hit) appear with match_type='image' - and empty content — UI shows the thumbnail instead of a snippet. + Per item:: + + rrf_score = text_weight / (k + text_rank) + + image_weight / (k + image_rank) + + A row that appears in only one list contributes that list's term + only. Rows joined at ``(id_field, page_number)`` get both terms — + surfaced as ``match_type='text+image'`` with the thumbnail attached. + + Halachot in precedent rows have no page_number; they remain + text-only under RRF (the case-level image boost is dropped — RRF + works on rank, not raw scores). """ + from legal_mcp import config as _cfg img_weight = 1.0 - text_weight - img_by_key: dict[tuple, dict] = {} - img_max_by_id: dict[str, float] = {} - for r in img_rows: - rid = str(r[id_field]) - page = r.get("page_number") - img_by_key[(rid, page)] = r - score = float(r.get("score", 0.0)) - img_max_by_id[rid] = max(img_max_by_id.get(rid, 0.0), score) + k = _cfg.MULTIMODAL_RRF_K - seen: set = set() + # Index image rows by their join key for boost detection. + img_rank_by_key: dict[tuple, int] = {} + img_row_by_key: dict[tuple, dict] = {} + for rank, r in enumerate(img_rows, 1): + key = (str(r[id_field]), r.get("page_number")) + img_rank_by_key[key] = rank + img_row_by_key[key] = r + + seen_image_keys: set = set() merged: list[dict] = [] - for r in text_rows: + for rank, r in enumerate(text_rows, 1): rid = str(r[id_field]) page = r.get("page_number") key = (rid, page) if page is not None else None - img_hit = img_by_key.get(key) if key else None - text_score = float(r.get("score", 0.0)) - if img_hit: - image_score = float(img_hit["score"]) - elif r.get("type") == "halacha": - image_score = img_max_by_id.get(rid, 0.0) - else: - image_score = 0.0 + img_rank = img_rank_by_key.get(key) if key else None + text_term = text_weight / (k + rank) + image_term = img_weight / (k + img_rank) if img_rank else 0.0 d = dict(r) - d["text_score"] = text_score - d["image_score"] = image_score - d["score"] = text_score * text_weight + image_score * img_weight - d["match_type"] = "text+image" if img_hit else "text" - if img_hit: + d["text_score"] = float(r.get("score", 0.0)) + d["text_rank"] = rank + if img_rank: + img_hit = img_row_by_key[key] + d["image_score"] = float(img_hit.get("score", 0.0)) + d["image_rank"] = img_rank d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path") - if key: - seen.add(key) + d["match_type"] = "text+image" + seen_image_keys.add(key) + else: + d["image_score"] = 0.0 + d["match_type"] = "text" + d["score"] = text_term + image_term merged.append(d) - for r in img_rows: - rid = str(r[id_field]) - key = (rid, r.get("page_number")) - if key in seen: + for rank, r in enumerate(img_rows, 1): + key = (str(r[id_field]), r.get("page_number")) + if key in seen_image_keys: continue d = dict(r) d["text_score"] = 0.0 d["image_score"] = float(r.get("score", 0.0)) - d["score"] = float(r.get("score", 0.0)) * img_weight + d["image_rank"] = rank + d["score"] = img_weight / (k + rank) d["match_type"] = "image" d["content"] = "" d["section_type"] = "image"