From 242f6683195642b9560d3623f099f6e6aa1e170d Mon Sep 17 00:00:00 2001 From: Chaim Date: Sun, 3 May 2026 19:24:52 +0000 Subject: [PATCH] feat(retrieval): add voyage-multimodal-3 page-image embeddings (feature flag) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stage C: per-page image embeddings via voyage-multimodal-3 + hybrid text+image search. Off by default; enable with MULTIMODAL_ENABLED=true. - Schema V9: document_image_embeddings + precedent_image_embeddings (vector(1024), page_number, image_thumbnail_path) - extractor.render_pages_for_multimodal renders PDF pages at MULTIMODAL_DPI (144) for embedding + JPEG thumbnails at MULTIMODAL_THUMB_DPI (96) for UI preview, in one pass - embeddings.embed_images calls voyage-multimodal-3 in 50-page batches - services/hybrid_search.py orchestrator: rerank applied to text side first (rerank-2 is text-only); image side cosine; weighted merge with text_weight 0.65 (env-tunable); image-only pages surface as match_type='image' so dense scanned content still appears - processor.process_document and precedent_library.ingest_precedent gated by flag — non-fatal on multimodal failure - scripts/multimodal_backfill.py — idempotent per-case CLI to embed existing documents without re-extracting text Validated locally on a 5-page response brief: render 0.31s, embed 8.32s, hybrid merge surfaces image rows correctly. Production rollout starts with flag=false (no behavior change), then per-case A/B. Co-Authored-By: Claude Opus 4.7 (1M context) --- mcp-server/src/legal_mcp/config.py | 23 ++ mcp-server/src/legal_mcp/services/db.py | 348 +++++++++++++++++- .../src/legal_mcp/services/embeddings.py | 48 +++ .../src/legal_mcp/services/extractor.py | 61 +++ .../src/legal_mcp/services/hybrid_search.py | 202 ++++++++++ .../legal_mcp/services/precedent_library.py | 86 ++++- .../src/legal_mcp/services/processor.py | 71 ++++ mcp-server/src/legal_mcp/tools/search.py | 52 +-- scripts/SCRIPTS.md | 1 + scripts/multimodal_backfill.py | 186 ++++++++++ 10 files changed, 1038 insertions(+), 40 deletions(-) create mode 100644 mcp-server/src/legal_mcp/services/hybrid_search.py create mode 100644 scripts/multimodal_backfill.py diff --git a/mcp-server/src/legal_mcp/config.py b/mcp-server/src/legal_mcp/config.py index 95f8258..c55bd2f 100644 --- a/mcp-server/src/legal_mcp/config.py +++ b/mcp-server/src/legal_mcp/config.py @@ -58,6 +58,29 @@ VOYAGE_RERANK_ENABLED = ( # 50 was the depth used in the POC; balances recall vs rerank cost. VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50")) +# Multimodal — page-image embeddings via voyage-multimodal-3. Off by +# default; flip with env to enable per-page image embedding during +# ingestion + hybrid (text+image) ranking at search time. POC #3 +# validated on a 89-page appraisal PDF (38s, 312K tokens, recovered +# table structure + image-only scanned pages that text-OCR misses). +MULTIMODAL_ENABLED = ( + os.environ.get("MULTIMODAL_ENABLED", "false").lower() == "true" +) +MULTIMODAL_MODEL = os.environ.get("MULTIMODAL_MODEL", "voyage-multimodal-3") +# Render DPI for the image fed to the embedder. POC used 144 — sweet +# spot between embedding quality and tokens/page (144 ≈ 3.5K tok/page). +MULTIMODAL_DPI = int(os.environ.get("MULTIMODAL_DPI", "144")) +# Separate, lower DPI for the JPEG thumbnail saved to disk for UI +# preview. ~96dpi → ~20KB/page; ingestion-time, no re-render at view. +MULTIMODAL_THUMB_DPI = int(os.environ.get("MULTIMODAL_THUMB_DPI", "96")) +# Hybrid merge weight for the *text* side. The image side gets +# (1 - this). POC found text dominates most queries; image wins only +# on table/visual queries — slight text bias starting point, tunable +# per env without redeploy. +MULTIMODAL_TEXT_WEIGHT = float( + os.environ.get("MULTIMODAL_TEXT_WEIGHT", "0.65") +) + # Halacha extraction — auto-approve threshold. Halachot with extractor # confidence >= this value are inserted with review_status='approved' # instead of 'pending_review' (so they immediately appear in diff --git a/mcp-server/src/legal_mcp/services/db.py b/mcp-server/src/legal_mcp/services/db.py index f4207b3..e9e85fd 100644 --- a/mcp-server/src/legal_mcp/services/db.py +++ b/mcp-server/src/legal_mcp/services/db.py @@ -623,6 +623,54 @@ CREATE INDEX IF NOT EXISTS idx_case_law_halacha_requested """ +# ── V9: Multimodal page-image embeddings ───────────────────────── +# voyage-multimodal-3 (1024-dim) embeds the whole page as an image: +# captures table layout, scanned content, signatures, plans — content +# that text-OCR loses. Ingestion is gated by config.MULTIMODAL_ENABLED; +# search_*_hybrid() merge text-cosine + image-cosine when present. +# image_thumbnail_path is a relative path under DATA_DIR/cases/{case}/ +# thumbnails/ or DATA_DIR/precedent-library/thumbnails/ — a small JPEG +# rendered at config.MULTIMODAL_THUMB_DPI for UI preview, distinct from +# the higher-DPI render fed to the embedder (which is not persisted). + +SCHEMA_V9_SQL = """ +CREATE TABLE IF NOT EXISTS document_image_embeddings ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + document_id UUID REFERENCES documents(id) ON DELETE CASCADE, + case_id UUID REFERENCES cases(id) ON DELETE CASCADE, + page_number INTEGER NOT NULL, + image_thumbnail_path TEXT, + embedding vector(1024), + model_name TEXT DEFAULT 'voyage-multimodal-3', + created_at TIMESTAMPTZ DEFAULT now(), + UNIQUE(document_id, page_number) +); +CREATE INDEX IF NOT EXISTS idx_doc_img_emb_vec + ON document_image_embeddings USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 50); +CREATE INDEX IF NOT EXISTS idx_doc_img_emb_doc + ON document_image_embeddings(document_id); +CREATE INDEX IF NOT EXISTS idx_doc_img_emb_case + ON document_image_embeddings(case_id); + +CREATE TABLE IF NOT EXISTS precedent_image_embeddings ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + case_law_id UUID REFERENCES case_law(id) ON DELETE CASCADE, + page_number INTEGER NOT NULL, + image_thumbnail_path TEXT, + embedding vector(1024), + model_name TEXT DEFAULT 'voyage-multimodal-3', + created_at TIMESTAMPTZ DEFAULT now(), + UNIQUE(case_law_id, page_number) +); +CREATE INDEX IF NOT EXISTS idx_prec_img_emb_vec + ON precedent_image_embeddings USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 50); +CREATE INDEX IF NOT EXISTS idx_prec_img_emb_case_law + ON precedent_image_embeddings(case_law_id); +""" + + async def init_schema() -> None: pool = await get_pool() async with pool.acquire() as conn: @@ -635,7 +683,8 @@ async def init_schema() -> None: await conn.execute(SCHEMA_V6_SQL) await conn.execute(SCHEMA_V7_SQL) await conn.execute(SCHEMA_V8_SQL) - logger.info("Database schema initialized (v1-v8)") + await conn.execute(SCHEMA_V9_SQL) + logger.info("Database schema initialized (v1-v9)") # ── Case CRUD ─────────────────────────────────────────────────────── @@ -2350,3 +2399,300 @@ async def clear_extraction_request( f"UPDATE case_law SET {col} = NULL WHERE id = $1", case_law_id, ) + + +# ── V9: Multimodal page image embeddings ───────────────────────── + + +async def store_document_image_embeddings( + document_id: UUID, + case_id: UUID | None, + page_records: list[dict], + model_name: str = "voyage-multimodal-3", +) -> int: + """Replace per-page image embeddings for a document. + + Each ``page_records`` entry: ``{page_number, embedding, image_thumbnail_path}``. + Embeddings should already be 1024-dim lists (or None for skipped pages). + """ + pool = await get_pool() + async with pool.acquire() as conn: + await conn.execute( + "DELETE FROM document_image_embeddings WHERE document_id = $1", + document_id, + ) + for r in page_records: + await conn.execute( + """INSERT INTO document_image_embeddings + (document_id, case_id, page_number, embedding, + image_thumbnail_path, model_name) + VALUES ($1, $2, $3, $4, $5, $6)""", + document_id, case_id, + r["page_number"], + r.get("embedding"), + r.get("image_thumbnail_path"), + model_name, + ) + return len(page_records) + + +async def store_precedent_image_embeddings( + case_law_id: UUID, + page_records: list[dict], + model_name: str = "voyage-multimodal-3", +) -> int: + """Same pattern as store_document_image_embeddings but for precedents.""" + pool = await get_pool() + async with pool.acquire() as conn: + await conn.execute( + "DELETE FROM precedent_image_embeddings WHERE case_law_id = $1", + case_law_id, + ) + for r in page_records: + await conn.execute( + """INSERT INTO precedent_image_embeddings + (case_law_id, page_number, embedding, + image_thumbnail_path, model_name) + VALUES ($1, $2, $3, $4, $5)""", + case_law_id, + r["page_number"], + r.get("embedding"), + r.get("image_thumbnail_path"), + model_name, + ) + return len(page_records) + + +async def search_document_images_similar( + query_embedding: list[float], + limit: int = 10, + case_id: UUID | None = None, + practice_area: str | None = None, + appeal_subtype: str | None = None, +) -> list[dict]: + """Cosine search over per-page image embeddings of case documents.""" + pool = await get_pool() + conditions: list[str] = [] + params: list = [query_embedding, limit] + idx = 3 + if case_id: + conditions.append(f"die.case_id = ${idx}") + params.append(case_id); idx += 1 + if practice_area: + conditions.append(f"c.practice_area = ${idx}") + params.append(practice_area); idx += 1 + if appeal_subtype: + conditions.append(f"c.appeal_subtype = ${idx}") + params.append(appeal_subtype); idx += 1 + where = f"WHERE {' AND '.join(conditions)}" if conditions else "" + sql = f""" + SELECT die.document_id, die.case_id, die.page_number, + die.image_thumbnail_path, + d.title AS document_title, + c.case_number, + 1 - (die.embedding <=> $1) AS score + FROM document_image_embeddings die + JOIN documents d ON d.id = die.document_id + JOIN cases c ON c.id = die.case_id + {where} + ORDER BY die.embedding <=> $1 + LIMIT $2 + """ + async with pool.acquire() as conn: + rows = await conn.fetch(sql, *params) + return [dict(r) for r in rows] + + +async def search_precedent_images_similar( + query_embedding: list[float], + limit: int = 10, + practice_area: str = "", + court: str = "", + precedent_level: str = "", + appeal_subtype: str = "", + is_binding: bool | None = None, +) -> list[dict]: + """Cosine search over per-page image embeddings of precedent rulings.""" + pool = await get_pool() + conditions: list[str] = ["cl.source_kind = 'external_upload'"] + params: list = [query_embedding, limit] + idx = 3 + if practice_area: + conditions.append(f"cl.practice_area = ${idx}") + params.append(practice_area); idx += 1 + if court: + conditions.append(f"cl.court ILIKE ${idx}") + params.append(f"%{court}%"); idx += 1 + if precedent_level: + conditions.append(f"cl.precedent_level = ${idx}") + params.append(precedent_level); idx += 1 + if appeal_subtype: + conditions.append(f"cl.appeal_subtype = ${idx}") + params.append(appeal_subtype); idx += 1 + if is_binding is not None: + conditions.append(f"cl.is_binding = ${idx}") + params.append(is_binding); idx += 1 + where = " AND ".join(conditions) + sql = f""" + SELECT pie.case_law_id, pie.page_number, pie.image_thumbnail_path, + cl.case_number, cl.case_name, cl.court, cl.date AS decision_date, + cl.precedent_level, cl.practice_area, + 1 - (pie.embedding <=> $1) AS score + FROM precedent_image_embeddings pie + JOIN case_law cl ON cl.id = pie.case_law_id + WHERE {where} + ORDER BY pie.embedding <=> $1 + LIMIT $2 + """ + async with pool.acquire() as conn: + rows = await conn.fetch(sql, *params) + out = [] + for r in rows: + d = dict(r) + if d.get("decision_date") is not None: + d["decision_date"] = d["decision_date"].isoformat() + out.append(d) + return out + + +async def search_similar_hybrid( + query_text_embedding: list[float], + query_image_embedding: list[float], + limit: int = 10, + fetch_k: int = 30, + text_weight: float = 0.65, + case_id: UUID | None = None, + section_type: str | None = None, + practice_area: str | None = None, + appeal_subtype: str | None = None, +) -> list[dict]: + """Weighted merge of text-chunk and per-page image search. + + Same (document_id, page_number) → boost text chunk by image score + on that page. Image-only pages with no overlapping text chunk are + surfaced as ``match_type='image'`` so dense scanned content still + appears in results. + """ + img_weight = 1.0 - text_weight + text_rows = await search_similar( + query_text_embedding, limit=fetch_k, case_id=case_id, + section_type=section_type, practice_area=practice_area, + appeal_subtype=appeal_subtype, + ) + img_rows = await search_document_images_similar( + query_image_embedding, limit=fetch_k, case_id=case_id, + practice_area=practice_area, appeal_subtype=appeal_subtype, + ) + img_by_page: dict[tuple, dict] = { + (str(r["document_id"]), r["page_number"]): r for r in img_rows + } + seen: set = set() + merged: list[dict] = [] + for r in text_rows: + page = r.get("page_number") + key = (str(r["document_id"]), page) if page is not None else None + img_hit = img_by_page.get(key) if key else None + text_score = float(r["score"]) + image_score = float(img_hit["score"]) if img_hit else 0.0 + d = dict(r) + d["text_score"] = text_score + d["image_score"] = image_score + d["score"] = text_score * text_weight + image_score * img_weight + d["match_type"] = "text+image" if img_hit else "text" + if img_hit: + d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path") + merged.append(d) + if key: + seen.add(key) + for r in img_rows: + key = (str(r["document_id"]), r["page_number"]) + if key in seen: + continue + d = dict(r) + d["text_score"] = 0.0 + d["image_score"] = float(r["score"]) + d["score"] = float(r["score"]) * img_weight + d["match_type"] = "image" + d["content"] = "" + d["section_type"] = "image" + merged.append(d) + merged.sort(key=lambda x: -x["score"]) + return merged[:limit] + + +async def search_precedent_library_hybrid( + query_text_embedding: list[float], + query_image_embedding: list[float], + limit: int = 10, + fetch_k: int = 30, + text_weight: float = 0.65, + practice_area: str = "", + court: str = "", + precedent_level: str = "", + appeal_subtype: str = "", + is_binding: bool | None = None, + subject_tag: str = "", + include_halachot: bool = True, +) -> list[dict]: + """Hybrid variant of search_precedent_library_semantic. + + Halachot have no ``page_number`` — they're boosted by the max + image score from any page in the same case_law row. + """ + img_weight = 1.0 - text_weight + text_results = await search_precedent_library_semantic( + query_text_embedding, + practice_area=practice_area, court=court, + precedent_level=precedent_level, appeal_subtype=appeal_subtype, + is_binding=is_binding, subject_tag=subject_tag, + limit=fetch_k, include_halachot=include_halachot, + ) + img_results = await search_precedent_images_similar( + query_image_embedding, limit=fetch_k, + practice_area=practice_area, court=court, + precedent_level=precedent_level, appeal_subtype=appeal_subtype, + is_binding=is_binding, + ) + img_by_page: dict[tuple, dict] = {} + img_by_case: dict[str, float] = {} + for r in img_results: + cid = str(r["case_law_id"]) + img_by_page[(cid, r["page_number"])] = r + img_by_case[cid] = max(img_by_case.get(cid, 0.0), float(r["score"])) + seen: set = set() + merged: list[dict] = [] + for r in text_results: + cid = str(r["case_law_id"]) + page = r.get("page_number") + key = (cid, page) if page is not None else None + img_hit = img_by_page.get(key) if key else None + if img_hit: + image_score = float(img_hit["score"]) + elif r.get("type") == "halacha": + image_score = img_by_case.get(cid, 0.0) + else: + image_score = 0.0 + text_score = float(r["score"]) + d = dict(r) + d["text_score"] = text_score + d["image_score"] = image_score + d["score"] = text_score * text_weight + image_score * img_weight + if img_hit: + d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path") + if key: + seen.add(key) + merged.append(d) + for r in img_results: + key = (str(r["case_law_id"]), r["page_number"]) + if key in seen: + continue + d = dict(r) + d["text_score"] = 0.0 + d["image_score"] = float(r["score"]) + d["score"] = float(r["score"]) * img_weight + d["type"] = "image_page" + d["content"] = "" + d["section_type"] = "image" + merged.append(d) + merged.sort(key=lambda x: -x["score"]) + return merged[:limit] diff --git a/mcp-server/src/legal_mcp/services/embeddings.py b/mcp-server/src/legal_mcp/services/embeddings.py index 9676d88..4a6aa1c 100644 --- a/mcp-server/src/legal_mcp/services/embeddings.py +++ b/mcp-server/src/legal_mcp/services/embeddings.py @@ -3,15 +3,24 @@ from __future__ import annotations import logging +from typing import TYPE_CHECKING import voyageai from legal_mcp import config +if TYPE_CHECKING: + from PIL import Image as PILImage + logger = logging.getLogger(__name__) _client: voyageai.Client | None = None +# Per-call cap for multimodal_embed. POC ran 89 pages (~312K tokens) +# in a single call comfortably; 50 leaves safe headroom for densely- +# OCR'd legal pages where tokens/page can exceed 4K. +_MULTIMODAL_BATCH_SIZE = 50 + def _get_client() -> voyageai.Client: global _client @@ -55,6 +64,45 @@ async def embed_query(query: str) -> list[float]: return results[0] +async def embed_images( + images: "list[PILImage.Image]", + input_type: str = "document", +) -> list[list[float]]: + """Embed page images via voyage-multimodal-3. + + Each input is a single PIL.Image (one page = one embedding). + Returns a list of 1024-dim vectors, one per input image, in order. + Batches at ``_MULTIMODAL_BATCH_SIZE`` to stay within Voyage's + per-request limits on dense legal pages. + """ + if not images: + return [] + client = _get_client() + out: list[list[float]] = [] + for i in range(0, len(images), _MULTIMODAL_BATCH_SIZE): + batch = images[i : i + _MULTIMODAL_BATCH_SIZE] + result = client.multimodal_embed( + inputs=[[img] for img in batch], + model=config.MULTIMODAL_MODEL, + input_type=input_type, + truncation=True, + ) + out.extend(result.embeddings) + return out + + +async def embed_query_for_multimodal(query: str) -> list[float]: + """Embed a text query in the multimodal vector space, so it can be + cosine-compared against page-image embeddings.""" + client = _get_client() + result = client.multimodal_embed( + inputs=[[query]], + model=config.MULTIMODAL_MODEL, + input_type="query", + ) + return result.embeddings[0] + + async def voyage_rerank( query: str, documents: list[str], top_k: int | None = None, ) -> list[tuple[int, float]]: diff --git a/mcp-server/src/legal_mcp/services/extractor.py b/mcp-server/src/legal_mcp/services/extractor.py index 42691ca..9d08100 100644 --- a/mcp-server/src/legal_mcp/services/extractor.py +++ b/mcp-server/src/legal_mcp/services/extractor.py @@ -9,6 +9,7 @@ Post-processing: Hebrew abbreviation quote fixer. from __future__ import annotations import asyncio +import io import logging import re import subprocess @@ -16,6 +17,7 @@ import tempfile from pathlib import Path import fitz # PyMuPDF +from PIL import Image from docx import Document as DocxDocument from google.cloud import vision from striprtf.striprtf import rtf_to_text @@ -220,6 +222,65 @@ def _extract_rtf(path: Path) -> str: return rtf_to_text(rtf_content) +# ── Multimodal page rendering (V9) ─────────────────────────────── + + +def _pixmap_to_pil(pix: fitz.Pixmap) -> Image.Image: + """Convert a PyMuPDF pixmap to PIL.Image (RGB) without going through + PNG bytes. Faster than tobytes('png') → Image.open().""" + if pix.alpha: + # Drop alpha channel — voyage multimodal expects RGB. + pix = fitz.Pixmap(pix, 0) + return Image.frombytes("RGB", (pix.width, pix.height), pix.samples) + + +def render_pages_for_multimodal( + pdf_path: str | Path, + embed_dpi: int, + thumb_dpi: int | None = None, + thumbnail_dir: Path | None = None, +) -> list[tuple[Image.Image, Path | None]]: + """Render each PDF page as PIL.Image at ``embed_dpi`` for the + multimodal embedder, and optionally save a smaller JPEG thumbnail + at ``thumb_dpi`` to ``thumbnail_dir`` for UI preview. + + Returns ``[(pil_image, thumb_path_or_None), ...]`` in page order. + The full-DPI image stays in memory only — only the thumbnail is + persisted to disk. + """ + src = Path(pdf_path) + if not src.is_file(): + raise FileNotFoundError(f"PDF not found: {src}") + if thumbnail_dir is not None: + thumbnail_dir.mkdir(parents=True, exist_ok=True) + + out: list[tuple[Image.Image, Path | None]] = [] + doc = fitz.open(str(src)) + try: + for page_idx, page in enumerate(doc): + page_num = page_idx + 1 + pix = page.get_pixmap(dpi=embed_dpi) + img = _pixmap_to_pil(pix) + + thumb_path: Path | None = None + if thumbnail_dir is not None and thumb_dpi: + thumb_path = thumbnail_dir / f"p{page_num:03d}.jpg" + # Downsample the same render rather than re-rendering + # with PyMuPDF — far faster. + ratio = thumb_dpi / embed_dpi + thumb_size = ( + max(1, int(img.width * ratio)), + max(1, int(img.height * ratio)), + ) + thumb = img.resize(thumb_size, Image.Resampling.LANCZOS) + thumb.save(thumb_path, "JPEG", quality=75, optimize=True) + + out.append((img, thumb_path)) + finally: + doc.close() + return out + + # ── Nevo preamble stripping ────────────────────────────────────── _NEVO_MARKERS = ("ספרות:", "חקיקה שאוזכרה:", "מיני-רציו:", "פסקי דין שאוזכרו:", diff --git a/mcp-server/src/legal_mcp/services/hybrid_search.py b/mcp-server/src/legal_mcp/services/hybrid_search.py new file mode 100644 index 0000000..30f4df4 --- /dev/null +++ b/mcp-server/src/legal_mcp/services/hybrid_search.py @@ -0,0 +1,202 @@ +"""Hybrid (text + image) search wrappers. + +Layered on top of ``rerank.maybe_rerank``. When ``MULTIMODAL_ENABLED`` is +true the result comes from a weighted merge of: + + • text side: cosine on chunks → optional rerank-2 cross-encoder + • image side: cosine on per-page voyage-multimodal-3 embeddings + +rerank-2 is a *text* cross-encoder, so image-side rows are NOT passed +through it; they keep their cosine score and merge alongside the +(possibly reranked) text rows. Image-only pages with no overlapping +text chunk are surfaced as ``match_type='image'`` so scanned-only or +visual-heavy content still appears in results. + +When ``MULTIMODAL_ENABLED`` is false this module degenerates to plain +``rerank.maybe_rerank`` — callers can wrap unconditionally and let env +control behaviour. +""" +from __future__ import annotations + +import logging +from typing import Any +from uuid import UUID + +from legal_mcp import config +from legal_mcp.services import db, embeddings, rerank + +logger = logging.getLogger(__name__) + + +async def search_documents_hybrid( + query: str, + query_text_embedding: list[float], + *, + limit: int, + case_id: UUID | None = None, + section_type: str | None = None, + practice_area: str | None = None, + appeal_subtype: str | None = None, +) -> list[dict]: + """Hybrid wrapper for document-chunk search (search_decisions / + search_case_documents / find_similar_cases).""" + fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit + text_results = await rerank.maybe_rerank( + query=query, + base_search=lambda **kw: db.search_similar( + query_embedding=query_text_embedding, **kw, + ), + limit=fetch_k, + case_id=case_id, + section_type=section_type, + practice_area=practice_area, + appeal_subtype=appeal_subtype, + ) + if not config.MULTIMODAL_ENABLED: + return text_results[:limit] + + try: + query_img_emb = await embeddings.embed_query_for_multimodal(query) + img_rows = await db.search_document_images_similar( + query_img_emb, + limit=fetch_k, + case_id=case_id, + practice_area=practice_area, + appeal_subtype=appeal_subtype, + ) + except Exception as e: + logger.warning("Hybrid: image side failed, returning text only: %s", e) + return text_results[:limit] + + merged = _merge( + text_results, img_rows, + id_field="document_id", + text_weight=config.MULTIMODAL_TEXT_WEIGHT, + ) + return merged[:limit] + + +async def search_precedent_library_hybrid( + query: str, + query_text_embedding: list[float], + *, + limit: int, + practice_area: str = "", + court: str = "", + precedent_level: str = "", + appeal_subtype: str = "", + is_binding: bool | None = None, + subject_tag: str = "", + include_halachot: bool = True, +) -> list[dict]: + """Hybrid wrapper for precedent-library search.""" + fetch_k = max(limit, config.VOYAGE_RERANK_FETCH_K) if config.MULTIMODAL_ENABLED else limit + + async def _base(limit_inner: int) -> list[dict]: + return await db.search_precedent_library_semantic( + query_embedding=query_text_embedding, + practice_area=practice_area, + court=court, + precedent_level=precedent_level, + appeal_subtype=appeal_subtype, + is_binding=is_binding, + subject_tag=subject_tag, + limit=limit_inner, + include_halachot=include_halachot, + ) + + text_results = await rerank.maybe_rerank( + query=query, base_search=_base, limit=fetch_k, + ) + if not config.MULTIMODAL_ENABLED: + return text_results[:limit] + + try: + query_img_emb = await embeddings.embed_query_for_multimodal(query) + img_rows = await db.search_precedent_images_similar( + query_img_emb, + limit=fetch_k, + practice_area=practice_area, + court=court, + precedent_level=precedent_level, + appeal_subtype=appeal_subtype, + is_binding=is_binding, + ) + except Exception as e: + logger.warning("Hybrid: image side failed, returning text only: %s", e) + return text_results[:limit] + + merged = _merge( + text_results, img_rows, + id_field="case_law_id", + text_weight=config.MULTIMODAL_TEXT_WEIGHT, + ) + return merged[:limit] + + +def _merge( + text_rows: list[dict], + img_rows: list[dict], + id_field: str, + text_weight: float, +) -> list[dict]: + """Weighted merge of text + image rows. + + Joins on ``(id_field, page_number)``. Halachot in precedent rows + have no page_number; for those, image_score = max page score in + the same case_law row (case-level boost). + + Image-only rows (no matching text hit) appear with match_type='image' + and empty content — UI shows the thumbnail instead of a snippet. + """ + img_weight = 1.0 - text_weight + img_by_key: dict[tuple, dict] = {} + img_max_by_id: dict[str, float] = {} + for r in img_rows: + rid = str(r[id_field]) + page = r.get("page_number") + img_by_key[(rid, page)] = r + score = float(r.get("score", 0.0)) + img_max_by_id[rid] = max(img_max_by_id.get(rid, 0.0), score) + + seen: set = set() + merged: list[dict] = [] + for r in text_rows: + rid = str(r[id_field]) + page = r.get("page_number") + key = (rid, page) if page is not None else None + img_hit = img_by_key.get(key) if key else None + text_score = float(r.get("score", 0.0)) + if img_hit: + image_score = float(img_hit["score"]) + elif r.get("type") == "halacha": + image_score = img_max_by_id.get(rid, 0.0) + else: + image_score = 0.0 + d = dict(r) + d["text_score"] = text_score + d["image_score"] = image_score + d["score"] = text_score * text_weight + image_score * img_weight + d["match_type"] = "text+image" if img_hit else "text" + if img_hit: + d["image_thumbnail_path"] = img_hit.get("image_thumbnail_path") + if key: + seen.add(key) + merged.append(d) + + for r in img_rows: + rid = str(r[id_field]) + key = (rid, r.get("page_number")) + if key in seen: + continue + d = dict(r) + d["text_score"] = 0.0 + d["image_score"] = float(r.get("score", 0.0)) + d["score"] = float(r.get("score", 0.0)) * img_weight + d["match_type"] = "image" + d["content"] = "" + d["section_type"] = "image" + merged.append(d) + + merged.sort(key=lambda x: -float(x["score"])) + return merged diff --git a/mcp-server/src/legal_mcp/services/precedent_library.py b/mcp-server/src/legal_mcp/services/precedent_library.py index 995e997..89e6af3 100644 --- a/mcp-server/src/legal_mcp/services/precedent_library.py +++ b/mcp-server/src/legal_mcp/services/precedent_library.py @@ -13,6 +13,7 @@ SSE plumbing without this module knowing about Redis. from __future__ import annotations +import asyncio import logging import re import shutil @@ -22,7 +23,7 @@ from typing import Awaitable, Callable from uuid import UUID, uuid4 from legal_mcp import config -from legal_mcp.services import chunker, db, embeddings, extractor, rerank +from legal_mcp.services import chunker, db, embeddings, extractor, hybrid_search, rerank # noqa: F401 # Note: halacha_extractor and precedent_metadata_extractor are NOT imported # at module load. They are imported lazily inside the dedicated re-extract @@ -188,6 +189,18 @@ async def ingest_precedent( ] stored_chunks = await db.store_precedent_chunks(case_law_id, chunk_dicts) + # Multimodal page-image embeddings (V9). Gated by feature flag. + # Non-fatal: text path already succeeded. Only PDFs. + if config.MULTIMODAL_ENABLED and page_count > 0 and staged.suffix.lower() == ".pdf": + try: + await progress( + "embedding_images", 70, + f"מטמיע {page_count} עמודי תמונה (multimodal)", + ) + await _embed_precedent_pages(case_law_id, staged, page_count) + except Exception as e: + logger.warning("Precedent multimodal embedding failed (non-fatal): %s", e) + # Pipeline split: the container does the non-LLM half (extract + # chunk + embed + store). LLM-driven extraction (metadata, halachot) # runs separately via the MCP tool `precedent_process_pending` from @@ -413,19 +426,60 @@ async def search_library( return [] query_vec = await embeddings.embed_query(query) - async def _base(limit: int) -> list[dict]: - return await db.search_precedent_library_semantic( - query_embedding=query_vec, - practice_area=practice_area, - court=court, - precedent_level=precedent_level, - appeal_subtype=appeal_subtype, - is_binding=is_binding, - subject_tag=subject_tag, - limit=limit, - include_halachot=include_halachot, - ) - - return await rerank.maybe_rerank( - query=query, base_search=_base, limit=limit, + return await hybrid_search.search_precedent_library_hybrid( + query=query, + query_text_embedding=query_vec, + limit=limit, + practice_area=practice_area, + court=court, + precedent_level=precedent_level, + appeal_subtype=appeal_subtype, + is_binding=is_binding, + subject_tag=subject_tag, + include_halachot=include_halachot, ) + + +async def _embed_precedent_pages( + case_law_id: UUID, + pdf_path: Path, + page_count: int, +) -> dict: + """Render precedent PDF pages → embed via voyage-multimodal → store. + + Thumbnails go to + ``data/precedent-library/thumbnails/{case_law_id}/p{N:03d}.jpg``. + """ + thumb_dir = PRECEDENT_LIBRARY_DIR / "thumbnails" / str(case_law_id) + rendered = await asyncio.to_thread( + extractor.render_pages_for_multimodal, + pdf_path, + config.MULTIMODAL_DPI, + config.MULTIMODAL_THUMB_DPI, + thumb_dir, + ) + images = [pil for pil, _ in rendered] + thumbs = [t for _, t in rendered] + img_embs = await embeddings.embed_images(images) + + page_records = [] + for i, (emb, thumb) in enumerate(zip(img_embs, thumbs)): + rel_thumb = None + if thumb is not None: + try: + rel_thumb = str(thumb.relative_to(config.DATA_DIR)) + except ValueError: + rel_thumb = str(thumb) + page_records.append({ + "page_number": i + 1, + "embedding": emb, + "image_thumbnail_path": rel_thumb, + }) + stored = await db.store_precedent_image_embeddings( + case_law_id, page_records, model_name=config.MULTIMODAL_MODEL, + ) + logger.info( + "Multimodal: stored %d page-image embeddings for case_law %s", + stored, case_law_id, + ) + return {"pages_embedded": stored} diff --git a/mcp-server/src/legal_mcp/services/processor.py b/mcp-server/src/legal_mcp/services/processor.py index 3bd2b3a..86228d9 100644 --- a/mcp-server/src/legal_mcp/services/processor.py +++ b/mcp-server/src/legal_mcp/services/processor.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import logging from pathlib import Path from uuid import UUID +from legal_mcp import config from legal_mcp.services import chunker, db, embeddings, extractor, references_extractor logger = logging.getLogger(__name__) @@ -95,6 +97,21 @@ async def process_document(document_id: UUID, case_id: UUID) -> dict: stored = await db.store_chunks(document_id, case_id, chunk_dicts) + # Step 4.5: Multimodal page-image embeddings (V9). Gated by + # MULTIMODAL_ENABLED. Renders each PDF page → embeds via + # voyage-multimodal-3 → stores per-page row with thumbnail. + # Non-fatal on failure (text path already succeeded). + multimodal_result = {"pages_embedded": 0} + if config.MULTIMODAL_ENABLED and page_count > 0: + try: + pdf_path = Path(doc["file_path"]) + if pdf_path.suffix.lower() == ".pdf": + multimodal_result = await _embed_document_pages( + document_id, case_id, pdf_path, page_count, + ) + except Exception as e: + logger.warning("Multimodal embedding failed (non-fatal): %s", e) + # Step 5: Extract references (plans, case law, legislation) — non-fatal refs_result = {"plans": 0, "case_law": 0, "case_law_linked": 0, "legislation": 0} try: @@ -124,9 +141,63 @@ async def process_document(document_id: UUID, case_id: UUID) -> dict: "case_law": refs_result["case_law"], "legislation": refs_result["legislation"], }, + "multimodal": multimodal_result, } except Exception as e: logger.exception("Document processing failed: %s", e) await db.update_document(document_id, extraction_status="failed") return {"status": "failed", "error": str(e)} + + +async def _embed_document_pages( + document_id: UUID, + case_id: UUID, + pdf_path: Path, + page_count: int, +) -> dict: + """Render PDF pages → embed via voyage-multimodal → store per-page rows. + + Thumbnails are saved under + ``data/cases/{case_number}/thumbnails/{document_id}/p{N:03d}.jpg`` + so the UI can show small previews next to image-side search hits. + """ + # Layout: data/cases/{case_number}/documents/originals/{file}.pdf + # → case_dir = pdf_path.parent.parent.parent + case_dir = pdf_path.parent.parent.parent + thumb_dir = case_dir / "thumbnails" / str(document_id) + + logger.info("Multimodal: rendering %d pages @ %ddpi", page_count, config.MULTIMODAL_DPI) + rendered = await asyncio.to_thread( + extractor.render_pages_for_multimodal, + pdf_path, + config.MULTIMODAL_DPI, + config.MULTIMODAL_THUMB_DPI, + thumb_dir, + ) + images = [pil for pil, _ in rendered] + thumb_paths = [thumb for _, thumb in rendered] + + logger.info("Multimodal: embedding %d pages via %s", len(images), config.MULTIMODAL_MODEL) + img_embs = await embeddings.embed_images(images) + + page_records = [] + for i, (emb, thumb) in enumerate(zip(img_embs, thumb_paths)): + rel_thumb = None + if thumb is not None: + try: + rel_thumb = str(thumb.relative_to(config.DATA_DIR)) + except ValueError: + rel_thumb = str(thumb) + page_records.append({ + "page_number": i + 1, + "embedding": emb, + "image_thumbnail_path": rel_thumb, + }) + + stored = await db.store_document_image_embeddings( + document_id, case_id, page_records, + model_name=config.MULTIMODAL_MODEL, + ) + logger.info("Multimodal: stored %d page-image embeddings", stored) + return {"pages_embedded": stored, "model": config.MULTIMODAL_MODEL} diff --git a/mcp-server/src/legal_mcp/tools/search.py b/mcp-server/src/legal_mcp/tools/search.py index c7b1bed..0805884 100644 --- a/mcp-server/src/legal_mcp/tools/search.py +++ b/mcp-server/src/legal_mcp/tools/search.py @@ -6,7 +6,7 @@ import json import logging from uuid import UUID -from legal_mcp.services import db, embeddings, rerank +from legal_mcp.services import db, embeddings, hybrid_search logger = logging.getLogger(__name__) @@ -43,9 +43,9 @@ async def search_decisions( ) query_emb = await embeddings.embed_query(query) - results = await rerank.maybe_rerank( + results = await hybrid_search.search_documents_hybrid( query=query, - base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), + query_text_embedding=query_emb, limit=limit, section_type=section_type or None, practice_area=practice_area or None, @@ -59,11 +59,13 @@ async def search_decisions( for r in results: formatted.append({ "score": round(float(r["score"]), 4), - "case_number": r["case_number"], - "document": r["document_title"], - "section": r["section_type"], - "page": r["page_number"], - "content": r["content"], + "case_number": r.get("case_number"), + "document": r.get("document_title"), + "section": r.get("section_type"), + "page": r.get("page_number"), + "content": r.get("content", ""), + "match_type": r.get("match_type", "text"), + "image_thumbnail": r.get("image_thumbnail_path"), }) return json.dumps(formatted, ensure_ascii=False, indent=2) @@ -87,9 +89,9 @@ async def search_case_documents( query_emb = await embeddings.embed_query(query) # Restricted to case_id — practice_area filter would be redundant. - results = await rerank.maybe_rerank( + results = await hybrid_search.search_documents_hybrid( query=query, - base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), + query_text_embedding=query_emb, limit=limit, case_id=UUID(case["id"]), ) @@ -101,10 +103,12 @@ async def search_case_documents( for r in results: formatted.append({ "score": round(float(r["score"]), 4), - "document": r["document_title"], - "section": r["section_type"], - "page": r["page_number"], - "content": r["content"], + "document": r.get("document_title"), + "section": r.get("section_type"), + "page": r.get("page_number"), + "content": r.get("content", ""), + "match_type": r.get("match_type", "text"), + "image_thumbnail": r.get("image_thumbnail_path"), }) return json.dumps(formatted, ensure_ascii=False, indent=2) @@ -139,12 +143,11 @@ async def find_similar_cases( ) query_emb = await embeddings.embed_query(description) - # Use description as the query text for rerank too. - # Note: even with rerank we ask for ``limit*3`` so the dedup-by-case + # Even with rerank we ask for ``limit*3`` so the dedup-by-case # step downstream still has enough rows to pick the best per case. - results = await rerank.maybe_rerank( + results = await hybrid_search.search_documents_hybrid( query=description, - base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), + query_text_embedding=query_emb, limit=limit * 3, practice_area=practice_area or None, appeal_subtype=appeal_subtype or None, @@ -153,14 +156,16 @@ async def find_similar_cases( if not results: return "לא נמצאו תיקים דומים." - # Deduplicate by case_number, keep best score per case + # Deduplicate by case_number, keep best score per case. + # image-only rows still carry case_number from the join. seen_cases = {} for r in results: - cn = r["case_number"] + cn = r.get("case_number") + if not cn: + continue if cn not in seen_cases or r["score"] > seen_cases[cn]["score"]: seen_cases[cn] = r - # Sort by score and limit top_cases = sorted(seen_cases.values(), key=lambda x: x["score"], reverse=True)[:limit] formatted = [] @@ -168,8 +173,9 @@ async def find_similar_cases( formatted.append({ "score": round(float(r["score"]), 4), "case_number": r["case_number"], - "document": r["document_title"], - "relevant_section": r["content"][:500], + "document": r.get("document_title"), + "relevant_section": (r.get("content") or "")[:500], + "match_type": r.get("match_type", "text"), }) return json.dumps(formatted, ensure_ascii=False, indent=2) diff --git a/scripts/SCRIPTS.md b/scripts/SCRIPTS.md index 47a0e7e..6297a81 100644 --- a/scripts/SCRIPTS.md +++ b/scripts/SCRIPTS.md @@ -22,6 +22,7 @@ | `voyage_multimodal_poc.py` | python | POC #3 — voyage-multimodal-3 על דוח שמאי (89 עמודים). הכרעה: שיפור משמעותי לטבלאות + 22 עמודי image-only שhttp text-OCR מאבד | בנצ'מרק חד-פעמי, מוכן לשלב C | | `voyage_rerank_judge_poc.py` | python | POC #4 — voyage-3 vs rerank-2 vs context-3 על אהרון ברק, 18 שאילתות, claude-haiku-4-5 כ-judge. הכרעה: rerank-2 ניצח עם +9% mean@3 | בנצ'מרק חד-פעמי | | `voyage_rerank_corpus_poc.py` | python | POC #5 — voyage-3 vs rerank-2 על קורפוס מלא (785 docs). הכרעה: +4.5% mean@3 כללי, +11.6% על P queries (practical) | בנצ'מרק חד-פעמי, אישר את שלב B | +| `multimodal_backfill.py` | python | Backfill voyage-multimodal-3 page embeddings על מסמכי תיקים קיימים. idempotent (skips by default), forces `MULTIMODAL_ENABLED=true` ל-run, רץ מהקונטיינר. שלב C — ראה `docs/voyage-upgrades-plan.md` | ידני per-case (`python multimodal_backfill.py 8174-24 8137-24`) | ## תיקיית `.archive/` — סקריפטים שהושלמו diff --git a/scripts/multimodal_backfill.py b/scripts/multimodal_backfill.py new file mode 100644 index 0000000..86d89fc --- /dev/null +++ b/scripts/multimodal_backfill.py @@ -0,0 +1,186 @@ +"""Multimodal backfill — embed page images for existing case documents. + +Iterates over documents already in the DB and renders + embeds + stores +per-page voyage-multimodal-3 vectors. Skips documents that already have +image embeddings (idempotent). + +Independent of the processor pipeline — does NOT re-extract text or +re-chunk; only the multimodal step. + +Designed to run from inside the FastAPI/MCP container (where /data is +mounted and writable). Locally it requires sudo for the thumbnails dir +under /home/chaim/legal-ai/data/cases/... + +Usage:: + + # In container (Coolify): + docker exec -it python -m legal_mcp.cli \\ + multimodal_backfill --cases 8174-24 8137-24 + + # Or as a script (sets MULTIMODAL_ENABLED=true automatically): + /opt/api/mcp-server/.venv/bin/python /opt/api/scripts/multimodal_backfill.py 8174-24 8137-24 +""" +from __future__ import annotations + +import argparse +import asyncio +import logging +import os +import sys +import time +from pathlib import Path +from uuid import UUID + + +def _setup_paths(): + """Ensure mcp-server src is on path even when run as a standalone script.""" + here = Path(__file__).resolve().parent + mcp_src = here.parent / "mcp-server" / "src" + if mcp_src.is_dir() and str(mcp_src) not in sys.path: + sys.path.insert(0, str(mcp_src)) + + +_setup_paths() +# Force the flag on for this run regardless of env — backfill is the +# whole point of running this script. The deploy-time default stays off. +os.environ["MULTIMODAL_ENABLED"] = "true" + +from legal_mcp import config # noqa: E402 +from legal_mcp.services import db, embeddings, extractor, processor # noqa: E402 + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) +logger = logging.getLogger("multimodal_backfill") + + +def _resolve_local_path(db_path: str) -> Path: + """Map container path /data/... to host /home/chaim/legal-ai/data/... + when running locally; pass-through when already absolute and present.""" + p = Path(db_path) + if p.is_file(): + return p + if str(p).startswith("/data/"): + local = Path("/home/chaim/legal-ai") / Path(*p.parts[1:]) + if local.is_file(): + return local + return p + + +async def _backfill_document( + document_id: UUID, + case_id: UUID, + title: str, + db_file_path: str, + skip_if_exists: bool, +) -> dict: + pool = await db.get_pool() + if skip_if_exists: + existing = await pool.fetchval( + "SELECT count(*) FROM document_image_embeddings WHERE document_id = $1", + document_id, + ) + if existing and existing > 0: + logger.info(" skip (%d rows already): %s", existing, title) + return {"status": "skipped", "rows": int(existing)} + + pdf_path = _resolve_local_path(db_file_path) + if not pdf_path.is_file(): + logger.warning(" file missing: %s (%s)", pdf_path, title) + return {"status": "missing"} + if pdf_path.suffix.lower() != ".pdf": + logger.info(" not a PDF, skipping: %s", title) + return {"status": "not_pdf"} + + page_count = await pool.fetchval( + "SELECT page_count FROM documents WHERE id = $1", document_id, + ) + if not page_count: + # Open to count + import fitz + d = fitz.open(str(pdf_path)) + page_count = len(d) + d.close() + + logger.info(" embedding %s (%d pages)", title, page_count) + t0 = time.time() + result = await processor._embed_document_pages( + document_id, case_id, pdf_path, page_count, + ) + elapsed = time.time() - t0 + logger.info(" done in %.1fs: %s", elapsed, result) + return {"status": "ok", "elapsed_sec": round(elapsed, 1), **result} + + +async def backfill_cases(case_numbers: list[str], skip_if_exists: bool = True) -> dict: + """Embed page images for every PDF document in the given cases.""" + await db.init_schema() # in case schema V9 hasn't been applied + pool = await db.get_pool() + summary: dict = {} + for cn in case_numbers: + logger.info("=" * 60) + logger.info("Case %s", cn) + case = await db.get_case_by_number(cn) + if not case: + logger.warning("Case not found: %s", cn) + summary[cn] = {"status": "case_not_found"} + continue + case_id = UUID(str(case["id"])) + docs = await pool.fetch( + "SELECT id, title, file_path FROM documents WHERE case_id = $1 ORDER BY title", + case_id, + ) + logger.info(" %d documents", len(docs)) + per_doc: list[dict] = [] + for d in docs: + doc_id = UUID(str(d["id"])) + title = d["title"] + r = await _backfill_document( + doc_id, case_id, title, d["file_path"], skip_if_exists, + ) + per_doc.append({"document_id": str(doc_id), "title": title, **r}) + summary[cn] = { + "documents_total": len(docs), + "embedded": sum(1 for r in per_doc if r["status"] == "ok"), + "skipped": sum(1 for r in per_doc if r["status"] == "skipped"), + "missing": sum(1 for r in per_doc if r["status"] == "missing"), + "not_pdf": sum(1 for r in per_doc if r["status"] == "not_pdf"), + "documents": per_doc, + } + return summary + + +def main(): + parser = argparse.ArgumentParser(description="Multimodal backfill for case documents") + parser.add_argument( + "cases", nargs="+", help="Case numbers to backfill (e.g. 8174-24 8137-24)" + ) + parser.add_argument( + "--re-embed", action="store_true", + help="Re-embed even if image embeddings already exist (default: skip)", + ) + args = parser.parse_args() + + logger.info("MULTIMODAL_MODEL=%s DPI=%d THUMB_DPI=%d", + config.MULTIMODAL_MODEL, config.MULTIMODAL_DPI, config.MULTIMODAL_THUMB_DPI) + summary = asyncio.run( + backfill_cases(args.cases, skip_if_exists=not args.re_embed) + ) + print() + print("=" * 60) + print("SUMMARY") + print("=" * 60) + for cn, s in summary.items(): + if s.get("status") == "case_not_found": + print(f" {cn}: NOT FOUND") + continue + print( + f" {cn}: {s['documents_total']} docs — " + f"embedded {s['embedded']}, skipped {s['skipped']}, " + f"missing {s['missing']}, non-pdf {s['not_pdf']}" + ) + + +if __name__ == "__main__": + main()