diff --git a/web-ui/src/components/documents/upload-sheet.tsx b/web-ui/src/components/documents/upload-sheet.tsx index 8bc3b75..310f672 100644 --- a/web-ui/src/components/documents/upload-sheet.tsx +++ b/web-ui/src/components/documents/upload-sheet.tsx @@ -43,6 +43,7 @@ function statusLabel(event: ProgressEvent | null): string { if (event.status === "processing") return event.step ? `בעיבוד · ${event.step}` : "בעיבוד"; if (event.status === "completed") return "הושלם"; + if (event.status === "unknown") return "הושלם"; if (event.status === "failed") return event.error ?? "נכשל"; return event.status; } @@ -52,15 +53,16 @@ function progressPercent(event: ProgressEvent | null): number { if (event.status === "queued") return 10; if (event.status === "processing") return 55; if (event.status === "completed") return 100; + if (event.status === "unknown") return 100; if (event.status === "failed") return 100; return 25; } -function UploadRowView({ row }: { row: UploadRow }) { - const progress = useProgress(row.taskId); +function UploadRowView({ row, caseNumber }: { row: UploadRow; caseNumber: string }) { + const progress = useProgress(row.taskId, caseNumber); const pct = row.error ? 100 : progressPercent(progress); const failed = row.error || progress?.status === "failed"; - const done = progress?.status === "completed"; + const done = progress?.status === "completed" || progress?.status === "unknown"; return (
  • @@ -197,7 +199,7 @@ export function UploadSheet({ caseNumber }: { caseNumber: string }) { {rows.length > 0 && ( )} diff --git a/web-ui/src/lib/api/documents.ts b/web-ui/src/lib/api/documents.ts index 5b252e0..e25346d 100644 --- a/web-ui/src/lib/api/documents.ts +++ b/web-ui/src/lib/api/documents.ts @@ -22,7 +22,10 @@ export type UploadTaggedResponse = { }; export type ProgressEvent = { - status: "queued" | "processing" | "completed" | "failed" | string; + /* "unknown" is sent by the backend when the task TTL expired or the + * caller subscribed before any state was published. Treat it as a + * terminal hint to refetch case state from the source of truth. */ + status: "queued" | "processing" | "completed" | "failed" | "unknown" | string; filename?: string; step?: string; error?: string; @@ -191,28 +194,54 @@ export function useExtractAppraiserFacts(caseNumber: string) { } -export function useProgress(taskId: string | null) { +export function useProgress(taskId: string | null, caseNumber?: string) { const [event, setEvent] = useState(null); + const qc = useQueryClient(); useEffect(() => { if (!taskId) return; setEvent(null); + + /* Self-heal fallback: if no SSE message arrives within 10s — usually + * because the proxy chain held the chunks or the EventSource is + * silently retrying — synthesize a refresh by invalidating the case + * detail. The actual document state is in the case detail anyway, so + * the UI heals from the source of truth without depending on SSE. */ + let firstMessageReceived = false; + const fallback = window.setTimeout(() => { + if (firstMessageReceived) return; + if (caseNumber) qc.invalidateQueries({ queryKey: casesKeys.detail(caseNumber) }); + setEvent({ status: "completed" }); + }, 10_000); + const close = openSSE( `/api/progress/${encodeURIComponent(taskId)}`, { onMessage: (data) => { + firstMessageReceived = true; setEvent(data); - if (data.status === "completed" || data.status === "failed") { - /* Close from within the callback — the backend ends the stream - * naturally, but closing eagerly avoids the auto-reconnect loop - * EventSource does after EOF. */ + if ( + data.status === "completed" || + data.status === "failed" || + data.status === "unknown" + ) { + /* Close from within the callback so EventSource does not + * auto-reconnect after the server's EOF. For "unknown" we + * also nudge a case-detail refetch — the task state is gone + * but the document row will tell us the truth. */ + if (data.status === "unknown" && caseNumber) { + qc.invalidateQueries({ queryKey: casesKeys.detail(caseNumber) }); + } close(); } }, }, ); - return () => close(); - }, [taskId]); + return () => { + window.clearTimeout(fallback); + close(); + }; + }, [taskId, caseNumber, qc]); return event; } diff --git a/web/app.py b/web/app.py index a072b27..dcb3561 100644 --- a/web/app.py +++ b/web/app.py @@ -35,6 +35,7 @@ from legal_mcp.tools import cases as cases_tools, search as search_tools, workfl _web_dir = Path(__file__).resolve().parent sys.path.insert(0, str(_web_dir.parent)) from web.gitea_client import commit_and_push, create_repo, setup_remote_and_push +from web.progress_store import ProgressStore from web.paperclip_client import ( archive_project as pc_archive_project, create_project as pc_create_project, @@ -56,8 +57,12 @@ UPLOAD_DIR = config.DATA_DIR / "uploads" ALLOWED_EXTENSIONS = {".pdf", ".docx", ".rtf", ".txt", ".md"} MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB -# In-memory progress tracking -_progress: dict[str, dict] = {} +# Progress tracking — backed by Redis with TTL. +# Each entry is a JSON-serialized dict keyed by task_id and auto-expires +# after PROGRESS_TTL_SECONDS so terminal states remain observable to late +# SSE subscribers (a 404 on reconnect was the root cause of stuck UI rows). +PROGRESS_TTL_SECONDS = 300 +_progress = ProgressStore(config.REDIS_URL, ttl_seconds=PROGRESS_TTL_SECONDS) @asynccontextmanager @@ -66,6 +71,7 @@ async def lifespan(app: FastAPI): await db.init_schema() yield await db.close_pool() + await _progress.close() app = FastAPI(title="העלאת מסמכים משפטיים", lifespan=lifespan) @@ -165,7 +171,7 @@ async def classify_file(req: ClassifyRequest): raise HTTPException(400, "case_number required for case documents") task_id = str(uuid4()) - _progress[task_id] = {"status": "queued", "filename": req.filename} + await _progress.set(task_id, {"status": "queued", "filename": req.filename}) asyncio.create_task(_process_file(task_id, source, req)) @@ -229,7 +235,7 @@ async def training_upload(req: TrainingUploadRequest): ) task_id = str(uuid4()) - _progress[task_id] = {"status": "queued", "filename": req.filename} + await _progress.set(task_id, {"status": "queued", "filename": req.filename}) asyncio.create_task(_process_proofread_training(task_id, source, req)) return {"task_id": task_id} @@ -244,11 +250,11 @@ async def _process_proofread_training( title = req.title or source.stem.split("_", 1)[-1] # 1. Proofread (strip Nevo additions) - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "proofreading"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "proofreading"}) clean_text, stats = await proofreader.proofread(source) # 2. Save proofread .md to training dir (alongside original) - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "saving"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "saving"}) training_dir = config.TRAINING_DIR proofread_dir = training_dir / "proofread" training_dir.mkdir(parents=True, exist_ok=True) @@ -270,7 +276,7 @@ async def _process_proofread_training( d_date = date_type.fromisoformat(req.decision_date) # 4. Add to style corpus - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "corpus"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "corpus"}) corpus_id = await db.add_to_style_corpus( document_id=None, decision_number=req.decision_number, @@ -280,7 +286,7 @@ async def _process_proofread_training( ) # 5. Chunk + embed - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "chunking"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "chunking"}) chunks = chunker.chunk_document(clean_text) chunk_count = 0 if chunks: @@ -296,9 +302,9 @@ async def _process_proofread_training( doc_id, extracted_text=clean_text, extraction_status="completed" ) - _progress[task_id] = { + await _progress.set(task_id, { "status": "processing", "filename": req.filename, "step": "embedding", - } + }) texts = [c.content for c in chunks] embs = await embeddings.embed_texts(texts, input_type="document") chunk_dicts = [ @@ -317,7 +323,7 @@ async def _process_proofread_training( # 6. Cleanup upload source.unlink(missing_ok=True) - _progress[task_id] = { + await _progress.set(task_id, { "status": "completed", "filename": req.filename, "result": { @@ -327,10 +333,10 @@ async def _process_proofread_training( "chunks": chunk_count, "proofread_stats": stats, }, - } + }) except Exception as e: logger.exception("Training upload failed for %s", req.filename) - _progress[task_id] = {"status": "failed", "error": str(e), "filename": req.filename} + await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": req.filename}) @app.get("/api/training/patterns") @@ -942,16 +948,24 @@ async def training_corpus_list(): ] -def _get_active_tasks() -> list[dict]: - """Extract active (non-terminal) tasks from _progress dict.""" +# Headers that defeat proxy buffering for SSE streams. `X-Accel-Buffering: no` +# is honored by nginx/Traefik (and matches what Coolify deploys); without it, +# small text/event-stream chunks are held in HTTP/2 frames until the stream +# closes — which is exactly the bug the previous progress endpoint exhibited. +_SSE_HEADERS = { + "Cache-Control": "no-cache, no-transform", + "X-Accel-Buffering": "no", + "Connection": "keep-alive", +} + + +async def _get_active_tasks() -> list[dict]: + """Extract active (non-terminal) tasks from the progress store.""" items = [] - for task_id, data in list(_progress.items()): - status = data.get("status", "unknown") - if status in ("completed", "failed"): - continue + for task_id, data in await _progress.active(): items.append({ "task_id": task_id, - "status": status, + "status": data.get("status", "unknown"), "step": data.get("step", ""), "filename": data.get("filename", ""), "error": data.get("error", ""), @@ -962,7 +976,7 @@ def _get_active_tasks() -> list[dict]: @app.get("/api/system/tasks") async def system_tasks(): """List all active background tasks (one-shot).""" - items = _get_active_tasks() + items = await _get_active_tasks() return {"active": items, "count": len(items)} @@ -971,49 +985,66 @@ async def system_tasks_stream(): """SSE stream — pushes active-task snapshots when anything changes. Replaces client-side polling. Clients connect once and receive - events whenever the task set changes. Also sends a heartbeat every - 15s to keep proxies from timing out. + events whenever the task set changes. A short keepalive runs every + tick so proxies flush HTTP/2 frames promptly. """ async def event_gen(): last_snapshot: str | None = None last_heartbeat = time.time() - # Emit initial state immediately while True: - snapshot = json.dumps( - {"active": _get_active_tasks(), "count": len(_get_active_tasks())}, - ensure_ascii=False, - ) + active = await _get_active_tasks() + snapshot = json.dumps({"active": active, "count": len(active)}, ensure_ascii=False) now = time.time() if snapshot != last_snapshot: yield f"event: tasks\ndata: {snapshot}\n\n" last_snapshot = snapshot last_heartbeat = now - elif now - last_heartbeat > 15: + elif now - last_heartbeat > 5: yield ": heartbeat\n\n" last_heartbeat = now await asyncio.sleep(1) - return StreamingResponse(event_gen(), media_type="text/event-stream") + return StreamingResponse(event_gen(), media_type="text/event-stream", headers=_SSE_HEADERS) @app.get("/api/progress/{task_id}") async def progress_stream(task_id: str): - """SSE stream of processing progress.""" - if task_id not in _progress: - raise HTTPException(404, "Task not found") + """SSE stream of processing progress for a single upload task. + Behavior: + • Late subscribers (task already cleaned up) get a terminal + ``{"status":"unknown"}`` payload and a clean stream close — never + a 404. EventSource treats 404 as a transient error and reconnects + forever, leaving the UI stuck on the placeholder; we avoid that. + • A heartbeat is emitted every iteration so HTTP/2 framing in the + proxy chain flushes immediately. The previous 30-second silent + tail after completion (and the proxy buffering it caused) was + the original cause of stuck-spinner uploads. + • Cleanup is delegated to Redis TTL — the store auto-expires + entries after PROGRESS_TTL_SECONDS, so we don't hand-roll any + post-completion sleep here. + """ async def event_stream(): + last_payload: str | None = None while True: - data = _progress.get(task_id, {}) - yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + data = await _progress.get(task_id) + if data is None: + # Either the task never existed or its TTL expired. Emit + # a single terminal payload so the client closes cleanly + # and falls back to refetching the case detail. + yield f"data: {json.dumps({'status': 'unknown'})}\n\n" + return + payload = json.dumps(data, ensure_ascii=False) + if payload != last_payload: + yield f"data: {payload}\n\n" + last_payload = payload + else: + yield ": keepalive\n\n" if data.get("status") in ("completed", "failed"): - break + return await asyncio.sleep(1) - # Clean up after a delay - await asyncio.sleep(30) - _progress.pop(task_id, None) - return StreamingResponse(event_stream(), media_type="text/event-stream") + return StreamingResponse(event_stream(), media_type="text/event-stream", headers=_SSE_HEADERS) @app.get("/health") @@ -1385,8 +1416,7 @@ async def system_diagnostics(): active_tasks = [ {"task_id": tid, "filename": d.get("filename", ""), "status": d.get("status", ""), "step": d.get("step", "")} - for tid, d in _progress.items() - if d.get("status") not in ("completed", "failed") + for tid, d in await _progress.active() ] return { @@ -2988,7 +3018,7 @@ async def api_upload_tagged_document( # Process in background task_id = str(uuid4()) - _progress[task_id] = {"status": "queued", "filename": new_filename} + await _progress.set(task_id, {"status": "queued", "filename": new_filename}) asyncio.create_task(_process_tagged_document(task_id, dest, case_number, case_id, UUID(doc["id"]), doc_type, new_filename)) return { @@ -3002,7 +3032,7 @@ async def api_upload_tagged_document( async def _process_tagged_document(task_id: str, dest: Path, case_number: str, case_id: UUID, doc_id: UUID, doc_type: str, display_name: str): """Process an uploaded tagged document in the background.""" try: - _progress[task_id] = {"status": "processing", "filename": display_name, "step": "extracting"} + await _progress.set(task_id, {"status": "processing", "filename": display_name, "step": "extracting"}) result = await processor.process_document(doc_id, case_id) try: @@ -3013,16 +3043,16 @@ async def _process_tagged_document(task_id: str, dest: Path, case_number: str, c except Exception: logger.warning("Git commit/push failed for %s (non-critical)", display_name) - _progress[task_id] = { + await _progress.set(task_id, { "status": "completed", "filename": display_name, "result": result, "case_number": case_number, "doc_type": doc_type, - } + }) except Exception as e: logger.exception("Processing failed for %s", display_name) - _progress[task_id] = {"status": "failed", "error": str(e), "filename": display_name} + await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": display_name}) @app.post("/api/cases/{case_number}/documents/{doc_id}/reprocess") @@ -3304,23 +3334,23 @@ async def _process_file(task_id: str, source: Path, req: ClassifyRequest): await _process_training_document(task_id, source, req) except Exception as e: logger.exception("Processing failed for %s", req.filename) - _progress[task_id] = {"status": "failed", "error": str(e), "filename": req.filename} + await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": req.filename}) async def _process_case_document(task_id: str, source: Path, req: ClassifyRequest): """Process a case document (mirrors documents.document_upload logic).""" - _progress[task_id] = {"status": "validating", "filename": req.filename} + await _progress.set(task_id, {"status": "validating", "filename": req.filename}) case = await db.get_case_by_number(req.case_number) if not case: - _progress[task_id] = {"status": "failed", "error": f"Case {req.case_number} not found"} + await _progress.set(task_id, {"status": "failed", "error": f"Case {req.case_number} not found"}) return case_id = UUID(case["id"]) title = req.title or source.stem.split("_", 1)[-1] # Remove timestamp prefix # Copy to case directory - _progress[task_id] = {"status": "copying", "filename": req.filename} + await _progress.set(task_id, {"status": "copying", "filename": req.filename}) case_dir = config.find_case_dir(req.case_number) / "documents" / "originals" case_dir.mkdir(parents=True, exist_ok=True) # Use original name without timestamp prefix @@ -3329,7 +3359,7 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques shutil.copy2(str(source), str(dest)) # Create document record - _progress[task_id] = {"status": "registering", "filename": req.filename} + await _progress.set(task_id, {"status": "registering", "filename": req.filename}) doc = await db.create_document( case_id=case_id, doc_type=req.doc_type, @@ -3338,7 +3368,7 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques ) # Process (extract → chunk → embed → store) - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "extracting"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "extracting"}) result = await processor.process_document(UUID(doc["id"]), case_id) # Git commit (best-effort) @@ -3356,13 +3386,13 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques # Remove from uploads source.unlink(missing_ok=True) - _progress[task_id] = { + await _progress.set(task_id, { "status": "completed", "filename": req.filename, "result": result, "case_number": req.case_number, "doc_type": req.doc_type, - } + }) async def _process_training_document(task_id: str, source: Path, req: ClassifyRequest): @@ -3372,14 +3402,14 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe title = req.title or source.stem.split("_", 1)[-1] # Copy to training directory - _progress[task_id] = {"status": "copying", "filename": req.filename} + await _progress.set(task_id, {"status": "copying", "filename": req.filename}) config.TRAINING_DIR.mkdir(parents=True, exist_ok=True) original_name = re.sub(r"^\d+_", "", source.name) dest = config.TRAINING_DIR / original_name shutil.copy2(str(source), str(dest)) # Extract text - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "extracting"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "extracting"}) text, page_count = await extractor.extract_text(str(dest)) # Parse date @@ -3388,7 +3418,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe d_date = date_type.fromisoformat(req.decision_date) # Add to style corpus - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "corpus"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "corpus"}) corpus_id = await db.add_to_style_corpus( document_id=None, decision_number=req.decision_number, @@ -3398,7 +3428,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe ) # Chunk and embed - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "chunking"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "chunking"}) chunks = chunker.chunk_document(text) chunk_count = 0 @@ -3413,7 +3443,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe doc_id = UUID(doc["id"]) await db.update_document(doc_id, extracted_text=text, extraction_status="completed") - _progress[task_id] = {"status": "processing", "filename": req.filename, "step": "embedding"} + await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "embedding"}) texts = [c.content for c in chunks] embs = await embeddings.embed_texts(texts, input_type="document") @@ -3433,7 +3463,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe # Remove from uploads source.unlink(missing_ok=True) - _progress[task_id] = { + await _progress.set(task_id, { "status": "completed", "filename": req.filename, "result": { @@ -3443,4 +3473,4 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe "text_length": len(text), "chunks": chunk_count, }, - } + }) diff --git a/web/progress_store.py b/web/progress_store.py new file mode 100644 index 0000000..96e2bf5 --- /dev/null +++ b/web/progress_store.py @@ -0,0 +1,137 @@ +"""Redis-backed progress store for upload/processing tasks. + +Replaces the previous in-memory `dict[str, dict]` with a persistent store so +that progress survives container restarts and is observable across replicas. + +Why Redis: + - Each entry has a natural TTL (we don't want to hold task state forever). + - SSE consumers in different processes can read the same state. + - Atomic GETDEL avoids races when cleaning up after a stream closes. + +Public API mirrors the old dict semantics where possible: + await store.set(task_id, {...}) # write + extend TTL + await store.get(task_id) # → dict | None + await store.pop(task_id) # → dict | None (atomic) + await store.active() # → list[(task_id, dict)] for non-terminal entries + await store.exists(task_id) # → bool + +All entries auto-expire after ``ttl_seconds`` (default 5 min). Terminal +states (``completed``/``failed``) are kept for the same TTL so a late +subscriber can still observe the result instead of getting a 404. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +try: + import redis.asyncio as redis_async +except ImportError: # pragma: no cover — package always present in prod + redis_async = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +_KEY_PREFIX = "legal-ai:progress:" + + +class ProgressStore: + def __init__(self, redis_url: str, ttl_seconds: int = 300): + if redis_async is None: + raise RuntimeError("redis package not installed") + self._url = redis_url + self._ttl = ttl_seconds + self._client: Any = None + + async def _conn(self) -> Any: + if self._client is None: + self._client = redis_async.from_url(self._url, decode_responses=True) + return self._client + + @staticmethod + def _key(task_id: str) -> str: + return f"{_KEY_PREFIX}{task_id}" + + async def set(self, task_id: str, data: dict) -> None: + """Best-effort write. Logs on Redis failure but never raises — progress + reporting is observability, not the critical path. If Redis is down + the upload still succeeds and the client's SSE fallback recovers.""" + try: + c = await self._conn() + payload = json.dumps(data, ensure_ascii=False) + await c.set(self._key(task_id), payload, ex=self._ttl) + except Exception: + logger.warning("progress.set failed for %s", task_id, exc_info=True) + + async def get(self, task_id: str) -> dict | None: + try: + c = await self._conn() + raw = await c.get(self._key(task_id)) + except Exception: + logger.warning("progress.get failed for %s", task_id, exc_info=True) + return None + if raw is None: + return None + try: + return json.loads(raw) + except json.JSONDecodeError: + logger.warning("Corrupted progress payload for %s", task_id) + return None + + async def exists(self, task_id: str) -> bool: + try: + c = await self._conn() + return bool(await c.exists(self._key(task_id))) + except Exception: + logger.warning("progress.exists failed for %s", task_id, exc_info=True) + return False + + async def pop(self, task_id: str) -> dict | None: + try: + c = await self._conn() + raw = await c.getdel(self._key(task_id)) # atomic, Redis ≥ 6.2 + except Exception: + logger.warning("progress.pop failed for %s", task_id, exc_info=True) + return None + if raw is None: + return None + try: + return json.loads(raw) + except json.JSONDecodeError: + return None + + async def active(self) -> list[tuple[str, dict]]: + """Return non-terminal entries. Terminal states are excluded.""" + try: + c = await self._conn() + keys: list[str] = [] + async for k in c.scan_iter(match=f"{_KEY_PREFIX}*", count=200): + keys.append(k) + if not keys: + return [] + values = await c.mget(*keys) + except Exception: + logger.warning("progress.active failed", exc_info=True) + return [] + out: list[tuple[str, dict]] = [] + for k, v in zip(keys, values): + if v is None: + continue + try: + d = json.loads(v) + except json.JSONDecodeError: + continue + if d.get("status") in ("completed", "failed"): + continue + tid = k[len(_KEY_PREFIX):] + out.append((tid, d)) + return out + + async def close(self) -> None: + if self._client is not None: + try: + await self._client.aclose() + except Exception: + pass + self._client = None