#!/usr/bin/env python3 """Backfill aggregated legal_arguments for existing cases. For every case that has rows in ``claims`` but none in ``legal_arguments``, run ``argument_aggregator.aggregate_claims_to_arguments``. Usage (must use mcp-server venv — pgvector + asyncpg are vendored there): PY=/home/chaim/legal-ai/mcp-server/.venv/bin/python # Default = dry-run (lists what would be processed): $PY scripts/backfill_legal_arguments.py # Process all cases that need it: $PY scripts/backfill_legal_arguments.py --apply # Re-aggregate even cases that already have arguments: $PY scripts/backfill_legal_arguments.py --apply --force # Only process specific cases: $PY scripts/backfill_legal_arguments.py --apply --case 1017-03-26 1018-03-26 The script must run from the local dev machine (not the container) because ``argument_aggregator`` calls ``claude_session`` which needs the Claude CLI. """ from __future__ import annotations import argparse import asyncio import os import sys from pathlib import Path from uuid import UUID # Make the mcp-server source importable as ``legal_mcp``. REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT / "mcp-server" / "src")) # Default DB connection (overridable via env / .env on the dev box). if "POSTGRES_URL" not in os.environ: pg_user = os.environ.get("POSTGRES_USER", "legal_ai") pg_pw = os.environ.get("POSTGRES_PASSWORD", "") pg_host = os.environ.get("POSTGRES_HOST", "127.0.0.1") pg_port = os.environ.get("POSTGRES_PORT", "5433") pg_db = os.environ.get("POSTGRES_DB", "legal_ai") os.environ["POSTGRES_URL"] = ( f"postgres://{pg_user}:{pg_pw}@{pg_host}:{pg_port}/{pg_db}" ) async def _list_cases_needing_backfill(force: bool) -> list[dict]: """Find cases that have claims but no aggregated arguments (or all, when ``force`` is True).""" from legal_mcp.services import db pool = await db.get_pool() async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT c.id, c.case_number, c.status, COUNT(DISTINCT cl.id) AS claim_count, COUNT(DISTINCT la.id) AS arg_count FROM cases c LEFT JOIN claims cl ON cl.case_id = c.id LEFT JOIN legal_arguments la ON la.case_id = c.id WHERE c.archived_at IS NULL GROUP BY c.id, c.case_number, c.status HAVING COUNT(DISTINCT cl.id) > 0 ORDER BY c.case_number """ ) out: list[dict] = [] for r in rows: d = dict(r) if force or d["arg_count"] == 0: out.append(d) return out async def _process_case(case: dict, force: bool) -> dict: from legal_mcp.services import argument_aggregator case_id = UUID(str(case["id"])) case_number = case["case_number"] print( f"[backfill] {case_number}: {case['claim_count']} claims, " f"{case['arg_count']} existing args — aggregating (force={force})...", flush=True, ) try: result = await argument_aggregator.aggregate_claims_to_arguments( case_id, force=force, ) except Exception as e: # noqa: BLE001 return { "case_number": case_number, "status": "error", "error": str(e), } print( f"[backfill] {case_number}: status={result.get('status')} " f"total={result.get('total')} by_party={result.get('by_party')}", flush=True, ) return {"case_number": case_number, **result} async def main() -> int: parser = argparse.ArgumentParser( description="Backfill legal_arguments for cases with extracted claims.", ) parser.add_argument( "--apply", action="store_true", help="Actually run aggregation (default: dry-run).", ) parser.add_argument( "--force", action="store_true", help="Re-aggregate even cases that already have arguments.", ) parser.add_argument( "--case", nargs="*", default=[], help="Only process these case numbers (e.g. --case 1017-03-26 1018-03-26).", ) args = parser.parse_args() cases = await _list_cases_needing_backfill(force=args.force) if args.case: wanted = set(args.case) cases = [c for c in cases if c["case_number"] in wanted] if not cases: print("[backfill] No cases need processing.") return 0 print(f"[backfill] {len(cases)} case(s) to process:") for c in cases: print( f" - {c['case_number']:<14} status={c['status']:<20} " f"claims={c['claim_count']:<4} args={c['arg_count']}", ) if not args.apply: print("\n[backfill] dry-run — pass --apply to actually run.") return 0 print() results: list[dict] = [] for case in cases: r = await _process_case(case, force=args.force) results.append(r) print("\n[backfill] === Summary ===") for r in results: print( f" {r['case_number']:<14} status={r.get('status', 'unknown'):<22} " f"total={r.get('total', 0)}", ) errors = [r for r in results if r.get("status") == "error"] return 1 if errors else 0 if __name__ == "__main__": sys.exit(asyncio.run(main()))