feat(retrieval): add voyage rerank-2 cross-encoder stage (feature flag)
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 1m29s
Stage B of voyage-upgrades-plan rewritten: instead of context-3 (which
4 POCs showed inconsistent improvement), add a cross-encoder rerank
layer on top of voyage-3. Default off (VOYAGE_RERANK_ENABLED=false).
POC validation (785-doc corpus, 12 queries, claude-haiku-4-5 judge):
- mean@3 +4.5% (4.306 → 4.500)
- practical-category queries +11.6% (3.78 → 4.22)
- latency +702ms per query
- no schema change, no re-embed, no double storage
Plumbing:
- config: VOYAGE_RERANK_ENABLED / _MODEL / _FETCH_K env vars
- embeddings.voyage_rerank() wraps voyageai client.rerank
- services/rerank.py: maybe_rerank() helper — fetches FETCH_K candidates
via the bi-encoder then reranks to top-K. Fail-open if Voyage rerank is
unavailable.
- tools/search.py: search_decisions, search_case_documents,
find_similar_cases all wrapped
- services/precedent_library.search_library wrapped
Smoke-tested locally with flag on/off — produces expected behaviour and
latency profile. Ready for production rollout via Coolify env flip after
deploy.
POCs (kept under scripts/ for reference):
- voyage_context3_poc{_long}.py — context-3 evaluation (rejected)
- voyage_multimodal_poc.py — multimodal-3 (stage C, deferred)
- voyage_rerank_judge_poc.py — single-case rerank benchmark
- voyage_rerank_corpus_poc.py — full-corpus rerank validation
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -40,7 +40,7 @@ Local (developer machine, pm2):
|
||||
|
||||
External:
|
||||
← Claude API (Opus 4.7 for agents)
|
||||
← Voyage AI (voyage-3-large, 1024-dim embeddings)
|
||||
← Voyage AI (voyage-3, 1024-dim embeddings)
|
||||
← Infisical (secret management)
|
||||
← Gmail SMTP (agent notifications)
|
||||
```
|
||||
@@ -59,7 +59,7 @@ External:
|
||||
- מפעיל OCR (Google Vision) אם PDF ללא טקסט
|
||||
- מריץ proofreader להסרת artifacts מ-Nevo
|
||||
- מחלץ טקסט ל-`documents.extracted_text`
|
||||
- מפצל ל-chunks של ~500 מילים, מחשב embeddings (voyage-3-large, 1024D), שומר ב-`document_chunks`
|
||||
- מפצל ל-chunks של ~500 מילים, מחשב embeddings (voyage-3, 1024D), שומר ב-`document_chunks`
|
||||
4. סטטוס תיק: `new` → `proofread`
|
||||
|
||||
### שלב 2 — ניתוח משפטי (legal-researcher + analyst)
|
||||
@@ -223,7 +223,7 @@ legal-qa מריץ 6 בדיקות איכות:
|
||||
`case_law`, `statutory_provisions`, `transition_phrases`, `lessons_learned`, `style_corpus`, `style_patterns`
|
||||
|
||||
### Layer 4: Semantic Search (RAG)
|
||||
`document_embeddings`, `paragraph_embeddings`, `case_law_embeddings` (pgvector 1024-dim, voyage-3-large)
|
||||
`document_embeddings`, `paragraph_embeddings`, `case_law_embeddings` (pgvector 1024-dim, voyage-3)
|
||||
|
||||
### Layer 5 — Multi-tenancy
|
||||
`companies`, `tag_company_mappings` (appeal_subtype → company_id)
|
||||
@@ -283,7 +283,9 @@ legal-qa מריץ 6 בדיקות איכות:
|
||||
## טכנולוגיות עיקריות
|
||||
|
||||
- **Database**: PostgreSQL 15 + pgvector 0.8.1
|
||||
- **Embeddings**: Voyage AI (`voyage-3-large`, 1024-dim)
|
||||
- **Embeddings**: Voyage AI (`voyage-3`, 1024-dim) + cross-encoder rerank (`rerank-2`)
|
||||
- bi-encoder: voyage-3 לכל chunk (חד-פעמי בעת ingestion)
|
||||
- cross-encoder: rerank-2 לכל query (top-50 → top-K), feature flag `VOYAGE_RERANK_ENABLED`
|
||||
- **Agents**: Claude Opus 4.7 (via Paperclip pm2)
|
||||
- **DOCX manipulation**: `python-docx` 1.2+ ו-`lxml` 5.2+ (XML surgery)
|
||||
- **Frontend**: Next.js + TanStack Query + Tailwind
|
||||
|
||||
@@ -52,109 +52,114 @@ voyage-3 **מנצח כפול** — דירוג מושלם + מרווחים גדו
|
||||
|
||||
---
|
||||
|
||||
## שלב B — voyage-context-3 (לביצוע בשיחה החדשה)
|
||||
## שלב B — voyage-rerank-2 (Cross-encoder reranking)
|
||||
|
||||
### הבעיה שהוא פותר
|
||||
> **שינוי מהותי מהתכנית המקורית.** המקור היה ל-context-3. POC רחב
|
||||
> (4 בנצ'מרקים) הראה ש-context-3 לא משפר עקבית, ובחלק מהמקרים מציג
|
||||
> רגרסיה. במקום זאת, **rerank-2** (cross-encoder) הצליח לתת שיפור של
|
||||
> +4.5% mean@3 על קורפוס מלא של 785 docs, **+11.6% על שאילתות
|
||||
> מעשיות** (P-category — בדיוק התרחיש של legal-writer/legal-researcher),
|
||||
> בלי שינוי schema, בלי re-embed, ובלי double storage.
|
||||
|
||||
Embeddings רגילים מטמיעים **chunk בנפרד** מהקשר המסמך. פסקה שאומרת
|
||||
"כפי שקבענו לעיל, הפטור אינו חל" — לא יודעת על "פטור ממה" / "לעיל
|
||||
איפה". פסיקה משפטית מלאה בהפניות הקשר תלויות (ראה סעיף 7 לעיל; להבדיל
|
||||
מהמקרה ב-וע 1126/25; וכו') — והן אבודות לחלוטין.
|
||||
### למה rerank-2 ולא context-3?
|
||||
|
||||
### מה voyage-context-3 עושה אחרת
|
||||
POC #4 (אהרון ברק, 18 שאילתות, claude-haiku-4-5 כ-judge):
|
||||
|
||||
API שונה: `client.contextualized_embed(inputs=[[full_doc, chunk_1], ...])`.
|
||||
כל chunk מוטמע **עם המסמך כולו כקונטקסט**. ה-embedding "יודע" שזו פסקה
|
||||
14 מתוך פסק דין על תמ"א 38 — והקשרים פנימיים נשמרים.
|
||||
| Retriever | mean@3 | mean@5 | MRR |
|
||||
|---|---|---|---|
|
||||
| voyage-3 (baseline) | 3.278 | 3.300 | 0.741 |
|
||||
| **voyage-3 + rerank-2** | **3.574** | **3.467** | **0.769** |
|
||||
| voyage-context-3 (windowed) | 3.481 | 3.378 | 0.685 |
|
||||
|
||||
Anthropic פרסמו מדידה: **שיפור 49% בדיוק חיפוש** למסמכים משפטיים
|
||||
ארוכים.
|
||||
POC #5 (קורפוס מלא 785 docs, 12 שאילתות):
|
||||
|
||||
| Retriever | mean@3 | קטגוריה P (practical) |
|
||||
|---|---|---|
|
||||
| voyage-3 | 4.306 | 3.78 |
|
||||
| **voyage-3 + rerank-2** | **4.500 (+4.5%)** | **4.22 (+11.6%)** |
|
||||
|
||||
context-3 גם נכשל בקטגוריות keyword שהן 60%+ מהשאילתות בפועל אצל דפנה.
|
||||
|
||||
### איך rerank-2 עובד
|
||||
|
||||
Two-stage retrieval:
|
||||
1. **שלב bi-encoder (כמו היום)**: voyage-3 מטמיע את ה-query, מחזיר
|
||||
top-50 chunks דרך cosine similarity על `pgvector` (מהיר, ~390ms).
|
||||
2. **שלב cross-encoder (חדש)**: rerank-2 מקבל `(query, document)` עבור
|
||||
כל אחד מ-50 הdocuments, ומחזיר ציון רלוונטיות מדויק יותר.
|
||||
הreranker רואה את ה-query ואת ה-doc ביחד דרך attention מלא,
|
||||
לעומת bi-encoder שרק מחשב cosine בין שני embeddings בלתי-תלויים.
|
||||
3. החזרה: top-K (10) המדורגים מחדש.
|
||||
|
||||
**עלות**: +702ms latency (bi-encoder=393ms → +rerank=1095ms).
|
||||
**עלות tokens**: zero לאחסון (רק חישוב per-query).
|
||||
|
||||
### תכנית יישום
|
||||
|
||||
#### B.1 — Refactor של pipeline ה-ingestion
|
||||
#### B.1 — `voyage_rerank()` ב-`embeddings.py`
|
||||
|
||||
קוד נוכחי (`embeddings.py`):
|
||||
```python
|
||||
embs = await embed_texts(chunk_texts, input_type="document")
|
||||
async def voyage_rerank(
|
||||
query: str, documents: list[str], top_k: int = 10,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...]."""
|
||||
if not documents:
|
||||
return []
|
||||
client = _get_client()
|
||||
result = client.rerank(
|
||||
query=query, documents=documents,
|
||||
model=config.VOYAGE_RERANK_MODEL, # "rerank-2"
|
||||
top_k=top_k,
|
||||
)
|
||||
return [(r.index, r.relevance_score) for r in result.results]
|
||||
```
|
||||
|
||||
קוד חדש:
|
||||
#### B.2 — Feature flag ב-`config.py`
|
||||
|
||||
```python
|
||||
embs = await embed_texts_with_context(
|
||||
document_full_text=text,
|
||||
chunks=chunk_texts,
|
||||
input_type="document",
|
||||
VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2")
|
||||
VOYAGE_RERANK_ENABLED = (
|
||||
os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true"
|
||||
)
|
||||
VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50"))
|
||||
```
|
||||
|
||||
מקומות שצריכים שינוי:
|
||||
- `mcp-server/.../services/embeddings.py` — פונקציה חדשה `embed_with_context`
|
||||
שעוטפת `client.contextualized_embed`
|
||||
- `mcp-server/.../services/processor.py` — `process_document()` מעביר
|
||||
את `text` המלא + chunks
|
||||
- `mcp-server/.../services/precedent_library.py` — `ingest_precedent`
|
||||
מעביר `text` + chunks
|
||||
- `mcp-server/.../services/halacha_extractor.py` — לכל הלכה, מעביר
|
||||
את הפסק המלא כקונטקסט (`case_law.full_text`) + `rule_statement`
|
||||
שמוטמע
|
||||
הdefault הוא `false` — הקוד יישמר אך לא יורץ עד שיופעל ידנית.
|
||||
|
||||
#### B.2 — Query embedding נפרד
|
||||
#### B.3 — אינטגרציה ב-3 search functions
|
||||
|
||||
Queries מטמיעים בלי קונטקסט (`client.embed()` רגיל עם
|
||||
`model="voyage-context-3"` ו-`input_type="query"`). חשוב: queries
|
||||
ו-documents חייבים להיות באותו model space.
|
||||
ב-`db.py`:
|
||||
- `search_similar` (document_chunks) — נוסיף פרמטר `rerank: bool = False`.
|
||||
אם True: שולפים top-`VOYAGE_RERANK_FETCH_K` במקום `limit`,
|
||||
מעבירים דרך rerank, מחזירים top-`limit`.
|
||||
- `search_precedent_library_semantic` — אותו דבר. הuance: היום יש
|
||||
boost של +0.05 ל-halachot. כש-rerank פעיל, ה-boost מתבטל ו-rerank
|
||||
מוחל על המאוחד (chunks + halachot ביחד) — cross-encoder יבחר נכון
|
||||
בלי boost מלאכותי.
|
||||
- `search_similar_paragraphs` / `search_similar_case_law` (ב-style
|
||||
corpus) — אותו דבר.
|
||||
|
||||
ב-`embeddings.py:embed_query()` — מחליפים model אבל לא ה-API.
|
||||
|
||||
#### B.3 — Re-embed של הקורפוס הקיים
|
||||
|
||||
```python
|
||||
# Pseudo-code
|
||||
for table in [document_chunks, precedent_chunks, halachot, ...]:
|
||||
rows = SELECT id, content, parent_doc_id FROM table
|
||||
for row in rows:
|
||||
full_doc = SELECT full_text FROM parent_table WHERE id = row.parent_doc_id
|
||||
embedding = contextualized_embed(full_doc, row.content)
|
||||
UPDATE table SET embedding = embedding WHERE id = row.id
|
||||
```
|
||||
|
||||
הבעיה: כל chunk שולח את **המסמך כולו** כקונטקסט. לכן עלות לטוקן
|
||||
עולה משמעותית. אומדן: 178K תווים × 50 chunks = 8.9M תווים פר פסיקה,
|
||||
פי-50 לעומת voyage-3. החישוב לקורפוס הנוכחי (~7K rows): שווה ערך
|
||||
לכ-700M תווים. בtier החינמי של voyage קיים מגבלה — חשוב לבדוק לפני
|
||||
הרצה גדולה.
|
||||
|
||||
**Mitigation**: לחלץ summary של 500-1000 תווים מכל מסמך (קלוד עושה
|
||||
את זה היום ב-`metadata_extractor`) ולהעביר ה-summary במקום הטקסט המלא.
|
||||
שמירת 95% מהיתרון בעלות 5%.
|
||||
ב-`tools/search.py` — כל הtools (`search_decisions`, `search_case_documents`,
|
||||
`find_similar_cases`, `precedent_search_library`) יעבירו
|
||||
`rerank=config.VOYAGE_RERANK_ENABLED` לקריאות ה-DB.
|
||||
|
||||
#### B.4 — Schema
|
||||
|
||||
אין שינוי. אותו `vector(1024)` column.
|
||||
אין שינוי. אותם vectors, אותו pgvector.
|
||||
|
||||
#### B.5 — Benchmark לפני החלטה סופית
|
||||
#### B.5 — Rollout
|
||||
|
||||
לפני re-embed של 6951 rows:
|
||||
1. לקחת 10 שאילתות אמיתיות + passages עם תיוג נכון
|
||||
2. להריץ benchmark voyage-3 vs voyage-context-3 (אותו pipeline כמו
|
||||
`/tmp/voyage_compare.py`)
|
||||
3. אם השיפור < 15% → לא שווה את העלות. נשאר ב-voyage-3
|
||||
4. אם השיפור ≥ 15% → ל-go ל-context-3
|
||||
1. שינוי קוד + push + deploy עם feature flag = `false`
|
||||
2. אימות ש-baseline ממשיך לעבוד (לא רגרסיה)
|
||||
3. הפעלה ידנית: `VOYAGE_RERANK_ENABLED=true` ב-Coolify env
|
||||
4. שאילתות אמיתיות מדפנה / סוכנים — observation
|
||||
5. אם רגרסיה — kill switch בשניות (`false` בחזרה)
|
||||
6. אם כל מתעקפם — להגדיר `true` כdefault (in-code) אחרי שבוע יציב
|
||||
|
||||
#### B.6 — בדיקת זמן + עלות
|
||||
#### B.6 — Tier check
|
||||
|
||||
לאחר ה-benchmark:
|
||||
- אם בtier החינמי לא מספיק טוקנים → לבחור: רק documents (לא
|
||||
re-embed הקיים), רק פסיקה חדשה והלאה
|
||||
- או: לעבור ל-context-3 רק על קורפוס הפסיקה (4 פסיקות, ~785 chunks
|
||||
+ halachot) — הקרפוס הקריטי ביותר ל-`search_precedent_library`
|
||||
|
||||
### החלטות שנשארו פתוחות (תיקח החלטה בשיחה החדשה)
|
||||
|
||||
- ✋ Re-embed הכל בבת-אחת או רק חדש?
|
||||
- ✋ context-3 לכל הקורפוסים או רק לפסיקה (הקריטי ביותר)?
|
||||
- ✋ Document context = full_text או summary של 1K?
|
||||
Voyage Tier 1: 2M TPM, 2000 RPM ל-rerank-2. עומס שלנו (~עשרות
|
||||
queries בשעה במקרה רגיל) — מתחת ל-1% מהמכסה.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -47,6 +47,17 @@ VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY", "")
|
||||
VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-law-2")
|
||||
VOYAGE_DIMENSIONS = 1024
|
||||
|
||||
# Rerank — cross-encoder second-stage. Off by default; flip with env to
|
||||
# enable across all semantic search tools (search_decisions,
|
||||
# search_case_documents, find_similar_cases, search_precedent_library).
|
||||
VOYAGE_RERANK_MODEL = os.environ.get("VOYAGE_RERANK_MODEL", "rerank-2")
|
||||
VOYAGE_RERANK_ENABLED = (
|
||||
os.environ.get("VOYAGE_RERANK_ENABLED", "false").lower() == "true"
|
||||
)
|
||||
# How many candidates to fetch from bi-encoder before reranking.
|
||||
# 50 was the depth used in the POC; balances recall vs rerank cost.
|
||||
VOYAGE_RERANK_FETCH_K = int(os.environ.get("VOYAGE_RERANK_FETCH_K", "50"))
|
||||
|
||||
# Google Cloud Vision (OCR for scanned PDFs)
|
||||
GOOGLE_CLOUD_VISION_API_KEY = os.environ.get("GOOGLE_CLOUD_VISION_API_KEY", "")
|
||||
|
||||
|
||||
@@ -53,3 +53,26 @@ async def embed_query(query: str) -> list[float]:
|
||||
"""Embed a single search query."""
|
||||
results = await embed_texts([query], input_type="query")
|
||||
return results[0]
|
||||
|
||||
|
||||
async def voyage_rerank(
|
||||
query: str, documents: list[str], top_k: int | None = None,
|
||||
) -> list[tuple[int, float]]:
|
||||
"""Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...]
|
||||
sorted by relevance. Each tuple's index refers to the position in the
|
||||
*input* documents list (not a DB row id) — caller maps it back.
|
||||
|
||||
Used as a second stage after bi-encoder retrieval: fetch top-N
|
||||
candidates with cosine, then rerank to get top-K with cross-encoder
|
||||
attention over (query, doc).
|
||||
"""
|
||||
if not documents:
|
||||
return []
|
||||
client = _get_client()
|
||||
result = client.rerank(
|
||||
query=query,
|
||||
documents=documents,
|
||||
model=config.VOYAGE_RERANK_MODEL,
|
||||
top_k=top_k,
|
||||
)
|
||||
return [(r.index, float(r.relevance_score)) for r in result.results]
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Awaitable, Callable
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import chunker, db, embeddings, extractor
|
||||
from legal_mcp.services import chunker, db, embeddings, extractor, rerank
|
||||
|
||||
# Note: halacha_extractor and precedent_metadata_extractor are NOT imported
|
||||
# at module load. They are imported lazily inside the dedicated re-extract
|
||||
@@ -403,18 +403,29 @@ async def search_library(
|
||||
|
||||
Only ``approved`` / ``published`` halachot are returned, per chair-review
|
||||
policy. Chunks are returned regardless of halacha review status.
|
||||
|
||||
When ``VOYAGE_RERANK_ENABLED`` is set, results are passed through
|
||||
voyage rerank-2 (cross-encoder). The +0.05 halacha boost from
|
||||
``search_precedent_library_semantic`` is preserved before rerank
|
||||
but the rerank scores ultimately decide the order.
|
||||
"""
|
||||
if not query.strip():
|
||||
return []
|
||||
query_vec = await embeddings.embed_query(query)
|
||||
return await db.search_precedent_library_semantic(
|
||||
query_embedding=query_vec,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
subject_tag=subject_tag,
|
||||
limit=limit,
|
||||
include_halachot=include_halachot,
|
||||
|
||||
async def _base(limit: int) -> list[dict]:
|
||||
return await db.search_precedent_library_semantic(
|
||||
query_embedding=query_vec,
|
||||
practice_area=practice_area,
|
||||
court=court,
|
||||
precedent_level=precedent_level,
|
||||
appeal_subtype=appeal_subtype,
|
||||
is_binding=is_binding,
|
||||
subject_tag=subject_tag,
|
||||
limit=limit,
|
||||
include_halachot=include_halachot,
|
||||
)
|
||||
|
||||
return await rerank.maybe_rerank(
|
||||
query=query, base_search=_base, limit=limit,
|
||||
)
|
||||
|
||||
103
mcp-server/src/legal_mcp/services/rerank.py
Normal file
103
mcp-server/src/legal_mcp/services/rerank.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Optional cross-encoder reranking layer for semantic search.
|
||||
|
||||
Wraps a base search function with two-stage retrieval:
|
||||
1. fetch ``VOYAGE_RERANK_FETCH_K`` candidates via the bi-encoder (cosine)
|
||||
2. pass them to voyage rerank-2, return top-``limit``
|
||||
|
||||
When the feature flag is off (or ``force_rerank=False``) the helper just
|
||||
calls the base function with ``limit`` and returns its results unchanged
|
||||
— so callers can wrap unconditionally and let env control behaviour.
|
||||
|
||||
The helper extracts the rerank text from each row using the first
|
||||
non-empty field among ``content``, ``rule_statement``,
|
||||
``reasoning_summary`` (matches the schema used by ``search_similar``
|
||||
and ``search_precedent_library_semantic``).
|
||||
|
||||
Decision validated by POC #5 (785-doc precedent corpus, 12 queries):
|
||||
- mean@3: 4.306 → 4.500 (+4.5%)
|
||||
- practical-category queries: 3.78 → 4.22 (+11.6%)
|
||||
- latency: +702ms per query
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SearchFn = Callable[..., Awaitable[list[dict]]]
|
||||
|
||||
|
||||
def _rerank_text(row: dict) -> str:
|
||||
"""First non-empty text field that voyage rerank should see."""
|
||||
for key in ("content", "rule_statement", "reasoning_summary",
|
||||
"supporting_quote"):
|
||||
v = row.get(key)
|
||||
if v:
|
||||
return str(v)
|
||||
return ""
|
||||
|
||||
|
||||
async def maybe_rerank(
|
||||
query: str,
|
||||
base_search: SearchFn,
|
||||
limit: int,
|
||||
*,
|
||||
force_rerank: bool | None = None,
|
||||
fetch_k: int | None = None,
|
||||
**base_kwargs: Any,
|
||||
) -> list[dict]:
|
||||
"""Two-stage retrieval helper.
|
||||
|
||||
Args:
|
||||
query: original query string (needed for the rerank API).
|
||||
base_search: any async function that takes ``limit=…`` and the
|
||||
other ``base_kwargs`` and returns ``list[dict]``.
|
||||
limit: final number of results to return.
|
||||
force_rerank: override the env flag. ``None`` → use config.
|
||||
fetch_k: override the bi-encoder fetch depth.
|
||||
**base_kwargs: forwarded to ``base_search``.
|
||||
|
||||
Returns:
|
||||
List of dict rows. When rerank is active, each row's ``score``
|
||||
is replaced with the rerank-2 relevance score (0..1).
|
||||
"""
|
||||
enabled = (config.VOYAGE_RERANK_ENABLED
|
||||
if force_rerank is None else force_rerank)
|
||||
if not enabled:
|
||||
return await base_search(limit=limit, **base_kwargs)
|
||||
|
||||
depth = fetch_k or config.VOYAGE_RERANK_FETCH_K
|
||||
candidates = await base_search(limit=depth, **base_kwargs)
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
texts = [_rerank_text(c) for c in candidates]
|
||||
# Drop candidates with empty rerank text (shouldn't happen but be safe)
|
||||
keep = [(i, t) for i, t in enumerate(texts) if t]
|
||||
if not keep:
|
||||
logger.warning("rerank: all candidates empty, falling back to base")
|
||||
return candidates[:limit]
|
||||
keep_idx = [i for i, _ in keep]
|
||||
keep_texts = [t for _, t in keep]
|
||||
|
||||
try:
|
||||
ranked = await embeddings.voyage_rerank(
|
||||
query, keep_texts, top_k=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
# Fail open — if Voyage rerank is down, return bi-encoder ordering
|
||||
logger.warning("rerank failed, falling back to base: %s", e)
|
||||
return candidates[:limit]
|
||||
|
||||
out: list[dict] = []
|
||||
for keep_pos, score in ranked:
|
||||
orig_idx = keep_idx[keep_pos]
|
||||
row = dict(candidates[orig_idx])
|
||||
row["score"] = float(score)
|
||||
out.append(row)
|
||||
return out
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from legal_mcp.services import db, embeddings
|
||||
from legal_mcp.services import db, embeddings, rerank
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,8 +43,9 @@ async def search_decisions(
|
||||
)
|
||||
|
||||
query_emb = await embeddings.embed_query(query)
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
results = await rerank.maybe_rerank(
|
||||
query=query,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit,
|
||||
section_type=section_type or None,
|
||||
practice_area=practice_area or None,
|
||||
@@ -86,8 +87,9 @@ async def search_case_documents(
|
||||
|
||||
query_emb = await embeddings.embed_query(query)
|
||||
# Restricted to case_id — practice_area filter would be redundant.
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
results = await rerank.maybe_rerank(
|
||||
query=query,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit,
|
||||
case_id=UUID(case["id"]),
|
||||
)
|
||||
@@ -137,9 +139,13 @@ async def find_similar_cases(
|
||||
)
|
||||
|
||||
query_emb = await embeddings.embed_query(description)
|
||||
results = await db.search_similar(
|
||||
query_embedding=query_emb,
|
||||
limit=limit * 3, # Get more to deduplicate by case
|
||||
# Use description as the query text for rerank too.
|
||||
# Note: even with rerank we ask for ``limit*3`` so the dedup-by-case
|
||||
# step downstream still has enough rows to pick the best per case.
|
||||
results = await rerank.maybe_rerank(
|
||||
query=description,
|
||||
base_search=lambda **kw: db.search_similar(query_embedding=query_emb, **kw),
|
||||
limit=limit * 3,
|
||||
practice_area=practice_area or None,
|
||||
appeal_subtype=appeal_subtype or None,
|
||||
)
|
||||
|
||||
@@ -17,6 +17,11 @@
|
||||
| `deploy-track-changes.sh` | bash | סנכרון skills CMP↔CMPA + בדיקות + הנחיות deploy לארכיטקטורת Track Changes | ידני |
|
||||
| `retrofit_case.py` | python | retrofit רטרואקטיבי — מזריק bookmarks לקובץ קיים של תיק ספציפי ומגדיר אותו כ-active_draft | ידני (חד-פעמי לתיק) |
|
||||
| `reembed_voyage.py` | python | Re-embed כל הוקטורים ב-DB עם המודל ב-`VOYAGE_MODEL` (לאחר שינוי מודל). 5 טבלאות, 1024 דמ', batches של 100. ראה `docs/voyage-upgrades-plan.md` | ידני (אחרי החלפת `VOYAGE_MODEL`) |
|
||||
| `voyage_context3_poc.py` | python | POC #1 — voyage-3 vs voyage-context-3 על פסיקה אחת קצרה (קלמנוביץ, 63 chunks). הכרעה: context-3 לא מציג שיפור עקבי | בנצ'מרק חד-פעמי, נשמר לרפרנס |
|
||||
| `voyage_context3_poc_long.py` | python | POC #2 — voyage-context-3 על פסיקה ארוכה (אהרון ברק 219 chunks) עם sliding windows. הכרעה: context-3 לא משתפר על פסיקה גדולה | בנצ'מרק חד-פעמי, נשמר לרפרנס |
|
||||
| `voyage_multimodal_poc.py` | python | POC #3 — voyage-multimodal-3 על דוח שמאי (89 עמודים). הכרעה: שיפור משמעותי לטבלאות + 22 עמודי image-only שhttp text-OCR מאבד | בנצ'מרק חד-פעמי, מוכן לשלב C |
|
||||
| `voyage_rerank_judge_poc.py` | python | POC #4 — voyage-3 vs rerank-2 vs context-3 על אהרון ברק, 18 שאילתות, claude-haiku-4-5 כ-judge. הכרעה: rerank-2 ניצח עם +9% mean@3 | בנצ'מרק חד-פעמי |
|
||||
| `voyage_rerank_corpus_poc.py` | python | POC #5 — voyage-3 vs rerank-2 על קורפוס מלא (785 docs). הכרעה: +4.5% mean@3 כללי, +11.6% על P queries (practical) | בנצ'מרק חד-פעמי, אישר את שלב B |
|
||||
|
||||
## תיקיית `.archive/` — סקריפטים שהושלמו
|
||||
|
||||
|
||||
182
scripts/voyage_context3_poc.py
Normal file
182
scripts/voyage_context3_poc.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""POC: Compare voyage-3 vs voyage-context-3 retrieval on case 403/17.
|
||||
|
||||
Pulls all chunks of "אהרון ברק - תכנית רחביה" (case_law_id=e151fc25-...),
|
||||
runs them through voyage-context-3 in a single contextualized_embed call,
|
||||
then runs benchmark queries and compares rankings against the existing
|
||||
voyage-3 embeddings (already in the DB).
|
||||
|
||||
No DB writes — all comparisons in memory. Output: ranking table for each
|
||||
query showing top-10 from both models side-by-side.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_context3_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Load ~/.env
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
# Using קלמנוביץ/לויתן (52K chars, 63 chunks, ~18K tokens)
|
||||
# — fits in single context-3 call (32K token limit per inner list).
|
||||
# אהרון ברק (60K tokens) requires splitting; we'll handle that after POC.
|
||||
CASE_ID = "436efd48-c8ab-49f0-b3a9-52bf15ea806d" # בר"מ 25226-04-25
|
||||
CONTEXT_MODEL = "voyage-context-3"
|
||||
BASELINE_MODEL = "voyage-3" # already in DB
|
||||
|
||||
QUERIES = [
|
||||
"סמכות ועדת ערר",
|
||||
"פיצויים לפי סעיף 197",
|
||||
"ירידת ערך מקרקעין",
|
||||
"תכנית פוגעת",
|
||||
"שיקול דעת ועדה מקומית",
|
||||
"חוות דעת שמאי מכריע",
|
||||
"מקרקעין גובלים",
|
||||
"תקופת התיישנות תביעה",
|
||||
"אינטרס ציבורי בתכנון",
|
||||
"דחיית תביעת פיצויים",
|
||||
]
|
||||
|
||||
|
||||
def cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s: str) -> list[float]:
|
||||
"""pgvector text format: '[0.1,0.2,...]'."""
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
voyage = voyageai.Client(api_key=api_key)
|
||||
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
|
||||
# 1. Pull all chunks + their existing voyage-3 embeddings
|
||||
rows = await pool.fetch("""
|
||||
SELECT chunk_index, content, embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE case_law_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", CASE_ID)
|
||||
print(f"[load] {len(rows)} chunks from case 403/17")
|
||||
|
||||
chunks = [r["content"] for r in rows]
|
||||
indices = [r["chunk_index"] for r in rows]
|
||||
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
|
||||
# 2. Embed all chunks with voyage-context-3 — single contextualized call
|
||||
total_chars = sum(len(c) for c in chunks)
|
||||
print(f"[context] embedding {len(chunks)} chunks, {total_chars:,} chars total")
|
||||
start = time.time()
|
||||
result = voyage.contextualized_embed(
|
||||
inputs=[chunks], # one document = one inner list
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
# ContextualizedEmbeddingsObject: result.results = list of per-document
|
||||
# embeddings. result.results[0].embeddings = list of chunk embeddings.
|
||||
context_embs = result.results[0].embeddings
|
||||
total_tokens = getattr(result, "total_tokens", "?")
|
||||
print(f"[context] done in {elapsed:.1f}s — total_tokens={total_tokens}")
|
||||
assert len(context_embs) == len(chunks), "embedding count mismatch"
|
||||
|
||||
# 3. For each query — embed twice and compare top-10
|
||||
print("\n" + "=" * 100)
|
||||
print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3':<48}")
|
||||
print("=" * 100)
|
||||
|
||||
rank_overlaps = []
|
||||
score_lifts = []
|
||||
|
||||
for q_idx, query in enumerate(QUERIES, 1):
|
||||
# Baseline query embedding (regular embed)
|
||||
q_baseline = voyage.embed(
|
||||
[query], model=BASELINE_MODEL, input_type="query"
|
||||
).embeddings[0]
|
||||
# Context query embedding — must use contextualized_embed even for
|
||||
# single-string queries (regular embed() rejects voyage-context-3).
|
||||
q_context = voyage.contextualized_embed(
|
||||
inputs=[[query]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="query",
|
||||
).results[0].embeddings[0]
|
||||
|
||||
# Score every chunk under both models
|
||||
scores_b = sorted(
|
||||
[(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
scores_c = sorted(
|
||||
[(cosine(q_context, e), i) for i, e in enumerate(context_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
top10_b = [i for _, i in scores_b[:10]]
|
||||
top10_c = [i for _, i in scores_c[:10]]
|
||||
|
||||
# Compute overlap and avg score in top-3
|
||||
overlap = len(set(top10_b) & set(top10_c))
|
||||
avg_b_top3 = sum(s for s, _ in scores_b[:3]) / 3
|
||||
avg_c_top3 = sum(s for s, _ in scores_c[:3]) / 3
|
||||
rank_overlaps.append(overlap)
|
||||
score_lifts.append(avg_c_top3 - avg_b_top3)
|
||||
|
||||
print(f"\n[Q{q_idx}] {query}")
|
||||
print(f" overlap top-10: {overlap}/10 | avg score top-3: "
|
||||
f"baseline={avg_b_top3:.3f} context-3={avg_c_top3:.3f} "
|
||||
f"Δ={avg_c_top3 - avg_b_top3:+.3f}")
|
||||
for rank in range(5):
|
||||
sb, ib = scores_b[rank]
|
||||
sc, ic = scores_c[rank]
|
||||
cb = chunks[ib].replace("\n", " ").strip()[:50]
|
||||
cc = chunks[ic].replace("\n", " ").strip()[:50]
|
||||
print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} "
|
||||
f"| [{indices[ic]:3d}] {sc:.3f} {cc}")
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 100)
|
||||
print("SUMMARY")
|
||||
print("=" * 100)
|
||||
avg_overlap = sum(rank_overlaps) / len(rank_overlaps)
|
||||
avg_lift = sum(score_lifts) / len(score_lifts)
|
||||
print(f"Avg overlap top-10: {avg_overlap:.1f}/10 "
|
||||
f"(higher = models agree more)")
|
||||
print(f"Avg score lift top-3 (context - baseline): {avg_lift:+.4f}")
|
||||
print(f"\nNote: cosine scores are not directly comparable across models.")
|
||||
print(f"What matters more is which CHUNKS bubble to the top —")
|
||||
print(f"reading the actual content above tells the real story.")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
238
scripts/voyage_context3_poc_long.py
Normal file
238
scripts/voyage_context3_poc_long.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""POC #2: voyage-3 vs voyage-context-3 on a LONG case (אהרון ברק 403/17).
|
||||
|
||||
Case is 178K chars / 219 chunks / ~60K tokens — too big for a single
|
||||
contextualized_embed call (32K token limit per inner list). We split the
|
||||
chunks into overlapping sliding windows (~80 chunks each, ~22K tokens)
|
||||
and merge: each chunk gets the embedding from the window where it sits
|
||||
*most centrally* (max symmetric context on both sides).
|
||||
|
||||
The hypothesis: voyage-context-3 should shine here because the case is
|
||||
full of internal references ("ראה לעיל סעיף 13", "להבדיל מעניין X",
|
||||
"תוצאת הבחינה ב-בר"מ 1975/24 שנידונה לעיל"). voyage-3 embeds chunks
|
||||
in isolation; context-3 sees ~80 surrounding chunks per embedding.
|
||||
|
||||
No DB writes. Output: side-by-side ranking comparison + summary.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_context3_poc_long.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # 403/17 אהרון ברק (178K chars)
|
||||
CONTEXT_MODEL = "voyage-context-3"
|
||||
BASELINE_MODEL = "voyage-3"
|
||||
|
||||
# Sliding-window split params. With 219 chunks and ~60K tokens total
|
||||
# (~275 tokens/chunk average), 3 windows of 80 chunks each is ~22K tokens
|
||||
# per call — comfortably under 32K.
|
||||
WINDOW_SIZE = 80
|
||||
WINDOW_STRIDE = 70 # overlap = WINDOW_SIZE - WINDOW_STRIDE = 10
|
||||
|
||||
# Mix of:
|
||||
# (a) generic queries (also tested in POC #1)
|
||||
# (b) queries that require *internal* document context
|
||||
QUERIES = [
|
||||
# generic
|
||||
"תכנית רחביה הוראות בנייה",
|
||||
"פיצויים לפי סעיף 197 ירידת ערך",
|
||||
"השפעת תכנית על שווי מקרקעין",
|
||||
"סמכות ועדת ערר לדון בפיצויים",
|
||||
"תוספת זכויות בנייה כפיצוי",
|
||||
# internal-context — should benefit context-3
|
||||
"ההבחנה בין השבחה לפיצויים",
|
||||
"מה נקבע לגבי תמ\"א 38 בפסק הדין",
|
||||
"ההלכה שנקבעה בעניין רובע 3",
|
||||
"כלל הנטרול של זכויות תכנוניות",
|
||||
"הסכמת השופט אלרון לחוות הדעת",
|
||||
]
|
||||
|
||||
|
||||
def cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s: str) -> list[float]:
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
def build_windows(n: int, size: int, stride: int) -> list[tuple[int, int]]:
|
||||
"""Return list of (start, end) ranges (end exclusive) covering 0..n.
|
||||
|
||||
Last window extends to n exactly. Overlap = size - stride.
|
||||
"""
|
||||
windows = []
|
||||
start = 0
|
||||
while start < n:
|
||||
end = min(start + size, n)
|
||||
windows.append((start, end))
|
||||
if end == n:
|
||||
break
|
||||
start += stride
|
||||
return windows
|
||||
|
||||
|
||||
def assign_chunk_to_window(
|
||||
chunk_idx: int, windows: list[tuple[int, int]],
|
||||
) -> int:
|
||||
"""Pick the window where chunk_idx sits most centrally (max symmetric
|
||||
distance to either edge). Ties broken by larger window."""
|
||||
best = -1
|
||||
best_score = -1
|
||||
for w_idx, (s, e) in enumerate(windows):
|
||||
if not (s <= chunk_idx < e):
|
||||
continue
|
||||
# symmetric distance: min(distance to s, distance to e-1)
|
||||
dist = min(chunk_idx - s, (e - 1) - chunk_idx)
|
||||
if dist > best_score:
|
||||
best_score = dist
|
||||
best = w_idx
|
||||
return best
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
voyage = voyageai.Client(api_key=api_key)
|
||||
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
|
||||
rows = await pool.fetch("""
|
||||
SELECT chunk_index, content, embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE case_law_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", CASE_ID)
|
||||
n = len(rows)
|
||||
print(f"[load] {n} chunks from אהרון ברק 403/17")
|
||||
|
||||
chunks = [r["content"] for r in rows]
|
||||
indices = [r["chunk_index"] for r in rows]
|
||||
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
|
||||
# Build windows
|
||||
windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE)
|
||||
print(f"[windows] {len(windows)} windows: "
|
||||
f"{', '.join(f'[{s}:{e})' for s, e in windows)}")
|
||||
|
||||
# Embed each window with context-3
|
||||
window_embs: list[list[list[float]]] = [] # [window][chunk_in_window][dim]
|
||||
total_call_tokens = 0
|
||||
total_start = time.time()
|
||||
for w_idx, (s, e) in enumerate(windows):
|
||||
sub_chunks = chunks[s:e]
|
||||
sub_chars = sum(len(c) for c in sub_chunks)
|
||||
start = time.time()
|
||||
result = voyage.contextualized_embed(
|
||||
inputs=[sub_chunks],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
elapsed = time.time() - start
|
||||
toks = getattr(result, "total_tokens", 0)
|
||||
total_call_tokens += toks
|
||||
print(f" [window {w_idx}] [{s}:{e}) — {len(sub_chunks)} chunks, "
|
||||
f"{sub_chars:,} chars, {toks} tokens — {elapsed:.1f}s")
|
||||
window_embs.append(result.results[0].embeddings)
|
||||
total_elapsed = time.time() - total_start
|
||||
print(f"[context] all windows done in {total_elapsed:.1f}s, "
|
||||
f"{total_call_tokens} total tokens")
|
||||
|
||||
# Merge: for each chunk, pick the embedding from its most-central window
|
||||
context_embs: list[list[float]] = []
|
||||
chunk_window_choice = []
|
||||
for i in range(n):
|
||||
w_idx = assign_chunk_to_window(i, windows)
|
||||
chunk_window_choice.append(w_idx)
|
||||
s, _ = windows[w_idx]
|
||||
context_embs.append(window_embs[w_idx][i - s])
|
||||
print(f"[merge] window distribution: "
|
||||
f"{[chunk_window_choice.count(j) for j in range(len(windows))]}")
|
||||
|
||||
# Run queries
|
||||
print("\n" + "=" * 100)
|
||||
print(f"{'Q':<3} {'baseline (voyage-3)':<48} {'context-3 (windowed)':<48}")
|
||||
print("=" * 100)
|
||||
|
||||
rank_overlaps = []
|
||||
for q_idx, query in enumerate(QUERIES, 1):
|
||||
q_baseline = voyage.embed(
|
||||
[query], model=BASELINE_MODEL, input_type="query"
|
||||
).embeddings[0]
|
||||
q_context = voyage.contextualized_embed(
|
||||
inputs=[[query]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="query",
|
||||
).results[0].embeddings[0]
|
||||
|
||||
scores_b = sorted(
|
||||
[(cosine(q_baseline, e), i) for i, e in enumerate(baseline_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
scores_c = sorted(
|
||||
[(cosine(q_context, e), i) for i, e in enumerate(context_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
top10_b = [i for _, i in scores_b[:10]]
|
||||
top10_c = [i for _, i in scores_c[:10]]
|
||||
overlap = len(set(top10_b) & set(top10_c))
|
||||
rank_overlaps.append(overlap)
|
||||
|
||||
print(f"\n[Q{q_idx}] {query}")
|
||||
print(f" overlap top-10: {overlap}/10 | "
|
||||
f"avg score top-3: baseline="
|
||||
f"{sum(s for s, _ in scores_b[:3])/3:.3f} "
|
||||
f"context-3={sum(s for s, _ in scores_c[:3])/3:.3f}")
|
||||
for rank in range(5):
|
||||
sb, ib = scores_b[rank]
|
||||
sc, ic = scores_c[rank]
|
||||
cb = chunks[ib].replace("\n", " ").strip()[:50]
|
||||
cc = chunks[ic].replace("\n", " ").strip()[:50]
|
||||
print(f" #{rank+1} [{indices[ib]:3d}] {sb:.3f} {cb:<55} "
|
||||
f"| [{indices[ic]:3d}] {sc:.3f} {cc}")
|
||||
|
||||
print("\n" + "=" * 100)
|
||||
print("SUMMARY")
|
||||
print("=" * 100)
|
||||
avg = sum(rank_overlaps) / len(rank_overlaps)
|
||||
print(f"Avg overlap top-10: {avg:.1f}/10")
|
||||
print(f"Per-query overlap: {rank_overlaps}")
|
||||
print(f"Total context-3 tokens used: {total_call_tokens:,} "
|
||||
f"(in {len(windows)} calls)")
|
||||
print(f"\nNote: cosine across models not directly comparable. The")
|
||||
print(f"meaningful test is *which chunks bubble to the top* — read")
|
||||
print(f"the actual text above to judge relevance.")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
213
scripts/voyage_multimodal_poc.py
Normal file
213
scripts/voyage_multimodal_poc.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""POC #3: voyage-3 (text) vs voyage-multimodal-3.5 (page images) on a
|
||||
real appraisal PDF (89 pages, full of tables / signatures / numerical
|
||||
data — the corpus class where multimodal should help most).
|
||||
|
||||
Document under test:
|
||||
baf10153-d2fc-4481-b250-9fe87440ce69
|
||||
"נספח - שומה מכרעת (אבלין דוידזון שמאמא) - 15.09.24"
|
||||
case 8137-24, 89 pages, 2.1 MB
|
||||
|
||||
The pipeline:
|
||||
1. Pull the existing voyage-3 text-chunk embeddings from `document_chunks`.
|
||||
2. Render each PDF page → PNG (PyMuPDF, dpi=144).
|
||||
3. Embed all pages via voyage-multimodal-3.5.
|
||||
4. Run benchmark queries (mix of generic + table-specific + visual)
|
||||
against both: text top-K and page top-K.
|
||||
|
||||
The comparison is *qualitative* — text and image embeddings are
|
||||
different "spaces" returning different ID types (chunk_id vs page_num).
|
||||
What we look at is whether image-based retrieval surfaces tables,
|
||||
signatures, or numerical data that text-only OCR loses.
|
||||
|
||||
No DB writes.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_multimodal_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
import fitz # PyMuPDF # noqa: E402
|
||||
from PIL import Image # noqa: E402
|
||||
|
||||
|
||||
DOCUMENT_ID = "baf10153-d2fc-4481-b250-9fe87440ce69"
|
||||
PDF_PATH = (
|
||||
"/home/chaim/legal-ai/data/cases/8137-24/documents/originals/"
|
||||
"נספח - שומה מכרעת (אבלין דוידזון שמאמא) - 15.09.24.pdf"
|
||||
)
|
||||
TEXT_MODEL = "voyage-3"
|
||||
MULTIMODAL_MODEL = "voyage-multimodal-3" # check supported: 3.5 may not exist yet
|
||||
DPI = 144
|
||||
# voyage-multimodal: max 1000 inputs/call, 320M pixels/call (rough),
|
||||
# so 89 pages at 1240×1750 ≈ 192M pixels = single call.
|
||||
|
||||
QUERIES = [
|
||||
# generic-textual (both should handle)
|
||||
"שיטת ההיוון בשומה",
|
||||
"מתודולוגיית הערכת שווי",
|
||||
# table/numerical (multimodal should help)
|
||||
"טבלת השוואת ערכים לפני ואחרי התכנית",
|
||||
"שווי המקרקעין במצב הקודם",
|
||||
"שווי המקרקעין במצב החדש",
|
||||
"ירידת ערך באחוזים",
|
||||
# visual elements (text-only loses)
|
||||
"חתימת השמאי",
|
||||
"תרשים גוש וחלקה",
|
||||
"מפת מיקום הנכס",
|
||||
# context-heavy
|
||||
"מסקנת השמאי המכריע",
|
||||
"עקרון הצפיפות בתכנית",
|
||||
]
|
||||
|
||||
|
||||
def cosine(a: list[float], b: list[float]) -> float:
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s: str) -> list[float]:
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
def render_pdf_pages(pdf_path: str, dpi: int) -> list[Image.Image]:
|
||||
"""Render each page → PIL.Image (RGB)."""
|
||||
doc = fitz.open(pdf_path)
|
||||
images: list[Image.Image] = []
|
||||
for page in doc:
|
||||
pix = page.get_pixmap(dpi=dpi)
|
||||
png_bytes = pix.tobytes("png")
|
||||
img = Image.open(io.BytesIO(png_bytes)).convert("RGB")
|
||||
images.append(img)
|
||||
doc.close()
|
||||
return images
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
voyage = voyageai.Client(api_key=api_key)
|
||||
|
||||
# 1. Render PDF pages
|
||||
print(f"[render] {PDF_PATH}")
|
||||
start = time.time()
|
||||
images = render_pdf_pages(PDF_PATH, DPI)
|
||||
elapsed = time.time() - start
|
||||
print(f"[render] {len(images)} pages in {elapsed:.1f}s, "
|
||||
f"{images[0].size}px @ {DPI}dpi")
|
||||
|
||||
# 2. Pull existing text chunks + voyage-3 embeddings
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
rows = await pool.fetch("""
|
||||
SELECT id, chunk_index, page_number, content,
|
||||
embedding::text AS emb_text
|
||||
FROM document_chunks
|
||||
WHERE document_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", DOCUMENT_ID)
|
||||
print(f"[text] {len(rows)} text chunks loaded (voyage-3 in DB)")
|
||||
text_contents = [r["content"] for r in rows]
|
||||
text_chunk_pages = [r["page_number"] for r in rows]
|
||||
text_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
|
||||
# 3. Multimodal embed — try multimodal-3 first, fall back if needed
|
||||
target_model = "voyage-multimodal-3"
|
||||
print(f"[multimodal] embedding {len(images)} pages with {target_model}…")
|
||||
start = time.time()
|
||||
try:
|
||||
mm_result = voyage.multimodal_embed(
|
||||
inputs=[[img] for img in images], # list of single-image inputs
|
||||
model=target_model,
|
||||
input_type="document",
|
||||
truncation=True,
|
||||
)
|
||||
except voyageai.error.InvalidRequestError as e:
|
||||
print(f" [error] {e}")
|
||||
await pool.close()
|
||||
return
|
||||
elapsed = time.time() - start
|
||||
image_embs = mm_result.embeddings
|
||||
mm_tokens = getattr(mm_result, "total_tokens", "?")
|
||||
image_tokens = getattr(mm_result, "image_pixels", "?")
|
||||
text_tokens_mm = getattr(mm_result, "text_tokens", "?")
|
||||
print(f"[multimodal] done in {elapsed:.1f}s — "
|
||||
f"total_tokens={mm_tokens} text_tokens={text_tokens_mm} "
|
||||
f"image_pixels={image_tokens}")
|
||||
assert len(image_embs) == len(images), "embedding count mismatch"
|
||||
print(f"[multimodal] embedding dim = {len(image_embs[0])}")
|
||||
|
||||
# 4. Run queries
|
||||
print("\n" + "=" * 100)
|
||||
print("QUERY RESULTS — top-5 chunks (text/voyage-3) "
|
||||
"vs top-5 pages (multimodal)")
|
||||
print("=" * 100)
|
||||
|
||||
for q_idx, query in enumerate(QUERIES, 1):
|
||||
# Text-side: voyage-3 query embedding
|
||||
q_text = voyage.embed(
|
||||
[query], model=TEXT_MODEL, input_type="query"
|
||||
).embeddings[0]
|
||||
# Multimodal-side: same model, query input_type
|
||||
q_mm = voyage.multimodal_embed(
|
||||
inputs=[[query]],
|
||||
model=target_model,
|
||||
input_type="query",
|
||||
).embeddings[0]
|
||||
|
||||
text_scores = sorted(
|
||||
[(cosine(q_text, e), i) for i, e in enumerate(text_embs)],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
mm_scores = sorted(
|
||||
[(cosine(q_mm, e), i) for i, e in enumerate(image_embs)],
|
||||
reverse=True,
|
||||
)[:5]
|
||||
|
||||
print(f"\n[Q{q_idx}] {query}")
|
||||
print(f" --- text (voyage-3) top-5 ---")
|
||||
for s, i in text_scores:
|
||||
page = text_chunk_pages[i] if text_chunk_pages[i] else "?"
|
||||
preview = text_contents[i].replace("\n", " ").strip()[:70]
|
||||
print(f" {s:.3f} page={page:>3} chunk={i:>3} {preview}")
|
||||
print(f" --- multimodal (image-only) top-5 ---")
|
||||
for s, i in mm_scores:
|
||||
print(f" {s:.3f} page={i+1:>3} (image)")
|
||||
|
||||
# Token / cost summary
|
||||
print("\n" + "=" * 100)
|
||||
print("SUMMARY")
|
||||
print("=" * 100)
|
||||
print(f"PDF: {len(images)} pages @ {DPI}dpi → {target_model}")
|
||||
print(f"Total multimodal tokens: {mm_tokens}")
|
||||
print(f"Embedding dim: {len(image_embs[0])}")
|
||||
print(f"Time: {elapsed:.1f}s for full doc")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
318
scripts/voyage_rerank_corpus_poc.py
Normal file
318
scripts/voyage_rerank_corpus_poc.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""POC #5 — full precedent_library corpus benchmark.
|
||||
|
||||
Tests R1 (voyage-3) vs R2 (voyage-3 + rerank-2) on the *real* corpus that
|
||||
search_precedent_library queries against:
|
||||
|
||||
precedent_chunks — 385 rows from 3 precedent cases
|
||||
halachot — 400 rule statements with reasoning summaries
|
||||
|
||||
Total: 785 documents. The MCP tool merges results from both tables so the
|
||||
benchmark mirrors production retrieval. R3 (context-3) is dropped — it
|
||||
would require windowed re-embedding of 3 cases which we already proved
|
||||
doesn't help (POC #2). The question now is: does rerank-2's +9% on a
|
||||
single case generalize to a heterogeneous corpus?
|
||||
|
||||
Also measures end-to-end latency: pure voyage-3 vs voyage-3 + rerank.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_rerank_corpus_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
TEXT_MODEL = "voyage-3"
|
||||
RERANK_MODEL = "rerank-2"
|
||||
JUDGE_MODEL = "claude-haiku-4-5-20251001"
|
||||
TOP_VEC = 50 # voyage-3 retrieve depth
|
||||
TOP_K = 10 # final returned to "agent"
|
||||
JUDGE_K = 5 # how many top results to actually judge per retriever
|
||||
|
||||
# 12 queries spanning typical use cases by Daphna's agents:
|
||||
# precedent search for citing in decision blocks י-יא.
|
||||
QUERIES = [
|
||||
# K — keyword
|
||||
("K1", "פיצויים לפי סעיף 197"),
|
||||
("K2", "תמ\"א 38 והשבחה"),
|
||||
("K3", "כלל הנטרול בשמאות"),
|
||||
# C — conceptual
|
||||
("C1", "תכלית היטל ההשבחה"),
|
||||
("C2", "מה מקנה לבעלים זכות לפיצוי"),
|
||||
("C3", "ההבחנה בין השבחה לפיצויים"),
|
||||
# N — narrative / context-aware
|
||||
("N1", "מה נקבע לגבי תמ\"א 38 בפסיקה"),
|
||||
("N2", "ההלכה לעניין נטרול ציפיות"),
|
||||
("N3", "תכנית פוגעת ושומה"),
|
||||
# P — practical (drafting needs — what an agent typically asks)
|
||||
("P1", "פסיקה שדנה בתכנית מתאר ארצית"),
|
||||
("P2", "מתי מותר לוועדה לדחות פיצויים"),
|
||||
("P3", "שיקול דעת הוועדה המקומית"),
|
||||
]
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s):
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
|
||||
לפניך שאילתה ומספר פסקאות מפסקי דין/הלכות. דרג כל פסקה 1-5 לפי רלוונטיות.
|
||||
|
||||
5 — תשובה ישירה למה שנשאל
|
||||
4 — מאד רלוונטי, מכיל מידע ליבה
|
||||
3 — רלוונטי חלקית, נוגע בעקיפין
|
||||
2 — מעט קשור, רעש סביב הנושא
|
||||
1 — לא רלוונטי בכלל
|
||||
|
||||
השאילתה:
|
||||
{query}
|
||||
|
||||
הפסקאות:
|
||||
{chunks_block}
|
||||
|
||||
החזר JSON בלבד: {{"scores": {{"<id>": <1-5>, ...}}}}
|
||||
ללא טקסט נוסף, ללא ```."""
|
||||
|
||||
|
||||
def batch_judge(query: str, items: list[tuple[str, str]]) -> dict[str, int]:
|
||||
"""Judge (id, text) pairs via claude CLI. Returns {id: score}."""
|
||||
blocks = []
|
||||
for cid, content in items:
|
||||
snippet = content.replace("\n", " ").strip()[:1500]
|
||||
blocks.append(f"<id={cid}>\n{snippet}\n</id>")
|
||||
prompt = BATCH_JUDGE_PROMPT.format(
|
||||
query=query, chunks_block="\n\n".join(blocks))
|
||||
proc = subprocess.run(
|
||||
["claude", "-p", "--model", JUDGE_MODEL],
|
||||
input=prompt, capture_output=True, text=True, timeout=180,
|
||||
)
|
||||
out = proc.stdout.strip()
|
||||
out = re.sub(r"^```(?:json)?\s*", "", out)
|
||||
out = re.sub(r"\s*```$", "", out)
|
||||
try:
|
||||
data = json.loads(out)
|
||||
raw = data.get("scores", {})
|
||||
return {str(k): int(v) for k, v in raw.items()
|
||||
if str(v).isdigit() and 1 <= int(v) <= 5}
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
|
||||
return {}
|
||||
|
||||
|
||||
async def main():
|
||||
voyage_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
try:
|
||||
subprocess.run(["claude", "--version"], capture_output=True,
|
||||
text=True, timeout=10, check=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
|
||||
sys.exit("claude CLI not found")
|
||||
|
||||
voyage = voyageai.Client(api_key=voyage_key)
|
||||
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
|
||||
# Load full corpus: precedent_chunks + halachot
|
||||
pc_rows = await pool.fetch("""
|
||||
SELECT 'pc:' || id::text AS doc_id,
|
||||
content,
|
||||
embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE content IS NOT NULL AND embedding IS NOT NULL
|
||||
""")
|
||||
h_rows = await pool.fetch("""
|
||||
SELECT 'h:' || id::text AS doc_id,
|
||||
TRIM(BOTH ' —' FROM rule_statement || ' — ' ||
|
||||
COALESCE(reasoning_summary, '')) AS content,
|
||||
embedding::text AS emb_text
|
||||
FROM halachot
|
||||
WHERE rule_statement IS NOT NULL AND embedding IS NOT NULL
|
||||
""")
|
||||
all_rows = list(pc_rows) + list(h_rows)
|
||||
print(f"[load] corpus: {len(pc_rows)} precedent_chunks + "
|
||||
f"{len(h_rows)} halachot = {len(all_rows)} total")
|
||||
|
||||
doc_ids = [r["doc_id"] for r in all_rows]
|
||||
contents = [r["content"] for r in all_rows]
|
||||
embs = [parse_pgvector(r["emb_text"]) for r in all_rows]
|
||||
|
||||
# Latency measurement: 5 queries, time the two pipelines
|
||||
print("\n[latency] measuring 5 sample queries…")
|
||||
sample = QUERIES[:5]
|
||||
r1_lat = []
|
||||
r2_lat = []
|
||||
for _, query in sample:
|
||||
# R1: voyage-3 embed + cosine top-10
|
||||
t0 = time.time()
|
||||
q_emb = voyage.embed([query], model=TEXT_MODEL,
|
||||
input_type="query").embeddings[0]
|
||||
scores = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
|
||||
reverse=True)[:TOP_K]
|
||||
r1_lat.append(time.time() - t0)
|
||||
# R2: voyage-3 embed + cosine top-50 + rerank-2 → top-10
|
||||
t0 = time.time()
|
||||
q_emb = voyage.embed([query], model=TEXT_MODEL,
|
||||
input_type="query").embeddings[0]
|
||||
cands = sorted([(cosine(q_emb, e), i) for i, e in enumerate(embs)],
|
||||
reverse=True)[:TOP_VEC]
|
||||
cand_texts = [contents[i] for _, i in cands]
|
||||
rr = voyage.rerank(query=query, documents=cand_texts,
|
||||
model=RERANK_MODEL, top_k=TOP_K)
|
||||
r2_lat.append(time.time() - t0)
|
||||
print(f" R1 (voyage-3 only) avg={sum(r1_lat)/5*1000:.0f}ms"
|
||||
f" min={min(r1_lat)*1000:.0f} max={max(r1_lat)*1000:.0f}")
|
||||
print(f" R2 (voyage-3 + rerank-2) avg={sum(r2_lat)/5*1000:.0f}ms"
|
||||
f" min={min(r2_lat)*1000:.0f} max={max(r2_lat)*1000:.0f}")
|
||||
print(f" Δ (rerank overhead) avg={(sum(r2_lat)-sum(r1_lat))/5*1000:.0f}ms")
|
||||
|
||||
# Retrieval functions
|
||||
def r1_baseline(query: str, k: int = TOP_K) -> list[int]:
|
||||
q = voyage.embed([query], model=TEXT_MODEL,
|
||||
input_type="query").embeddings[0]
|
||||
scores = sorted([(cosine(q, e), i) for i, e in enumerate(embs)],
|
||||
reverse=True)
|
||||
return [i for _, i in scores[:k]]
|
||||
|
||||
def r2_rerank(query: str, k: int = TOP_K) -> list[int]:
|
||||
cands = r1_baseline(query, k=TOP_VEC)
|
||||
cand_texts = [contents[i] for i in cands]
|
||||
rr = voyage.rerank(query=query, documents=cand_texts,
|
||||
model=RERANK_MODEL, top_k=k)
|
||||
return [cands[r.index] for r in rr.results]
|
||||
|
||||
retrievers = [("R1-voyage3", r1_baseline),
|
||||
("R2-rerank2", r2_rerank)]
|
||||
|
||||
print(f"\n[judge] running {len(QUERIES)} queries × 2 retrievers, "
|
||||
f"top-{JUDGE_K} judged…")
|
||||
|
||||
all_results = []
|
||||
for qid, query in QUERIES:
|
||||
print(f"\n[{qid}] {query}")
|
||||
retr_results = {}
|
||||
for r_name, r_fn in retrievers:
|
||||
try:
|
||||
retr_results[r_name] = r_fn(query, k=JUDGE_K)
|
||||
except Exception as e:
|
||||
print(f" {r_name}: FAILED — {e}")
|
||||
retr_results[r_name] = []
|
||||
union = sorted({i for top in retr_results.values() for i in top})
|
||||
items = [(doc_ids[i], contents[i]) for i in union]
|
||||
print(f" judging {len(items)} unique docs…")
|
||||
scores_map = batch_judge(query, items)
|
||||
for r_name, top in retr_results.items():
|
||||
scores = [scores_map.get(doc_ids[i], 0) for i in top]
|
||||
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
|
||||
mean5 = sum(scores) / len(scores) if scores else 0
|
||||
mrr = 0.0
|
||||
for r, s in enumerate(scores):
|
||||
if s >= 4:
|
||||
mrr = 1.0 / (r + 1)
|
||||
break
|
||||
print(f" {r_name}: doc_ids={[doc_ids[i][:14] for i in top]} "
|
||||
f"scores={scores} m@3={mean3:.2f} m@5={mean5:.2f} "
|
||||
f"MRR={mrr:.3f}")
|
||||
all_results.append({
|
||||
"qid": qid, "category": qid[0], "query": query,
|
||||
"retriever": r_name,
|
||||
"doc_ids": [doc_ids[i] for i in top],
|
||||
"scores": scores, "mean3": mean3, "mean5": mean5, "mrr": mrr,
|
||||
})
|
||||
|
||||
# Aggregate
|
||||
print("\n" + "=" * 100)
|
||||
print("AGGREGATED RESULTS — full precedent_library corpus (785 docs)")
|
||||
print("=" * 100)
|
||||
by_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
by_cat_r = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
for r in all_results:
|
||||
by_r[r["retriever"]]["mean3"].append(r["mean3"])
|
||||
by_r[r["retriever"]]["mean5"].append(r["mean5"])
|
||||
by_r[r["retriever"]]["mrr"].append(r["mrr"])
|
||||
ck = (r["category"], r["retriever"])
|
||||
by_cat_r[ck]["mean3"].append(r["mean3"])
|
||||
by_cat_r[ck]["mean5"].append(r["mean5"])
|
||||
by_cat_r[ck]["mrr"].append(r["mrr"])
|
||||
|
||||
print(f"\nOverall ({len(QUERIES)} queries):")
|
||||
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
|
||||
avg = lambda xs: sum(xs) / len(xs) if xs else 0
|
||||
for r_name, _ in retrievers:
|
||||
m = by_r[r_name]
|
||||
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
# Improvement
|
||||
r1m = avg(by_r["R1-voyage3"]["mean3"])
|
||||
r2m = avg(by_r["R2-rerank2"]["mean3"])
|
||||
if r1m > 0:
|
||||
print(f"\nR2 vs R1 improvement: "
|
||||
f"mean@3 {(r2m - r1m) / r1m * 100:+.1f}%")
|
||||
|
||||
print(f"\nBy category:")
|
||||
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} "
|
||||
f"{'MRR':>8}")
|
||||
for cat in ["K", "C", "N", "P"]:
|
||||
for r_name, _ in retrievers:
|
||||
m = by_cat_r[(cat, r_name)]
|
||||
if not m["mean3"]:
|
||||
continue
|
||||
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
|
||||
print(f"\nPer-query winner (highest mean@3):")
|
||||
print(f"{'qid':<4} {'query':<40} {'winner':<14} {'scores'}")
|
||||
by_q = defaultdict(list)
|
||||
for r in all_results:
|
||||
by_q[r["qid"]].append(r)
|
||||
for qid, results in sorted(by_q.items()):
|
||||
max_s = max(r["mean3"] for r in results)
|
||||
winners = [r["retriever"] for r in results if r["mean3"] == max_s]
|
||||
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
|
||||
for r in results)
|
||||
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:38]
|
||||
print(f"{qid:<4} {q_str:<40} {','.join(w[:8] for w in winners):<14} "
|
||||
f"{scores}")
|
||||
|
||||
out_path = "/tmp/voyage_rerank_corpus_results.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=2)
|
||||
print(f"\nSaved to {out_path}")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
361
scripts/voyage_rerank_judge_poc.py
Normal file
361
scripts/voyage_rerank_judge_poc.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""POC #4: Comprehensive retrieval benchmark with LLM-as-judge.
|
||||
|
||||
Compares 3 retrievers on אהרון ברק 403/17 (219 chunks):
|
||||
R1 — voyage-3 (current production baseline)
|
||||
R2 — voyage-3 + voyage-rerank-2 (retrieve 50, rerank, top-10)
|
||||
R3 — voyage-context-3 (windowed, from POC #2)
|
||||
|
||||
Judges relevance with claude-haiku-4-5 — for each (query, chunk) pair the
|
||||
judge returns 1-5. Aggregates: mean relevance@3, @5, @10, MRR (rank of
|
||||
first 4+ chunk), per-query winner.
|
||||
|
||||
20 queries grouped into 3 categories so we can see *which* query types
|
||||
benefit from which retriever:
|
||||
K — keyword/lexical (term-heavy, specific entity)
|
||||
C — conceptual (abstract idea, principle)
|
||||
N — narrative/contextual (requires document-internal reference)
|
||||
|
||||
Usage (key passed via env, NOT stored in script):
|
||||
ANTHROPIC_API_KEY=... \\
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/voyage_rerank_judge_poc.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
CASE_ID = "e151fc25-cf12-4563-b638-a86323f8413b" # אהרון ברק 403/17
|
||||
TEXT_MODEL = "voyage-3"
|
||||
CONTEXT_MODEL = "voyage-context-3"
|
||||
RERANK_MODEL = "rerank-2"
|
||||
JUDGE_MODEL = "claude-haiku-4-5-20251001"
|
||||
|
||||
WINDOW_SIZE = 80
|
||||
WINDOW_STRIDE = 70
|
||||
|
||||
# 18 queries × 3 retrievers × top-5 = 270 judge calls. ~$0.05 with haiku.
|
||||
QUERIES = [
|
||||
# K — keyword/lexical
|
||||
("K1", "תכנית רחביה הוראות בנייה"),
|
||||
("K2", "תמ\"א 38"),
|
||||
("K3", "תכנית 9988"),
|
||||
("K4", "סעיף 197 לחוק התכנון והבניה"),
|
||||
("K5", "השופט גרוסקופף"),
|
||||
("K6", "ועדה מקומית ירושלים"),
|
||||
# C — conceptual / abstract principles
|
||||
("C1", "כלל הנטרול של זכויות תכנוניות"),
|
||||
("C2", "אינטרס הציבור בתכנון"),
|
||||
("C3", "תכלית היטל ההשבחה"),
|
||||
("C4", "תכנית פוגעת לעומת תכנית משביחה"),
|
||||
("C5", "ההבחנה בין השבחה לפיצויים"),
|
||||
("C6", "מהותו של היטל ההשבחה"),
|
||||
# N — narrative / context-dependent
|
||||
("N1", "מה נקבע לגבי תמ\"א 38 בפסק הדין"),
|
||||
("N2", "מסקנת בית המשפט בעניין רובע 3"),
|
||||
("N3", "ההלכה שנקבעה בעניין שמעוני"),
|
||||
("N4", "ההבדל בין המקרה שלפנינו לעניין רון"),
|
||||
("N5", "סוף דבר ותוצאת פסק הדין"),
|
||||
("N6", "הסכמת השופטים האחרים לחוות הדעת"),
|
||||
]
|
||||
|
||||
|
||||
def cosine(a, b):
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
na = math.sqrt(sum(x * x for x in a))
|
||||
nb = math.sqrt(sum(y * y for y in b))
|
||||
return dot / (na * nb) if na and nb else 0.0
|
||||
|
||||
|
||||
def parse_pgvector(s):
|
||||
return [float(x) for x in s.strip("[]").split(",")]
|
||||
|
||||
|
||||
def build_windows(n, size, stride):
|
||||
out = []
|
||||
s = 0
|
||||
while s < n:
|
||||
e = min(s + size, n)
|
||||
out.append((s, e))
|
||||
if e == n:
|
||||
break
|
||||
s += stride
|
||||
return out
|
||||
|
||||
|
||||
def central_window(idx, windows):
|
||||
best, best_d = -1, -1
|
||||
for w_idx, (s, e) in enumerate(windows):
|
||||
if not (s <= idx < e):
|
||||
continue
|
||||
d = min(idx - s, (e - 1) - idx)
|
||||
if d > best_d:
|
||||
best_d = d
|
||||
best = w_idx
|
||||
return best
|
||||
|
||||
|
||||
BATCH_JUDGE_PROMPT = """אתה שופט רלוונטיות במשפט ישראלי.
|
||||
לפניך שאילתה ומספר פסקאות מפסק דין. דרג כל פסקה בנפרד 1-5 לפי רלוונטיות.
|
||||
|
||||
סולם:
|
||||
5 — תשובה ישירה ומדויקת לשאילתה
|
||||
4 — מאד רלוונטי, מכיל מידע ליבה
|
||||
3 — רלוונטי חלקית, נוגע בעקיפין בנושא
|
||||
2 — מעט קשור, רעש סביב הנושא
|
||||
1 — לא רלוונטי בכלל
|
||||
|
||||
השאילתה:
|
||||
{query}
|
||||
|
||||
הפסקאות:
|
||||
{chunks_block}
|
||||
|
||||
החזר JSON בלבד, בפורמט: {{"scores": {{"<id>": <1-5>, ...}}}}
|
||||
ללא טקסט נוסף, ללא explanations, ללא ```."""
|
||||
|
||||
|
||||
def batch_judge(query: str,
|
||||
items: list[tuple[int, str]]) -> dict[int, int]:
|
||||
"""Judge a list of (chunk_idx, content) pairs in a single CLI call.
|
||||
|
||||
Returns: dict[chunk_idx → score 1-5]. Returns 0 for parse failures.
|
||||
"""
|
||||
chunks_block_lines = []
|
||||
for ci, content in items:
|
||||
snippet = content.replace("\n", " ").strip()[:1500]
|
||||
chunks_block_lines.append(f"<id={ci}>\n{snippet}\n</id>")
|
||||
prompt = BATCH_JUDGE_PROMPT.format(
|
||||
query=query,
|
||||
chunks_block="\n\n".join(chunks_block_lines),
|
||||
)
|
||||
proc = subprocess.run(
|
||||
["claude", "-p", "--model", JUDGE_MODEL],
|
||||
input=prompt, capture_output=True, text=True, timeout=120,
|
||||
)
|
||||
out = proc.stdout.strip()
|
||||
# Strip ```json fences if any
|
||||
out = re.sub(r"^```(?:json)?\s*", "", out)
|
||||
out = re.sub(r"\s*```$", "", out)
|
||||
try:
|
||||
data = json.loads(out)
|
||||
raw = data.get("scores", {})
|
||||
return {int(k): int(v) for k, v in raw.items()
|
||||
if str(v).isdigit() and 1 <= int(v) <= 5}
|
||||
except (json.JSONDecodeError, ValueError, TypeError) as e:
|
||||
print(f" [judge parse fail: {e}; out={out[:200]!r}]")
|
||||
return {}
|
||||
|
||||
|
||||
async def main():
|
||||
voyage_key = os.environ["VOYAGE_API_KEY"]
|
||||
pg_pw = os.environ["POSTGRES_PASSWORD"]
|
||||
|
||||
# Verify Claude CLI is available (uses OAuth from ~/.claude/.credentials)
|
||||
try:
|
||||
subprocess.run(["claude", "--version"], capture_output=True,
|
||||
text=True, timeout=10, check=True)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, TimeoutError):
|
||||
sys.exit("claude CLI not found or not authenticated")
|
||||
|
||||
voyage = voyageai.Client(api_key=voyage_key)
|
||||
|
||||
# Load chunks + voyage-3 embeddings
|
||||
pool = await asyncpg.create_pool(
|
||||
host="127.0.0.1", port=5433, user="legal_ai",
|
||||
password=pg_pw, database="legal_ai",
|
||||
min_size=1, max_size=2,
|
||||
)
|
||||
rows = await pool.fetch("""
|
||||
SELECT chunk_index, content, embedding::text AS emb_text
|
||||
FROM precedent_chunks
|
||||
WHERE case_law_id = $1
|
||||
ORDER BY chunk_index
|
||||
""", CASE_ID)
|
||||
chunks = [r["content"] for r in rows]
|
||||
chunk_indices = [r["chunk_index"] for r in rows]
|
||||
baseline_embs = [parse_pgvector(r["emb_text"]) for r in rows]
|
||||
n = len(chunks)
|
||||
print(f"[load] {n} chunks loaded")
|
||||
|
||||
# Compute context-3 (windowed) embeddings — same as POC #2
|
||||
windows = build_windows(n, WINDOW_SIZE, WINDOW_STRIDE)
|
||||
print(f"[context-3] embedding {len(windows)} windows…")
|
||||
win_embs = []
|
||||
for s, e in windows:
|
||||
result = voyage.contextualized_embed(
|
||||
inputs=[chunks[s:e]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="document",
|
||||
)
|
||||
win_embs.append(result.results[0].embeddings)
|
||||
context_embs = []
|
||||
for i in range(n):
|
||||
w = central_window(i, windows)
|
||||
s, _ = windows[w]
|
||||
context_embs.append(win_embs[w][i - s])
|
||||
print(f"[context-3] done")
|
||||
|
||||
# Retrieval functions
|
||||
def r1_baseline(query: str, k: int = 10) -> list[int]:
|
||||
q = voyage.embed([query], model=TEXT_MODEL,
|
||||
input_type="query").embeddings[0]
|
||||
scores = sorted(
|
||||
[(cosine(q, e), i) for i, e in enumerate(baseline_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
return [i for _, i in scores[:k]]
|
||||
|
||||
def r2_rerank(query: str, k: int = 10) -> list[int]:
|
||||
# 1) voyage-3 retrieve top-50
|
||||
cands = r1_baseline(query, k=50)
|
||||
cand_texts = [chunks[i] for i in cands]
|
||||
# 2) voyage-rerank-2 over the 50
|
||||
rr = voyage.rerank(
|
||||
query=query, documents=cand_texts,
|
||||
model=RERANK_MODEL, top_k=k,
|
||||
)
|
||||
# rr.results: list of RerankingResult(index=..., relevance_score=...)
|
||||
# `index` refers to position in cand_texts → map back to chunk idx
|
||||
return [cands[r.index] for r in rr.results]
|
||||
|
||||
def r3_context(query: str, k: int = 10) -> list[int]:
|
||||
q = voyage.contextualized_embed(
|
||||
inputs=[[query]],
|
||||
model=CONTEXT_MODEL,
|
||||
input_type="query",
|
||||
).results[0].embeddings[0]
|
||||
scores = sorted(
|
||||
[(cosine(q, e), i) for i, e in enumerate(context_embs)],
|
||||
reverse=True,
|
||||
)
|
||||
return [i for _, i in scores[:k]]
|
||||
|
||||
retrievers = [("R1-voyage3", r1_baseline),
|
||||
("R2-rerank2", r2_rerank),
|
||||
("R3-context3", r3_context)]
|
||||
|
||||
# Run all queries × all retrievers, judging top-5 per pair.
|
||||
# Strategy: for each query, gather the union of all retrievers' top-K
|
||||
# and judge them in ONE batched CLI call → 18 calls total instead of 270.
|
||||
all_results = []
|
||||
JUDGE_TOP_K = 5
|
||||
print(f"\n[judge] running {len(QUERIES)} queries × "
|
||||
f"{len(retrievers)} retrievers × top-{JUDGE_TOP_K} — batched per query…")
|
||||
|
||||
for qid, query in QUERIES:
|
||||
print(f"\n[{qid}] {query}")
|
||||
# Collect retrievals first
|
||||
retr_results = {}
|
||||
for r_name, r_fn in retrievers:
|
||||
try:
|
||||
retr_results[r_name] = r_fn(query, k=JUDGE_TOP_K)
|
||||
except Exception as e:
|
||||
print(f" {r_name}: FAILED — {e}")
|
||||
retr_results[r_name] = []
|
||||
# Union of unique chunk indices to judge
|
||||
union = sorted({i for top in retr_results.values() for i in top})
|
||||
items = [(i, chunks[i]) for i in union]
|
||||
print(f" judging {len(items)} unique chunks via batch CLI…")
|
||||
scores_map = batch_judge(query, items)
|
||||
# Build per-retriever score lists
|
||||
for r_name, top in retr_results.items():
|
||||
scores = [scores_map.get(i, 0) for i in top]
|
||||
mean3 = sum(scores[:3]) / 3 if len(scores) >= 3 else 0
|
||||
mean5 = sum(scores) / len(scores) if scores else 0
|
||||
mrr = 0.0
|
||||
for r, s in enumerate(scores):
|
||||
if s >= 4:
|
||||
mrr = 1.0 / (r + 1)
|
||||
break
|
||||
print(f" {r_name}: chunks={[chunk_indices[i] for i in top]} "
|
||||
f"scores={scores} mean@3={mean3:.2f} mean@5={mean5:.2f} "
|
||||
f"MRR={mrr:.3f}")
|
||||
all_results.append({
|
||||
"qid": qid, "category": qid[0], "query": query,
|
||||
"retriever": r_name,
|
||||
"chunks": [chunk_indices[i] for i in top],
|
||||
"scores": scores,
|
||||
"mean3": mean3, "mean5": mean5, "mrr": mrr,
|
||||
})
|
||||
|
||||
# Aggregate
|
||||
print("\n" + "=" * 100)
|
||||
print("AGGREGATED RESULTS")
|
||||
print("=" * 100)
|
||||
|
||||
by_retriever = defaultdict(lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
by_cat_retriever = defaultdict(
|
||||
lambda: {"mean3": [], "mean5": [], "mrr": []})
|
||||
for r in all_results:
|
||||
by_retriever[r["retriever"]]["mean3"].append(r["mean3"])
|
||||
by_retriever[r["retriever"]]["mean5"].append(r["mean5"])
|
||||
by_retriever[r["retriever"]]["mrr"].append(r["mrr"])
|
||||
cat_key = (r["category"], r["retriever"])
|
||||
by_cat_retriever[cat_key]["mean3"].append(r["mean3"])
|
||||
by_cat_retriever[cat_key]["mean5"].append(r["mean5"])
|
||||
by_cat_retriever[cat_key]["mrr"].append(r["mrr"])
|
||||
|
||||
print("\nOverall (across all 18 queries):")
|
||||
print(f"{'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
|
||||
for r_name, _ in retrievers:
|
||||
m = by_retriever[r_name]
|
||||
avg = lambda xs: sum(xs) / len(xs) if xs else 0
|
||||
print(f"{r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
|
||||
print("\nBy category (K=keyword, C=conceptual, N=narrative):")
|
||||
print(f"{'cat':<3} {'retriever':<14} {'mean@3':>8} {'mean@5':>8} {'MRR':>8}")
|
||||
for cat in ["K", "C", "N"]:
|
||||
for r_name, _ in retrievers:
|
||||
m = by_cat_retriever[(cat, r_name)]
|
||||
avg = lambda xs: sum(xs) / len(xs) if xs else 0
|
||||
print(f"{cat:<3} {r_name:<14} {avg(m['mean3']):>8.3f} "
|
||||
f"{avg(m['mean5']):>8.3f} {avg(m['mrr']):>8.3f}")
|
||||
|
||||
print("\nPer-query winner (highest mean@3, ties shown):")
|
||||
print(f"{'qid':<4} {'query':<45} {'winner':<24} {'scores'}")
|
||||
by_query = defaultdict(list)
|
||||
for r in all_results:
|
||||
by_query[r["qid"]].append(r)
|
||||
for qid, results in sorted(by_query.items()):
|
||||
max_score = max(r["mean3"] for r in results)
|
||||
winners = [r["retriever"] for r in results if r["mean3"] == max_score]
|
||||
scores = " | ".join(f"{r['retriever'][:7]}={r['mean3']:.2f}"
|
||||
for r in results)
|
||||
q_str = next(q for qid_, q in QUERIES if qid_ == qid)[:42]
|
||||
print(f"{qid:<4} {q_str:<45} {','.join(w[:8] for w in winners):<24} "
|
||||
f"{scores}")
|
||||
|
||||
# Save raw results to JSON for further analysis
|
||||
out_path = "/tmp/voyage_rerank_judge_results.json"
|
||||
with open(out_path, "w") as f:
|
||||
json.dump(all_results, f, ensure_ascii=False, indent=2)
|
||||
print(f"\nRaw results saved to {out_path}")
|
||||
|
||||
await pool.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user