fix(retrieval): rewrite chunk-page retrofit to skip OCR
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 16s

The first-pass retrofit re-extracted via extractor.extract_text, which
re-runs Google Vision OCR on scanned pages. OCR is non-deterministic,
so the new text didn't match the chunk content stored in the DB
(produced by the original OCR run) — only ~7% of chunks were located.

New approach (no OCR cost):

1. Use the stored documents.extracted_text from the DB — the exact
   text the chunks were produced from, so chunk lookups match.
2. Anchor page boundaries via PyMuPDF direct text reads (free, no
   OCR). Pages with usable direct text are anchored by snippet match;
   OCR-only pages are linearly interpolated between anchors.
3. Search each chunk in extracted_text using a whitespace-tolerant
   helper — needed because the chunker joins paragraphs with single
   '\\n' while extracted_text uses '\\n\\n' as page separators.

Verified on 8174-24 (5 docs, 307 chunks) + 8137-24 (9 docs, 512
chunks): 100% chunks tagged, 13s total, $0 cost.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-03 20:04:33 +00:00
parent 81ccf3a888
commit 8a815ecff5

View File

@@ -1,36 +1,42 @@
"""Backfill page_number on existing document_chunks. """Backfill page_number on existing document_chunks (no re-OCR).
Why this exists: the legacy chunker did not track which page each chunk Why this exists: the legacy chunker did not track which page each chunk
came from. After the page-tracking fix, new uploads carry page_number came from. After the page-tracking fix, new uploads carry page_number
correctly, but existing chunks have ``page_number=NULL`` in the DB. correctly, but existing chunks have ``page_number=NULL`` in the DB.
That blocks the multimodal hybrid retriever's text+image boost (it That blocks the multimodal hybrid retriever's text+image boost (which
joins (chunk, image) on (document_id, page_number)). joins (chunk, image) on (document_id, page_number)).
What it does (per case): What it does (per case, per document):
1. List every document in the case
2. For each document with NULL page_number chunks: 1. Load stored ``documents.extracted_text`` from the DB. This is
a. Re-extract via extractor.extract_text (re-runs OCR if needed the exact text that was used to produce the existing chunks
~$0.0015/page on Google Vision; idempotent on the DB side) so chunk content lookups against it match verbatim.
b. Compute page_offsets from the re-extracted text 2. Open the PDF with PyMuPDF and call ``page.get_text()`` on each
c. For every chunk row (sorted by chunk_index), search its page (cheap, no OCR). For pages with usable direct text we get
content in the re-extracted text → look up page → UPDATE a clean snippet; for fully-scanned pages we get little/nothing.
3. Skip documents whose chunks already have non-null page_number 3. Anchor: for each page with a usable snippet, search the snippet
in ``extracted_text`` to recover that page's start offset.
4. Interpolate: for OCR-only pages with no anchor, position is
linearly interpolated between the nearest anchored neighbors
(or uniformly when no anchors exist at all).
5. For every chunk row (sorted by chunk_index), find the chunk's
content in ``extracted_text`` (verbatim match), look up the
page from the offsets, and ``UPDATE document_chunks SET
page_number = ?``.
Idempotent: a second run with no --force is a no-op. Idempotent: a second run with no --force is a no-op.
Designed to run from inside the FastAPI/MCP container (where /data Cost: zero. Runs in seconds even for the 89-page appraisal report.
is mounted and Google Vision creds are present). Locally it requires
GOOGLE_CLOUD_VISION_API_KEY in ~/.env.
Usage: Usage:
docker exec -it <legal-ai-container> python /tmp/backfill_chunk_pages.py 8174-24 8137-24 docker cp scripts/backfill_chunk_pages.py <c>:/tmp/
docker exec <c> python /tmp/backfill_chunk_pages.py 8174-24 8137-24
""" """
from __future__ import annotations from __future__ import annotations
import argparse import argparse
import asyncio import asyncio
import logging import logging
import os
import sys import sys
import time import time
from pathlib import Path from pathlib import Path
@@ -45,7 +51,8 @@ def _setup_paths():
_setup_paths() _setup_paths()
from legal_mcp.services import db, extractor # noqa: E402 import fitz # PyMuPDF # noqa: E402
from legal_mcp.services import db # noqa: E402
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
@@ -54,6 +61,14 @@ logging.basicConfig(
logger = logging.getLogger("backfill_chunk_pages") logger = logging.getLogger("backfill_chunk_pages")
# Snippet length for page anchoring. Long enough to be unique, short
# enough to survive minor whitespace variation between PyMuPDF direct
# extraction and the stored OCR text.
ANCHOR_SNIPPET_LEN = 80
# Minimum direct-text length on a page to attempt anchoring at all.
MIN_DIRECT_LEN = 60
def _resolve_local_path(db_path: str) -> Path: def _resolve_local_path(db_path: str) -> Path:
p = Path(db_path) p = Path(db_path)
if p.is_file(): if p.is_file():
@@ -65,6 +80,123 @@ def _resolve_local_path(db_path: str) -> Path:
return p return p
def _norm_whitespace(s: str) -> str:
"""Collapse runs of whitespace; helps cross-source matching where
PyMuPDF direct extraction may differ from the stored OCR text in
line-break placement."""
return " ".join(s.split())
def _find_anchored_snippet(
extracted_text: str, snippet: str, search_start: int = 0,
) -> int:
"""Search for ``snippet`` in ``extracted_text``, tolerant to
whitespace differences. Returns the offset in the original
extracted_text, or -1."""
# Direct match first — fastest path
idx = extracted_text.find(snippet, search_start)
if idx >= 0:
return idx
# Whitespace-normalized fallback
norm_text = _norm_whitespace(extracted_text)
norm_snip = _norm_whitespace(snippet)
if not norm_snip:
return -1
norm_idx = norm_text.find(norm_snip)
if norm_idx < 0:
return -1
# Map norm offset back to original — count chars until we've passed
# `norm_idx` non-collapsed characters in the original.
orig_pos = 0
norm_pos = 0
in_ws = False
for ch in extracted_text:
if norm_pos == norm_idx:
return orig_pos
if ch.isspace():
if not in_ws:
norm_pos += 1
in_ws = True
else:
in_ws = False
norm_pos += 1
orig_pos += 1
return -1
def _compute_page_offsets(pdf_path: Path, extracted_text: str) -> list[int]:
"""Return ``page_offsets`` (start char offset of each page in
``extracted_text``), using direct PyMuPDF reads for anchoring and
linear interpolation for OCR-only pages."""
doc = fitz.open(str(pdf_path))
n_pages = len(doc)
anchors: list[int | None] = [None] * n_pages
last_pos = 0
for i, page in enumerate(doc):
direct = page.get_text().strip()
if len(direct) < MIN_DIRECT_LEN:
continue
# Take the first ANCHOR_SNIPPET_LEN chars after stripping
snippet = direct[:ANCHOR_SNIPPET_LEN]
pos = _find_anchored_snippet(extracted_text, snippet, last_pos)
if pos < 0:
# try a global search before giving up
pos = _find_anchored_snippet(extracted_text, snippet, 0)
if pos >= 0:
anchors[i] = pos
last_pos = pos
doc.close()
# Force first page to start at 0 if not already anchored
if anchors[0] is None:
anchors[0] = 0
# Fill gaps via linear interpolation between the nearest anchors;
# extrapolate beyond the last anchor by the average page length.
page_offsets: list[int] = [0] * n_pages
for i in range(n_pages):
if anchors[i] is not None:
page_offsets[i] = anchors[i]
continue
# Find prev anchored
prev_i = i - 1
while prev_i >= 0 and anchors[prev_i] is None:
prev_i -= 1
# Find next anchored
next_i = i + 1
while next_i < n_pages and anchors[next_i] is None:
next_i += 1
prev_pos = anchors[prev_i] if prev_i >= 0 else 0
if next_i < n_pages:
next_pos = anchors[next_i]
ratio = (i - prev_i) / (next_i - prev_i)
page_offsets[i] = int(prev_pos + ratio * (next_pos - prev_pos))
else:
# Extrapolate: assume uniform distribution beyond last anchor
# using page-density inferred from prior anchors (or fall
# back to total_text/n_pages).
avg = len(extracted_text) / max(1, n_pages)
page_offsets[i] = int(prev_pos + avg * (i - prev_i))
# Monotone-clip just in case interpolation ever goes backwards
for i in range(1, n_pages):
if page_offsets[i] < page_offsets[i - 1]:
page_offsets[i] = page_offsets[i - 1]
return page_offsets
def _page_at_offset(offset: int, page_offsets: list[int]) -> int:
if not page_offsets:
return 1
page = 1
for i, start in enumerate(page_offsets):
if start <= offset:
page = i + 1
else:
break
return page
async def _backfill_document( async def _backfill_document(
document_id: UUID, document_id: UUID,
title: str, title: str,
@@ -73,7 +205,6 @@ async def _backfill_document(
) -> dict: ) -> dict:
pool = await db.get_pool() pool = await db.get_pool()
# Fetch chunks for this document
chunks = await pool.fetch( chunks = await pool.fetch(
"SELECT id, chunk_index, content, page_number FROM document_chunks " "SELECT id, chunk_index, content, page_number FROM document_chunks "
"WHERE document_id = $1 ORDER BY chunk_index", "WHERE document_id = $1 ORDER BY chunk_index",
@@ -94,15 +225,21 @@ async def _backfill_document(
if pdf_path.suffix.lower() != ".pdf": if pdf_path.suffix.lower() != ".pdf":
return {"status": "not_pdf"} return {"status": "not_pdf"}
logger.info(" re-extracting %s (%d chunks, %d need page)", doc_row = await pool.fetchrow(
title, len(chunks), n_null) "SELECT extracted_text FROM documents WHERE id = $1", document_id,
t0 = time.time() )
text, page_count, page_offsets = await extractor.extract_text(str(pdf_path)) extracted_text = doc_row["extracted_text"] if doc_row else None
elapsed = time.time() - t0 if not extracted_text:
if not page_offsets: return {"status": "no_extracted_text"}
return {"status": "no_offsets"}
# Walk chunks, find each in the re-extracted text, assign page t0 = time.time()
page_offsets = _compute_page_offsets(pdf_path, extracted_text)
n_anchored = sum(1 for i in range(len(page_offsets)) if i == 0 or page_offsets[i] > page_offsets[i - 1])
# The chunker joins paragraphs with single `\n` while extracted_text
# has `\n\n` between pages, so verbatim search misses cross-page
# chunks. Use the whitespace-tolerant helper that returns an offset
# in the *original* text.
pos = 0 pos = 0
updated = 0 updated = 0
not_found = 0 not_found = 0
@@ -110,29 +247,34 @@ async def _backfill_document(
content = c["content"] content = c["content"]
if not content: if not content:
continue continue
idx = text.find(content, pos) # Use a unique slice from the chunk to anchor in extracted_text
# — anchoring on the chunk's first ~120 chars is enough to
# disambiguate across the document.
snippet = content[: min(len(content), 120)]
idx = _find_anchored_snippet(extracted_text, snippet, pos)
if idx < 0: if idx < 0:
idx = text.find(content) # global fallback idx = _find_anchored_snippet(extracted_text, snippet, 0)
if idx < 0: if idx < 0:
not_found += 1 not_found += 1
continue continue
page = extractor.page_at_offset(idx, page_offsets) page = _page_at_offset(idx, page_offsets)
await pool.execute( await pool.execute(
"UPDATE document_chunks SET page_number = $1 WHERE id = $2", "UPDATE document_chunks SET page_number = $1 WHERE id = $2",
page, c["id"], page, c["id"],
) )
updated += 1 updated += 1
# advance roughly past midpoint — chunks have overlap
pos = idx + max(1, len(content) // 2) pos = idx + max(1, len(content) // 2)
elapsed = time.time() - t0
logger.info( logger.info(
" done in %.1fs: extracted %d pages, updated %d/%d chunks, " " %s%d pages, %d anchors, updated %d/%d chunks (%d not found) in %.2fs",
"%d not found", elapsed, page_count, updated, len(chunks), not_found, title, len(page_offsets), n_anchored, updated, len(chunks), not_found, elapsed,
) )
return { return {
"status": "ok", "status": "ok",
"elapsed_sec": round(elapsed, 1), "elapsed_sec": round(elapsed, 2),
"pages": page_count, "pages": len(page_offsets),
"anchors": n_anchored,
"chunks_total": len(chunks), "chunks_total": len(chunks),
"chunks_updated": updated, "chunks_updated": updated,
"chunks_not_found": not_found, "chunks_not_found": not_found,
@@ -168,7 +310,7 @@ async def backfill_cases(case_numbers: list[str], force: bool) -> dict:
"skipped": sum(1 for r in per_doc if r["status"] == "skipped"), "skipped": sum(1 for r in per_doc if r["status"] == "skipped"),
"missing": sum(1 for r in per_doc if r["status"] == "missing"), "missing": sum(1 for r in per_doc if r["status"] == "missing"),
"no_chunks": sum(1 for r in per_doc if r["status"] == "no_chunks"), "no_chunks": sum(1 for r in per_doc if r["status"] == "no_chunks"),
"no_offsets": sum(1 for r in per_doc if r["status"] == "no_offsets"), "no_extracted_text": sum(1 for r in per_doc if r["status"] == "no_extracted_text"),
"chunks_updated": sum(r.get("chunks_updated", 0) for r in per_doc), "chunks_updated": sum(r.get("chunks_updated", 0) for r in per_doc),
"documents": per_doc, "documents": per_doc,
} }
@@ -176,11 +318,11 @@ async def backfill_cases(case_numbers: list[str], force: bool) -> dict:
def main(): def main():
parser = argparse.ArgumentParser(description="Backfill page_number on existing chunks") parser = argparse.ArgumentParser(description="Backfill page_number on existing chunks (no OCR)")
parser.add_argument("cases", nargs="+", help="Case numbers (e.g. 8174-24 8137-24)") parser.add_argument("cases", nargs="+", help="Case numbers (e.g. 8174-24 8137-24)")
parser.add_argument( parser.add_argument(
"--force", action="store_true", "--force", action="store_true",
help="Re-extract even if all chunks already have page_number (default: skip)", help="Re-process even if all chunks already have page_number (default: skip)",
) )
args = parser.parse_args() args = parser.parse_args()
@@ -195,8 +337,8 @@ def main():
continue continue
print( print(
f" {cn}: {s['documents_total']} docs — " f" {cn}: {s['documents_total']} docs — "
f"ok {s['ok']}, skipped {s['skipped']}, missing {s['missing']}, " f"ok {s['ok']}, skipped {s['skipped']}, "
f"chunks_updated {s['chunks_updated']}" f"missing {s['missing']}, chunks_updated {s['chunks_updated']}"
) )