Initial commit: MCP server + web upload interface
Ezer Mishpati - AI legal decision drafting system with: - MCP server (FastMCP) with document processing pipeline - Web upload interface (FastAPI) for file upload and classification - pgvector-based semantic search - Hebrew legal document chunking and embedding
This commit is contained in:
0
mcp-server/src/legal_mcp/services/__init__.py
Normal file
0
mcp-server/src/legal_mcp/services/__init__.py
Normal file
130
mcp-server/src/legal_mcp/services/chunker.py
Normal file
130
mcp-server/src/legal_mcp/services/chunker.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Legal document chunker - splits text into sections and chunks for RAG."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from legal_mcp import config
|
||||
|
||||
# Hebrew legal section headers
|
||||
SECTION_PATTERNS = [
|
||||
(r"רקע\s*עובדתי|רקע\s*כללי|העובדות|הרקע", "facts"),
|
||||
(r"טענות\s*העוררי[םן]|טענות\s*המערערי[םן]|עיקר\s*טענות\s*העוררי[םן]", "appellant_claims"),
|
||||
(r"טענות\s*המשיבי[םן]|תשובת\s*המשיבי[םן]|עיקר\s*טענות\s*המשיבי[םן]", "respondent_claims"),
|
||||
(r"דיון\s*והכרעה|דיון|הכרעה|ניתוח\s*משפטי|המסגרת\s*המשפטית", "legal_analysis"),
|
||||
(r"מסקנ[הות]|סיכום", "conclusion"),
|
||||
(r"החלטה|לפיכך\s*אני\s*מחליט|התוצאה", "ruling"),
|
||||
(r"מבוא|פתיחה|לפניי", "intro"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
content: str
|
||||
section_type: str = "other"
|
||||
page_number: int | None = None
|
||||
chunk_index: int = 0
|
||||
|
||||
|
||||
def chunk_document(
|
||||
text: str,
|
||||
chunk_size: int = config.CHUNK_SIZE_TOKENS,
|
||||
overlap: int = config.CHUNK_OVERLAP_TOKENS,
|
||||
) -> list[Chunk]:
|
||||
"""Split a legal document into chunks, respecting section boundaries."""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
sections = _split_into_sections(text)
|
||||
chunks: list[Chunk] = []
|
||||
idx = 0
|
||||
|
||||
for section_type, section_text in sections:
|
||||
section_chunks = _split_section(section_text, chunk_size, overlap)
|
||||
for chunk_text in section_chunks:
|
||||
chunks.append(Chunk(
|
||||
content=chunk_text,
|
||||
section_type=section_type,
|
||||
chunk_index=idx,
|
||||
))
|
||||
idx += 1
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _split_into_sections(text: str) -> list[tuple[str, str]]:
|
||||
"""Split text into (section_type, text) pairs based on Hebrew headers."""
|
||||
# Find all section headers and their positions
|
||||
markers: list[tuple[int, str]] = []
|
||||
|
||||
for pattern, section_type in SECTION_PATTERNS:
|
||||
for match in re.finditer(pattern, text):
|
||||
markers.append((match.start(), section_type))
|
||||
|
||||
if not markers:
|
||||
# No sections found - treat as single block
|
||||
return [("other", text)]
|
||||
|
||||
markers.sort(key=lambda x: x[0])
|
||||
|
||||
sections: list[tuple[str, str]] = []
|
||||
|
||||
# Text before first section
|
||||
if markers[0][0] > 0:
|
||||
intro_text = text[: markers[0][0]].strip()
|
||||
if intro_text:
|
||||
sections.append(("intro", intro_text))
|
||||
|
||||
# Each section
|
||||
for i, (pos, section_type) in enumerate(markers):
|
||||
end = markers[i + 1][0] if i + 1 < len(markers) else len(text)
|
||||
section_text = text[pos:end].strip()
|
||||
if section_text:
|
||||
sections.append((section_type, section_text))
|
||||
|
||||
return sections
|
||||
|
||||
|
||||
def _split_section(text: str, chunk_size: int, overlap: int) -> list[str]:
|
||||
"""Split a section into overlapping chunks by paragraphs.
|
||||
|
||||
Uses approximate token counting (Hebrew ~1.5 chars per token).
|
||||
"""
|
||||
if not text.strip():
|
||||
return []
|
||||
|
||||
paragraphs = [p.strip() for p in text.split("\n") if p.strip()]
|
||||
chunks: list[str] = []
|
||||
current: list[str] = []
|
||||
current_tokens = 0
|
||||
|
||||
for para in paragraphs:
|
||||
para_tokens = _estimate_tokens(para)
|
||||
|
||||
if current_tokens + para_tokens > chunk_size and current:
|
||||
chunks.append("\n".join(current))
|
||||
# Keep overlap
|
||||
overlap_paras: list[str] = []
|
||||
overlap_tokens = 0
|
||||
for p in reversed(current):
|
||||
pt = _estimate_tokens(p)
|
||||
if overlap_tokens + pt > overlap:
|
||||
break
|
||||
overlap_paras.insert(0, p)
|
||||
overlap_tokens += pt
|
||||
current = overlap_paras
|
||||
current_tokens = overlap_tokens
|
||||
|
||||
current.append(para)
|
||||
current_tokens += para_tokens
|
||||
|
||||
if current:
|
||||
chunks.append("\n".join(current))
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
"""Rough token estimate for Hebrew text (~1.5 chars per token)."""
|
||||
return max(1, len(text) // 2)
|
||||
440
mcp-server/src/legal_mcp/services/db.py
Normal file
440
mcp-server/src/legal_mcp/services/db.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""Database service - asyncpg connection pool and queries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import asyncpg
|
||||
from pgvector.asyncpg import register_vector
|
||||
|
||||
from legal_mcp import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pool: asyncpg.Pool | None = None
|
||||
|
||||
|
||||
async def get_pool() -> asyncpg.Pool:
|
||||
global _pool
|
||||
if _pool is None:
|
||||
# First, ensure pgvector extension exists (before registering type codec)
|
||||
conn = await asyncpg.connect(config.POSTGRES_URL)
|
||||
await conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
|
||||
await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')
|
||||
await conn.close()
|
||||
|
||||
_pool = await asyncpg.create_pool(
|
||||
config.POSTGRES_URL,
|
||||
min_size=2,
|
||||
max_size=10,
|
||||
init=_init_connection,
|
||||
)
|
||||
return _pool
|
||||
|
||||
|
||||
async def _init_connection(conn: asyncpg.Connection) -> None:
|
||||
await register_vector(conn)
|
||||
|
||||
|
||||
async def close_pool() -> None:
|
||||
global _pool
|
||||
if _pool:
|
||||
await _pool.close()
|
||||
_pool = None
|
||||
|
||||
|
||||
# ── Schema ──────────────────────────────────────────────────────────
|
||||
|
||||
SCHEMA_SQL = """
|
||||
|
||||
CREATE TABLE IF NOT EXISTS cases (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
case_number TEXT UNIQUE NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
appellants JSONB DEFAULT '[]',
|
||||
respondents JSONB DEFAULT '[]',
|
||||
subject TEXT DEFAULT '',
|
||||
property_address TEXT DEFAULT '',
|
||||
permit_number TEXT DEFAULT '',
|
||||
committee_type TEXT DEFAULT 'ועדה מקומית',
|
||||
status TEXT DEFAULT 'new',
|
||||
hearing_date DATE,
|
||||
decision_date DATE,
|
||||
tags JSONB DEFAULT '[]',
|
||||
notes TEXT DEFAULT '',
|
||||
created_at TIMESTAMPTZ DEFAULT now(),
|
||||
updated_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
case_id UUID REFERENCES cases(id) ON DELETE CASCADE,
|
||||
doc_type TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
file_path TEXT NOT NULL,
|
||||
extracted_text TEXT DEFAULT '',
|
||||
extraction_status TEXT DEFAULT 'pending',
|
||||
page_count INTEGER,
|
||||
metadata JSONB DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS document_chunks (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
document_id UUID REFERENCES documents(id) ON DELETE CASCADE,
|
||||
case_id UUID REFERENCES cases(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
section_type TEXT DEFAULT 'other',
|
||||
embedding vector(1024),
|
||||
page_number INTEGER,
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS style_corpus (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
document_id UUID REFERENCES documents(id) ON DELETE SET NULL,
|
||||
decision_number TEXT,
|
||||
decision_date DATE,
|
||||
subject_categories JSONB DEFAULT '[]',
|
||||
full_text TEXT NOT NULL,
|
||||
summary TEXT DEFAULT '',
|
||||
outcome TEXT DEFAULT '',
|
||||
key_principles JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS style_patterns (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
pattern_type TEXT NOT NULL,
|
||||
pattern_text TEXT NOT NULL,
|
||||
frequency INTEGER DEFAULT 1,
|
||||
context TEXT DEFAULT '',
|
||||
examples JSONB DEFAULT '[]',
|
||||
created_at TIMESTAMPTZ DEFAULT now()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_embedding
|
||||
ON document_chunks USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_case ON document_chunks(case_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_chunks_doc ON document_chunks(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_docs_case ON documents(case_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_cases_status ON cases(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_cases_number ON cases(case_number);
|
||||
"""
|
||||
|
||||
|
||||
async def init_schema() -> None:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(SCHEMA_SQL)
|
||||
logger.info("Database schema initialized")
|
||||
|
||||
|
||||
# ── Case CRUD ───────────────────────────────────────────────────────
|
||||
|
||||
async def create_case(
|
||||
case_number: str,
|
||||
title: str,
|
||||
appellants: list[str] | None = None,
|
||||
respondents: list[str] | None = None,
|
||||
subject: str = "",
|
||||
property_address: str = "",
|
||||
permit_number: str = "",
|
||||
committee_type: str = "ועדה מקומית",
|
||||
hearing_date: date | None = None,
|
||||
notes: str = "",
|
||||
) -> dict:
|
||||
pool = await get_pool()
|
||||
case_id = uuid4()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO cases (id, case_number, title, appellants, respondents,
|
||||
subject, property_address, permit_number, committee_type,
|
||||
hearing_date, notes)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)""",
|
||||
case_id, case_number, title,
|
||||
json.dumps(appellants or []),
|
||||
json.dumps(respondents or []),
|
||||
subject, property_address, permit_number, committee_type,
|
||||
hearing_date, notes,
|
||||
)
|
||||
return await get_case(case_id)
|
||||
|
||||
|
||||
async def get_case(case_id: UUID) -> dict | None:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("SELECT * FROM cases WHERE id = $1", case_id)
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_case(row)
|
||||
|
||||
|
||||
async def get_case_by_number(case_number: str) -> dict | None:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT * FROM cases WHERE case_number = $1", case_number
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_case(row)
|
||||
|
||||
|
||||
async def list_cases(status: str | None = None, limit: int = 50) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if status:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM cases WHERE status = $1 ORDER BY updated_at DESC LIMIT $2",
|
||||
status, limit,
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM cases ORDER BY updated_at DESC LIMIT $1", limit
|
||||
)
|
||||
return [_row_to_case(r) for r in rows]
|
||||
|
||||
|
||||
async def update_case(case_id: UUID, **fields) -> dict | None:
|
||||
if not fields:
|
||||
return await get_case(case_id)
|
||||
pool = await get_pool()
|
||||
set_clauses = []
|
||||
values = []
|
||||
for i, (key, val) in enumerate(fields.items(), start=2):
|
||||
if key in ("appellants", "respondents", "tags"):
|
||||
val = json.dumps(val)
|
||||
set_clauses.append(f"{key} = ${i}")
|
||||
values.append(val)
|
||||
set_clauses.append("updated_at = now()")
|
||||
sql = f"UPDATE cases SET {', '.join(set_clauses)} WHERE id = $1"
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(sql, case_id, *values)
|
||||
return await get_case(case_id)
|
||||
|
||||
|
||||
def _row_to_case(row: asyncpg.Record) -> dict:
|
||||
d = dict(row)
|
||||
for field in ("appellants", "respondents", "tags"):
|
||||
if isinstance(d.get(field), str):
|
||||
d[field] = json.loads(d[field])
|
||||
d["id"] = str(d["id"])
|
||||
return d
|
||||
|
||||
|
||||
# ── Document CRUD ───────────────────────────────────────────────────
|
||||
|
||||
async def create_document(
|
||||
case_id: UUID,
|
||||
doc_type: str,
|
||||
title: str,
|
||||
file_path: str,
|
||||
page_count: int | None = None,
|
||||
) -> dict:
|
||||
pool = await get_pool()
|
||||
doc_id = uuid4()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO documents (id, case_id, doc_type, title, file_path, page_count)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)""",
|
||||
doc_id, case_id, doc_type, title, file_path, page_count,
|
||||
)
|
||||
row = await conn.fetchrow("SELECT * FROM documents WHERE id = $1", doc_id)
|
||||
return _row_to_doc(row)
|
||||
|
||||
|
||||
async def update_document(doc_id: UUID, **fields) -> None:
|
||||
if not fields:
|
||||
return
|
||||
pool = await get_pool()
|
||||
set_clauses = []
|
||||
values = []
|
||||
for i, (key, val) in enumerate(fields.items(), start=2):
|
||||
if key == "metadata":
|
||||
val = json.dumps(val)
|
||||
set_clauses.append(f"{key} = ${i}")
|
||||
values.append(val)
|
||||
sql = f"UPDATE documents SET {', '.join(set_clauses)} WHERE id = $1"
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(sql, doc_id, *values)
|
||||
|
||||
|
||||
async def get_document(doc_id: UUID) -> dict | None:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow("SELECT * FROM documents WHERE id = $1", doc_id)
|
||||
return _row_to_doc(row) if row else None
|
||||
|
||||
|
||||
async def list_documents(case_id: UUID) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM documents WHERE case_id = $1 ORDER BY created_at", case_id
|
||||
)
|
||||
return [_row_to_doc(r) for r in rows]
|
||||
|
||||
|
||||
async def get_document_text(doc_id: UUID) -> str:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT extracted_text FROM documents WHERE id = $1", doc_id
|
||||
)
|
||||
return row["extracted_text"] if row else ""
|
||||
|
||||
|
||||
def _row_to_doc(row: asyncpg.Record) -> dict:
|
||||
d = dict(row)
|
||||
d["id"] = str(d["id"])
|
||||
d["case_id"] = str(d["case_id"])
|
||||
if isinstance(d.get("metadata"), str):
|
||||
d["metadata"] = json.loads(d["metadata"])
|
||||
return d
|
||||
|
||||
|
||||
# ── Chunks & Vectors ───────────────────────────────────────────────
|
||||
|
||||
async def store_chunks(
|
||||
document_id: UUID,
|
||||
case_id: UUID | None,
|
||||
chunks: list[dict],
|
||||
) -> int:
|
||||
"""Store document chunks with embeddings. Each chunk dict has:
|
||||
content, section_type, embedding (list[float]), page_number, chunk_index
|
||||
"""
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
# Delete existing chunks for this document
|
||||
await conn.execute(
|
||||
"DELETE FROM document_chunks WHERE document_id = $1", document_id
|
||||
)
|
||||
for chunk in chunks:
|
||||
await conn.execute(
|
||||
"""INSERT INTO document_chunks
|
||||
(document_id, case_id, chunk_index, content, section_type, embedding, page_number)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)""",
|
||||
document_id, case_id,
|
||||
chunk["chunk_index"],
|
||||
chunk["content"],
|
||||
chunk.get("section_type", "other"),
|
||||
chunk["embedding"],
|
||||
chunk.get("page_number"),
|
||||
)
|
||||
return len(chunks)
|
||||
|
||||
|
||||
async def search_similar(
|
||||
query_embedding: list[float],
|
||||
limit: int = 10,
|
||||
case_id: UUID | None = None,
|
||||
section_type: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Cosine similarity search on document chunks."""
|
||||
pool = await get_pool()
|
||||
conditions = []
|
||||
params: list = [query_embedding, limit]
|
||||
param_idx = 3
|
||||
|
||||
if case_id:
|
||||
conditions.append(f"dc.case_id = ${param_idx}")
|
||||
params.append(case_id)
|
||||
param_idx += 1
|
||||
if section_type:
|
||||
conditions.append(f"dc.section_type = ${param_idx}")
|
||||
params.append(section_type)
|
||||
param_idx += 1
|
||||
|
||||
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
|
||||
|
||||
sql = f"""
|
||||
SELECT dc.content, dc.section_type, dc.page_number,
|
||||
dc.document_id, dc.case_id,
|
||||
d.title AS document_title,
|
||||
c.case_number,
|
||||
1 - (dc.embedding <=> $1) AS score
|
||||
FROM document_chunks dc
|
||||
JOIN documents d ON d.id = dc.document_id
|
||||
JOIN cases c ON c.id = dc.case_id
|
||||
{where}
|
||||
ORDER BY dc.embedding <=> $1
|
||||
LIMIT $2
|
||||
"""
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(sql, *params)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
# ── Style corpus ────────────────────────────────────────────────────
|
||||
|
||||
async def add_to_style_corpus(
|
||||
document_id: UUID | None,
|
||||
decision_number: str,
|
||||
decision_date: date | None,
|
||||
subject_categories: list[str],
|
||||
full_text: str,
|
||||
summary: str = "",
|
||||
outcome: str = "",
|
||||
key_principles: list[str] | None = None,
|
||||
) -> UUID:
|
||||
pool = await get_pool()
|
||||
corpus_id = uuid4()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""INSERT INTO style_corpus
|
||||
(id, document_id, decision_number, decision_date,
|
||||
subject_categories, full_text, summary, outcome, key_principles)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)""",
|
||||
corpus_id, document_id, decision_number, decision_date,
|
||||
json.dumps(subject_categories), full_text, summary, outcome,
|
||||
json.dumps(key_principles or []),
|
||||
)
|
||||
return corpus_id
|
||||
|
||||
|
||||
async def get_style_patterns(pattern_type: str | None = None) -> list[dict]:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
if pattern_type:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM style_patterns WHERE pattern_type = $1 ORDER BY frequency DESC",
|
||||
pattern_type,
|
||||
)
|
||||
else:
|
||||
rows = await conn.fetch(
|
||||
"SELECT * FROM style_patterns ORDER BY pattern_type, frequency DESC"
|
||||
)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def upsert_style_pattern(
|
||||
pattern_type: str,
|
||||
pattern_text: str,
|
||||
context: str = "",
|
||||
examples: list[str] | None = None,
|
||||
) -> None:
|
||||
pool = await get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
existing = await conn.fetchrow(
|
||||
"SELECT id, frequency FROM style_patterns WHERE pattern_type = $1 AND pattern_text = $2",
|
||||
pattern_type, pattern_text,
|
||||
)
|
||||
if existing:
|
||||
await conn.execute(
|
||||
"UPDATE style_patterns SET frequency = frequency + 1 WHERE id = $1",
|
||||
existing["id"],
|
||||
)
|
||||
else:
|
||||
await conn.execute(
|
||||
"""INSERT INTO style_patterns (pattern_type, pattern_text, context, examples)
|
||||
VALUES ($1, $2, $3, $4)""",
|
||||
pattern_type, pattern_text, context,
|
||||
json.dumps(examples or []),
|
||||
)
|
||||
55
mcp-server/src/legal_mcp/services/embeddings.py
Normal file
55
mcp-server/src/legal_mcp/services/embeddings.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Embedding service using Voyage AI API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import voyageai
|
||||
|
||||
from legal_mcp import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client: voyageai.Client | None = None
|
||||
|
||||
|
||||
def _get_client() -> voyageai.Client:
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = voyageai.Client(api_key=config.VOYAGE_API_KEY)
|
||||
return _client
|
||||
|
||||
|
||||
async def embed_texts(texts: list[str], input_type: str = "document") -> list[list[float]]:
|
||||
"""Embed a batch of texts using Voyage AI.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed (max 128 per call).
|
||||
input_type: "document" for indexing, "query" for search queries.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (1024 dimensions each).
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
client = _get_client()
|
||||
all_embeddings = []
|
||||
|
||||
# Voyage AI supports up to 128 texts per batch
|
||||
for i in range(0, len(texts), 128):
|
||||
batch = texts[i : i + 128]
|
||||
result = client.embed(
|
||||
batch,
|
||||
model=config.VOYAGE_MODEL,
|
||||
input_type=input_type,
|
||||
)
|
||||
all_embeddings.extend(result.embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float]:
|
||||
"""Embed a single search query."""
|
||||
results = await embed_texts([query], input_type="query")
|
||||
return results[0]
|
||||
126
mcp-server/src/legal_mcp/services/extractor.py
Normal file
126
mcp-server/src/legal_mcp/services/extractor.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Text extraction from PDF, DOCX, and RTF files.
|
||||
|
||||
Primary PDF extraction: Claude Vision API (for scanned documents).
|
||||
Fallback: PyMuPDF direct text extraction (for born-digital PDFs).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import anthropic
|
||||
import fitz # PyMuPDF
|
||||
from docx import Document as DocxDocument
|
||||
from striprtf.striprtf import rtf_to_text
|
||||
|
||||
from legal_mcp import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_anthropic_client: anthropic.Anthropic | None = None
|
||||
|
||||
|
||||
def _get_anthropic() -> anthropic.Anthropic:
|
||||
global _anthropic_client
|
||||
if _anthropic_client is None:
|
||||
_anthropic_client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY)
|
||||
return _anthropic_client
|
||||
|
||||
|
||||
async def extract_text(file_path: str) -> tuple[str, int]:
|
||||
"""Extract text from a document file.
|
||||
|
||||
Returns:
|
||||
Tuple of (extracted_text, page_count).
|
||||
page_count is 0 for non-PDF files.
|
||||
"""
|
||||
path = Path(file_path)
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
if suffix == ".pdf":
|
||||
return await _extract_pdf(path)
|
||||
elif suffix == ".docx":
|
||||
return _extract_docx(path), 0
|
||||
elif suffix == ".rtf":
|
||||
return _extract_rtf(path), 0
|
||||
elif suffix == ".txt":
|
||||
return path.read_text(encoding="utf-8"), 0
|
||||
else:
|
||||
raise ValueError(f"Unsupported file type: {suffix}")
|
||||
|
||||
|
||||
async def _extract_pdf(path: Path) -> tuple[str, int]:
|
||||
"""Extract text from PDF. Try direct text first, fall back to Claude Vision for scanned pages."""
|
||||
doc = fitz.open(str(path))
|
||||
page_count = len(doc)
|
||||
pages_text: list[str] = []
|
||||
|
||||
for page_num in range(page_count):
|
||||
page = doc[page_num]
|
||||
# Try direct text extraction first
|
||||
text = page.get_text().strip()
|
||||
|
||||
if len(text) > 50:
|
||||
# Sufficient text found - born-digital page
|
||||
pages_text.append(text)
|
||||
logger.debug("Page %d: direct text extraction (%d chars)", page_num + 1, len(text))
|
||||
else:
|
||||
# Likely scanned - use Claude Vision
|
||||
logger.info("Page %d: using Claude Vision OCR", page_num + 1)
|
||||
pix = page.get_pixmap(dpi=200)
|
||||
img_bytes = pix.tobytes("png")
|
||||
ocr_text = await _ocr_with_claude(img_bytes, page_num + 1)
|
||||
pages_text.append(ocr_text)
|
||||
|
||||
doc.close()
|
||||
return "\n\n".join(pages_text), page_count
|
||||
|
||||
|
||||
async def _ocr_with_claude(image_bytes: bytes, page_num: int) -> str:
|
||||
"""OCR a single page image using Claude Vision API."""
|
||||
client = _get_anthropic()
|
||||
b64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
message = client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4096,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": b64_image,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"חלץ את כל הטקסט מהתמונה הזו. זהו מסמך משפטי בעברית. "
|
||||
"שמור על מבנה הפסקאות המקורי. "
|
||||
"החזר רק את הטקסט המחולץ, ללא הערות נוספות."
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
return message.content[0].text
|
||||
|
||||
|
||||
def _extract_docx(path: Path) -> str:
|
||||
"""Extract text from DOCX file."""
|
||||
doc = DocxDocument(str(path))
|
||||
paragraphs = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return "\n\n".join(paragraphs)
|
||||
|
||||
|
||||
def _extract_rtf(path: Path) -> str:
|
||||
"""Extract text from RTF file."""
|
||||
rtf_content = path.read_text(encoding="utf-8", errors="replace")
|
||||
return rtf_to_text(rtf_content)
|
||||
79
mcp-server/src/legal_mcp/services/processor.py
Normal file
79
mcp-server/src/legal_mcp/services/processor.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Document processing pipeline: extract → chunk → embed → store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from legal_mcp.services import chunker, db, embeddings, extractor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def process_document(document_id: UUID, case_id: UUID) -> dict:
|
||||
"""Full processing pipeline for a document.
|
||||
|
||||
1. Extract text from file
|
||||
2. Split into chunks
|
||||
3. Generate embeddings
|
||||
4. Store chunks + embeddings in DB
|
||||
|
||||
Returns processing summary.
|
||||
"""
|
||||
doc = await db.get_document(document_id)
|
||||
if not doc:
|
||||
raise ValueError(f"Document {document_id} not found")
|
||||
|
||||
await db.update_document(document_id, extraction_status="processing")
|
||||
|
||||
try:
|
||||
# Step 1: Extract text
|
||||
logger.info("Extracting text from %s", doc["file_path"])
|
||||
text, page_count = await extractor.extract_text(doc["file_path"])
|
||||
|
||||
await db.update_document(
|
||||
document_id,
|
||||
extracted_text=text,
|
||||
page_count=page_count,
|
||||
)
|
||||
|
||||
# Step 2: Chunk
|
||||
logger.info("Chunking document (%d chars)", len(text))
|
||||
chunks = chunker.chunk_document(text)
|
||||
|
||||
if not chunks:
|
||||
await db.update_document(document_id, extraction_status="completed")
|
||||
return {"status": "completed", "chunks": 0, "message": "No text to chunk"}
|
||||
|
||||
# Step 3: Embed
|
||||
logger.info("Generating embeddings for %d chunks", len(chunks))
|
||||
texts = [c.content for c in chunks]
|
||||
embs = await embeddings.embed_texts(texts, input_type="document")
|
||||
|
||||
# Step 4: Store
|
||||
chunk_dicts = [
|
||||
{
|
||||
"content": c.content,
|
||||
"section_type": c.section_type,
|
||||
"embedding": emb,
|
||||
"page_number": c.page_number,
|
||||
"chunk_index": c.chunk_index,
|
||||
}
|
||||
for c, emb in zip(chunks, embs)
|
||||
]
|
||||
|
||||
stored = await db.store_chunks(document_id, case_id, chunk_dicts)
|
||||
await db.update_document(document_id, extraction_status="completed")
|
||||
|
||||
logger.info("Document processed: %d chunks stored", stored)
|
||||
return {
|
||||
"status": "completed",
|
||||
"chunks": stored,
|
||||
"pages": page_count,
|
||||
"text_length": len(text),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Document processing failed: %s", e)
|
||||
await db.update_document(document_id, extraction_status="failed")
|
||||
return {"status": "failed", "error": str(e)}
|
||||
121
mcp-server/src/legal_mcp/services/style_analyzer.py
Normal file
121
mcp-server/src/legal_mcp/services/style_analyzer.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Style analyzer - extracts writing patterns from Dafna's decision corpus."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
import anthropic
|
||||
|
||||
from legal_mcp import config
|
||||
from legal_mcp.services import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ANALYSIS_PROMPT = """\
|
||||
אתה מנתח סגנון כתיבה משפטית. לפניך החלטות משפטיות שנכתבו על ידי אותה יושבת ראש של ועדת ערר.
|
||||
|
||||
נתח את ההחלטות וחלץ את דפוסי הכתיבה הבאים:
|
||||
|
||||
1. **נוסחאות פתיחה** (opening_formula) - איך מתחילות ההחלטות
|
||||
2. **ביטויי מעבר** (transition) - ביטויים שמחברים בין חלקי ההחלטה
|
||||
3. **סגנון ציטוט** (citation_style) - איך מצטטים חקיקה ופסיקה
|
||||
4. **מבנה ניתוח** (analysis_structure) - איך בנוי הניתוח המשפטי
|
||||
5. **נוסחאות סיום** (closing_formula) - איך מסתיימות ההחלטות
|
||||
6. **ביטויים אופייניים** (characteristic_phrase) - ביטויים ייחודיים שחוזרים
|
||||
|
||||
לכל דפוס, תן:
|
||||
- הטקסט המדויק של הדפוס
|
||||
- הקשר (באיזה חלק של ההחלטה הוא מופיע)
|
||||
- דוגמה מתוך הטקסט
|
||||
|
||||
החזר את התוצאות בפורמט הבא (JSON array):
|
||||
```json
|
||||
[
|
||||
{{
|
||||
"type": "opening_formula",
|
||||
"text": "לפניי ערר על החלטת...",
|
||||
"context": "פתיחת ההחלטה",
|
||||
"example": "לפניי ערר על החלטת הוועדה המקומית לתכנון ובניה ירושלים"
|
||||
}}
|
||||
]
|
||||
```
|
||||
|
||||
ההחלטות:
|
||||
{decisions}
|
||||
"""
|
||||
|
||||
|
||||
async def analyze_corpus() -> dict:
|
||||
"""Analyze the style corpus and extract/update patterns.
|
||||
|
||||
Returns summary of patterns found.
|
||||
"""
|
||||
pool = await db.get_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT full_text, decision_number FROM style_corpus ORDER BY decision_date DESC LIMIT 20"
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return {"error": "אין החלטות בקורפוס. העלה החלטות קודמות תחילה."}
|
||||
|
||||
# Prepare text for analysis
|
||||
decisions_text = ""
|
||||
for row in rows:
|
||||
decisions_text += f"\n\n--- החלטה {row['decision_number'] or 'ללא מספר'} ---\n"
|
||||
# Limit each decision to ~3000 chars to fit context
|
||||
text = row["full_text"]
|
||||
if len(text) > 3000:
|
||||
text = text[:1500] + "\n...\n" + text[-1500:]
|
||||
decisions_text += text
|
||||
|
||||
# Call Claude to analyze patterns
|
||||
client = anthropic.Anthropic(api_key=config.ANTHROPIC_API_KEY)
|
||||
message = client.messages.create(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=16384,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": ANALYSIS_PROMPT.format(decisions=decisions_text),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
response_text = message.content[0].text
|
||||
|
||||
# Extract JSON from response - prefer code-block fenced JSON
|
||||
import json
|
||||
code_block = re.search(r"```(?:json)?\s*(\[[\s\S]*?\])\s*```", response_text)
|
||||
if code_block:
|
||||
json_str = code_block.group(1)
|
||||
else:
|
||||
# Fallback: find the last JSON array (skip prose brackets)
|
||||
all_arrays = list(re.finditer(r"\[[\s\S]*?\]", response_text))
|
||||
if not all_arrays:
|
||||
return {"error": "Could not parse analysis results", "raw": response_text}
|
||||
json_str = all_arrays[-1].group()
|
||||
|
||||
try:
|
||||
patterns = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
return {"error": f"JSON parse error: {e}", "raw": response_text}
|
||||
|
||||
# Store patterns
|
||||
count = 0
|
||||
for pattern in patterns:
|
||||
await db.upsert_style_pattern(
|
||||
pattern_type=pattern.get("type", "other"),
|
||||
pattern_text=pattern.get("text", ""),
|
||||
context=pattern.get("context", ""),
|
||||
examples=[pattern.get("example", "")],
|
||||
)
|
||||
count += 1
|
||||
|
||||
return {
|
||||
"patterns_found": count,
|
||||
"decisions_analyzed": len(rows),
|
||||
"pattern_types": list({p.get("type") for p in patterns}),
|
||||
}
|
||||
Reference in New Issue
Block a user