#!/usr/bin/env python3 """Link claims to discussion paragraphs using semantic similarity. For each claim, finds the most similar paragraph in block-yod of the same decision. Updates claims.addressed_in_paragraph with the paragraph number. """ import asyncio import sys from pathlib import Path from uuid import UUID sys.path.insert(0, str(Path(__file__).parent.parent / "mcp-server" / "src")) from legal_mcp.services.db import get_pool, init_schema, close_pool from legal_mcp.services.embeddings import embed_texts async def main(): await init_schema() pool = await get_pool() async with pool.acquire() as conn: # Get all cases with both claims and discussion blocks cases = await conn.fetch( """SELECT DISTINCT c.id as case_id, c.case_number FROM cases c JOIN claims cl ON cl.case_id = c.id JOIN decisions d ON d.case_id = c.id JOIN decision_blocks db ON db.decision_id = d.id AND db.block_id = 'block-yod' AND db.word_count > 0 ORDER BY c.case_number""" ) total_linked = 0 for case in cases: case_id = case["case_id"] case_number = case["case_number"] async with pool.acquire() as conn: # Get claims for this case claims = await conn.fetch( "SELECT id, claim_text, party_role, claim_index FROM claims WHERE case_id = $1 ORDER BY claim_index", case_id, ) # Get discussion paragraphs (split block-yod into paragraphs) yod_content = await conn.fetchval( """SELECT db.content FROM decision_blocks db JOIN decisions d ON d.id = db.decision_id WHERE d.case_id = $1 AND db.block_id = 'block-yod'""", case_id, ) if not yod_content or not claims: continue # Split discussion into paragraphs disc_paragraphs = [p.strip() for p in yod_content.split("\n") if p.strip() and len(p.strip()) > 30] if not disc_paragraphs: continue print(f"\n{case_number}: {len(claims)} טענות ← {len(disc_paragraphs)} פסקאות דיון") # Embed all claims and discussion paragraphs claim_texts = [c["claim_text"][:500] for c in claims] all_texts = claim_texts + disc_paragraphs embeddings = await embed_texts(all_texts, input_type="document") claim_embeddings = embeddings[:len(claims)] disc_embeddings = embeddings[len(claims):] # For each claim, find the best matching discussion paragraph linked = 0 async with pool.acquire() as conn: for i, claim in enumerate(claims): claim_emb = claim_embeddings[i] # Cosine similarity best_score = -1 best_para_idx = -1 for j, disc_emb in enumerate(disc_embeddings): dot = sum(a * b for a, b in zip(claim_emb, disc_emb)) norm_a = sum(a * a for a in claim_emb) ** 0.5 norm_b = sum(b * b for b in disc_emb) ** 0.5 score = dot / (norm_a * norm_b) if norm_a > 0 and norm_b > 0 else 0 if score > best_score: best_score = score best_para_idx = j if best_para_idx >= 0 and best_score > 0.3: # paragraph_number is 1-indexed para_num = best_para_idx + 1 await conn.execute( "UPDATE claims SET addressed_in_paragraph = $1 WHERE id = $2", para_num, claim["id"], ) linked += 1 total_linked += linked print(f" קושרו: {linked}/{len(claims)} טענות (ציון מינימלי: 0.3)") # Summary async with pool.acquire() as conn: total_claims = await conn.fetchval("SELECT count(*) FROM claims") linked_claims = await conn.fetchval("SELECT count(*) FROM claims WHERE addressed_in_paragraph IS NOT NULL") await close_pool() print(f"\n{'='*50}") print(f"סיכום: {linked_claims}/{total_claims} טענות קושרו לפסקאות דיון ({linked_claims/total_claims*100:.0f}%)") if __name__ == "__main__": asyncio.run(main())