feat(rag): Stage B — RAG improvements (HNSW + BM25 hybrid + MMR + dynamic boost)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m35s

Five enhancements to the precedent retrieval stack:

* **#44 HNSW indexes** for precedent_chunks + halachot (replacing IVFFlat
  lists=50). Build time ~3s combined. Better recall@10 with pgvector 0.8.2.
* **#45 Halacha sweep** — 96 pending halachot at conf>=0.78 promoted to
  approved (1141 → 1237). Cluster at conf=0.78 spot-checked OK. Applied
  via psql only — env HALACHA_AUTO_APPROVE_THRESHOLD unchanged (0.80).
* **#43 MMR diversity** — search_precedent_library_hybrid now caps at
  ``max_per_case_law=2`` (default). Prevents one precedent dominating
  top-10 when many of its chunks/halachot rank high. New helper
  ``_diversify_by_case_law`` in hybrid_search.py.
* **#46 Dynamic halacha boost** — replaces the static ``score+=0.05``
  with ``score+=confidence*0.06``. Calibrated so avg-confidence (~0.85)
  stays at +0.05; high-conf halachot get a slight extra lift, low-conf
  ones get less. Behaviour preserved at the mean.
* **#41 BM25/tsvector hybrid + RRF**. Schema V12 adds STORED tsvector
  columns ``precedent_chunks.content_tsv`` and ``halachot.rule_tsv``
  (using simple config — Postgres has no Hebrew stemmer) + GIN indexes.
  New ``db.search_precedent_library_lexical`` mirrors the semantic
  function with ts_rank_cd over plainto_tsquery. ``hybrid_search``
  runs sem+lex in parallel and fuses via RRF before rerank. Toggle:
  env ``BM25_HYBRID_ENABLED`` (default true), graceful fallback to
  semantic-only on lexical failure.

#40 (VOYAGE_RERANK_ENABLED) was already true in Coolify env; no change.
#42 (Claude Haiku query expansion) deferred — latency + cost concerns
warrant a separate plan; the bm25 lexical leg already recovers most of
the exact-string recall #42 was meant to address.

Closes TaskMaster #41, #43-#46.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-26 08:08:02 +00:00
parent b197d2329c
commit af651d0135
3 changed files with 370 additions and 6 deletions

View File

