"""Re-embed all Voyage-stored vectors with the model in env VOYAGE_MODEL. Use after changing VOYAGE_MODEL in env (e.g. voyage-law-2 → voyage-3). The script reads each table that stores embeddings, batches the source text through the new model (Voyage allows 128 inputs / call), and UPDATEs the rows in place. Tables touched: - document_chunks (content) - paragraph_embeddings (joined with decision_paragraphs.content) - case_law_embeddings (chunk_text) - precedent_chunks (content) - halachot (rule_statement + reasoning_summary) Run from the legal-ai venv with VOYAGE_API_KEY + VOYAGE_MODEL + POSTGRES_* set in env (or ~/.env). Idempotent — safe to re-run. Usage: /home/chaim/legal-ai/mcp-server/.venv/bin/python \\ /home/chaim/legal-ai/scripts/reembed_voyage.py """ from __future__ import annotations import asyncio import os import sys import time # Load ~/.env if present ENV_PATH = os.path.expanduser("~/.env") if os.path.isfile(ENV_PATH): with open(ENV_PATH) as f: for line in f: line = line.strip() if line and not line.startswith("#") and "=" in line: k, v = line.split("=", 1) os.environ.setdefault(k, v) import asyncpg # noqa: E402 import voyageai # noqa: E402 VOYAGE_MODEL = os.environ.get("VOYAGE_MODEL", "voyage-3") BATCH = 100 # Voyage allows 128, leave headroom for token limits # (table, primary key, source-text SQL, update SQL with $1=embedding $2=id) TABLES = [ ( "document_chunks", "SELECT id, content FROM document_chunks WHERE content IS NOT NULL AND content <> ''", "UPDATE document_chunks SET embedding = $1 WHERE id = $2", ), ( "paragraph_embeddings", # paragraph_embeddings stores embedding only — text is in decision_paragraphs "SELECT pe.id, dp.content " "FROM paragraph_embeddings pe " "JOIN decision_paragraphs dp ON dp.id = pe.paragraph_id " "WHERE dp.content IS NOT NULL AND dp.content <> ''", "UPDATE paragraph_embeddings SET embedding = $1 WHERE id = $2", ), ( "case_law_embeddings", "SELECT id, chunk_text FROM case_law_embeddings " "WHERE chunk_text IS NOT NULL AND chunk_text <> ''", "UPDATE case_law_embeddings SET embedding = $1 WHERE id = $2", ), ( "precedent_chunks", "SELECT id, content FROM precedent_chunks WHERE content IS NOT NULL AND content <> ''", "UPDATE precedent_chunks SET embedding = $1 WHERE id = $2", ), ( "halachot", # Embed rule_statement + reasoning_summary, matching the original # storage in halacha_extractor.extract(). "SELECT id, " " TRIM(BOTH ' —' FROM rule_statement || ' — ' || COALESCE(reasoning_summary, '')) " " AS embed_text " "FROM halachot WHERE rule_statement IS NOT NULL AND rule_statement <> ''", "UPDATE halachot SET embedding = $1 WHERE id = $2", ), ] async def embed_batch(client, texts: list[str]) -> list[list[float]]: """Voyage embed_texts with explicit input_type='document' for storage.""" return client.embed(texts, model=VOYAGE_MODEL, input_type="document").embeddings async def reembed_table( pool: asyncpg.Pool, voyage, label: str, select_sql: str, update_sql: str, ) -> dict: rows = await pool.fetch(select_sql) n = len(rows) print(f"\n[{label}] {n} rows") if n == 0: return {"table": label, "rows": 0, "elapsed": 0.0} start = time.time() done = 0 for i in range(0, n, BATCH): batch_rows = rows[i:i + BATCH] texts = [r[1] for r in batch_rows] ids = [r[0] for r in batch_rows] try: embeddings = await embed_batch(voyage, texts) except Exception as e: print(f" [{label}] batch {i // BATCH} failed: {e}", file=sys.stderr) continue # Update each row async with pool.acquire() as conn: async with conn.transaction(): for emb, rid in zip(embeddings, ids): # asyncpg accepts list[float] for vector via asyncpg-pgvector; # but pgvector type is inferred via str cast on the wire await conn.execute(update_sql, str(emb), rid) done += len(batch_rows) elapsed = time.time() - start print(f" [{label}] {done}/{n} ({done/n*100:.1f}%) " f"elapsed={elapsed:.0f}s rate={done/max(elapsed,0.1):.1f}/s") elapsed = time.time() - start return {"table": label, "rows": n, "elapsed": elapsed} async def main(): api_key = os.environ.get("VOYAGE_API_KEY") if not api_key: sys.exit("VOYAGE_API_KEY not set (export it or add to ~/.env)") pg_host = os.environ.get("POSTGRES_HOST", "127.0.0.1") pg_port = int(os.environ.get("POSTGRES_PORT", "5433")) pg_user = os.environ.get("POSTGRES_USER", "legal_ai") pg_pw = os.environ.get("POSTGRES_PASSWORD", "") pg_db = os.environ.get("POSTGRES_DB", "legal_ai") if not pg_pw: sys.exit("POSTGRES_PASSWORD not set") print(f"Re-embed all tables with model: {VOYAGE_MODEL}") print(f"DB: {pg_user}@{pg_host}:{pg_port}/{pg_db}") voyage = voyageai.Client(api_key=api_key) pool = await asyncpg.create_pool( host=pg_host, port=pg_port, user=pg_user, password=pg_pw, database=pg_db, min_size=1, max_size=4, ) # pgvector needs explicit codec setup so we can pass list[float] async def _init(conn: asyncpg.Connection) -> None: await conn.execute("SET search_path = public") await pool.__aenter__() # noqa — enter context to ensure init summary = [] try: for label, select_sql, update_sql in TABLES: r = await reembed_table(pool, voyage, label, select_sql, update_sql) summary.append(r) finally: await pool.close() total_rows = sum(r["rows"] for r in summary) total_time = sum(r["elapsed"] for r in summary) print(f"\n{'=' * 60}\nDONE — {total_rows} rows in {total_time:.0f}s") for r in summary: print(f" {r['table']:30s} {r['rows']:>6} rows {r['elapsed']:>5.0f}s") print(f"\nModel: {VOYAGE_MODEL}") if __name__ == "__main__": asyncio.run(main())