#!/usr/bin/env python3 """Generate embeddings for decision blocks and case law. Creates: - paragraph_embeddings: for each decision block with content - case_law_embeddings: for each case law summary """ import asyncio import sys from pathlib import Path from uuid import UUID sys.path.insert(0, str(Path(__file__).parent.parent / "mcp-server" / "src")) from legal_mcp.services.db import get_pool, init_schema, close_pool from legal_mcp.services.embeddings import embed_texts from legal_mcp import config async def generate_block_embeddings(conn) -> int: """Generate embeddings for decision blocks. First creates decision_paragraphs records from block content, then generates embeddings in paragraph_embeddings. """ blocks = await conn.fetch( """SELECT db.id as block_id, db.decision_id, db.block_id as block_type, db.content, db.word_count, c.case_number FROM decision_blocks db JOIN decisions d ON d.id = db.decision_id JOIN cases c ON c.id = d.case_id WHERE db.word_count > 10 AND db.block_id NOT IN ('block-alef', 'block-bet', 'block-gimel', 'block-dalet') ORDER BY c.case_number, db.block_index""" ) if not blocks: print(" אין בלוקים ליצירת embeddings") return 0 print(f" מעבד {len(blocks)} בלוקים...") # Create paragraphs and collect texts for embedding para_records = [] para_number = 1 for block in blocks: content = block["content"] words = content.split() # Split into chunks for embedding if len(words) <= 600: chunk_texts = [content] else: chunk_texts = [] for start in range(0, len(words), 400): chunk_words = words[start:start + 500] if len(chunk_words) > 50: chunk_texts.append(" ".join(chunk_words)) for chunk_text in chunk_texts: # Create decision_paragraph record para_id = await conn.fetchval( """INSERT INTO decision_paragraphs (block_id, paragraph_number, content, word_count) VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING RETURNING id""", block["block_id"], para_number, chunk_text, len(chunk_text.split()), ) if para_id: para_records.append({ "para_id": para_id, "text": chunk_text, "case_number": block["case_number"], }) para_number += 1 if not para_records: print(" אין פסקאות חדשות") return 0 print(f" {len(para_records)} פסקאות נוצרו, מייצר embeddings...") # Generate embeddings in batches texts = [p["text"] for p in para_records] embeddings = await embed_texts(texts, input_type="document") # Store embeddings count = 0 for para, embedding in zip(para_records, embeddings): await conn.execute( """INSERT INTO paragraph_embeddings (paragraph_id, embedding) VALUES ($1, $2)""", para["para_id"], embedding, ) count += 1 return count async def generate_case_law_embeddings(conn) -> int: """Generate embeddings for case law summaries.""" cases = await conn.fetch( """SELECT id, case_number, case_name, summary, key_quote FROM case_law WHERE summary != '' OR key_quote != ''""" ) # Filter out existing existing = await conn.fetch("SELECT case_law_id FROM case_law_embeddings") existing_ids = {r["case_law_id"] for r in existing} to_embed = [c for c in cases if c["id"] not in existing_ids] if not to_embed: print(" אין פסיקה חדשה ליצירת embeddings") return 0 print(f" מייצר embeddings ל-{len(to_embed)} תקדימים...") texts = [] for c in to_embed: # Combine case info into a searchable text text = f"{c['case_number']} {c['case_name']}: {c['summary']}" if c["key_quote"]: text += f" ציטוט: {c['key_quote']}" texts.append(text) embeddings = await embed_texts(texts, input_type="document") count = 0 for case, embedding in zip(to_embed, embeddings): await conn.execute( """INSERT INTO case_law_embeddings (case_law_id, chunk_text, embedding) VALUES ($1, $2, $3)""", case["id"], f"{case['case_number']} {case['case_name']}: {case['summary']}", embedding, ) count += 1 return count async def main(): await init_schema() pool = await get_pool() async with pool.acquire() as conn: print("שלב 1: embeddings לבלוקי החלטה") block_count = await generate_block_embeddings(conn) print(f" ✅ {block_count} embeddings נוצרו") print("\nשלב 2: embeddings לפסיקה") cl_count = await generate_case_law_embeddings(conn) print(f" ✅ {cl_count} embeddings נוצרו") # Summary para_total = await conn.fetchval("SELECT count(*) FROM paragraph_embeddings") cl_total = await conn.fetchval("SELECT count(*) FROM case_law_embeddings") await close_pool() print(f"\nסיכום:") print(f" סה\"כ paragraph_embeddings: {para_total}") print(f" סה\"כ case_law_embeddings: {cl_total}") print(f" מודל: {config.VOYAGE_MODEL}") if __name__ == "__main__": asyncio.run(main())