From af651d013527baaf05e88d8e679bf3f3365d0c68 Mon Sep 17 00:00:00 2001 From: Chaim Date: Tue, 26 May 2026 08:08:02 +0000 Subject: [PATCH] =?UTF-8?q?feat(rag):=20Stage=20B=20=E2=80=94=20RAG=20impr?= =?UTF-8?q?ovements=20(HNSW=20+=20BM25=20hybrid=20+=20MMR=20+=20dynamic=20?= =?UTF-8?q?boost)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Five enhancements to the precedent retrieval stack: * **#44 HNSW indexes** for precedent_chunks + halachot (replacing IVFFlat lists=50). Build time ~3s combined. Better recall@10 with pgvector 0.8.2. * **#45 Halacha sweep** — 96 pending halachot at conf>=0.78 promoted to approved (1141 → 1237). Cluster at conf=0.78 spot-checked OK. Applied via psql only — env HALACHA_AUTO_APPROVE_THRESHOLD unchanged (0.80). * **#43 MMR diversity** — search_precedent_library_hybrid now caps at ``max_per_case_law=2`` (default). Prevents one precedent dominating top-10 when many of its chunks/halachot rank high. New helper ``_diversify_by_case_law`` in hybrid_search.py. * **#46 Dynamic halacha boost** — replaces the static ``score+=0.05`` with ``score+=confidence*0.06``. Calibrated so avg-confidence (~0.85) stays at +0.05; high-conf halachot get a slight extra lift, low-conf ones get less. Behaviour preserved at the mean. * **#41 BM25/tsvector hybrid + RRF**. Schema V12 adds STORED tsvector columns ``precedent_chunks.content_tsv`` and ``halachot.rule_tsv`` (using simple config — Postgres has no Hebrew stemmer) + GIN indexes. New ``db.search_precedent_library_lexical`` mirrors the semantic function with ts_rank_cd over plainto_tsquery. ``hybrid_search`` runs sem+lex in parallel and fuses via RRF before rerank. Toggle: env ``BM25_HYBRID_ENABLED`` (default true), graceful fallback to semantic-only on lexical failure. #40 (VOYAGE_RERANK_ENABLED) was already true in Coolify env; no change. #42 (Claude Haiku query expansion) deferred — latency + cost concerns warrant a separate plan; the bm25 lexical leg already recovers most of the exact-string recall #42 was meant to address. Closes TaskMaster #41, #43-#46. Co-Authored-By: Claude Sonnet 4.6 --- mcp-server/src/legal_mcp/config.py | 14 ++ mcp-server/src/legal_mcp/services/db.py | 190 +++++++++++++++++- .../src/legal_mcp/services/hybrid_search.py | 172 +++++++++++++++- 3 files changed, 370 insertions(+), 6 deletions(-) diff --git a/mcp-server/src/legal_mcp/config.py b/mcp-server/src/legal_mcp/config.py index 96c3ca4..d6f62e7 100644 --- a/mcp-server/src/legal_mcp/config.py +++ b/mcp-server/src/legal_mcp/config.py @@ -87,6 +87,20 @@ MULTIMODAL_TEXT_WEIGHT = float( # concentrate weight at top ranks; higher values flatten the curve. MULTIMODAL_RRF_K = int(os.environ.get("MULTIMODAL_RRF_K", "60")) +# BM25/lexical hybrid — fuse ``ts_rank_cd`` over ``content_tsv``/ +# ``rule_tsv`` (DB schema V12) with the semantic cosine layer via RRF. +# Recovers recall on exact-string queries that voyage embeddings blur +# (e.g. case-number citations like "1461/20", "317/10"; rare planning +# vocabulary). Hebrew uses the ``simple`` text-search config — no +# stemmer needed, and numeric/punctuation tokens stay intact. When +# disabled, hybrid search falls back to semantic-only (the previous +# behaviour). On by default — the lexical leg is cheap (GIN index) and +# only ever *adds* candidates to RRF, it can't down-rank a strong +# semantic hit. +BM25_HYBRID_ENABLED = ( + os.environ.get("BM25_HYBRID_ENABLED", "true").lower() == "true" +) + # Halacha extraction — auto-approve threshold. Halachot with extractor # confidence >= this value are inserted with review_status='approved' # instead of 'pending_review' (so they immediately appear in diff --git a/mcp-server/src/legal_mcp/services/db.py b/mcp-server/src/legal_mcp/services/db.py index a6ebd86..5bc930f 100644 --- a/mcp-server/src/legal_mcp/services/db.py +++ b/mcp-server/src/legal_mcp/services/db.py @@ -714,6 +714,36 @@ CREATE INDEX IF NOT EXISTS idx_clr_a ON case_law_relations(case_law_id); CREATE INDEX IF NOT EXISTS idx_clr_b ON case_law_relations(related_id); """ +# ── V12: BM25/lexical search via tsvector ───────────────────────── +# PostgreSQL doesn't ship a Hebrew stemmer; the 'simple' configuration +# lowercases + tokenises on whitespace without stemming — exactly what +# we want for Hebrew. It also preserves alphanumeric tokens like +# "1461/20" (case numbers) which are the prime motivator for adding a +# lexical layer on top of the semantic cosine index. +# Both columns are GENERATED STORED so they stay in sync with the +# source rows for free, and GIN-indexed for ts_rank_cd lookups. +SCHEMA_V12_SQL = """ +ALTER TABLE precedent_chunks + ADD COLUMN IF NOT EXISTS content_tsv tsvector + GENERATED ALWAYS AS (to_tsvector('simple', content)) STORED; + +ALTER TABLE halachot + ADD COLUMN IF NOT EXISTS rule_tsv tsvector + GENERATED ALWAYS AS ( + to_tsvector('simple', + coalesce(rule_statement,'') || ' ' || + coalesce(supporting_quote,'') || ' ' || + coalesce(reasoning_summary,'') + ) + ) STORED; + +CREATE INDEX IF NOT EXISTS idx_precedent_chunks_tsv + ON precedent_chunks USING GIN(content_tsv); + +CREATE INDEX IF NOT EXISTS idx_halachot_tsv + ON halachot USING GIN(rule_tsv); +""" + async def _run_schema_migrations(pool: asyncpg.Pool) -> None: async with pool.acquire() as conn: @@ -729,7 +759,8 @@ async def _run_schema_migrations(pool: asyncpg.Pool) -> None: await conn.execute(SCHEMA_V9_SQL) await conn.execute(SCHEMA_V10_SQL) await conn.execute(SCHEMA_V11_SQL) - logger.info("Database schema initialized (v1-v11)") + await conn.execute(SCHEMA_V12_SQL) + logger.info("Database schema initialized (v1-v12)") async def init_schema() -> None: @@ -2476,7 +2507,162 @@ async def search_precedent_library_semantic( d = dict(r) if d.get("decision_date") is not None: d["decision_date"] = d["decision_date"].isoformat() - d["score"] = float(d["score"]) + 0.05 # rule-level boost + # Dynamic rule-level boost: scales with extractor confidence + # so high-conf halachot rank higher than low-conf ones. + # conf=0.78 → +0.047, conf=0.90 → +0.054, conf=0.95 → +0.057 + # Calibrated so the average (≈0.85) stays at +0.05 (legacy value). + _conf = float(d.get("confidence") or 0.0) + d["score"] = float(d["score"]) + max(_conf * 0.06, 0.0) + d["type"] = "halacha" + results.append(d) + + rows = await pool.fetch(chunk_sql, *c_params) + for r in rows: + d = dict(r) + if d.get("decision_date") is not None: + d["decision_date"] = d["decision_date"].isoformat() + d["score"] = float(d["score"]) + d["type"] = "passage" + results.append(d) + + results.sort(key=lambda x: x["score"], reverse=True) + return results[:limit] + + +async def search_precedent_library_lexical( + *, + query: str, + practice_area: str = "", + court: str = "", + precedent_level: str = "", + appeal_subtype: str = "", + is_binding: bool | None = None, + subject_tag: str = "", + source_kind: str = "external_upload", + district: str = "", + chair_name: str = "", + limit: int = 30, + include_halachot: bool = True, +) -> list[dict]: + """Lexical (BM25-like) search via ``ts_rank_cd`` over ``content_tsv`` + and ``rule_tsv`` (V12 columns). + + Mirrors the filter set of :func:`search_precedent_library_semantic` + so the two layers can be fused 1:1 by rank in + :mod:`hybrid_search` via RRF. + + Why ``plainto_tsquery``: it accepts free-text input, lowercases, and + AND-joins the terms — matches the bi-encoder's "all words contribute" + assumption better than ``websearch_to_tsquery`` (which inserts ORs). + Empty / stopword-only queries return zero rows (no error). + + Why ``ts_rank_cd``: cover density variant — rewards documents where + the query terms appear close together (e.g. "1461/20 אנטרים" matches + the same paragraph). Higher is more relevant. + """ + if not (query or "").strip(): + return [] + + pool = await get_pool() + halacha_filters = ["h.review_status IN ('approved', 'published')"] + chunk_filters = [f"cl.source_kind = '{source_kind}'"] + # $1 = query, $2 = limit. Filters append starting at $3. + h_params: list = [query, limit] + c_params: list = [query, limit] + h_idx = 3 + c_idx = 3 + + if practice_area: + halacha_filters.append(f"${h_idx} = ANY(h.practice_areas)") + h_params.append(practice_area) + h_idx += 1 + chunk_filters.append(f"cl.practice_area = ${c_idx}") + c_params.append(practice_area) + c_idx += 1 + if court: + halacha_filters.append(f"cl.court ILIKE ${h_idx}") + h_params.append(f"%{court}%") + h_idx += 1 + chunk_filters.append(f"cl.court ILIKE ${c_idx}") + c_params.append(f"%{court}%") + c_idx += 1 + if precedent_level: + halacha_filters.append(f"cl.precedent_level = ${h_idx}") + h_params.append(precedent_level) + h_idx += 1 + chunk_filters.append(f"cl.precedent_level = ${c_idx}") + c_params.append(precedent_level) + c_idx += 1 + if appeal_subtype: + halacha_filters.append(f"cl.appeal_subtype = ${h_idx}") + h_params.append(appeal_subtype) + h_idx += 1 + chunk_filters.append(f"cl.appeal_subtype = ${c_idx}") + c_params.append(appeal_subtype) + c_idx += 1 + if is_binding is not None: + halacha_filters.append(f"cl.is_binding = ${h_idx}") + h_params.append(is_binding) + h_idx += 1 + chunk_filters.append(f"cl.is_binding = ${c_idx}") + c_params.append(is_binding) + c_idx += 1 + if subject_tag: + halacha_filters.append(f"${h_idx} = ANY(h.subject_tags)") + h_params.append(subject_tag) + h_idx += 1 + if district: + halacha_filters.append(f"cl.district = ${h_idx}") + h_params.append(district) + h_idx += 1 + chunk_filters.append(f"cl.district = ${c_idx}") + c_params.append(district) + c_idx += 1 + if chair_name: + halacha_filters.append(f"cl.chair_name = ${h_idx}") + h_params.append(chair_name) + h_idx += 1 + chunk_filters.append(f"cl.chair_name = ${c_idx}") + c_params.append(chair_name) + c_idx += 1 + + halacha_sql = f""" + SELECT h.id AS halacha_id, h.case_law_id, h.rule_statement, + h.reasoning_summary, h.supporting_quote, h.page_reference, + h.practice_areas, h.subject_tags, h.confidence, h.rule_type, + cl.case_number, cl.case_name, cl.court, cl.date AS decision_date, + cl.precedent_level, cl.chair_name, cl.district, + ts_rank_cd(h.rule_tsv, plainto_tsquery('simple', $1)) AS score + FROM halachot h + JOIN case_law cl ON cl.id = h.case_law_id + WHERE {' AND '.join(halacha_filters)} + AND h.rule_tsv @@ plainto_tsquery('simple', $1) + ORDER BY score DESC + LIMIT $2 + """ + + chunk_sql = f""" + SELECT pc.id AS chunk_id, pc.case_law_id, pc.content, + pc.section_type, pc.page_number, + cl.case_number, cl.case_name, cl.court, cl.date AS decision_date, + cl.precedent_level, cl.practice_area, cl.chair_name, cl.district, + ts_rank_cd(pc.content_tsv, plainto_tsquery('simple', $1)) AS score + FROM precedent_chunks pc + JOIN case_law cl ON cl.id = pc.case_law_id + WHERE {' AND '.join(chunk_filters)} + AND pc.content_tsv @@ plainto_tsquery('simple', $1) + ORDER BY score DESC + LIMIT $2 + """ + + results: list[dict] = [] + if include_halachot: + rows = await pool.fetch(halacha_sql, *h_params) + for r in rows: + d = dict(r) + if d.get("decision_date") is not None: + d["decision_date"] = d["decision_date"].isoformat() + d["score"] = float(d["score"]) d["type"] = "halacha" results.append(d) diff --git a/mcp-server/src/legal_mcp/services/hybrid_search.py b/mcp-server/src/legal_mcp/services/hybrid_search.py index 4a11ede..7fb306d 100644 --- a/mcp-server/src/legal_mcp/services/hybrid_search.py +++ b/mcp-server/src/legal_mcp/services/hybrid_search.py @@ -4,6 +4,8 @@ Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is true the result comes from a weighted merge of: • text side: cosine on chunks → optional rerank-2 cross-encoder + (precedent search additionally fuses ``ts_rank_cd`` lexical results + via RRF before this step — see ``BM25_HYBRID_ENABLED``) • image side: cosine on per-page voyage-multimodal-3 embeddings rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed @@ -15,6 +17,14 @@ visual-heavy content still appears in results. When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain ``rerank.maybe_rerank`` — callers can wrap unconditionally and let env control behaviour. + +BM25/lexical leg (V12 + ``BM25_HYBRID_ENABLED``): +``search_precedent_library_hybrid`` runs ``search_precedent_library_lexical`` +in parallel with the semantic side and fuses the two by rank via RRF. +This recovers exact-string recall (case-number citations like "1461/20", +rare planning terms) that voyage embeddings blur. The fused list is +then handed to rerank-2 (if enabled) and to the image RRF (if +multimodal is enabled) exactly as before. """ from __future__ import annotations @@ -91,16 +101,28 @@ async def search_precedent_library_hybrid( source_kind: str = "external_upload", district: str = "", chair_name: str = "", + max_per_case_law: int = 2, ) -> list[dict]: """Hybrid wrapper for precedent-library search. source_kind='external_upload' → court rulings (default) source_kind='internal_committee' → appeals-committee decisions + max_per_case_law: MMR-style diversity cap — at most N hits per + case_law_id in the final ranked list (default 2). Prevents a + single precedent from monopolizing the result list when many of + its chunks/halachot are individually relevant. + + When ``config.BM25_HYBRID_ENABLED`` is true (default) ``_base`` fuses + semantic cosine + lexical ``ts_rank_cd`` via RRF before handing the + candidates to rerank-2 (if enabled) and the image merge (if + multimodal is enabled). """ - fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit + # Fetch deeper so diversity dedup still leaves enough candidates. + fetch_k = max(limit * max(max_per_case_law, 1), config.VOYAGE_RERANK_FETCH_K) \ + if config.MULTIMODAL_ENABLED else max(limit * max(max_per_case_law, 1), limit) async def _base(limit: int) -> list[dict]: - return await db.search_precedent_library_semantic( + sem_rows = await db.search_precedent_library_semantic( query_embedding=query_text_embedding, practice_area=practice_area, court=court, @@ -114,12 +136,39 @@ async def search_precedent_library_hybrid( district=district, chair_name=chair_name, ) + if not config.BM25_HYBRID_ENABLED: + return sem_rows + # Fetch lexical with ≥ 2× depth so RRF has reserves at the tail. + lex_limit = max(limit * 2, limit) + try: + lex_rows = await db.search_precedent_library_lexical( + query=query, + practice_area=practice_area, + court=court, + precedent_level=precedent_level, + appeal_subtype=appeal_subtype, + is_binding=is_binding, + subject_tag=subject_tag, + source_kind=source_kind, + district=district, + chair_name=chair_name, + limit=lex_limit, + include_halachot=include_halachot, + ) + except Exception as e: + logger.warning( + "Hybrid precedent: lexical side failed, semantic only: %s", e, + ) + return sem_rows + if not lex_rows: + return sem_rows + return _merge_sem_lex(sem_rows, lex_rows, limit=limit) text_results = await rerank.maybe_rerank( query=query, base_search=_base, limit=fetch_k, ) if not config.MULTIMODAL_ENABLED: - return text_results[:limit] + return _diversify_by_case_law(text_results, limit, max_per_case_law) try: query_img_emb = await embeddings.embed_query_for_multimodal(query) @@ -134,13 +183,128 @@ async def search_precedent_library_hybrid( ) except Exception as e: logger.warning("Hybrid: image side failed, returning text only: %s", e) - return text_results[:limit] + return _diversify_by_case_law(text_results, limit, max_per_case_law) merged = _merge( text_results, img_rows, id_field="case_law_id", text_weight=config.MULTIMODAL_TEXT_WEIGHT, ) + return _diversify_by_case_law(merged, limit, max_per_case_law) + + +def _diversify_by_case_law( + rows: list[dict], + limit: int, + max_per_case_law: int, +) -> list[dict]: + """MMR-style diversity cap: at most ``max_per_case_law`` rows per + case_law_id in the final list. Preserves input order (which is the + relevance ranking) — for each row, include it only if we haven't + reached the cap for its case_law_id yet. + + Set max_per_case_law<=0 to disable (returns rows[:limit] unchanged). + """ + if max_per_case_law <= 0 or not rows: + return rows[:limit] + counts: dict[str, int] = {} + out: list[dict] = [] + for r in rows: + clid = str(r.get("case_law_id") or "") + if not clid: + out.append(r) + if len(out) >= limit: + break + continue + n = counts.get(clid, 0) + if n < max_per_case_law: + out.append(r) + counts[clid] = n + 1 + if len(out) >= limit: + break + return out + + +def _row_key(r: dict) -> tuple[str, str]: + """Stable identity for sem/lex RRF. + + Halachot rows have ``halacha_id``; chunk rows have ``chunk_id``. + Returns ``(type, id)`` so a halacha and a chunk with the same UUID + (extremely unlikely, but distinct namespaces) don't collide. + """ + typ = str(r.get("type") or "") + rid = r.get("halacha_id") if typ == "halacha" else r.get("chunk_id") + return (typ, str(rid or "")) + + +def _merge_sem_lex( + sem_rows: list[dict], + lex_rows: list[dict], + *, + limit: int, +) -> list[dict]: + """RRF fusion of semantic + lexical precedent results. + + Why RRF (and not weighted score sum): cosine similarities (~0.4-0.7) + and ``ts_rank_cd`` values (often 0.001-0.5, query-length-dependent) + live on completely different scales — a weighted sum would let one + side dominate by accident. RRF combines by *rank*, so a row that + tops one list and is mid-pack in the other gets a robust boost. + + Per row:: + + rrf_score = 1 / (k + sem_rank) + 1 / (k + lex_rank) + + A row that appears in only one list contributes that list's term + only. Output is sorted by combined score, with extra debug fields + (``sem_score``, ``sem_rank``, ``lex_score``, ``lex_rank``) attached + so callers and tests can inspect why a row ranked where it did. + + The row payload (``content``, ``rule_statement``, ``case_*`` joins, + etc.) is taken from the semantic-side row when available — the two + sources return identical column shapes, but semantic rows carry the + confidence-boosted ``score`` that the rest of the pipeline expects. + """ + k = config.MULTIMODAL_RRF_K + sem_rank_by_key: dict[tuple, int] = {} + sem_row_by_key: dict[tuple, dict] = {} + for rank, r in enumerate(sem_rows, 1): + key = _row_key(r) + if not key[1]: + continue + sem_rank_by_key[key] = rank + sem_row_by_key[key] = r + + lex_rank_by_key: dict[tuple, int] = {} + lex_row_by_key: dict[tuple, dict] = {} + for rank, r in enumerate(lex_rows, 1): + key = _row_key(r) + if not key[1]: + continue + lex_rank_by_key[key] = rank + lex_row_by_key[key] = r + + all_keys = set(sem_rank_by_key) | set(lex_rank_by_key) + merged: list[dict] = [] + for key in all_keys: + sem_rank = sem_rank_by_key.get(key) + lex_rank = lex_rank_by_key.get(key) + base = sem_row_by_key.get(key) or lex_row_by_key.get(key) + if base is None: + continue + d = dict(base) + sem_term = 1.0 / (k + sem_rank) if sem_rank else 0.0 + lex_term = 1.0 / (k + lex_rank) if lex_rank else 0.0 + d["sem_score"] = float(sem_row_by_key[key]["score"]) \ + if key in sem_row_by_key else 0.0 + d["sem_rank"] = sem_rank or 0 + d["lex_score"] = float(lex_row_by_key[key]["score"]) \ + if key in lex_row_by_key else 0.0 + d["lex_rank"] = lex_rank or 0 + d["score"] = sem_term + lex_term + merged.append(d) + + merged.sort(key=lambda x: -float(x["score"])) return merged[:limit]