Files
legal-ai/scripts/backfill_chunk_pages.py
Chaim 8a815ecff5
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 16s
fix(retrieval): rewrite chunk-page retrofit to skip OCR
The first-pass retrofit re-extracted via extractor.extract_text, which
re-runs Google Vision OCR on scanned pages. OCR is non-deterministic,
so the new text didn't match the chunk content stored in the DB
(produced by the original OCR run) — only ~7% of chunks were located.

New approach (no OCR cost):

1. Use the stored documents.extracted_text from the DB — the exact
   text the chunks were produced from, so chunk lookups match.
2. Anchor page boundaries via PyMuPDF direct text reads (free, no
   OCR). Pages with usable direct text are anchored by snippet match;
   OCR-only pages are linearly interpolated between anchors.
3. Search each chunk in extracted_text using a whitespace-tolerant
   helper — needed because the chunker joins paragraphs with single
   '\\n' while extracted_text uses '\\n\\n' as page separators.

Verified on 8174-24 (5 docs, 307 chunks) + 8137-24 (9 docs, 512
chunks): 100% chunks tagged, 13s total, $0 cost.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-03 20:04:33 +00:00

347 lines
12 KiB
Python

"""Backfill page_number on existing document_chunks (no re-OCR).
Why this exists: the legacy chunker did not track which page each chunk
came from. After the page-tracking fix, new uploads carry page_number
correctly, but existing chunks have ``page_number=NULL`` in the DB.
That blocks the multimodal hybrid retriever's text+image boost (which
joins (chunk, image) on (document_id, page_number)).
What it does (per case, per document):
1. Load stored ``documents.extracted_text`` from the DB. This is
the exact text that was used to produce the existing chunks —
so chunk content lookups against it match verbatim.
2. Open the PDF with PyMuPDF and call ``page.get_text()`` on each
page (cheap, no OCR). For pages with usable direct text we get
a clean snippet; for fully-scanned pages we get little/nothing.
3. Anchor: for each page with a usable snippet, search the snippet
in ``extracted_text`` to recover that page's start offset.
4. Interpolate: for OCR-only pages with no anchor, position is
linearly interpolated between the nearest anchored neighbors
(or uniformly when no anchors exist at all).
5. For every chunk row (sorted by chunk_index), find the chunk's
content in ``extracted_text`` (verbatim match), look up the
page from the offsets, and ``UPDATE document_chunks SET
page_number = ?``.
Idempotent: a second run with no --force is a no-op.
Cost: zero. Runs in seconds even for the 89-page appraisal report.
Usage:
docker cp scripts/backfill_chunk_pages.py <c>:/tmp/
docker exec <c> python /tmp/backfill_chunk_pages.py 8174-24 8137-24
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import sys
import time
from pathlib import Path
from uuid import UUID
def _setup_paths():
here = Path(__file__).resolve().parent
mcp_src = here.parent / "mcp-server" / "src"
if mcp_src.is_dir() and str(mcp_src) not in sys.path:
sys.path.insert(0, str(mcp_src))
_setup_paths()
import fitz # PyMuPDF # noqa: E402
from legal_mcp.services import db # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("backfill_chunk_pages")
# Snippet length for page anchoring. Long enough to be unique, short
# enough to survive minor whitespace variation between PyMuPDF direct
# extraction and the stored OCR text.
ANCHOR_SNIPPET_LEN = 80
# Minimum direct-text length on a page to attempt anchoring at all.
MIN_DIRECT_LEN = 60
def _resolve_local_path(db_path: str) -> Path:
p = Path(db_path)
if p.is_file():
return p
if str(p).startswith("/data/"):
local = Path("/home/chaim/legal-ai") / Path(*p.parts[1:])
if local.is_file():
return local
return p
def _norm_whitespace(s: str) -> str:
"""Collapse runs of whitespace; helps cross-source matching where
PyMuPDF direct extraction may differ from the stored OCR text in
line-break placement."""
return " ".join(s.split())
def _find_anchored_snippet(
extracted_text: str, snippet: str, search_start: int = 0,
) -> int:
"""Search for ``snippet`` in ``extracted_text``, tolerant to
whitespace differences. Returns the offset in the original
extracted_text, or -1."""
# Direct match first — fastest path
idx = extracted_text.find(snippet, search_start)
if idx >= 0:
return idx
# Whitespace-normalized fallback
norm_text = _norm_whitespace(extracted_text)
norm_snip = _norm_whitespace(snippet)
if not norm_snip:
return -1
norm_idx = norm_text.find(norm_snip)
if norm_idx < 0:
return -1
# Map norm offset back to original — count chars until we've passed
# `norm_idx` non-collapsed characters in the original.
orig_pos = 0
norm_pos = 0
in_ws = False
for ch in extracted_text:
if norm_pos == norm_idx:
return orig_pos
if ch.isspace():
if not in_ws:
norm_pos += 1
in_ws = True
else:
in_ws = False
norm_pos += 1
orig_pos += 1
return -1
def _compute_page_offsets(pdf_path: Path, extracted_text: str) -> list[int]:
"""Return ``page_offsets`` (start char offset of each page in
``extracted_text``), using direct PyMuPDF reads for anchoring and
linear interpolation for OCR-only pages."""
doc = fitz.open(str(pdf_path))
n_pages = len(doc)
anchors: list[int | None] = [None] * n_pages
last_pos = 0
for i, page in enumerate(doc):
direct = page.get_text().strip()
if len(direct) < MIN_DIRECT_LEN:
continue
# Take the first ANCHOR_SNIPPET_LEN chars after stripping
snippet = direct[:ANCHOR_SNIPPET_LEN]
pos = _find_anchored_snippet(extracted_text, snippet, last_pos)
if pos < 0:
# try a global search before giving up
pos = _find_anchored_snippet(extracted_text, snippet, 0)
if pos >= 0:
anchors[i] = pos
last_pos = pos
doc.close()
# Force first page to start at 0 if not already anchored
if anchors[0] is None:
anchors[0] = 0
# Fill gaps via linear interpolation between the nearest anchors;
# extrapolate beyond the last anchor by the average page length.
page_offsets: list[int] = [0] * n_pages
for i in range(n_pages):
if anchors[i] is not None:
page_offsets[i] = anchors[i]
continue
# Find prev anchored
prev_i = i - 1
while prev_i >= 0 and anchors[prev_i] is None:
prev_i -= 1
# Find next anchored
next_i = i + 1
while next_i < n_pages and anchors[next_i] is None:
next_i += 1
prev_pos = anchors[prev_i] if prev_i >= 0 else 0
if next_i < n_pages:
next_pos = anchors[next_i]
ratio = (i - prev_i) / (next_i - prev_i)
page_offsets[i] = int(prev_pos + ratio * (next_pos - prev_pos))
else:
# Extrapolate: assume uniform distribution beyond last anchor
# using page-density inferred from prior anchors (or fall
# back to total_text/n_pages).
avg = len(extracted_text) / max(1, n_pages)
page_offsets[i] = int(prev_pos + avg * (i - prev_i))
# Monotone-clip just in case interpolation ever goes backwards
for i in range(1, n_pages):
if page_offsets[i] < page_offsets[i - 1]:
page_offsets[i] = page_offsets[i - 1]
return page_offsets
def _page_at_offset(offset: int, page_offsets: list[int]) -> int:
if not page_offsets:
return 1
page = 1
for i, start in enumerate(page_offsets):
if start <= offset:
page = i + 1
else:
break
return page
async def _backfill_document(
document_id: UUID,
title: str,
db_file_path: str,
force: bool,
) -> dict:
pool = await db.get_pool()
chunks = await pool.fetch(
"SELECT id, chunk_index, content, page_number FROM document_chunks "
"WHERE document_id = $1 ORDER BY chunk_index",
document_id,
)
if not chunks:
return {"status": "no_chunks"}
n_null = sum(1 for c in chunks if c["page_number"] is None)
if not force and n_null == 0:
logger.info(" skip (all %d chunks already tagged): %s", len(chunks), title)
return {"status": "skipped", "chunks": len(chunks)}
pdf_path = _resolve_local_path(db_file_path)
if not pdf_path.is_file():
logger.warning(" file missing: %s (%s)", pdf_path, title)
return {"status": "missing"}
if pdf_path.suffix.lower() != ".pdf":
return {"status": "not_pdf"}
doc_row = await pool.fetchrow(
"SELECT extracted_text FROM documents WHERE id = $1", document_id,
)
extracted_text = doc_row["extracted_text"] if doc_row else None
if not extracted_text:
return {"status": "no_extracted_text"}
t0 = time.time()
page_offsets = _compute_page_offsets(pdf_path, extracted_text)
n_anchored = sum(1 for i in range(len(page_offsets)) if i == 0 or page_offsets[i] > page_offsets[i - 1])
# The chunker joins paragraphs with single `\n` while extracted_text
# has `\n\n` between pages, so verbatim search misses cross-page
# chunks. Use the whitespace-tolerant helper that returns an offset
# in the *original* text.
pos = 0
updated = 0
not_found = 0
for c in chunks:
content = c["content"]
if not content:
continue
# Use a unique slice from the chunk to anchor in extracted_text
# — anchoring on the chunk's first ~120 chars is enough to
# disambiguate across the document.
snippet = content[: min(len(content), 120)]
idx = _find_anchored_snippet(extracted_text, snippet, pos)
if idx < 0:
idx = _find_anchored_snippet(extracted_text, snippet, 0)
if idx < 0:
not_found += 1
continue
page = _page_at_offset(idx, page_offsets)
await pool.execute(
"UPDATE document_chunks SET page_number = $1 WHERE id = $2",
page, c["id"],
)
updated += 1
pos = idx + max(1, len(content) // 2)
elapsed = time.time() - t0
logger.info(
" %s%d pages, %d anchors, updated %d/%d chunks (%d not found) in %.2fs",
title, len(page_offsets), n_anchored, updated, len(chunks), not_found, elapsed,
)
return {
"status": "ok",
"elapsed_sec": round(elapsed, 2),
"pages": len(page_offsets),
"anchors": n_anchored,
"chunks_total": len(chunks),
"chunks_updated": updated,
"chunks_not_found": not_found,
}
async def backfill_cases(case_numbers: list[str], force: bool) -> dict:
pool = await db.get_pool()
summary: dict = {}
for cn in case_numbers:
logger.info("=" * 60)
logger.info("Case %s", cn)
case = await db.get_case_by_number(cn)
if not case:
logger.warning("Case not found: %s", cn)
summary[cn] = {"status": "case_not_found"}
continue
case_id = UUID(str(case["id"]))
docs = await pool.fetch(
"SELECT id, title, file_path FROM documents WHERE case_id = $1 ORDER BY title",
case_id,
)
logger.info(" %d documents", len(docs))
per_doc: list[dict] = []
for d in docs:
r = await _backfill_document(
UUID(str(d["id"])), d["title"], d["file_path"], force,
)
per_doc.append({"document_id": str(d["id"]), "title": d["title"], **r})
summary[cn] = {
"documents_total": len(docs),
"ok": sum(1 for r in per_doc if r["status"] == "ok"),
"skipped": sum(1 for r in per_doc if r["status"] == "skipped"),
"missing": sum(1 for r in per_doc if r["status"] == "missing"),
"no_chunks": sum(1 for r in per_doc if r["status"] == "no_chunks"),
"no_extracted_text": sum(1 for r in per_doc if r["status"] == "no_extracted_text"),
"chunks_updated": sum(r.get("chunks_updated", 0) for r in per_doc),
"documents": per_doc,
}
return summary
def main():
parser = argparse.ArgumentParser(description="Backfill page_number on existing chunks (no OCR)")
parser.add_argument("cases", nargs="+", help="Case numbers (e.g. 8174-24 8137-24)")
parser.add_argument(
"--force", action="store_true",
help="Re-process even if all chunks already have page_number (default: skip)",
)
args = parser.parse_args()
summary = asyncio.run(backfill_cases(args.cases, force=args.force))
print()
print("=" * 60)
print("SUMMARY")
print("=" * 60)
for cn, s in summary.items():
if s.get("status") == "case_not_found":
print(f" {cn}: NOT FOUND")
continue
print(
f" {cn}: {s['documents_total']} docs — "
f"ok {s['ok']}, skipped {s['skipped']}, "
f"missing {s['missing']}, chunks_updated {s['chunks_updated']}"
)
if __name__ == "__main__":
main()