From 26c3fddf41ae851d67d4658c727f0f3cd5fc70ff Mon Sep 17 00:00:00 2001 From: Chaim Date: Sun, 3 May 2026 18:43:41 +0000 Subject: [PATCH] feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which 4 POCs showed inconsistent improvement), add a cross-encoder rerank layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false). POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge): - mean@3 +4.5% (4.306 → 4.500) - practical-category queries +11.6% (3.78 → 4.22) - latency +702ms per query - no schema change, no re-embed, no double storage Plumbing: - config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars - embeddings.voyage_rerank() wraps voyageai client.rerank - services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is unavailable. - tools/search.py: search_decisions, search_case_documents, find_similar_cases all wrapped - services/precedent_library.search_library wrapped Smoke-tested locally with flag on/off — produces expected behaviour and latency profile. Ready for production rollout via Coolify env flip after deploy. POCs (kept under scripts/ for reference): - voyage_context3_poc{_long}.py — context-3 evaluation (rejected) - voyage_multimodal_poc.py — multimodal-3 (stage C, deferred) - voyage_rerank_judge_poc.py — single-case rerank benchmark - voyage_rerank_corpus_poc.py — full-corpus rerank validation Co-Authored-By: Claude Opus 4.7 (1M context) --- docs/architecture.md | 10 +- docs/voyage-upgrades-plan.md | 159 ++++---- mcp-server/src/legal_mcp/config.py | 11 + .../src/legal_mcp/services/embeddings.py | 23 ++ .../legal_mcp/services/precedent_library.py | 33 +- mcp-server/src/legal_mcp/services/rerank.py | 103 +++++ mcp-server/src/legal_mcp/tools/search.py | 22 +- scripts/SCRIPTS.md | 5 + scripts/voyage_context3_poc.py | 182 +++++++++ scripts/voyage_context3_poc_long.py | 238 ++++++++++++ scripts/voyage_multimodal_poc.py | 213 +++++++++++ scripts/voyage_rerank_corpus_poc.py | 318 +++++++++++++++ scripts/voyage_rerank_judge_poc.py | 361 ++++++++++++++++++ 13 files changed, 1578 insertions(+), 100 deletions(-) create mode 100644 mcp-server/src/legal_mcp/services/rerank.py create mode 100644 scripts/voyage_context3_poc.py create mode 100644 scripts/voyage_context3_poc_long.py create mode 100644 scripts/voyage_multimodal_poc.py create mode 100644 scripts/voyage_rerank_corpus_poc.py create mode 100644 scripts/voyage_rerank_judge_poc.py diff --git a/docs/architecture.md b/docs/architecture.md index d74bb8a..48a56bf 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -40,7 +40,7 @@ Local (developer machine, pm2): External: ← Claude API (Opus 4.7 for agents) - ← Voyage AI (voyage-3-large, 1024-dim embeddings) + ← Voyage AI (voyage-3, 1024-dim embeddings) ← Infisical (secret management) ← Gmail SMTP (agent notifications) ``` @@ -59,7 +59,7 @@ External: - מפעיל OCR (Google Vision) אם PDF ללא טקסט - מריץ proofreader להסרת artifacts מ-Nevo - מחלץ טקסט ל-`documents.extracted_text` - - מפצל ל-chunks של ~500 מילים, מחשב embeddings (voyage-3-large, 1024D), שומר ב-`document_chunks` + - מפצל ל-chunks של ~500 מילים, מחשב embeddings (voyage-3, 1024D), שומר ב-`document_chunks` 4. סטטוס תיק: `new` → `proofread` ### שלב 2 — ניתוח משפטי (legal-researcher + analyst) @@ -223,7 +223,7 @@ legal-qa מריץ 6 בדיקות איכות: `case_law`, `statutory_provisions`, `transition_phrases`, `lessons_learned`, `style_corpus`, `style_patterns` ### Layer 4: Semantic Search (RAG) -`document_embeddings`, `paragraph_embeddings`, `case_law_embeddings` (pgvector 1024-dim, voyage-3-large) +`document_embeddings`, `paragraph_embeddings`, `case_law_embeddings` (pgvector 1024-dim, voyage-3) ### Layer 5 — Multi-tenancy `companies`, `tag_company_mappings` (appeal_subtype → company_id) @@ -283,7 +283,9 @@ legal-qa מריץ 6 בדיקות איכות: ## טכנולוגיות עיקריות - **Database**: PostgreSQL 15 + pgvector 0.8.1 -- **Embeddings**: Voyage AI (`voyage-3-large`, 1024-dim) +- **Embeddings**: Voyage AI (`voyage-3`, 1024-dim) + cross-encoder rerank (`rerank-2`) + - bi-encoder: voyage-3 לכל chunk (חד-פעמי בעת ingestion) + - cross-encoder: rerank-2 לכל query (top-50 → top-K), feature flag `VOYAGE_RERANK_ENABLED` - **Agents**: Claude Opus 4.7 (via Paperclip pm2) - **DOCX manipulation**: `python-docx` 1.2+ ו-`lxml` 5.2+ (XML surgery) - **Frontend**: Next.js + TanStack Query + Tailwind diff --git a/docs/voyage-upgrades-plan.md b/docs/voyage-upgrades-plan.md index 2119569..042e447 100644 --- a/docs/voyage-upgrades-plan.md +++ b/docs/voyage-upgrades-plan.md @@ -52,109 +52,114 @@ voyage-3 **מנצח כפול** — דירוג מושלם + מרווחים גדו --- -## שלב B — voyage-context-3 (לביצוע בשיחה החדשה) +## שלב B — voyage-rerank-2 (Cross-encoder reranking) -### הבעיה שהוא פותר +> **שינוי מהותי מהתכנית המקורית.** המקור היה ל-context-3. POC רחב +> (4 בנצ'מרקים) הראה ש-context-3 לא משפר עקבית, ובחלק מהמקרים מציג +> רגרסיה. במקום זאת, **rerank-2** (cross-encoder) הצליח לתת שיפור של +> +4.5% mean@3 על קורפוס מלא של 785 docs, **+11.6% על שאילתות +> מעשיות** (P-category — בדיוק התרחיש של legal-writer/legal-researcher), +> בלי שינוי schema, בלי re-embed, ובלי double storage. -Embeddings רגילים מטמיעים **chunk בנפרד** מהקשר המסמך. פסקה שאומרת -"כפי שקבענו לעיל, הפטור אינו חל" — לא יודעת על "פטור ממה" / "לעיל -איפה". פסיקה משפטית מלאה בהפניות הקשר תלויות (ראה סעיף 7 לעיל; להבדיל -מהמקרה ב-וע 1126/25; וכו') — והן אבודות לחלוטין. +### למה rerank-2 ולא context-3? -### מה voyage-context-3 עושה אחרת +POC #4 (אהרון ברק, 18 שאילתות, claude-haiku-4-5 כ-judge): -API שונה: `client.contextualized_embed(inputs=[[full_doc, chunk_1], ...])`. -כל chunk מוטמע **עם המסמך כולו כקונטקסט**. ה-embedding "יודע" שזו פסקה -14 מתוך פסק דין על תמ"א 38 — והקשרים פנימיים נשמרים. +| Retriever | mean@3 | mean@5 | MRR | +|---|---|---|---| +| voyage-3 (baseline) | 3.278 | 3.300 | 0.741 | +| **voyage-3 + rerank-2** | **3.574** | **3.467** | **0.769** | +| voyage-context-3 (windowed) | 3.481 | 3.378 | 0.685 | -Anthropic פרסמו מדידה: **שיפור 49% בדיוק חיפוש** למסמכים משפטיים -ארוכים. +POC #5 (קורפוס מלא 785 docs, 12 שאילתות): + +| Retriever | mean@3 | קטגוריה P (practical) | +|---|---|---| +| voyage-3 | 4.306 | 3.78 | +| **voyage-3 + rerank-2** | **4.500 (+4.5%)** | **4.22 (+11.6%)** | + +context-3 גם נכשל בקטגוריות keyword שהן 60%+ מהשאילתות בפועל אצל דפנה. + +### איך rerank-2 עובד + +Two-stage retrieval: +1. **שלב bi-encoder (כמו היום)**: voyage-3 מטמיע את ה-query, מחזיר + top-50 chunks דרך cosine similarity על `pgvector` (מהיר, ~390ms). +2. **שלב cross-encoder (חדש)**: rerank-2 מקבל `(query, document)` עבור + כל אחד מ-50 הdocuments, ומחזיר ציון רלוונטיות מדויק יותר. + הreranker רואה את ה-query ואת ה-doc ביחד דרך attention מלא, + לעומת bi-encoder שרק מחשב cosine בין שני embeddings בלתי-תלויים. +3. החזרה: top-K (10) המדורגים מחדש. + +**עלות**: +702ms latency (bi-encoder=393ms → +rerank=1095ms). +**עלות tokens**: zero לאחסון (רק חישוב per-query). ### תכנית יישום -#### B.1 — Refactor של pipeline ה-ingestion +#### B.1 — `voyage_rerank()` ב-`embeddings.py` -קוד נוכחי (`embeddings.py`): ```python -embs = await embed_texts(chunk_texts, input_type="document") +async def voyage_rerank( + query: str, documents: list[str], top_k: int = 10, +) -> list[tuple[int, float]]: + """Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...].""" + if not documents: + return [] + client = _get_client() + result = client.rerank( + query=query, documents=documents, + model=config.VOYAGE_RERANK_MODEL, # "rerank-2" + top_k=top_k, + ) + return [(r.index, r.relevance_score) for r in result.results] ``` -קוד חדש: +#### B.2 — Feature flag ב-`config.py` + ```python -embs = await embed_texts_with_context( - document_full_text=text, - chunks=chunk_texts, - input_type="document", +VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2") +VOYAGE_RERANK_ENABLED = ( + os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true" ) +VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50")) ``` -מקומות שצריכים שינוי: -- `mcp-server/.../services/embeddings.py` — פונקציה חדשה `embed_with_context` - שעוטפת `client.contextualized_embed` -- `mcp-server/.../services/processor.py` — `process_document()` מעביר - את `text` המלא + chunks -- `mcp-server/.../services/precedent_library.py` — `ingest_precedent` - מעביר `text` + chunks -- `mcp-server/.../services/halacha_extractor.py` — לכל הלכה, מעביר - את הפסק המלא כקונטקסט (`case_law.full_text`) + `rule_statement` - שמוטמע +הdefault הוא `false` — הקוד יישמר אך לא יורץ עד שיופעל ידנית. -#### B.2 — Query embedding נפרד +#### B.3 — אינטגרציה ב-3 search functions -Queries מטמיעים בלי קונטקסט (`client.embed()` רגיל עם -`model="voyage-context-3"` ו-`input_type="query"`). חשוב: queries -ו-documents חייבים להיות באותו model space. +ב-`db.py`: +- `search_similar` (document_chunks) — נוסיף פרמטר `rerank: bool = False`. + אם True: שולפים top-`VOYAGE_RERANK_FETCH_K` במקום `limit`, + מעבירים דרך rerank, מחזירים top-`limit`. +- `search_precedent_library_semantic` — אותו דבר. הuance: היום יש + boost של +0.05 ל-halachot. כש-rerank פעיל, ה-boost מתבטל ו-rerank + מוחל על המאוחד (chunks + halachot ביחד) — cross-encoder יבחר נכון + בלי boost מלאכותי. +- `search_similar_paragraphs` / `search_similar_case_law` (ב-style + corpus) — אותו דבר. -ב-`embeddings.py:embed_query()` — מחליפים model אבל לא ה-API. - -#### B.3 — Re-embed של הקורפוס הקיים - -```python -# Pseudo-code -for table in [document_chunks, precedent_chunks, halachot, ...]: - rows = SELECT id, content, parent_doc_id FROM table - for row in rows: - full_doc = SELECT full_text FROM parent_table WHERE id = row.parent_doc_id - embedding = contextualized_embed(full_doc, row.content) - UPDATE table SET embedding = embedding WHERE id = row.id -``` - -הבעיה: כל chunk שולח את **המסמך כולו** כקונטקסט. לכן עלות לטוקן -עולה משמעותית. אומדן: 178K תווים × 50 chunks = 8.9M תווים פר פסיקה, -פי-50 לעומת voyage-3. החישוב לקורפוס הנוכחי (~7K rows): שווה ערך -לכ-700M תווים. בtier החינמי של voyage קיים מגבלה — חשוב לבדוק לפני -הרצה גדולה. - -**Mitigation**: לחלץ summary של 500-1000 תווים מכל מסמך (קלוד עושה -את זה היום ב-`metadata_extractor`) ולהעביר ה-summary במקום הטקסט המלא. -שמירת 95% מהיתרון בעלות 5%. +ב-`tools/search.py` — כל הtools (`search_decisions`, `search_case_documents`, +`find_similar_cases`, `precedent_search_library`) יעבירו +`rerank=config.VOYAGE_RERANK_ENABLED` לקריאות ה-DB. #### B.4 — Schema -אין שינוי. אותו `vector(1024)` column. +אין שינוי. אותם vectors, אותו pgvector. -#### B.5 — Benchmark לפני החלטה סופית +#### B.5 — Rollout -לפני re-embed של 6951 rows: -1. לקחת 10 שאילתות אמיתיות + passages עם תיוג נכון -2. להריץ benchmark voyage-3 vs voyage-context-3 (אותו pipeline כמו - `/tmp/voyage_compare.py`) -3. אם השיפור < 15% → לא שווה את העלות. נשאר ב-voyage-3 -4. אם השיפור ≥ 15% → ל-go ל-context-3 +1. שינוי קוד + push + deploy עם feature flag = `false` +2. אימות ש-baseline ממשיך לעבוד (לא רגרסיה) +3. הפעלה ידנית: `VOYAGE_RERANK_ENABLED=true` ב-Coolify env +4. שאילתות אמיתיות מדפנה / סוכנים — observation +5. אם רגרסיה — kill switch בשניות (`false` בחזרה) +6. אם כל מתעקפם — להגדיר `true` כdefault (in-code) אחרי שבוע יציב -#### B.6 — בדיקת זמן + עלות +#### B.6 — Tier check -לאחר ה-benchmark: -- אם בtier החינמי לא מספיק טוקנים → לבחור: רק documents (לא - re-embed הקיים), רק פסיקה חדשה והלאה -- או: לעבור ל-context-3 רק על קורפוס הפסיקה (4 פסיקות, ~785 chunks - + halachot) — הקרפוס הקריטי ביותר ל-`search_precedent_library` - -### החלטות שנשארו פתוחות (תיקח החלטה בשיחה החדשה) - -- ✋ Re-embed הכל בבת-אחת או רק חדש? -- ✋ context-3 לכל הקורפוסים או רק לפסיקה (הקריטי ביותר)? -- ✋ Document context = full_text או summary של 1K? +Voyage Tier 1: 2M TPM, 2000 RPM ל-rerank-2. עומס שלנו (~עשרות +queries בשעה במקרה רגיל) — מתחת ל-1% מהמכסה. --- diff --git a/mcp-server/src/legal_mcp/config.py b/mcp-server/src/legal_mcp/config.py index 5a82648..611c5ba 100644 --- a/mcp-server/src/legal_mcp/config.py +++ b/mcp-server/src/legal_mcp/config.py @@ -47,6 +47,17 @@ VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", "") VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-law-2") VOYAGE_DIMENSIONS = 1024 +# Rerank — cross-encoder second-stage. Off by default; flip with env to +# enable across all semantic search tools (search_decisions, +# search_case_documents, find_similar_cases, search_precedent_library). +VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2") +VOYAGE_RERANK_ENABLED = ( + os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true" +) +# How many candidates to fetch from bi-encoder before reranking. +# 50 was the depth used in the POC; balances recall vs rerank cost. +VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50")) + # Google Cloud Vision (OCR for scanned PDFs) GOOGLE_CLOUD_VISION_API_KEY = os.environ.get("GOOGLE_CLOUD_VISION_API_KEY", "") diff --git a/mcp-server/src/legal_mcp/services/embeddings.py b/mcp-server/src/legal_mcp/services/embeddings.py index 69d39be..9676d88 100644 --- a/mcp-server/src/legal_mcp/services/embeddings.py +++ b/mcp-server/src/legal_mcp/services/embeddings.py @@ -53,3 +53,26 @@ async def embed_query(query: str) -> list[float]: """Embed a single search query.""" results = await embed_texts([query], input_type="query") return results[0] + + +async def voyage_rerank( + query: str, documents: list[str], top_k: int | None = None, +) -> list[tuple[int, float]]: + """Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...] + sorted by relevance. Each tuple's index refers to the position in the + *input* documents list (not a DB row id) — caller maps it back. + + Used as a second stage after bi-encoder retrieval: fetch top-N + candidates with cosine, then rerank to get top-K with cross-encoder + attention over (query, doc). + """ + if not documents: + return [] + client = _get_client() + result = client.rerank( + query=query, + documents=documents, + model=config.VOYAGE_RERANK_MODEL, + top_k=top_k, + ) + return [(r.index, float(r.relevance_score)) for r in result.results] diff --git a/mcp-server/src/legal_mcp/services/precedent_library.py b/mcp-server/src/legal_mcp/services/precedent_library.py index a401884..995e997 100644 --- a/mcp-server/src/legal_mcp/services/precedent_library.py +++ b/mcp-server/src/legal_mcp/services/precedent_library.py @@ -22,7 +22,7 @@ from typing import Awaitable, Callable from uuid import UUID, uuid4 from legal_mcp import config -from legal_mcp.services import chunker, db, embeddings, extractor +from legal_mcp.services import chunker, db, embeddings, extractor, rerank # Note: halacha_extractor and precedent_metadata_extractor are NOT imported # at module load. They are imported lazily inside the dedicated re-extract @@ -403,18 +403,29 @@ async def search_library( Only ``approved`` / ``published`` halachot are returned, per chair-review policy. Chunks are returned regardless of halacha review status. + + When ``VOYAGE_RERANK_ENABLED`` is set, results are passed through + voyage rerank-2 (cross-encoder). The +0.05 halacha boost from + ``search_precedent_library_semantic`` is preserved before rerank + but the rerank scores ultimately decide the order. """ if not query.strip(): return [] query_vec = await embeddings.embed_query(query) - return await db.search_precedent_library_semantic( - query_embedding=query_vec, - practice_area=practice_area, - court=court, - precedent_level=precedent_level, - appeal_subtype=appeal_subtype, - is_binding=is_binding, - subject_tag=subject_tag, - limit=limit, - include_halachot=include_halachot, + + async def _base(limit: int) -> list[dict]: + return await db.search_precedent_library_semantic( + query_embedding=query_vec, + practice_area=practice_area, + court=court, + precedent_level=precedent_level, + appeal_subtype=appeal_subtype, + is_binding=is_binding, + subject_tag=subject_tag, + limit=limit, + include_halachot=include_halachot, + ) + + return await rerank.maybe_rerank( + query=query, base_search=_base, limit=limit, ) diff --git a/mcp-server/src/legal_mcp/services/rerank.py b/mcp-server/src/legal_mcp/services/rerank.py new file mode 100644 index 0000000..659a69e --- /dev/null +++ b/mcp-server/src/legal_mcp/services/rerank.py @@ -0,0 +1,103 @@ +"""Optional cross-encoder reranking layer for semantic search. + +Wraps a base search function with two-stage retrieval: + 1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine) + 2. pass them to voyage rerank-2, return top-``limit`` + +When the feature flag is off (or ``force_rerank=False``) the helper just +calls the base function with ``limit`` and returns its results unchanged +— so callers can wrap unconditionally and let env control behaviour. + +The helper extracts the rerank text from each row using the first +non-empty field among ``content``, ``rule_statement``, +``reasoning_summary`` (matches the schema used by ``search_similar`` +and ``search_precedent_library_semantic``). + +Decision validated by POC #5 (785-doc precedent corpus, 12 queries): + - mean@3: 4.306 → 4.500 (+4.5%) + - practical-category queries: 3.78 → 4.22 (+11.6%) + - latency: +702ms per query +""" +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from legal_mcp import config +from legal_mcp.services import embeddings + +logger = logging.getLogger(__name__) + +SearchFn = Callable[..., Awaitable[list[dict]]] + + +def _rerank_text(row: dict) -> str: + """First non-empty text field that voyage rerank should see.""" + for key in ("content", "rule_statement", "reasoning_summary", + "supporting_quote"): + v = row.get(key) + if v: + return str(v) + return "" + + +async def maybe_rerank( + query: str, + base_search: SearchFn, + limit: int, + *, + force_rerank: bool | None = None, + fetch_k: int | None = None, + **base_kwargs: Any, +) -> list[dict]: + """Two-stage retrieval helper. + + Args: + query: original query string (needed for the rerank API). + base_search: any async function that takes ``limit=…`` and the + other ``base_kwargs`` and returns ``list[dict]``. + limit: final number of results to return. + force_rerank: override the env flag. ``None`` → use config. + fetch_k: override the bi-encoder fetch depth. + **base_kwargs: forwarded to ``base_search``. + + Returns: + List of dict rows. When rerank is active, each row's ``score`` + is replaced with the rerank-2 relevance score (0..1). + """ + enabled = (config.VOYAGE_RERANK_ENABLED + if force_rerank is None else force_rerank) + if not enabled: + return await base_search(limit=limit, **base_kwargs) + + depth = fetch_k or config.VOYAGE_RERANK_FETCH_K + candidates = await base_search(limit=depth, **base_kwargs) + if not candidates: + return [] + + texts = [_rerank_text(c) for c in candidates] + # Drop candidates with empty rerank text (shouldn't happen but be safe) + keep = [(i, t) for i, t in enumerate(texts) if t] + if not keep: + logger.warning("rerank: all candidates empty, falling back to base") + return candidates[:limit] + keep_idx = [i for i, _ in keep] + keep_texts = [t for _, t in keep] + + try: + ranked = await embeddings.voyage_rerank( + query, keep_texts, top_k=limit, + ) + except Exception as e: + # Fail open — if Voyage rerank is down, return bi-encoder ordering + logger.warning("rerank failed, falling back to base: %s", e) + return candidates[:limit] + + out: list[dict] = [] + for keep_pos, score in ranked: + orig_idx = keep_idx[keep_pos] + row = dict(candidates[orig_idx]) + row["score"] = float(score) + out.append(row) + return out diff --git a/mcp-server/src/legal_mcp/tools/search.py b/mcp-server/src/legal_mcp/tools/search.py index 3c2d3f0..c7b1bed 100644 --- a/mcp-server/src/legal_mcp/tools/search.py +++ b/mcp-server/src/legal_mcp/tools/search.py @@ -6,7 +6,7 @@ import json import logging from uuid import UUID -from legal_mcp.services import db, embeddings +from legal_mcp.services import db, embeddings, rerank logger = logging.getLogger(__name__) @@ -43,8 +43,9 @@ async def search_decisions( ) query_emb = await embeddings.embed_query(query) - results = await db.search_similar( - query_embedding=query_emb, + results = await rerank.maybe_rerank( + query=query, + base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), limit=limit, section_type=section_type or None, practice_area=practice_area or None, @@ -86,8 +87,9 @@ async def search_case_documents( query_emb = await embeddings.embed_query(query) # Restricted to case_id — practice_area filter would be redundant. - results = await db.search_similar( - query_embedding=query_emb, + results = await rerank.maybe_rerank( + query=query, + base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), limit=limit, case_id=UUID(case["id"]), ) @@ -137,9 +139,13 @@ async def find_similar_cases( ) query_emb = await embeddings.embed_query(description) - results = await db.search_similar( - query_embedding=query_emb, - limit=limit * 3, # Get more to deduplicate by case + # Use description as the query text for rerank too. + # Note: even with rerank we ask for ``limit*3`` so the dedup-by-case + # step downstream still has enough rows to pick the best per case. + results = await rerank.maybe_rerank( + query=description, + base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw), + limit=limit * 3, practice_area=practice_area or None, appeal_subtype=appeal_subtype or None, ) diff --git a/scripts/SCRIPTS.md b/scripts/SCRIPTS.md index 8a4f819..47a0e7e 100644 --- a/scripts/SCRIPTS.md +++ b/scripts/SCRIPTS.md @@ -17,6 +17,11 @@ | `deploy-track-changes.sh` | bash | סנכרון skills CMP↔CMPA + בדיקות + הנחיות deploy לארכיטקטורת Track Changes | ידני | | `retrofit_case.py` | python | retrofit רטרואקטיבי — מזריק bookmarks לקובץ קיים של תיק ספציפי ומגדיר אותו כ-active_draft | ידני (חד-פעמי לתיק) | | `reembed_voyage.py` | python | Re-embed כל הוקטורים ב-DB עם המודל ב-`VOYAGE_MODEL` (לאחר שינוי מודל). 5 טבלאות, 1024 דמ', batches של 100. ראה `docs/voyage-upgrades-plan.md` | ידני (אחרי החלפת `VOYAGE_MODEL`) | +| `voyage_context3_poc.py` | python | POC #1 — voyage-3 vs voyage-context-3 על פסיקה אחת קצרה (קלמנוביץ, 63 chunks). הכרעה: context-3 לא מציג שיפור עקבי | בנצ'מרק חד-פעמי, נשמר לרפרנס | +| `voyage_context3_poc_long.py` | python | POC #2 — voyage-context-3 על פסיקה ארוכה (אהרון ברק 219 chunks) עם sliding windows. הכרעה: context-3 לא משתפר על פסיקה גדולה | בנצ'מרק חד-פעמי, נשמר לרפרנס | +| `voyage_multimodal_poc.py` | python | POC #3 — voyage-multimodal-3 על דוח שמאי (89 עמודים). הכרעה: שיפור משמעותי לטבלאות + 22 עמודי image-only שhttp text-OCR מאבד | בנצ'מרק חד-פעמי, מוכן לשלב C | +| `voyage_rerank_judge_poc.py` | python | POC #4 — voyage-3 vs rerank-2 vs context-3 על אהרון ברק, 18 שאילתות, claude-haiku-4-5 כ-judge. הכרעה: rerank-2 ניצח עם +9% mean@3 | בנצ'מרק חד-פעמי | +| `voyage_rerank_corpus_poc.py` | python | POC #5 — voyage-3 vs rerank-2 על קורפוס מלא (785 docs). הכרעה: +4.5% mean@3 כללי, +11.6% על P queries (practical) | בנצ'מרק חד-פעמי, אישר את שלב B | ## תיקיית `.archive/` — סקריפטים שהושלמו diff --git a/scripts/voyage_context3_poc.py b/scripts/voyage_context3_poc.py new file mode 100644 index 0000000..033e761 --- /dev/null +++ b/scripts/voyage_context3_poc.py @@ -0,0 +1,182 @@ +"""POC: Compare voyage-3 vs voyage-context-3 retrieval on case 403/17. + +Pulls all chunks of "אהרון ברק - תכנית רחביה" (case_law_id=e151fc25-...), +runs them through voyage-context-3 in a single contextualized_embed call, +then runs benchmark queries and compares rankings against the existing +voyage-3 embeddings (already in the DB). + +No DB writes — all comparisons in memory. Output: ranking table for each +query showing top-10 from both models side-by-side. + +Usage: + /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ + /home/chaim/legal-ai/scripts/voyage_context3_poc.py +""" +from __future__ import annotations + +import asyncio +import math +import os +import sys +import time + +# Load ~/.env +ENV_PATH = os.path.expanduser("~/.env") +if os.path.isfile(ENV_PATH): + with open(ENV_PATH) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + os.environ.setdefault(k, v) + +import asyncpg # noqa: E402 +import voyageai # noqa: E402 + + +# Using קלמנוביץ/לויתן (52K chars, 63 chunks, ~18K tokens) +# — fits in single context-3 call (32K token limit per inner list). +# אהרון ברק (60K tokens) requires splitting; we'll handle that after POC. +CASE_ID = "436efd48-c8ab-49f0-b3a9-52bf15ea806d" # בר"מ 25226-04-25 +CONTEXT_MODEL = "voyage-context-3" +BASELINE_MODEL = "voyage-3" # already in DB + +QUERIES = [ + "סמכות ועדת ערר", + "פיצויים לפי סעיף 197", + "ירידת ערך מקרקעין", + "תכנית פוגעת", + "שיקול דעת ועדה מקומית", + "חוות דעת שמאי מכריע", + "מקרקעין גובלים", + "תקופת התיישנות תביעה", + "אינטרס ציבורי בתכנון", + "דחיית תביעת פיצויים", +] + + +def cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(y * y for y in b)) + return dot / (na * nb) if na and nb else 0.0 + + +def parse_pgvector(s: str) -> list[float]: + """pgvector text format: '[0.1,0.2,...]'.""" + return [float(x) for x in s.strip("[]").split(",")] + + +async def main(): + api_key = os.environ["VOYAGE_API_KEY"] + pg_pw = os.environ["POSTGRES_PASSWORD"] + + voyage = voyageai.Client(api_key=api_key) + + pool = await asyncpg.create_pool( + host="127.0.0.1", port=5433, user="legal_ai", + password=pg_pw, database="legal_ai", + min_size=1, max_size=2, + ) + + # 1. Pull all chunks + their existing voyage-3 embeddings + rows = await pool.fetch(""" + SELECT chunk_index, content, embedding::text AS emb_text + FROM precedent_chunks + WHERE case_law_id = $1 + ORDER BY chunk_index + """, CASE_ID) + print(f"[load] {len(rows)} chunks from case 403/17") + + chunks = [r["content"] for r in rows] + indices = [r["chunk_index"] for r in rows] + baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows] + + # 2. Embed all chunks with voyage-context-3 — single contextualized call + total_chars = sum(len(c) for c in chunks) + print(f"[context] embedding {len(chunks)} chunks, {total_chars:,} chars total") + start = time.time() + result = voyage.contextualized_embed( + inputs=[chunks], # one document = one inner list + model=CONTEXT_MODEL, + input_type="document", + ) + elapsed = time.time() - start + # ContextualizedEmbeddingsObject: result.results = list of per-document + # embeddings. result.results[0].embeddings = list of chunk embeddings. + context_embs = result.results[0].embeddings + total_tokens = getattr(result, "total_tokens", "?") + print(f"[context] done in {elapsed:.1f}s — total_tokens={total_tokens}") + assert len(context_embs) == len(chunks), "embedding count mismatch" + + # 3. For each query — embed twice and compare top-10 + print("\n" + "=" * 100) + print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3':<48}") + print("=" * 100) + + rank_overlaps = [] + score_lifts = [] + + for q_idx, query in enumerate(QUERIES, 1): + # Baseline query embedding (regular embed) + q_baseline = voyage.embed( + [query], model=BASELINE_MODEL, input_type="query" + ).embeddings[0] + # Context query embedding — must use contextualized_embed even for + # single-string queries (regular embed() rejects voyage-context-3). + q_context = voyage.contextualized_embed( + inputs=[[query]], + model=CONTEXT_MODEL, + input_type="query", + ).results[0].embeddings[0] + + # Score every chunk under both models + scores_b = sorted( + [(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)], + reverse=True, + ) + scores_c = sorted( + [(cosine(q_context, e), i) for i, e in enumerate(context_embs)], + reverse=True, + ) + + top10_b = [i for _, i in scores_b[:10]] + top10_c = [i for _, i in scores_c[:10]] + + # Compute overlap and avg score in top-3 + overlap = len(set(top10_b) & set(top10_c)) + avg_b_top3 = sum(s for s, _ in scores_b[:3]) / 3 + avg_c_top3 = sum(s for s, _ in scores_c[:3]) / 3 + rank_overlaps.append(overlap) + score_lifts.append(avg_c_top3 - avg_b_top3) + + print(f"\n[Q{q_idx}] {query}") + print(f" overlap top-10: {overlap}/10 | avg score top-3: " + f"baseline={avg_b_top3:.3f} context-3={avg_c_top3:.3f} " + f"Δ={avg_c_top3 - avg_b_top3:+.3f}") + for rank in range(5): + sb, ib = scores_b[rank] + sc, ic = scores_c[rank] + cb = chunks[ib].replace("\n", " ").strip()[:50] + cc = chunks[ic].replace("\n", " ").strip()[:50] + print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} " + f"| [{indices[ic]:3d}] {sc:.3f} {cc}") + + # Summary + print("\n" + "=" * 100) + print("SUMMARY") + print("=" * 100) + avg_overlap = sum(rank_overlaps) / len(rank_overlaps) + avg_lift = sum(score_lifts) / len(score_lifts) + print(f"Avg overlap top-10: {avg_overlap:.1f}/10 " + f"(higher = models agree more)") + print(f"Avg score lift top-3 (context - baseline): {avg_lift:+.4f}") + print(f"\nNote: cosine scores are not directly comparable across models.") + print(f"What matters more is which CHUNKS bubble to the top —") + print(f"reading the actual content above tells the real story.") + + await pool.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/voyage_context3_poc_long.py b/scripts/voyage_context3_poc_long.py new file mode 100644 index 0000000..5a7c71b --- /dev/null +++ b/scripts/voyage_context3_poc_long.py @@ -0,0 +1,238 @@ +"""POC #2: voyage-3 vs voyage-context-3 on a LONG case (אהרון ברק 403/17). + +Case is 178K chars / 219 chunks / ~60K tokens — too big for a single +contextualized_embed call (32K token limit per inner list). We split the +chunks into overlapping sliding windows (~80 chunks each, ~22K tokens) +and merge: each chunk gets the embedding from the window where it sits +*most centrally* (max symmetric context on both sides). + +The hypothesis: voyage-context-3 should shine here because the case is +full of internal references ("ראה לעיל סעיף 13", "להבדיל מעניין X", +"תוצאת הבחינה ב-בר"מ 1975/24 שנידונה לעיל"). voyage-3 embeds chunks +in isolation; context-3 sees ~80 surrounding chunks per embedding. + +No DB writes. Output: side-by-side ranking comparison + summary. + +Usage: + /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ + /home/chaim/legal-ai/scripts/voyage_context3_poc_long.py +""" +from __future__ import annotations + +import asyncio +import math +import os +import sys +import time + +ENV_PATH = os.path.expanduser("~/.env") +if os.path.isfile(ENV_PATH): + with open(ENV_PATH) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + os.environ.setdefault(k, v) + +import asyncpg # noqa: E402 +import voyageai # noqa: E402 + + +CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # 403/17 אהרון ברק (178K chars) +CONTEXT_MODEL = "voyage-context-3" +BASELINE_MODEL = "voyage-3" + +# Sliding-window split params. With 219 chunks and ~60K tokens total +# (~275 tokens/chunk average), 3 windows of 80 chunks each is ~22K tokens +# per call — comfortably under 32K. +WINDOW_SIZE = 80 +WINDOW_STRIDE = 70 # overlap = WINDOW_SIZE - WINDOW_STRIDE = 10 + +# Mix of: +# (a) generic queries (also tested in POC #1) +# (b) queries that require *internal* document context +QUERIES = [ + # generic + "תכנית רחביה הוראות בנייה", + "פיצויים לפי סעיף 197 ירידת ערך", + "השפעת תכנית על שווי מקרקעין", + "סמכות ועדת ערר לדון בפיצויים", + "תוספת זכויות בנייה כפיצוי", + # internal-context — should benefit context-3 + "ההבחנה בין השבחה לפיצויים", + "מה נקבע לגבי תמ\"א 38 בפסק הדין", + "ההלכה שנקבעה בעניין רובע 3", + "כלל הנטרול של זכויות תכנוניות", + "הסכמת השופט אלרון לחוות הדעת", +] + + +def cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(y * y for y in b)) + return dot / (na * nb) if na and nb else 0.0 + + +def parse_pgvector(s: str) -> list[float]: + return [float(x) for x in s.strip("[]").split(",")] + + +def build_windows(n: int, size: int, stride: int) -> list[tuple[int, int]]: + """Return list of (start, end) ranges (end exclusive) covering 0..n. + + Last window extends to n exactly. Overlap = size - stride. + """ + windows = [] + start = 0 + while start < n: + end = min(start + size, n) + windows.append((start, end)) + if end == n: + break + start += stride + return windows + + +def assign_chunk_to_window( + chunk_idx: int, windows: list[tuple[int, int]], +) -> int: + """Pick the window where chunk_idx sits most centrally (max symmetric + distance to either edge). Ties broken by larger window.""" + best = -1 + best_score = -1 + for w_idx, (s, e) in enumerate(windows): + if not (s <= chunk_idx < e): + continue + # symmetric distance: min(distance to s, distance to e-1) + dist = min(chunk_idx - s, (e - 1) - chunk_idx) + if dist > best_score: + best_score = dist + best = w_idx + return best + + +async def main(): + api_key = os.environ["VOYAGE_API_KEY"] + pg_pw = os.environ["POSTGRES_PASSWORD"] + + voyage = voyageai.Client(api_key=api_key) + + pool = await asyncpg.create_pool( + host="127.0.0.1", port=5433, user="legal_ai", + password=pg_pw, database="legal_ai", + min_size=1, max_size=2, + ) + + rows = await pool.fetch(""" + SELECT chunk_index, content, embedding::text AS emb_text + FROM precedent_chunks + WHERE case_law_id = $1 + ORDER BY chunk_index + """, CASE_ID) + n = len(rows) + print(f"[load] {n} chunks from אהרון ברק 403/17") + + chunks = [r["content"] for r in rows] + indices = [r["chunk_index"] for r in rows] + baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows] + + # Build windows + windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE) + print(f"[windows] {len(windows)} windows: " + f"{', '.join(f'[{s}:{e})' for s, e in windows)}") + + # Embed each window with context-3 + window_embs: list[list[list[float]]] = [] # [window][chunk_in_window][dim] + total_call_tokens = 0 + total_start = time.time() + for w_idx, (s, e) in enumerate(windows): + sub_chunks = chunks[s:e] + sub_chars = sum(len(c) for c in sub_chunks) + start = time.time() + result = voyage.contextualized_embed( + inputs=[sub_chunks], + model=CONTEXT_MODEL, + input_type="document", + ) + elapsed = time.time() - start + toks = getattr(result, "total_tokens", 0) + total_call_tokens += toks + print(f" [window {w_idx}] [{s}:{e}) — {len(sub_chunks)} chunks, " + f"{sub_chars:,} chars, {toks} tokens — {elapsed:.1f}s") + window_embs.append(result.results[0].embeddings) + total_elapsed = time.time() - total_start + print(f"[context] all windows done in {total_elapsed:.1f}s, " + f"{total_call_tokens} total tokens") + + # Merge: for each chunk, pick the embedding from its most-central window + context_embs: list[list[float]] = [] + chunk_window_choice = [] + for i in range(n): + w_idx = assign_chunk_to_window(i, windows) + chunk_window_choice.append(w_idx) + s, _ = windows[w_idx] + context_embs.append(window_embs[w_idx][i - s]) + print(f"[merge] window distribution: " + f"{[chunk_window_choice.count(j) for j in range(len(windows))]}") + + # Run queries + print("\n" + "=" * 100) + print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3 (windowed)':<48}") + print("=" * 100) + + rank_overlaps = [] + for q_idx, query in enumerate(QUERIES, 1): + q_baseline = voyage.embed( + [query], model=BASELINE_MODEL, input_type="query" + ).embeddings[0] + q_context = voyage.contextualized_embed( + inputs=[[query]], + model=CONTEXT_MODEL, + input_type="query", + ).results[0].embeddings[0] + + scores_b = sorted( + [(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)], + reverse=True, + ) + scores_c = sorted( + [(cosine(q_context, e), i) for i, e in enumerate(context_embs)], + reverse=True, + ) + + top10_b = [i for _, i in scores_b[:10]] + top10_c = [i for _, i in scores_c[:10]] + overlap = len(set(top10_b) & set(top10_c)) + rank_overlaps.append(overlap) + + print(f"\n[Q{q_idx}] {query}") + print(f" overlap top-10: {overlap}/10 | " + f"avg score top-3: baseline=" + f"{sum(s for s, _ in scores_b[:3])/3:.3f} " + f"context-3={sum(s for s, _ in scores_c[:3])/3:.3f}") + for rank in range(5): + sb, ib = scores_b[rank] + sc, ic = scores_c[rank] + cb = chunks[ib].replace("\n", " ").strip()[:50] + cc = chunks[ic].replace("\n", " ").strip()[:50] + print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} " + f"| [{indices[ic]:3d}] {sc:.3f} {cc}") + + print("\n" + "=" * 100) + print("SUMMARY") + print("=" * 100) + avg = sum(rank_overlaps) / len(rank_overlaps) + print(f"Avg overlap top-10: {avg:.1f}/10") + print(f"Per-query overlap: {rank_overlaps}") + print(f"Total context-3 tokens used: {total_call_tokens:,} " + f"(in {len(windows)} calls)") + print(f"\nNote: cosine across models not directly comparable. The") + print(f"meaningful test is *which chunks bubble to the top* — read") + print(f"the actual text above to judge relevance.") + + await pool.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/voyage_multimodal_poc.py b/scripts/voyage_multimodal_poc.py new file mode 100644 index 0000000..02d8617 --- /dev/null +++ b/scripts/voyage_multimodal_poc.py @@ -0,0 +1,213 @@ +"""POC #3: voyage-3 (text) vs voyage-multimodal-3.5 (page images) on a +real appraisal PDF (89 pages, full of tables / signatures / numerical +data — the corpus class where multimodal should help most). + +Document under test: + baf10153-d2fc-4481-b250-9fe87440ce69 + "נספח - שומה מכרעת (אבלין דוידזון שמאמא) - 15.09.24" + case 8137-24, 89 pages, 2.1 MB + +The pipeline: + 1. Pull the existing voyage-3 text-chunk embeddings from `document_chunks`. + 2. Render each PDF page → PNG (PyMuPDF, dpi=144). + 3. Embed all pages via voyage-multimodal-3.5. + 4. Run benchmark queries (mix of generic + table-specific + visual) + against both: text top-K and page top-K. + +The comparison is *qualitative* — text and image embeddings are +different "spaces" returning different ID types (chunk_id vs page_num). +What we look at is whether image-based retrieval surfaces tables, +signatures, or numerical data that text-only OCR loses. + +No DB writes. + +Usage: + /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ + /home/chaim/legal-ai/scripts/voyage_multimodal_poc.py +""" +from __future__ import annotations + +import asyncio +import io +import math +import os +import time + +ENV_PATH = os.path.expanduser("~/.env") +if os.path.isfile(ENV_PATH): + with open(ENV_PATH) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + os.environ.setdefault(k, v) + +import asyncpg # noqa: E402 +import voyageai # noqa: E402 +import fitz # PyMuPDF # noqa: E402 +from PIL import Image # noqa: E402 + + +DOCUMENT_ID = "baf10153-d2fc-4481-b250-9fe87440ce69" +PDF_PATH = ( + "/home/chaim/legal-ai/data/cases/8137-24/documents/originals/" + "נספח - שומה מכרעת (אבלין דוידזון שמאמא) - 15.09.24.pdf" +) +TEXT_MODEL = "voyage-3" +MULTIMODAL_MODEL = "voyage-multimodal-3" # check supported: 3.5 may not exist yet +DPI = 144 +# voyage-multimodal: max 1000 inputs/call, 320M pixels/call (rough), +# so 89 pages at 1240×1750 ≈ 192M pixels = single call. + +QUERIES = [ + # generic-textual (both should handle) + "שיטת ההיוון בשומה", + "מתודולוגיית הערכת שווי", + # table/numerical (multimodal should help) + "טבלת השוואת ערכים לפני ואחרי התכנית", + "שווי המקרקעין במצב הקודם", + "שווי המקרקעין במצב החדש", + "ירידת ערך באחוזים", + # visual elements (text-only loses) + "חתימת השמאי", + "תרשים גוש וחלקה", + "מפת מיקום הנכס", + # context-heavy + "מסקנת השמאי המכריע", + "עקרון הצפיפות בתכנית", +] + + +def cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(y * y for y in b)) + return dot / (na * nb) if na and nb else 0.0 + + +def parse_pgvector(s: str) -> list[float]: + return [float(x) for x in s.strip("[]").split(",")] + + +def render_pdf_pages(pdf_path: str, dpi: int) -> list[Image.Image]: + """Render each page → PIL.Image (RGB).""" + doc = fitz.open(pdf_path) + images: list[Image.Image] = [] + for page in doc: + pix = page.get_pixmap(dpi=dpi) + png_bytes = pix.tobytes("png") + img = Image.open(io.BytesIO(png_bytes)).convert("RGB") + images.append(img) + doc.close() + return images + + +async def main(): + api_key = os.environ["VOYAGE_API_KEY"] + pg_pw = os.environ["POSTGRES_PASSWORD"] + + voyage = voyageai.Client(api_key=api_key) + + # 1. Render PDF pages + print(f"[render] {PDF_PATH}") + start = time.time() + images = render_pdf_pages(PDF_PATH, DPI) + elapsed = time.time() - start + print(f"[render] {len(images)} pages in {elapsed:.1f}s, " + f"{images[0].size}px @ {DPI}dpi") + + # 2. Pull existing text chunks + voyage-3 embeddings + pool = await asyncpg.create_pool( + host="127.0.0.1", port=5433, user="legal_ai", + password=pg_pw, database="legal_ai", + min_size=1, max_size=2, + ) + rows = await pool.fetch(""" + SELECT id, chunk_index, page_number, content, + embedding::text AS emb_text + FROM document_chunks + WHERE document_id = $1 + ORDER BY chunk_index + """, DOCUMENT_ID) + print(f"[text] {len(rows)} text chunks loaded (voyage-3 in DB)") + text_contents = [r["content"] for r in rows] + text_chunk_pages = [r["page_number"] for r in rows] + text_embs = [parse_pgvector(r["emb_text"]) for r in rows] + + # 3. Multimodal embed — try multimodal-3 first, fall back if needed + target_model = "voyage-multimodal-3" + print(f"[multimodal] embedding {len(images)} pages with {target_model}…") + start = time.time() + try: + mm_result = voyage.multimodal_embed( + inputs=[[img] for img in images], # list of single-image inputs + model=target_model, + input_type="document", + truncation=True, + ) + except voyageai.error.InvalidRequestError as e: + print(f" [error] {e}") + await pool.close() + return + elapsed = time.time() - start + image_embs = mm_result.embeddings + mm_tokens = getattr(mm_result, "total_tokens", "?") + image_tokens = getattr(mm_result, "image_pixels", "?") + text_tokens_mm = getattr(mm_result, "text_tokens", "?") + print(f"[multimodal] done in {elapsed:.1f}s — " + f"total_tokens={mm_tokens} text_tokens={text_tokens_mm} " + f"image_pixels={image_tokens}") + assert len(image_embs) == len(images), "embedding count mismatch" + print(f"[multimodal] embedding dim = {len(image_embs[0])}") + + # 4. Run queries + print("\n" + "=" * 100) + print("QUERY RESULTS — top-5 chunks (text/voyage-3) " + "vs top-5 pages (multimodal)") + print("=" * 100) + + for q_idx, query in enumerate(QUERIES, 1): + # Text-side: voyage-3 query embedding + q_text = voyage.embed( + [query], model=TEXT_MODEL, input_type="query" + ).embeddings[0] + # Multimodal-side: same model, query input_type + q_mm = voyage.multimodal_embed( + inputs=[[query]], + model=target_model, + input_type="query", + ).embeddings[0] + + text_scores = sorted( + [(cosine(q_text, e), i) for i, e in enumerate(text_embs)], + reverse=True, + )[:5] + mm_scores = sorted( + [(cosine(q_mm, e), i) for i, e in enumerate(image_embs)], + reverse=True, + )[:5] + + print(f"\n[Q{q_idx}] {query}") + print(f" --- text (voyage-3) top-5 ---") + for s, i in text_scores: + page = text_chunk_pages[i] if text_chunk_pages[i] else "?" + preview = text_contents[i].replace("\n", " ").strip()[:70] + print(f" {s:.3f} page={page:>3} chunk={i:>3} {preview}") + print(f" --- multimodal (image-only) top-5 ---") + for s, i in mm_scores: + print(f" {s:.3f} page={i+1:>3} (image)") + + # Token / cost summary + print("\n" + "=" * 100) + print("SUMMARY") + print("=" * 100) + print(f"PDF: {len(images)} pages @ {DPI}dpi → {target_model}") + print(f"Total multimodal tokens: {mm_tokens}") + print(f"Embedding dim: {len(image_embs[0])}") + print(f"Time: {elapsed:.1f}s for full doc") + + await pool.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/voyage_rerank_corpus_poc.py b/scripts/voyage_rerank_corpus_poc.py new file mode 100644 index 0000000..c020e07 --- /dev/null +++ b/scripts/voyage_rerank_corpus_poc.py @@ -0,0 +1,318 @@ +"""POC #5 — full precedent_library corpus benchmark. + +Tests R1 (voyage-3) vs R2 (voyage-3 + rerank-2) on the *real* corpus that +search_precedent_library queries against: + + precedent_chunks — 385 rows from 3 precedent cases + halachot — 400 rule statements with reasoning summaries + +Total: 785 documents. The MCP tool merges results from both tables so the +benchmark mirrors production retrieval. R3 (context-3) is dropped — it +would require windowed re-embedding of 3 cases which we already proved +doesn't help (POC #2). The question now is: does rerank-2's +9% on a +single case generalize to a heterogeneous corpus? + +Also measures end-to-end latency: pure voyage-3 vs voyage-3 + rerank. + +Usage: + /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ + /home/chaim/legal-ai/scripts/voyage_rerank_corpus_poc.py +""" +from __future__ import annotations + +import asyncio +import json +import math +import os +import re +import subprocess +import sys +import time +from collections import defaultdict + +ENV_PATH = os.path.expanduser("~/.env") +if os.path.isfile(ENV_PATH): + with open(ENV_PATH) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + os.environ.setdefault(k, v) + +import asyncpg # noqa: E402 +import voyageai # noqa: E402 + + +TEXT_MODEL = "voyage-3" +RERANK_MODEL = "rerank-2" +JUDGE_MODEL = "claude-haiku-4-5-20251001" +TOP_VEC = 50 # voyage-3 retrieve depth +TOP_K = 10 # final returned to "agent" +JUDGE_K = 5 # how many top results to actually judge per retriever + +# 12 queries spanning typical use cases by Daphna's agents: +# precedent search for citing in decision blocks י-יא. +QUERIES = [ + # K — keyword + ("K1", "פיצויים לפי סעיף 197"), + ("K2", "תמ\"א 38 והשבחה"), + ("K3", "כלל הנטרול בשמאות"), + # C — conceptual + ("C1", "תכלית היטל ההשבחה"), + ("C2", "מה מקנה לבעלים זכות לפיצוי"), + ("C3", "ההבחנה בין השבחה לפיצויים"), + # N — narrative / context-aware + ("N1", "מה נקבע לגבי תמ\"א 38 בפסיקה"), + ("N2", "ההלכה לעניין נטרול ציפיות"), + ("N3", "תכנית פוגעת ושומה"), + # P — practical (drafting needs — what an agent typically asks) + ("P1", "פסיקה שדנה בתכנית מתאר ארצית"), + ("P2", "מתי מותר לוועדה לדחות פיצויים"), + ("P3", "שיקול דעת הוועדה המקומית"), +] + + +def cosine(a, b): + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(y * y for y in b)) + return dot / (na * nb) if na and nb else 0.0 + + +def parse_pgvector(s): + return [float(x) for x in s.strip("[]").split(",")] + + +BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי. +לפניך שאילתה ומספר פסקאות מפסקי דין/הלכות. דרג כל פסקה 1-5 לפי רלוונטיות. + +5 — תשובה ישירה למה שנשאל +4 — מאד רלוונטי, מכיל מידע ליבה +3 — רלוונטי חלקית, נוגע בעקיפין +2 — מעט קשור, רעש סביב הנושא +1 — לא רלוונטי בכלל + +השאילתה: +{query} + +הפסקאות: +{chunks_block} + +החזר JSON בלבד: {{"scores": {{"": <1-5>, ...}}}} +ללא טקסט נוסף, ללא ```.""" + + +def batch_judge(query: str, items: list[tuple[str, str]]) -> dict[str, int]: + """Judge (id, text) pairs via claude CLI. Returns {id: score}.""" + blocks = [] + for cid, content in items: + snippet = content.replace("\n", " ").strip()[:1500] + blocks.append(f"\n{snippet}\n") + prompt = BATCH_JUDGE_PROMPT.format( + query=query, chunks_block="\n\n".join(blocks)) + proc = subprocess.run( + ["claude", "-p", "--model", JUDGE_MODEL], + input=prompt, capture_output=True, text=True, timeout=180, + ) + out = proc.stdout.strip() + out = re.sub(r"^```(?:json)?\s*", "", out) + out = re.sub(r"\s*```$", "", out) + try: + data = json.loads(out) + raw = data.get("scores", {}) + return {str(k): int(v) for k, v in raw.items() + if str(v).isdigit() and 1 <= int(v) <= 5} + except (json.JSONDecodeError, ValueError, TypeError) as e: + print(f" [judge parse fail: {e}; out={out[:200]!r}]") + return {} + + +async def main(): + voyage_key = os.environ["VOYAGE_API_KEY"] + pg_pw = os.environ["POSTGRES_PASSWORD"] + + try: + subprocess.run(["claude", "--version"], capture_output=True, + text=True, timeout=10, check=True) + except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError): + sys.exit("claude CLI not found") + + voyage = voyageai.Client(api_key=voyage_key) + + pool = await asyncpg.create_pool( + host="127.0.0.1", port=5433, user="legal_ai", + password=pg_pw, database="legal_ai", + min_size=1, max_size=2, + ) + + # Load full corpus: precedent_chunks + halachot + pc_rows = await pool.fetch(""" + SELECT 'pc:' || id::text AS doc_id, + content, + embedding::text AS emb_text + FROM precedent_chunks + WHERE content IS NOT NULL AND embedding IS NOT NULL + """) + h_rows = await pool.fetch(""" + SELECT 'h:' || id::text AS doc_id, + TRIM(BOTH ' —' FROM rule_statement || ' — ' || + COALESCE(reasoning_summary, '')) AS content, + embedding::text AS emb_text + FROM halachot + WHERE rule_statement IS NOT NULL AND embedding IS NOT NULL + """) + all_rows = list(pc_rows) + list(h_rows) + print(f"[load] corpus: {len(pc_rows)} precedent_chunks + " + f"{len(h_rows)} halachot = {len(all_rows)} total") + + doc_ids = [r["doc_id"] for r in all_rows] + contents = [r["content"] for r in all_rows] + embs = [parse_pgvector(r["emb_text"]) for r in all_rows] + + # Latency measurement: 5 queries, time the two pipelines + print("\n[latency] measuring 5 sample queries…") + sample = QUERIES[:5] + r1_lat = [] + r2_lat = [] + for _, query in sample: + # R1: voyage-3 embed + cosine top-10 + t0 = time.time() + q_emb = voyage.embed([query], model=TEXT_MODEL, + input_type="query").embeddings[0] + scores = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)], + reverse=True)[:TOP_K] + r1_lat.append(time.time() - t0) + # R2: voyage-3 embed + cosine top-50 + rerank-2 → top-10 + t0 = time.time() + q_emb = voyage.embed([query], model=TEXT_MODEL, + input_type="query").embeddings[0] + cands = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)], + reverse=True)[:TOP_VEC] + cand_texts = [contents[i] for _, i in cands] + rr = voyage.rerank(query=query, documents=cand_texts, + model=RERANK_MODEL, top_k=TOP_K) + r2_lat.append(time.time() - t0) + print(f" R1 (voyage-3 only) avg={sum(r1_lat)/5*1000:.0f}ms" + f" min={min(r1_lat)*1000:.0f} max={max(r1_lat)*1000:.0f}") + print(f" R2 (voyage-3 + rerank-2) avg={sum(r2_lat)/5*1000:.0f}ms" + f" min={min(r2_lat)*1000:.0f} max={max(r2_lat)*1000:.0f}") + print(f" Δ (rerank overhead) avg={(sum(r2_lat)-sum(r1_lat))/5*1000:.0f}ms") + + # Retrieval functions + def r1_baseline(query: str, k: int = TOP_K) -> list[int]: + q = voyage.embed([query], model=TEXT_MODEL, + input_type="query").embeddings[0] + scores = sorted([(cosine(q, e), i) for i, e in enumerate(embs)], + reverse=True) + return [i for _, i in scores[:k]] + + def r2_rerank(query: str, k: int = TOP_K) -> list[int]: + cands = r1_baseline(query, k=TOP_VEC) + cand_texts = [contents[i] for i in cands] + rr = voyage.rerank(query=query, documents=cand_texts, + model=RERANK_MODEL, top_k=k) + return [cands[r.index] for r in rr.results] + + retrievers = [("R1-voyage3", r1_baseline), + ("R2-rerank2", r2_rerank)] + + print(f"\n[judge] running {len(QUERIES)} queries × 2 retrievers, " + f"top-{JUDGE_K} judged…") + + all_results = [] + for qid, query in QUERIES: + print(f"\n[{qid}] {query}") + retr_results = {} + for r_name, r_fn in retrievers: + try: + retr_results[r_name] = r_fn(query, k=JUDGE_K) + except Exception as e: + print(f" {r_name}: FAILED — {e}") + retr_results[r_name] = [] + union = sorted({i for top in retr_results.values() for i in top}) + items = [(doc_ids[i], contents[i]) for i in union] + print(f" judging {len(items)} unique docs…") + scores_map = batch_judge(query, items) + for r_name, top in retr_results.items(): + scores = [scores_map.get(doc_ids[i], 0) for i in top] + mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0 + mean5 = sum(scores) / len(scores) if scores else 0 + mrr = 0.0 + for r, s in enumerate(scores): + if s >= 4: + mrr = 1.0 / (r + 1) + break + print(f" {r_name}: doc_ids={[doc_ids[i][:14] for i in top]} " + f"scores={scores} m@3={mean3:.2f} m@5={mean5:.2f} " + f"MRR={mrr:.3f}") + all_results.append({ + "qid": qid, "category": qid[0], "query": query, + "retriever": r_name, + "doc_ids": [doc_ids[i] for i in top], + "scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr, + }) + + # Aggregate + print("\n" + "=" * 100) + print("AGGREGATED RESULTS — full precedent_library corpus (785 docs)") + print("=" * 100) + by_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) + by_cat_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) + for r in all_results: + by_r[r["retriever"]]["mean3"].append(r["mean3"]) + by_r[r["retriever"]]["mean5"].append(r["mean5"]) + by_r[r["retriever"]]["mrr"].append(r["mrr"]) + ck = (r["category"], r["retriever"]) + by_cat_r[ck]["mean3"].append(r["mean3"]) + by_cat_r[ck]["mean5"].append(r["mean5"]) + by_cat_r[ck]["mrr"].append(r["mrr"]) + + print(f"\nOverall ({len(QUERIES)} queries):") + print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") + avg = lambda xs: sum(xs) / len(xs) if xs else 0 + for r_name, _ in retrievers: + m = by_r[r_name] + print(f"{r_name:<14} {avg(m['mean3']):>8.3f} " + f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") + # Improvement + r1m = avg(by_r["R1-voyage3"]["mean3"]) + r2m = avg(by_r["R2-rerank2"]["mean3"]) + if r1m > 0: + print(f"\nR2 vs R1 improvement: " + f"mean@3 {(r2m - r1m) / r1m * 100:+.1f}%") + + print(f"\nBy category:") + print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} " + f"{'MRR':>8}") + for cat in ["K", "C", "N", "P"]: + for r_name, _ in retrievers: + m = by_cat_r[(cat, r_name)] + if not m["mean3"]: + continue + print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} " + f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") + + print(f"\nPer-query winner (highest mean@3):") + print(f"{'qid':<4} {'query':<40} {'winner':<14} {'scores'}") + by_q = defaultdict(list) + for r in all_results: + by_q[r["qid"]].append(r) + for qid, results in sorted(by_q.items()): + max_s = max(r["mean3"] for r in results) + winners = [r["retriever"] for r in results if r["mean3"] == max_s] + scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}" + for r in results) + q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:38] + print(f"{qid:<4} {q_str:<40} {','.join(w[:8] for w in winners):<14} " + f"{scores}") + + out_path = "/tmp/voyage_rerank_corpus_results.json" + with open(out_path, "w") as f: + json.dump(all_results, f, ensure_ascii=False, indent=2) + print(f"\nSaved to {out_path}") + + await pool.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/voyage_rerank_judge_poc.py b/scripts/voyage_rerank_judge_poc.py new file mode 100644 index 0000000..fbfb9bd --- /dev/null +++ b/scripts/voyage_rerank_judge_poc.py @@ -0,0 +1,361 @@ +"""POC #4: Comprehensive retrieval benchmark with LLM-as-judge. + +Compares 3 retrievers on אהרון ברק 403/17 (219 chunks): + R1 — voyage-3 (current production baseline) + R2 — voyage-3 + voyage-rerank-2 (retrieve 50, rerank, top-10) + R3 — voyage-context-3 (windowed, from POC #2) + +Judges relevance with claude-haiku-4-5 — for each (query, chunk) pair the +judge returns 1-5. Aggregates: mean relevance@3, @5, @10, MRR (rank of +first 4+ chunk), per-query winner. + +20 queries grouped into 3 categories so we can see *which* query types +benefit from which retriever: + K — keyword/lexical (term-heavy, specific entity) + C — conceptual (abstract idea, principle) + N — narrative/contextual (requires document-internal reference) + +Usage (key passed via env, NOT stored in script): + ANTHROPIC_API_KEY=... \\ + /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ + /home/chaim/legal-ai/scripts/voyage_rerank_judge_poc.py +""" +from __future__ import annotations + +import asyncio +import json +import math +import os +import sys +import time +from collections import defaultdict + +ENV_PATH = os.path.expanduser("~/.env") +if os.path.isfile(ENV_PATH): + with open(ENV_PATH) as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + k, v = line.split("=", 1) + os.environ.setdefault(k, v) + +import re +import subprocess + +import asyncpg # noqa: E402 +import voyageai # noqa: E402 + + +CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # אהרון ברק 403/17 +TEXT_MODEL = "voyage-3" +CONTEXT_MODEL = "voyage-context-3" +RERANK_MODEL = "rerank-2" +JUDGE_MODEL = "claude-haiku-4-5-20251001" + +WINDOW_SIZE = 80 +WINDOW_STRIDE = 70 + +# 18 queries × 3 retrievers × top-5 = 270 judge calls. ~$0.05 with haiku. +QUERIES = [ + # K — keyword/lexical + ("K1", "תכנית רחביה הוראות בנייה"), + ("K2", "תמ\"א 38"), + ("K3", "תכנית 9988"), + ("K4", "סעיף 197 לחוק התכנון והבניה"), + ("K5", "השופט גרוסקופף"), + ("K6", "ועדה מקומית ירושלים"), + # C — conceptual / abstract principles + ("C1", "כלל הנטרול של זכויות תכנוניות"), + ("C2", "אינטרס הציבור בתכנון"), + ("C3", "תכלית היטל ההשבחה"), + ("C4", "תכנית פוגעת לעומת תכנית משביחה"), + ("C5", "ההבחנה בין השבחה לפיצויים"), + ("C6", "מהותו של היטל ההשבחה"), + # N — narrative / context-dependent + ("N1", "מה נקבע לגבי תמ\"א 38 בפסק הדין"), + ("N2", "מסקנת בית המשפט בעניין רובע 3"), + ("N3", "ההלכה שנקבעה בעניין שמעוני"), + ("N4", "ההבדל בין המקרה שלפנינו לעניין רון"), + ("N5", "סוף דבר ותוצאת פסק הדין"), + ("N6", "הסכמת השופטים האחרים לחוות הדעת"), +] + + +def cosine(a, b): + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) + nb = math.sqrt(sum(y * y for y in b)) + return dot / (na * nb) if na and nb else 0.0 + + +def parse_pgvector(s): + return [float(x) for x in s.strip("[]").split(",")] + + +def build_windows(n, size, stride): + out = [] + s = 0 + while s < n: + e = min(s + size, n) + out.append((s, e)) + if e == n: + break + s += stride + return out + + +def central_window(idx, windows): + best, best_d = -1, -1 + for w_idx, (s, e) in enumerate(windows): + if not (s <= idx < e): + continue + d = min(idx - s, (e - 1) - idx) + if d > best_d: + best_d = d + best = w_idx + return best + + +BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי. +לפניך שאילתה ומספר פסקאות מפסק דין. דרג כל פסקה בנפרד 1-5 לפי רלוונטיות. + +סולם: +5 — תשובה ישירה ומדויקת לשאילתה +4 — מאד רלוונטי, מכיל מידע ליבה +3 — רלוונטי חלקית, נוגע בעקיפין בנושא +2 — מעט קשור, רעש סביב הנושא +1 — לא רלוונטי בכלל + +השאילתה: +{query} + +הפסקאות: +{chunks_block} + +החזר JSON בלבד, בפורמט: {{"scores": {{"": <1-5>, ...}}}} +ללא טקסט נוסף, ללא explanations, ללא ```.""" + + +def batch_judge(query: str, + items: list[tuple[int, str]]) -> dict[int, int]: + """Judge a list of (chunk_idx, content) pairs in a single CLI call. + + Returns: dict[chunk_idx → score 1-5]. Returns 0 for parse failures. + """ + chunks_block_lines = [] + for ci, content in items: + snippet = content.replace("\n", " ").strip()[:1500] + chunks_block_lines.append(f"\n{snippet}\n") + prompt = BATCH_JUDGE_PROMPT.format( + query=query, + chunks_block="\n\n".join(chunks_block_lines), + ) + proc = subprocess.run( + ["claude", "-p", "--model", JUDGE_MODEL], + input=prompt, capture_output=True, text=True, timeout=120, + ) + out = proc.stdout.strip() + # Strip ```json fences if any + out = re.sub(r"^```(?:json)?\s*", "", out) + out = re.sub(r"\s*```$", "", out) + try: + data = json.loads(out) + raw = data.get("scores", {}) + return {int(k): int(v) for k, v in raw.items() + if str(v).isdigit() and 1 <= int(v) <= 5} + except (json.JSONDecodeError, ValueError, TypeError) as e: + print(f" [judge parse fail: {e}; out={out[:200]!r}]") + return {} + + +async def main(): + voyage_key = os.environ["VOYAGE_API_KEY"] + pg_pw = os.environ["POSTGRES_PASSWORD"] + + # Verify Claude CLI is available (uses OAuth from ~/.claude/.credentials) + try: + subprocess.run(["claude", "--version"], capture_output=True, + text=True, timeout=10, check=True) + except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError): + sys.exit("claude CLI not found or not authenticated") + + voyage = voyageai.Client(api_key=voyage_key) + + # Load chunks + voyage-3 embeddings + pool = await asyncpg.create_pool( + host="127.0.0.1", port=5433, user="legal_ai", + password=pg_pw, database="legal_ai", + min_size=1, max_size=2, + ) + rows = await pool.fetch(""" + SELECT chunk_index, content, embedding::text AS emb_text + FROM precedent_chunks + WHERE case_law_id = $1 + ORDER BY chunk_index + """, CASE_ID) + chunks = [r["content"] for r in rows] + chunk_indices = [r["chunk_index"] for r in rows] + baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows] + n = len(chunks) + print(f"[load] {n} chunks loaded") + + # Compute context-3 (windowed) embeddings — same as POC #2 + windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE) + print(f"[context-3] embedding {len(windows)} windows…") + win_embs = [] + for s, e in windows: + result = voyage.contextualized_embed( + inputs=[chunks[s:e]], + model=CONTEXT_MODEL, + input_type="document", + ) + win_embs.append(result.results[0].embeddings) + context_embs = [] + for i in range(n): + w = central_window(i, windows) + s, _ = windows[w] + context_embs.append(win_embs[w][i - s]) + print(f"[context-3] done") + + # Retrieval functions + def r1_baseline(query: str, k: int = 10) -> list[int]: + q = voyage.embed([query], model=TEXT_MODEL, + input_type="query").embeddings[0] + scores = sorted( + [(cosine(q, e), i) for i, e in enumerate(baseline_embs)], + reverse=True, + ) + return [i for _, i in scores[:k]] + + def r2_rerank(query: str, k: int = 10) -> list[int]: + # 1) voyage-3 retrieve top-50 + cands = r1_baseline(query, k=50) + cand_texts = [chunks[i] for i in cands] + # 2) voyage-rerank-2 over the 50 + rr = voyage.rerank( + query=query, documents=cand_texts, + model=RERANK_MODEL, top_k=k, + ) + # rr.results: list of RerankingResult(index=..., relevance_score=...) + # `index` refers to position in cand_texts → map back to chunk idx + return [cands[r.index] for r in rr.results] + + def r3_context(query: str, k: int = 10) -> list[int]: + q = voyage.contextualized_embed( + inputs=[[query]], + model=CONTEXT_MODEL, + input_type="query", + ).results[0].embeddings[0] + scores = sorted( + [(cosine(q, e), i) for i, e in enumerate(context_embs)], + reverse=True, + ) + return [i for _, i in scores[:k]] + + retrievers = [("R1-voyage3", r1_baseline), + ("R2-rerank2", r2_rerank), + ("R3-context3", r3_context)] + + # Run all queries × all retrievers, judging top-5 per pair. + # Strategy: for each query, gather the union of all retrievers' top-K + # and judge them in ONE batched CLI call → 18 calls total instead of 270. + all_results = [] + JUDGE_TOP_K = 5 + print(f"\n[judge] running {len(QUERIES)} queries × " + f"{len(retrievers)} retrievers × top-{JUDGE_TOP_K} — batched per query…") + + for qid, query in QUERIES: + print(f"\n[{qid}] {query}") + # Collect retrievals first + retr_results = {} + for r_name, r_fn in retrievers: + try: + retr_results[r_name] = r_fn(query, k=JUDGE_TOP_K) + except Exception as e: + print(f" {r_name}: FAILED — {e}") + retr_results[r_name] = [] + # Union of unique chunk indices to judge + union = sorted({i for top in retr_results.values() for i in top}) + items = [(i, chunks[i]) for i in union] + print(f" judging {len(items)} unique chunks via batch CLI…") + scores_map = batch_judge(query, items) + # Build per-retriever score lists + for r_name, top in retr_results.items(): + scores = [scores_map.get(i, 0) for i in top] + mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0 + mean5 = sum(scores) / len(scores) if scores else 0 + mrr = 0.0 + for r, s in enumerate(scores): + if s >= 4: + mrr = 1.0 / (r + 1) + break + print(f" {r_name}: chunks={[chunk_indices[i] for i in top]} " + f"scores={scores} mean@3={mean3:.2f} mean@5={mean5:.2f} " + f"MRR={mrr:.3f}") + all_results.append({ + "qid": qid, "category": qid[0], "query": query, + "retriever": r_name, + "chunks": [chunk_indices[i] for i in top], + "scores": scores, + "mean3": mean3, "mean5": mean5, "mrr": mrr, + }) + + # Aggregate + print("\n" + "=" * 100) + print("AGGREGATED RESULTS") + print("=" * 100) + + by_retriever = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []}) + by_cat_retriever = defaultdict( + lambda: {"mean3": [], "mean5": [], "mrr": []}) + for r in all_results: + by_retriever[r["retriever"]]["mean3"].append(r["mean3"]) + by_retriever[r["retriever"]]["mean5"].append(r["mean5"]) + by_retriever[r["retriever"]]["mrr"].append(r["mrr"]) + cat_key = (r["category"], r["retriever"]) + by_cat_retriever[cat_key]["mean3"].append(r["mean3"]) + by_cat_retriever[cat_key]["mean5"].append(r["mean5"]) + by_cat_retriever[cat_key]["mrr"].append(r["mrr"]) + + print("\nOverall (across all 18 queries):") + print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") + for r_name, _ in retrievers: + m = by_retriever[r_name] + avg = lambda xs: sum(xs) / len(xs) if xs else 0 + print(f"{r_name:<14} {avg(m['mean3']):>8.3f} " + f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") + + print("\nBy category (K=keyword, C=conceptual, N=narrative):") + print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}") + for cat in ["K", "C", "N"]: + for r_name, _ in retrievers: + m = by_cat_retriever[(cat, r_name)] + avg = lambda xs: sum(xs) / len(xs) if xs else 0 + print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} " + f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}") + + print("\nPer-query winner (highest mean@3, ties shown):") + print(f"{'qid':<4} {'query':<45} {'winner':<24} {'scores'}") + by_query = defaultdict(list) + for r in all_results: + by_query[r["qid"]].append(r) + for qid, results in sorted(by_query.items()): + max_score = max(r["mean3"] for r in results) + winners = [r["retriever"] for r in results if r["mean3"] == max_score] + scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}" + for r in results) + q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:42] + print(f"{qid:<4} {q_str:<45} {','.join(w[:8] for w in winners):<24} " + f"{scores}") + + # Save raw results to JSON for further analysis + out_path = "/tmp/voyage_rerank_judge_results.json" + with open(out_path, "w") as f: + json.dump(all_results, f, ensure_ascii=False, indent=2) + print(f"\nRaw results saved to {out_path}") + + await pool.close() + + +if __name__ == "__main__": + asyncio.run(main())