ops: switch embeddings to voyage-3 + plan for context-3 + multimodal-3.5
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 7s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 7s
Phase A — voyage-3 migration (executed): - VOYAGE_MODEL=voyage-3 set in Coolify (legal-ai app) and ~/.env - scripts/reembed_voyage.py: re-embeds document_chunks (6157), case_law_embeddings (9), precedent_chunks (385), and halachot (400) using the new model. paragraph_embeddings was empty. 6951 rows re-embedded in 93s, ~75 rows/sec. - Same 1024 dim → no schema change needed. Why voyage-3 over voyage-law-2: benchmark on 3 Hebrew legal queries with real passages from the corpus gave voyage-3 perfect ordering on 3/3 tests AND the largest separation (+0.483 vs voyage-law-2's +0.238). voyage-4 family had bigger separation but missed top-1 on the hardest test. Phase B (voyage-context-3) and Phase C (voyage-multimodal-3.5 for scanned + appraiser docs) are designed in docs/voyage-upgrades-plan.md but deferred — to be picked up in a fresh conversation. The plan includes: - Phase B: contextualized embeddings refactor (~49% recall lift on legal docs per Anthropic's research). Same dim, but ingestion pipeline must pass full doc context per chunk. - Phase C: page-level image embeddings via voyage-multimodal-3.5, stored in a parallel *_image_embeddings table. Hybrid text+image search. Targets appraiser report tables and scanned PDFs where current OCR loses layout. After this commit: MCP server needs a /mcp reconnect to pick up the new VOYAGE_MODEL env, and the legal-ai container will pick it up on its next redeploy. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
170
scripts/reembed_voyage.py
Normal file
170
scripts/reembed_voyage.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""Re-embed all Voyage-stored vectors with the model in env VOYAGE_MODEL.
|
||||
|
||||
Use after changing VOYAGE_MODEL in env (e.g. voyage-law-2 → voyage-3).
|
||||
The script reads each table that stores embeddings, batches the source
|
||||
text through the new model (Voyage allows 128 inputs / call), and
|
||||
UPDATEs the rows in place.
|
||||
|
||||
Tables touched:
|
||||
- document_chunks (content)
|
||||
- paragraph_embeddings (joined with decision_paragraphs.content)
|
||||
- case_law_embeddings (chunk_text)
|
||||
- precedent_chunks (content)
|
||||
- halachot (rule_statement + reasoning_summary)
|
||||
|
||||
Run from the legal-ai venv with VOYAGE_API_KEY + VOYAGE_MODEL +
|
||||
POSTGRES_* set in env (or ~/.env). Idempotent — safe to re-run.
|
||||
|
||||
Usage:
|
||||
/home/chaim/legal-ai/mcp-server/.venv/bin/python \\
|
||||
/home/chaim/legal-ai/scripts/reembed_voyage.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Load ~/.env if present
|
||||
ENV_PATH = os.path.expanduser("~/.env")
|
||||
if os.path.isfile(ENV_PATH):
|
||||
with open(ENV_PATH) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#") and "=" in line:
|
||||
k, v = line.split("=", 1)
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
import asyncpg # noqa: E402
|
||||
import voyageai # noqa: E402
|
||||
|
||||
|
||||
VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-3")
|
||||
BATCH = 100 # Voyage allows 128, leave headroom for token limits
|
||||
|
||||
# (table, primary key, source-text SQL, update SQL with $1=embedding $2=id)
|
||||
TABLES = [
|
||||
(
|
||||
"document_chunks",
|
||||
"SELECT id, content FROM document_chunks WHERE content IS NOT NULL AND content <> ''",
|
||||
"UPDATE document_chunks SET embedding = $1 WHERE id = $2",
|
||||
),
|
||||
(
|
||||
"paragraph_embeddings",
|
||||
# paragraph_embeddings stores embedding only — text is in decision_paragraphs
|
||||
"SELECT pe.id, dp.content "
|
||||
"FROM paragraph_embeddings pe "
|
||||
"JOIN decision_paragraphs dp ON dp.id = pe.paragraph_id "
|
||||
"WHERE dp.content IS NOT NULL AND dp.content <> ''",
|
||||
"UPDATE paragraph_embeddings SET embedding = $1 WHERE id = $2",
|
||||
),
|
||||
(
|
||||
"case_law_embeddings",
|
||||
"SELECT id, chunk_text FROM case_law_embeddings "
|
||||
"WHERE chunk_text IS NOT NULL AND chunk_text <> ''",
|
||||
"UPDATE case_law_embeddings SET embedding = $1 WHERE id = $2",
|
||||
),
|
||||
(
|
||||
"precedent_chunks",
|
||||
"SELECT id, content FROM precedent_chunks WHERE content IS NOT NULL AND content <> ''",
|
||||
"UPDATE precedent_chunks SET embedding = $1 WHERE id = $2",
|
||||
),
|
||||
(
|
||||
"halachot",
|
||||
# Embed rule_statement + reasoning_summary, matching the original
|
||||
# storage in halacha_extractor.extract().
|
||||
"SELECT id, "
|
||||
" TRIM(BOTH ' —' FROM rule_statement || ' — ' || COALESCE(reasoning_summary, '')) "
|
||||
" AS embed_text "
|
||||
"FROM halachot WHERE rule_statement IS NOT NULL AND rule_statement <> ''",
|
||||
"UPDATE halachot SET embedding = $1 WHERE id = $2",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def embed_batch(client, texts: list[str]) -> list[list[float]]:
|
||||
"""Voyage embed_texts with explicit input_type='document' for storage."""
|
||||
return client.embed(texts, model=VOYAGE_MODEL, input_type="document").embeddings
|
||||
|
||||
|
||||
async def reembed_table(
|
||||
pool: asyncpg.Pool, voyage, label: str, select_sql: str, update_sql: str,
|
||||
) -> dict:
|
||||
rows = await pool.fetch(select_sql)
|
||||
n = len(rows)
|
||||
print(f"\n[{label}] {n} rows")
|
||||
if n == 0:
|
||||
return {"table": label, "rows": 0, "elapsed": 0.0}
|
||||
start = time.time()
|
||||
done = 0
|
||||
for i in range(0, n, BATCH):
|
||||
batch_rows = rows[i:i + BATCH]
|
||||
texts = [r[1] for r in batch_rows]
|
||||
ids = [r[0] for r in batch_rows]
|
||||
try:
|
||||
embeddings = await embed_batch(voyage, texts)
|
||||
except Exception as e:
|
||||
print(f" [{label}] batch {i // BATCH} failed: {e}", file=sys.stderr)
|
||||
continue
|
||||
# Update each row
|
||||
async with pool.acquire() as conn:
|
||||
async with conn.transaction():
|
||||
for emb, rid in zip(embeddings, ids):
|
||||
# asyncpg accepts list[float] for vector via asyncpg-pgvector;
|
||||
# but pgvector type is inferred via str cast on the wire
|
||||
await conn.execute(update_sql, str(emb), rid)
|
||||
done += len(batch_rows)
|
||||
elapsed = time.time() - start
|
||||
print(f" [{label}] {done}/{n} ({done/n*100:.1f}%) "
|
||||
f"elapsed={elapsed:.0f}s rate={done/max(elapsed,0.1):.1f}/s")
|
||||
elapsed = time.time() - start
|
||||
return {"table": label, "rows": n, "elapsed": elapsed}
|
||||
|
||||
|
||||
async def main():
|
||||
api_key = os.environ.get("VOYAGE_API_KEY")
|
||||
if not api_key:
|
||||
sys.exit("VOYAGE_API_KEY not set (export it or add to ~/.env)")
|
||||
|
||||
pg_host = os.environ.get("POSTGRES_HOST", "127.0.0.1")
|
||||
pg_port = int(os.environ.get("POSTGRES_PORT", "5433"))
|
||||
pg_user = os.environ.get("POSTGRES_USER", "legal_ai")
|
||||
pg_pw = os.environ.get("POSTGRES_PASSWORD", "")
|
||||
pg_db = os.environ.get("POSTGRES_DB", "legal_ai")
|
||||
if not pg_pw:
|
||||
sys.exit("POSTGRES_PASSWORD not set")
|
||||
|
||||
print(f"Re-embed all tables with model: {VOYAGE_MODEL}")
|
||||
print(f"DB: {pg_user}@{pg_host}:{pg_port}/{pg_db}")
|
||||
|
||||
voyage = voyageai.Client(api_key=api_key)
|
||||
pool = await asyncpg.create_pool(
|
||||
host=pg_host, port=pg_port, user=pg_user,
|
||||
password=pg_pw, database=pg_db,
|
||||
min_size=1, max_size=4,
|
||||
)
|
||||
|
||||
# pgvector needs explicit codec setup so we can pass list[float]
|
||||
async def _init(conn: asyncpg.Connection) -> None:
|
||||
await conn.execute("SET search_path = public")
|
||||
await pool.__aenter__() # noqa — enter context to ensure init
|
||||
|
||||
summary = []
|
||||
try:
|
||||
for label, select_sql, update_sql in TABLES:
|
||||
r = await reembed_table(pool, voyage, label, select_sql, update_sql)
|
||||
summary.append(r)
|
||||
finally:
|
||||
await pool.close()
|
||||
|
||||
total_rows = sum(r["rows"] for r in summary)
|
||||
total_time = sum(r["elapsed"] for r in summary)
|
||||
print(f"\n{'=' * 60}\nDONE — {total_rows} rows in {total_time:.0f}s")
|
||||
for r in summary:
|
||||
print(f" {r['table']:30s} {r['rows']:>6} rows {r['elapsed']:>5.0f}s")
|
||||
print(f"\nModel: {VOYAGE_MODEL}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user