"""Legal document chunker - splits text into sections and chunks for RAG.""" from __future__ import annotations import re from dataclasses import dataclass, field from legal_mcp import config # Hebrew legal section headers. # Covers both appeals committee decisions and external court rulings — # court rulings use slightly different vocabulary (פסק דין, נימוקים, סוף דבר). SECTION_PATTERNS = [ (r"רקע\s*עובדתי|רקע\s*כללי|העובדות|הרקע", "facts"), (r"טענות\s*העוררי[םן]|טענות\s*המערערי[םן]|עיקר\s*טענות\s*העוררי[םן]", "appellant_claims"), (r"טענות\s*המשיבי[םן]|תשובת\s*המשיבי[םן]|עיקר\s*טענות\s*המשיבי[םן]", "respondent_claims"), (r"דיון\s*והכרעה|דיון|הכרעה|ניתוח\s*משפטי|המסגרת\s*המשפטית|נימוקים", "legal_analysis"), (r"מסקנ[הות]|סיכום|סוף\s*דבר", "conclusion"), (r"פסק[- ]?דין|החלטה|לפיכך\s*אני\s*מחליט|התוצאה", "ruling"), (r"מבוא|פתיחה|לפניי", "intro"), ] @dataclass class Chunk: content: str section_type: str = "other" page_number: int | None = None chunk_index: int = 0 def chunk_document( text: str, chunk_size: int = config.CHUNK_SIZE_TOKENS, overlap: int = config.CHUNK_OVERLAP_TOKENS, page_offsets: list[int] | None = None, ) -> list[Chunk]: """Split a legal document into chunks, respecting section boundaries. When ``page_offsets`` is supplied (from a PDF extraction), each chunk is tagged with the page number of its first character — used by the multimodal hybrid retriever to join (text chunk, image at same page) and surface text+image matches. """ if not text.strip(): return [] sections = _split_into_sections(text) chunks: list[Chunk] = [] idx = 0 for section_type, section_text in sections: section_chunks = _split_section(section_text, chunk_size, overlap) for chunk_text in section_chunks: chunks.append(Chunk( content=chunk_text, section_type=section_type, chunk_index=idx, )) idx += 1 if page_offsets: _assign_pages(chunks, text, page_offsets) return chunks def _assign_pages(chunks: list[Chunk], text: str, page_offsets: list[int]) -> None: """Locate each chunk's first character in ``text`` and tag with the page that contains that offset. Mutates chunks in-place. Chunks have overlap so we search forward from a position slightly past the previous chunk's start. Falls back to a global search if the forward scan misses (rare — happens only when overlap is bigger than the advance distance below). """ from legal_mcp.services.extractor import page_at_offset pos = 0 for c in chunks: idx = text.find(c.content, pos) if idx < 0: idx = text.find(c.content) if idx < 0: continue c.page_number = page_at_offset(idx, page_offsets) # advance past the chunk's halfway point — overlap is < 50% so # the next chunk's starting point will be after this cursor. pos = idx + max(1, len(c.content) // 2) def _split_into_sections(text: str) -> list[tuple[str, str]]: """Split text into (section_type, text) pairs based on Hebrew headers.""" # Find all section headers and their positions markers: list[tuple[int, str]] = [] for pattern, section_type in SECTION_PATTERNS: for match in re.finditer(pattern, text): markers.append((match.start(), section_type)) if not markers: # No sections found - treat as single block return [("other", text)] markers.sort(key=lambda x: x[0]) sections: list[tuple[str, str]] = [] # Text before first section if markers[0][0] > 0: intro_text = text[: markers[0][0]].strip() if intro_text: sections.append(("intro", intro_text)) # Each section for i, (pos, section_type) in enumerate(markers): end = markers[i + 1][0] if i + 1 < len(markers) else len(text) section_text = text[pos:end].strip() if section_text: sections.append((section_type, section_text)) return sections def _split_section(text: str, chunk_size: int, overlap: int) -> list[str]: """Split a section into overlapping chunks by paragraphs. Uses approximate token counting (Hebrew ~1.5 chars per token). """ if not text.strip(): return [] paragraphs = [p.strip() for p in text.split("\n") if p.strip()] chunks: list[str] = [] current: list[str] = [] current_tokens = 0 for para in paragraphs: para_tokens = _estimate_tokens(para) if current_tokens + para_tokens > chunk_size and current: chunks.append("\n".join(current)) # Keep overlap overlap_paras: list[str] = [] overlap_tokens = 0 for p in reversed(current): pt = _estimate_tokens(p) if overlap_tokens + pt > overlap: break overlap_paras.insert(0, p) overlap_tokens += pt current = overlap_paras current_tokens = overlap_tokens current.append(para) current_tokens += para_tokens if current: chunks.append("\n".join(current)) return chunks def _estimate_tokens(text: str) -> int: """Rough token estimate for Hebrew text (~1.5 chars per token).""" return max(1, len(text) // 2)