@@ -87,6 +87,20 @@ MULTIMODAL_TEXT_WEIGHT = float(
# concentrate weight at top ranks; higher values flatten the curve.
MULTIMODAL_RRF_K = int(os.environ.get("MULTIMODAL_RRF_K", "60"))
# BM25/lexical hybrid — fuse ``ts_rank_cd`` over ``content_tsv``/
# ``rule_tsv`` (DB schema V12) with the semantic cosine layer via RRF.
# Recovers recall on exact-string queries that voyage embeddings blur
# (e.g. case-number citations like "1461/20", "317/10"; rare planning
# vocabulary). Hebrew uses the ``simple`` text-search config — no
# stemmer needed, and numeric/punctuation tokens stay intact. When
# disabled, hybrid search falls back to semantic-only (the previous
# behaviour). On by default — the lexical leg is cheap (GIN index) and
# only ever *adds* candidates to RRF, it can't down-rank a strong
# semantic hit.
BM25_HYBRID_ENABLED = (
os.environ.get("BM25_HYBRID_ENABLED", "true").lower() == "true"
)
# Halacha extraction — auto-approve threshold. Halachot with extractor
# confidence >= this value are inserted with review_status='approved'
# instead of 'pending_review' (so they immediately appear in

View File

@@ -714,6 +714,36 @@ CREATE INDEX IF NOT EXISTS idx_clr_a ON case_law_relations(case_law_id);
CREATE INDEX IF NOT EXISTS idx_clr_b ON case_law_relations(related_id);
"""
# ── V12: BM25/lexical search via tsvector ─────────────────────────
# PostgreSQL doesn't ship a Hebrew stemmer; the 'simple' configuration
# lowercases + tokenises on whitespace without stemming — exactly what
# we want for Hebrew. It also preserves alphanumeric tokens like
# "1461/20" (case numbers) which are the prime motivator for adding a
# lexical layer on top of the semantic cosine index.
# Both columns are GENERATED STORED so they stay in sync with the
# source rows for free, and GIN-indexed for ts_rank_cd lookups.
SCHEMA_V12_SQL = """
ALTER TABLE precedent_chunks
ADD COLUMN IF NOT EXISTS content_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('simple', content)) STORED;
ALTER TABLE halachot
ADD COLUMN IF NOT EXISTS rule_tsv tsvector
GENERATED ALWAYS AS (
to_tsvector('simple',
coalesce(rule_statement,'') || ' ' ||
coalesce(supporting_quote,'') || ' ' ||
coalesce(reasoning_summary,'')
)
) STORED;
CREATE INDEX IF NOT EXISTS idx_precedent_chunks_tsv
ON precedent_chunks USING GIN(content_tsv);
CREATE INDEX IF NOT EXISTS idx_halachot_tsv
ON halachot USING GIN(rule_tsv);
"""
async def _run_schema_migrations(pool: asyncpg.Pool) -> None:
async with pool.acquire() as conn:
@@ -729,7 +759,8 @@ async def _run_schema_migrations(pool: asyncpg.Pool) -> None:
await conn.execute(SCHEMA_V9_SQL)
await conn.execute(SCHEMA_V10_SQL)
await conn.execute(SCHEMA_V11_SQL)
logger.info("Database schema initialized (v1-v11)")
await conn.execute(SCHEMA_V12_SQL)
logger.info("Database schema initialized (v1-v12)")
async def init_schema() -> None:
@@ -2476,7 +2507,162 @@ async def search_precedent_library_semantic(
d = dict(r)
if d.get("decision_date") is not None:
d["decision_date"] = d["decision_date"].isoformat()
d["score"] = float(d["score"]) + 0.05 # rule-level boost
# Dynamic rule-level boost: scales with extractor confidence
# so high-conf halachot rank higher than low-conf ones.
# conf=0.78 → +0.047, conf=0.90 → +0.054, conf=0.95 → +0.057
# Calibrated so the average (≈0.85) stays at +0.05 (legacy value).
_conf = float(d.get("confidence") or 0.0)
d["score"] = float(d["score"]) + max(_conf * 0.06, 0.0)
d["type"] = "halacha"
results.append(d)
rows = await pool.fetch(chunk_sql, *c_params)
for r in rows:
d = dict(r)
if d.get("decision_date") is not None:
d["decision_date"] = d["decision_date"].isoformat()
d["score"] = float(d["score"])
d["type"] = "passage"
results.append(d)
results.sort(key=lambda x: x["score"], reverse=True)
return results[:limit]
async def search_precedent_library_lexical(
*,
query: str,
practice_area: str = "",
court: str = "",
precedent_level: str = "",
appeal_subtype: str = "",
is_binding: bool | None = None,
subject_tag: str = "",
source_kind: str = "external_upload",
district: str = "",
chair_name: str = "",
limit: int = 30,
include_halachot: bool = True,
) -> list[dict]:
"""Lexical (BM25-like) search via ``ts_rank_cd`` over ``content_tsv``
and ``rule_tsv`` (V12 columns).
Mirrors the filter set of :func:`search_precedent_library_semantic`
so the two layers can be fused 1:1 by rank in
:mod:`hybrid_search` via RRF.
Why ``plainto_tsquery``: it accepts free-text input, lowercases, and
AND-joins the terms — matches the bi-encoder's "all words contribute"
assumption better than ``websearch_to_tsquery`` (which inserts ORs).
Empty / stopword-only queries return zero rows (no error).
Why ``ts_rank_cd``: cover density variant — rewards documents where
the query terms appear close together (e.g. "1461/20 אנטרים" matches
the same paragraph). Higher is more relevant.
"""
if not (query or "").strip():
return []
pool = await get_pool()
halacha_filters = ["h.review_status IN ('approved', 'published')"]
chunk_filters = [f"cl.source_kind = '{source_kind}'"]
# $1 = query, $2 = limit. Filters append starting at $3.
h_params: list = [query, limit]
c_params: list = [query, limit]
h_idx = 3
c_idx = 3
if practice_area:
halacha_filters.append(f"${h_idx} = ANY(h.practice_areas)")
h_params.append(practice_area)
h_idx += 1
chunk_filters.append(f"cl.practice_area = ${c_idx}")
c_params.append(practice_area)
c_idx += 1
if court:
halacha_filters.append(f"cl.court ILIKE ${h_idx}")
h_params.append(f"%{court}%")
h_idx += 1
chunk_filters.append(f"cl.court ILIKE ${c_idx}")
c_params.append(f"%{court}%")
c_idx += 1
if precedent_level:
halacha_filters.append(f"cl.precedent_level = ${h_idx}")
h_params.append(precedent_level)
h_idx += 1
chunk_filters.append(f"cl.precedent_level = ${c_idx}")
c_params.append(precedent_level)
c_idx += 1
if appeal_subtype:
halacha_filters.append(f"cl.appeal_subtype = ${h_idx}")
h_params.append(appeal_subtype)
h_idx += 1
chunk_filters.append(f"cl.appeal_subtype = ${c_idx}")
c_params.append(appeal_subtype)
c_idx += 1
if is_binding is not None:
halacha_filters.append(f"cl.is_binding = ${h_idx}")
h_params.append(is_binding)
h_idx += 1
chunk_filters.append(f"cl.is_binding = ${c_idx}")
c_params.append(is_binding)
c_idx += 1
if subject_tag:
halacha_filters.append(f"${h_idx} = ANY(h.subject_tags)")
h_params.append(subject_tag)
h_idx += 1
if district:
halacha_filters.append(f"cl.district = ${h_idx}")
h_params.append(district)
h_idx += 1
chunk_filters.append(f"cl.district = ${c_idx}")
c_params.append(district)
c_idx += 1
if chair_name:
halacha_filters.append(f"cl.chair_name = ${h_idx}")
h_params.append(chair_name)
h_idx += 1
chunk_filters.append(f"cl.chair_name = ${c_idx}")
c_params.append(chair_name)
c_idx += 1
halacha_sql = f"""
SELECT h.id AS halacha_id, h.case_law_id, h.rule_statement,
h.reasoning_summary, h.supporting_quote, h.page_reference,
h.practice_areas, h.subject_tags, h.confidence, h.rule_type,
cl.case_number, cl.case_name, cl.court, cl.date AS decision_date,
cl.precedent_level, cl.chair_name, cl.district,
ts_rank_cd(h.rule_tsv, plainto_tsquery('simple', $1)) AS score
FROM halachot h
JOIN case_law cl ON cl.id = h.case_law_id
WHERE {' AND '.join(halacha_filters)}
AND h.rule_tsv @@ plainto_tsquery('simple', $1)
ORDER BY score DESC
LIMIT $2
"""
chunk_sql = f"""
SELECT pc.id AS chunk_id, pc.case_law_id, pc.content,
pc.section_type, pc.page_number,
cl.case_number, cl.case_name, cl.court, cl.date AS decision_date,
cl.precedent_level, cl.practice_area, cl.chair_name, cl.district,
ts_rank_cd(pc.content_tsv, plainto_tsquery('simple', $1)) AS score
FROM precedent_chunks pc
JOIN case_law cl ON cl.id = pc.case_law_id
WHERE {' AND '.join(chunk_filters)}
AND pc.content_tsv @@ plainto_tsquery('simple', $1)
ORDER BY score DESC
LIMIT $2
"""
results: list[dict] = []
if include_halachot:
rows = await pool.fetch(halacha_sql, *h_params)
for r in rows:
d = dict(r)
if d.get("decision_date") is not None:
d["decision_date"] = d["decision_date"].isoformat()
d["score"] = float(d["score"])
d["type"] = "halacha"
results.append(d)

View File

@@ -4,6 +4,8 @@ Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is
true the result comes from a weighted merge of:
• text side: cosine on chunks → optional rerank-2 cross-encoder
(precedent search additionally fuses ``ts_rank_cd`` lexical results
via RRF before this step — see ``BM25_HYBRID_ENABLED``)
• image side: cosine on per-page voyage-multimodal-3 embeddings
rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed
@@ -15,6 +17,14 @@ visual-heavy content still appears in results.
When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain
``rerank.maybe_rerank`` — callers can wrap unconditionally and let env
control behaviour.
BM25/lexical leg (V12 + ``BM25_HYBRID_ENABLED``):
``search_precedent_library_hybrid`` runs ``search_precedent_library_lexical``
in parallel with the semantic side and fuses the two by rank via RRF.
This recovers exact-string recall (case-number citations like "1461/20",
rare planning terms) that voyage embeddings blur. The fused list is
then handed to rerank-2 (if enabled) and to the image RRF (if
multimodal is enabled) exactly as before.
"""
from __future__ import annotations
@@ -91,16 +101,28 @@ async def search_precedent_library_hybrid(
source_kind: str = "external_upload",
district: str = "",
chair_name: str = "",
max_per_case_law: int = 2,
) -> list[dict]:
"""Hybrid wrapper for precedent-library search.
source_kind='external_upload' → court rulings (default)
source_kind='internal_committee' → appeals-committee decisions
max_per_case_law: MMR-style diversity cap — at most N hits per
case_law_id in the final ranked list (default 2). Prevents a
single precedent from monopolizing the result list when many of
its chunks/halachot are individually relevant.
When ``config.BM25_HYBRID_ENABLED`` is true (default) ``_base`` fuses
semantic cosine + lexical ``ts_rank_cd`` via RRF before handing the
candidates to rerank-2 (if enabled) and the image merge (if
multimodal is enabled).
"""
fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit
# Fetch deeper so diversity dedup still leaves enough candidates.
fetch_k = max(limit * max(max_per_case_law, 1), config.VOYAGE_RERANK_FETCH_K) \
if config.MULTIMODAL_ENABLED else max(limit * max(max_per_case_law, 1), limit)
async def _base(limit: int) -> list[dict]:
return await db.search_precedent_library_semantic(
sem_rows = await db.search_precedent_library_semantic(
query_embedding=query_text_embedding,
practice_area=practice_area,
court=court,
@@ -114,12 +136,39 @@ async def search_precedent_library_hybrid(
district=district,
chair_name=chair_name,
)
if not config.BM25_HYBRID_ENABLED:
return sem_rows
# Fetch lexical with ≥ 2× depth so RRF has reserves at the tail.
lex_limit = max(limit * 2, limit)
try:
lex_rows = await db.search_precedent_library_lexical(
query=query,
practice_area=practice_area,
court=court,
precedent_level=precedent_level,
appeal_subtype=appeal_subtype,
is_binding=is_binding,
subject_tag=subject_tag,
source_kind=source_kind,
district=district,
chair_name=chair_name,
limit=lex_limit,
include_halachot=include_halachot,
)
except Exception as e:
logger.warning(
"Hybrid precedent: lexical side failed, semantic only: %s", e,
)
return sem_rows
if not lex_rows:
return sem_rows
return _merge_sem_lex(sem_rows, lex_rows, limit=limit)
text_results = await rerank.maybe_rerank(
query=query, base_search=_base, limit=fetch_k,
)
if not config.MULTIMODAL_ENABLED:
return text_results[:limit]
return _diversify_by_case_law(text_results, limit, max_per_case_law)
try:
query_img_emb = await embeddings.embed_query_for_multimodal(query)
@@ -134,13 +183,128 @@ async def search_precedent_library_hybrid(
)
except Exception as e:
logger.warning("Hybrid: image side failed, returning text only: %s", e)
return text_results[:limit]
return _diversify_by_case_law(text_results, limit, max_per_case_law)
merged = _merge(
text_results, img_rows,
id_field="case_law_id",
text_weight=config.MULTIMODAL_TEXT_WEIGHT,
)
return _diversify_by_case_law(merged, limit, max_per_case_law)
def _diversify_by_case_law(
rows: list[dict],
limit: int,
max_per_case_law: int,
) -> list[dict]:
"""MMR-style diversity cap: at most ``max_per_case_law`` rows per
case_law_id in the final list. Preserves input order (which is the
relevance ranking) — for each row, include it only if we haven't
reached the cap for its case_law_id yet.
Set max_per_case_law<=0 to disable (returns rows[:limit] unchanged).
"""
if max_per_case_law <= 0 or not rows:
return rows[:limit]
counts: dict[str, int] = {}
out: list[dict] = []
for r in rows:
clid = str(r.get("case_law_id") or "")
if not clid:
out.append(r)
if len(out) >= limit:
break
continue
n = counts.get(clid, 0)
if n < max_per_case_law:
out.append(r)
counts[clid] = n + 1
if len(out) >= limit:
break
return out
def _row_key(r: dict) -> tuple[str, str]:
"""Stable identity for sem/lex RRF.
Halachot rows have ``halacha_id``; chunk rows have ``chunk_id``.
Returns ``(type, id)`` so a halacha and a chunk with the same UUID
(extremely unlikely, but distinct namespaces) don't collide.
"""
typ = str(r.get("type") or "")
rid = r.get("halacha_id") if typ == "halacha" else r.get("chunk_id")
return (typ, str(rid or ""))
def _merge_sem_lex(
sem_rows: list[dict],
lex_rows: list[dict],
*,
limit: int,
) -> list[dict]:
"""RRF fusion of semantic + lexical precedent results.
Why RRF (and not weighted score sum): cosine similarities (~0.4-0.7)
and ``ts_rank_cd`` values (often 0.001-0.5, query-length-dependent)
live on completely different scales — a weighted sum would let one
side dominate by accident. RRF combines by *rank*, so a row that
tops one list and is mid-pack in the other gets a robust boost.
Per row::
rrf_score = 1 / (k + sem_rank) + 1 / (k + lex_rank)
A row that appears in only one list contributes that list's term
only. Output is sorted by combined score, with extra debug fields
(``sem_score``, ``sem_rank``, ``lex_score``, ``lex_rank``) attached
so callers and tests can inspect why a row ranked where it did.
The row payload (``content``, ``rule_statement``, ``case_*`` joins,
etc.) is taken from the semantic-side row when available — the two
sources return identical column shapes, but semantic rows carry the
confidence-boosted ``score`` that the rest of the pipeline expects.
"""
k = config.MULTIMODAL_RRF_K
sem_rank_by_key: dict[tuple, int] = {}
sem_row_by_key: dict[tuple, dict] = {}
for rank, r in enumerate(sem_rows, 1):
key = _row_key(r)
if not key[1]:
continue
sem_rank_by_key[key] = rank
sem_row_by_key[key] = r
lex_rank_by_key: dict[tuple, int] = {}
lex_row_by_key: dict[tuple, dict] = {}
for rank, r in enumerate(lex_rows, 1):
key = _row_key(r)
if not key[1]:
continue
lex_rank_by_key[key] = rank
lex_row_by_key[key] = r
all_keys = set(sem_rank_by_key) | set(lex_rank_by_key)
merged: list[dict] = []
for key in all_keys:
sem_rank = sem_rank_by_key.get(key)
lex_rank = lex_rank_by_key.get(key)
base = sem_row_by_key.get(key) or lex_row_by_key.get(key)
if base is None:
continue
d = dict(base)
sem_term = 1.0 / (k + sem_rank) if sem_rank else 0.0
lex_term = 1.0 / (k + lex_rank) if lex_rank else 0.0
d["sem_score"] = float(sem_row_by_key[key]["score"]) \
if key in sem_row_by_key else 0.0
d["sem_rank"] = sem_rank or 0
d["lex_score"] = float(lex_row_by_key[key]["score"]) \
if key in lex_row_by_key else 0.0
d["lex_rank"] = lex_rank or 0
d["score"] = sem_term + lex_term
merged.append(d)
merged.sort(key=lambda x: -float(x["score"]))
return merged[:limit]