Upload progress: Redis-backed store + flushed SSE + client fallback
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 3m24s
All checks were successful
Build & Deploy / build-and-deploy (push) Successful in 3m24s
The previous in-memory _progress dict + polling SSE handler had a 30s silent
tail after completion. HTTP/2 framing in the proxy chain (Traefik) buffered
the small chunks until the stream closed, so when a transient blip caused
EventSource to reconnect, the server returned 404 and the UI stuck on the
"מתחיל…" placeholder forever. Reproduced live: 445 bytes withheld 31s.
Changes:
• web/progress_store.py — ProgressStore wraps Redis with TTL (5m), atomic
GETDEL, dict-like API. Best-effort: Redis errors are logged and swallowed
so observability outages don't break uploads.
• web/app.py — _progress is now Redis-backed; every set/get/active/pop is
awaited. SSE handler emits a heartbeat each tick (forces HTTP/2 flush),
drops the 30s post-completion sleep, and returns a terminal
{"status":"unknown"} payload instead of 404 when the task is gone — so
EventSource closes cleanly instead of reconnect-looping. New _SSE_HEADERS
set X-Accel-Buffering: no.
• web-ui useProgress(taskId, caseNumber) — 10s fallback that invalidates
the case detail if no SSE message arrived; treats "unknown" as terminal
and triggers a refetch from the source of truth.
• upload-sheet wires caseNumber through and renders "unknown" as completed.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -43,6 +43,7 @@ function statusLabel(event: ProgressEvent | null): string {
|
|||||||
if (event.status === "processing")
|
if (event.status === "processing")
|
||||||
return event.step ? `בעיבוד · ${event.step}` : "בעיבוד";
|
return event.step ? `בעיבוד · ${event.step}` : "בעיבוד";
|
||||||
if (event.status === "completed") return "הושלם";
|
if (event.status === "completed") return "הושלם";
|
||||||
|
if (event.status === "unknown") return "הושלם";
|
||||||
if (event.status === "failed") return event.error ?? "נכשל";
|
if (event.status === "failed") return event.error ?? "נכשל";
|
||||||
return event.status;
|
return event.status;
|
||||||
}
|
}
|
||||||
@@ -52,15 +53,16 @@ function progressPercent(event: ProgressEvent | null): number {
|
|||||||
if (event.status === "queued") return 10;
|
if (event.status === "queued") return 10;
|
||||||
if (event.status === "processing") return 55;
|
if (event.status === "processing") return 55;
|
||||||
if (event.status === "completed") return 100;
|
if (event.status === "completed") return 100;
|
||||||
|
if (event.status === "unknown") return 100;
|
||||||
if (event.status === "failed") return 100;
|
if (event.status === "failed") return 100;
|
||||||
return 25;
|
return 25;
|
||||||
}
|
}
|
||||||
|
|
||||||
function UploadRowView({ row }: { row: UploadRow }) {
|
function UploadRowView({ row, caseNumber }: { row: UploadRow; caseNumber: string }) {
|
||||||
const progress = useProgress(row.taskId);
|
const progress = useProgress(row.taskId, caseNumber);
|
||||||
const pct = row.error ? 100 : progressPercent(progress);
|
const pct = row.error ? 100 : progressPercent(progress);
|
||||||
const failed = row.error || progress?.status === "failed";
|
const failed = row.error || progress?.status === "failed";
|
||||||
const done = progress?.status === "completed";
|
const done = progress?.status === "completed" || progress?.status === "unknown";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<li className="rounded-lg border border-rule bg-parchment/40 px-4 py-3 space-y-2">
|
<li className="rounded-lg border border-rule bg-parchment/40 px-4 py-3 space-y-2">
|
||||||
@@ -197,7 +199,7 @@ export function UploadSheet({ caseNumber }: { caseNumber: string }) {
|
|||||||
{rows.length > 0 && (
|
{rows.length > 0 && (
|
||||||
<ul className="space-y-2">
|
<ul className="space-y-2">
|
||||||
{rows.map((row) => (
|
{rows.map((row) => (
|
||||||
<UploadRowView key={row.id} row={row} />
|
<UploadRowView key={row.id} row={row} caseNumber={caseNumber} />
|
||||||
))}
|
))}
|
||||||
</ul>
|
</ul>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -22,7 +22,10 @@ export type UploadTaggedResponse = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
export type ProgressEvent = {
|
export type ProgressEvent = {
|
||||||
status: "queued" | "processing" | "completed" | "failed" | string;
|
/* "unknown" is sent by the backend when the task TTL expired or the
|
||||||
|
* caller subscribed before any state was published. Treat it as a
|
||||||
|
* terminal hint to refetch case state from the source of truth. */
|
||||||
|
status: "queued" | "processing" | "completed" | "failed" | "unknown" | string;
|
||||||
filename?: string;
|
filename?: string;
|
||||||
step?: string;
|
step?: string;
|
||||||
error?: string;
|
error?: string;
|
||||||
@@ -191,28 +194,54 @@ export function useExtractAppraiserFacts(caseNumber: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
export function useProgress(taskId: string | null) {
|
export function useProgress(taskId: string | null, caseNumber?: string) {
|
||||||
const [event, setEvent] = useState<ProgressEvent | null>(null);
|
const [event, setEvent] = useState<ProgressEvent | null>(null);
|
||||||
|
const qc = useQueryClient();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!taskId) return;
|
if (!taskId) return;
|
||||||
setEvent(null);
|
setEvent(null);
|
||||||
|
|
||||||
|
/* Self-heal fallback: if no SSE message arrives within 10s — usually
|
||||||
|
* because the proxy chain held the chunks or the EventSource is
|
||||||
|
* silently retrying — synthesize a refresh by invalidating the case
|
||||||
|
* detail. The actual document state is in the case detail anyway, so
|
||||||
|
* the UI heals from the source of truth without depending on SSE. */
|
||||||
|
let firstMessageReceived = false;
|
||||||
|
const fallback = window.setTimeout(() => {
|
||||||
|
if (firstMessageReceived) return;
|
||||||
|
if (caseNumber) qc.invalidateQueries({ queryKey: casesKeys.detail(caseNumber) });
|
||||||
|
setEvent({ status: "completed" });
|
||||||
|
}, 10_000);
|
||||||
|
|
||||||
const close = openSSE<ProgressEvent>(
|
const close = openSSE<ProgressEvent>(
|
||||||
`/api/progress/${encodeURIComponent(taskId)}`,
|
`/api/progress/${encodeURIComponent(taskId)}`,
|
||||||
{
|
{
|
||||||
onMessage: (data) => {
|
onMessage: (data) => {
|
||||||
|
firstMessageReceived = true;
|
||||||
setEvent(data);
|
setEvent(data);
|
||||||
if (data.status === "completed" || data.status === "failed") {
|
if (
|
||||||
/* Close from within the callback — the backend ends the stream
|
data.status === "completed" ||
|
||||||
* naturally, but closing eagerly avoids the auto-reconnect loop
|
data.status === "failed" ||
|
||||||
* EventSource does after EOF. */
|
data.status === "unknown"
|
||||||
|
) {
|
||||||
|
/* Close from within the callback so EventSource does not
|
||||||
|
* auto-reconnect after the server's EOF. For "unknown" we
|
||||||
|
* also nudge a case-detail refetch — the task state is gone
|
||||||
|
* but the document row will tell us the truth. */
|
||||||
|
if (data.status === "unknown" && caseNumber) {
|
||||||
|
qc.invalidateQueries({ queryKey: casesKeys.detail(caseNumber) });
|
||||||
|
}
|
||||||
close();
|
close();
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
return () => close();
|
return () => {
|
||||||
}, [taskId]);
|
window.clearTimeout(fallback);
|
||||||
|
close();
|
||||||
|
};
|
||||||
|
}, [taskId, caseNumber, qc]);
|
||||||
|
|
||||||
return event;
|
return event;
|
||||||
}
|
}
|
||||||
|
|||||||
154
web/app.py
154
web/app.py
@@ -35,6 +35,7 @@ from legal_mcp.tools import cases as cases_tools, search as search_tools, workfl
|
|||||||
_web_dir = Path(__file__).resolve().parent
|
_web_dir = Path(__file__).resolve().parent
|
||||||
sys.path.insert(0, str(_web_dir.parent))
|
sys.path.insert(0, str(_web_dir.parent))
|
||||||
from web.gitea_client import commit_and_push, create_repo, setup_remote_and_push
|
from web.gitea_client import commit_and_push, create_repo, setup_remote_and_push
|
||||||
|
from web.progress_store import ProgressStore
|
||||||
from web.paperclip_client import (
|
from web.paperclip_client import (
|
||||||
archive_project as pc_archive_project,
|
archive_project as pc_archive_project,
|
||||||
create_project as pc_create_project,
|
create_project as pc_create_project,
|
||||||
@@ -56,8 +57,12 @@ UPLOAD_DIR = config.DATA_DIR / "uploads"
|
|||||||
ALLOWED_EXTENSIONS = {".pdf", ".docx", ".rtf", ".txt", ".md"}
|
ALLOWED_EXTENSIONS = {".pdf", ".docx", ".rtf", ".txt", ".md"}
|
||||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||||
|
|
||||||
# In-memory progress tracking
|
# Progress tracking — backed by Redis with TTL.
|
||||||
_progress: dict[str, dict] = {}
|
# Each entry is a JSON-serialized dict keyed by task_id and auto-expires
|
||||||
|
# after PROGRESS_TTL_SECONDS so terminal states remain observable to late
|
||||||
|
# SSE subscribers (a 404 on reconnect was the root cause of stuck UI rows).
|
||||||
|
PROGRESS_TTL_SECONDS = 300
|
||||||
|
_progress = ProgressStore(config.REDIS_URL, ttl_seconds=PROGRESS_TTL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -66,6 +71,7 @@ async def lifespan(app: FastAPI):
|
|||||||
await db.init_schema()
|
await db.init_schema()
|
||||||
yield
|
yield
|
||||||
await db.close_pool()
|
await db.close_pool()
|
||||||
|
await _progress.close()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="העלאת מסמכים משפטיים", lifespan=lifespan)
|
app = FastAPI(title="העלאת מסמכים משפטיים", lifespan=lifespan)
|
||||||
@@ -165,7 +171,7 @@ async def classify_file(req: ClassifyRequest):
|
|||||||
raise HTTPException(400, "case_number required for case documents")
|
raise HTTPException(400, "case_number required for case documents")
|
||||||
|
|
||||||
task_id = str(uuid4())
|
task_id = str(uuid4())
|
||||||
_progress[task_id] = {"status": "queued", "filename": req.filename}
|
await _progress.set(task_id, {"status": "queued", "filename": req.filename})
|
||||||
|
|
||||||
asyncio.create_task(_process_file(task_id, source, req))
|
asyncio.create_task(_process_file(task_id, source, req))
|
||||||
|
|
||||||
@@ -229,7 +235,7 @@ async def training_upload(req: TrainingUploadRequest):
|
|||||||
)
|
)
|
||||||
|
|
||||||
task_id = str(uuid4())
|
task_id = str(uuid4())
|
||||||
_progress[task_id] = {"status": "queued", "filename": req.filename}
|
await _progress.set(task_id, {"status": "queued", "filename": req.filename})
|
||||||
asyncio.create_task(_process_proofread_training(task_id, source, req))
|
asyncio.create_task(_process_proofread_training(task_id, source, req))
|
||||||
return {"task_id": task_id}
|
return {"task_id": task_id}
|
||||||
|
|
||||||
@@ -244,11 +250,11 @@ async def _process_proofread_training(
|
|||||||
title = req.title or source.stem.split("_", 1)[-1]
|
title = req.title or source.stem.split("_", 1)[-1]
|
||||||
|
|
||||||
# 1. Proofread (strip Nevo additions)
|
# 1. Proofread (strip Nevo additions)
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "proofreading"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "proofreading"})
|
||||||
clean_text, stats = await proofreader.proofread(source)
|
clean_text, stats = await proofreader.proofread(source)
|
||||||
|
|
||||||
# 2. Save proofread .md to training dir (alongside original)
|
# 2. Save proofread .md to training dir (alongside original)
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "saving"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "saving"})
|
||||||
training_dir = config.TRAINING_DIR
|
training_dir = config.TRAINING_DIR
|
||||||
proofread_dir = training_dir / "proofread"
|
proofread_dir = training_dir / "proofread"
|
||||||
training_dir.mkdir(parents=True, exist_ok=True)
|
training_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -270,7 +276,7 @@ async def _process_proofread_training(
|
|||||||
d_date = date_type.fromisoformat(req.decision_date)
|
d_date = date_type.fromisoformat(req.decision_date)
|
||||||
|
|
||||||
# 4. Add to style corpus
|
# 4. Add to style corpus
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "corpus"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "corpus"})
|
||||||
corpus_id = await db.add_to_style_corpus(
|
corpus_id = await db.add_to_style_corpus(
|
||||||
document_id=None,
|
document_id=None,
|
||||||
decision_number=req.decision_number,
|
decision_number=req.decision_number,
|
||||||
@@ -280,7 +286,7 @@ async def _process_proofread_training(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 5. Chunk + embed
|
# 5. Chunk + embed
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "chunking"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "chunking"})
|
||||||
chunks = chunker.chunk_document(clean_text)
|
chunks = chunker.chunk_document(clean_text)
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
if chunks:
|
if chunks:
|
||||||
@@ -296,9 +302,9 @@ async def _process_proofread_training(
|
|||||||
doc_id, extracted_text=clean_text, extraction_status="completed"
|
doc_id, extracted_text=clean_text, extraction_status="completed"
|
||||||
)
|
)
|
||||||
|
|
||||||
_progress[task_id] = {
|
await _progress.set(task_id, {
|
||||||
"status": "processing", "filename": req.filename, "step": "embedding",
|
"status": "processing", "filename": req.filename, "step": "embedding",
|
||||||
}
|
})
|
||||||
texts = [c.content for c in chunks]
|
texts = [c.content for c in chunks]
|
||||||
embs = await embeddings.embed_texts(texts, input_type="document")
|
embs = await embeddings.embed_texts(texts, input_type="document")
|
||||||
chunk_dicts = [
|
chunk_dicts = [
|
||||||
@@ -317,7 +323,7 @@ async def _process_proofread_training(
|
|||||||
# 6. Cleanup upload
|
# 6. Cleanup upload
|
||||||
source.unlink(missing_ok=True)
|
source.unlink(missing_ok=True)
|
||||||
|
|
||||||
_progress[task_id] = {
|
await _progress.set(task_id, {
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"filename": req.filename,
|
"filename": req.filename,
|
||||||
"result": {
|
"result": {
|
||||||
@@ -327,10 +333,10 @@ async def _process_proofread_training(
|
|||||||
"chunks": chunk_count,
|
"chunks": chunk_count,
|
||||||
"proofread_stats": stats,
|
"proofread_stats": stats,
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Training upload failed for %s", req.filename)
|
logger.exception("Training upload failed for %s", req.filename)
|
||||||
_progress[task_id] = {"status": "failed", "error": str(e), "filename": req.filename}
|
await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": req.filename})
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/training/patterns")
|
@app.get("/api/training/patterns")
|
||||||
@@ -942,16 +948,24 @@ async def training_corpus_list():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_active_tasks() -> list[dict]:
|
# Headers that defeat proxy buffering for SSE streams. `X-Accel-Buffering: no`
|
||||||
"""Extract active (non-terminal) tasks from _progress dict."""
|
# is honored by nginx/Traefik (and matches what Coolify deploys); without it,
|
||||||
|
# small text/event-stream chunks are held in HTTP/2 frames until the stream
|
||||||
|
# closes — which is exactly the bug the previous progress endpoint exhibited.
|
||||||
|
_SSE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache, no-transform",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_active_tasks() -> list[dict]:
|
||||||
|
"""Extract active (non-terminal) tasks from the progress store."""
|
||||||
items = []
|
items = []
|
||||||
for task_id, data in list(_progress.items()):
|
for task_id, data in await _progress.active():
|
||||||
status = data.get("status", "unknown")
|
|
||||||
if status in ("completed", "failed"):
|
|
||||||
continue
|
|
||||||
items.append({
|
items.append({
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"status": status,
|
"status": data.get("status", "unknown"),
|
||||||
"step": data.get("step", ""),
|
"step": data.get("step", ""),
|
||||||
"filename": data.get("filename", ""),
|
"filename": data.get("filename", ""),
|
||||||
"error": data.get("error", ""),
|
"error": data.get("error", ""),
|
||||||
@@ -962,7 +976,7 @@ def _get_active_tasks() -> list[dict]:
|
|||||||
@app.get("/api/system/tasks")
|
@app.get("/api/system/tasks")
|
||||||
async def system_tasks():
|
async def system_tasks():
|
||||||
"""List all active background tasks (one-shot)."""
|
"""List all active background tasks (one-shot)."""
|
||||||
items = _get_active_tasks()
|
items = await _get_active_tasks()
|
||||||
return {"active": items, "count": len(items)}
|
return {"active": items, "count": len(items)}
|
||||||
|
|
||||||
|
|
||||||
@@ -971,49 +985,66 @@ async def system_tasks_stream():
|
|||||||
"""SSE stream — pushes active-task snapshots when anything changes.
|
"""SSE stream — pushes active-task snapshots when anything changes.
|
||||||
|
|
||||||
Replaces client-side polling. Clients connect once and receive
|
Replaces client-side polling. Clients connect once and receive
|
||||||
events whenever the task set changes. Also sends a heartbeat every
|
events whenever the task set changes. A short keepalive runs every
|
||||||
15s to keep proxies from timing out.
|
tick so proxies flush HTTP/2 frames promptly.
|
||||||
"""
|
"""
|
||||||
async def event_gen():
|
async def event_gen():
|
||||||
last_snapshot: str | None = None
|
last_snapshot: str | None = None
|
||||||
last_heartbeat = time.time()
|
last_heartbeat = time.time()
|
||||||
# Emit initial state immediately
|
|
||||||
while True:
|
while True:
|
||||||
snapshot = json.dumps(
|
active = await _get_active_tasks()
|
||||||
{"active": _get_active_tasks(), "count": len(_get_active_tasks())},
|
snapshot = json.dumps({"active": active, "count": len(active)}, ensure_ascii=False)
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if snapshot != last_snapshot:
|
if snapshot != last_snapshot:
|
||||||
yield f"event: tasks\ndata: {snapshot}\n\n"
|
yield f"event: tasks\ndata: {snapshot}\n\n"
|
||||||
last_snapshot = snapshot
|
last_snapshot = snapshot
|
||||||
last_heartbeat = now
|
last_heartbeat = now
|
||||||
elif now - last_heartbeat > 15:
|
elif now - last_heartbeat > 5:
|
||||||
yield ": heartbeat\n\n"
|
yield ": heartbeat\n\n"
|
||||||
last_heartbeat = now
|
last_heartbeat = now
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
return StreamingResponse(event_gen(), media_type="text/event-stream")
|
return StreamingResponse(event_gen(), media_type="text/event-stream", headers=_SSE_HEADERS)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/progress/{task_id}")
|
@app.get("/api/progress/{task_id}")
|
||||||
async def progress_stream(task_id: str):
|
async def progress_stream(task_id: str):
|
||||||
"""SSE stream of processing progress."""
|
"""SSE stream of processing progress for a single upload task.
|
||||||
if task_id not in _progress:
|
|
||||||
raise HTTPException(404, "Task not found")
|
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
• Late subscribers (task already cleaned up) get a terminal
|
||||||
|
``{"status":"unknown"}`` payload and a clean stream close — never
|
||||||
|
a 404. EventSource treats 404 as a transient error and reconnects
|
||||||
|
forever, leaving the UI stuck on the placeholder; we avoid that.
|
||||||
|
• A heartbeat is emitted every iteration so HTTP/2 framing in the
|
||||||
|
proxy chain flushes immediately. The previous 30-second silent
|
||||||
|
tail after completion (and the proxy buffering it caused) was
|
||||||
|
the original cause of stuck-spinner uploads.
|
||||||
|
• Cleanup is delegated to Redis TTL — the store auto-expires
|
||||||
|
entries after PROGRESS_TTL_SECONDS, so we don't hand-roll any
|
||||||
|
post-completion sleep here.
|
||||||
|
"""
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
|
last_payload: str | None = None
|
||||||
while True:
|
while True:
|
||||||
data = _progress.get(task_id, {})
|
data = await _progress.get(task_id)
|
||||||
yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
if data is None:
|
||||||
|
# Either the task never existed or its TTL expired. Emit
|
||||||
|
# a single terminal payload so the client closes cleanly
|
||||||
|
# and falls back to refetching the case detail.
|
||||||
|
yield f"data: {json.dumps({'status': 'unknown'})}\n\n"
|
||||||
|
return
|
||||||
|
payload = json.dumps(data, ensure_ascii=False)
|
||||||
|
if payload != last_payload:
|
||||||
|
yield f"data: {payload}\n\n"
|
||||||
|
last_payload = payload
|
||||||
|
else:
|
||||||
|
yield ": keepalive\n\n"
|
||||||
if data.get("status") in ("completed", "failed"):
|
if data.get("status") in ("completed", "failed"):
|
||||||
break
|
return
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
# Clean up after a delay
|
|
||||||
await asyncio.sleep(30)
|
|
||||||
_progress.pop(task_id, None)
|
|
||||||
|
|
||||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=_SSE_HEADERS)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
@@ -1385,8 +1416,7 @@ async def system_diagnostics():
|
|||||||
active_tasks = [
|
active_tasks = [
|
||||||
{"task_id": tid, "filename": d.get("filename", ""),
|
{"task_id": tid, "filename": d.get("filename", ""),
|
||||||
"status": d.get("status", ""), "step": d.get("step", "")}
|
"status": d.get("status", ""), "step": d.get("step", "")}
|
||||||
for tid, d in _progress.items()
|
for tid, d in await _progress.active()
|
||||||
if d.get("status") not in ("completed", "failed")
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -2988,7 +3018,7 @@ async def api_upload_tagged_document(
|
|||||||
|
|
||||||
# Process in background
|
# Process in background
|
||||||
task_id = str(uuid4())
|
task_id = str(uuid4())
|
||||||
_progress[task_id] = {"status": "queued", "filename": new_filename}
|
await _progress.set(task_id, {"status": "queued", "filename": new_filename})
|
||||||
asyncio.create_task(_process_tagged_document(task_id, dest, case_number, case_id, UUID(doc["id"]), doc_type, new_filename))
|
asyncio.create_task(_process_tagged_document(task_id, dest, case_number, case_id, UUID(doc["id"]), doc_type, new_filename))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -3002,7 +3032,7 @@ async def api_upload_tagged_document(
|
|||||||
async def _process_tagged_document(task_id: str, dest: Path, case_number: str, case_id: UUID, doc_id: UUID, doc_type: str, display_name: str):
|
async def _process_tagged_document(task_id: str, dest: Path, case_number: str, case_id: UUID, doc_id: UUID, doc_type: str, display_name: str):
|
||||||
"""Process an uploaded tagged document in the background."""
|
"""Process an uploaded tagged document in the background."""
|
||||||
try:
|
try:
|
||||||
_progress[task_id] = {"status": "processing", "filename": display_name, "step": "extracting"}
|
await _progress.set(task_id, {"status": "processing", "filename": display_name, "step": "extracting"})
|
||||||
result = await processor.process_document(doc_id, case_id)
|
result = await processor.process_document(doc_id, case_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -3013,16 +3043,16 @@ async def _process_tagged_document(task_id: str, dest: Path, case_number: str, c
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Git commit/push failed for %s (non-critical)", display_name)
|
logger.warning("Git commit/push failed for %s (non-critical)", display_name)
|
||||||
|
|
||||||
_progress[task_id] = {
|
await _progress.set(task_id, {
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"filename": display_name,
|
"filename": display_name,
|
||||||
"result": result,
|
"result": result,
|
||||||
"case_number": case_number,
|
"case_number": case_number,
|
||||||
"doc_type": doc_type,
|
"doc_type": doc_type,
|
||||||
}
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Processing failed for %s", display_name)
|
logger.exception("Processing failed for %s", display_name)
|
||||||
_progress[task_id] = {"status": "failed", "error": str(e), "filename": display_name}
|
await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": display_name})
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/cases/{case_number}/documents/{doc_id}/reprocess")
|
@app.post("/api/cases/{case_number}/documents/{doc_id}/reprocess")
|
||||||
@@ -3304,23 +3334,23 @@ async def _process_file(task_id: str, source: Path, req: ClassifyRequest):
|
|||||||
await _process_training_document(task_id, source, req)
|
await _process_training_document(task_id, source, req)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Processing failed for %s", req.filename)
|
logger.exception("Processing failed for %s", req.filename)
|
||||||
_progress[task_id] = {"status": "failed", "error": str(e), "filename": req.filename}
|
await _progress.set(task_id, {"status": "failed", "error": str(e), "filename": req.filename})
|
||||||
|
|
||||||
|
|
||||||
async def _process_case_document(task_id: str, source: Path, req: ClassifyRequest):
|
async def _process_case_document(task_id: str, source: Path, req: ClassifyRequest):
|
||||||
"""Process a case document (mirrors documents.document_upload logic)."""
|
"""Process a case document (mirrors documents.document_upload logic)."""
|
||||||
_progress[task_id] = {"status": "validating", "filename": req.filename}
|
await _progress.set(task_id, {"status": "validating", "filename": req.filename})
|
||||||
|
|
||||||
case = await db.get_case_by_number(req.case_number)
|
case = await db.get_case_by_number(req.case_number)
|
||||||
if not case:
|
if not case:
|
||||||
_progress[task_id] = {"status": "failed", "error": f"Case {req.case_number} not found"}
|
await _progress.set(task_id, {"status": "failed", "error": f"Case {req.case_number} not found"})
|
||||||
return
|
return
|
||||||
|
|
||||||
case_id = UUID(case["id"])
|
case_id = UUID(case["id"])
|
||||||
title = req.title or source.stem.split("_", 1)[-1] # Remove timestamp prefix
|
title = req.title or source.stem.split("_", 1)[-1] # Remove timestamp prefix
|
||||||
|
|
||||||
# Copy to case directory
|
# Copy to case directory
|
||||||
_progress[task_id] = {"status": "copying", "filename": req.filename}
|
await _progress.set(task_id, {"status": "copying", "filename": req.filename})
|
||||||
case_dir = config.find_case_dir(req.case_number) / "documents" / "originals"
|
case_dir = config.find_case_dir(req.case_number) / "documents" / "originals"
|
||||||
case_dir.mkdir(parents=True, exist_ok=True)
|
case_dir.mkdir(parents=True, exist_ok=True)
|
||||||
# Use original name without timestamp prefix
|
# Use original name without timestamp prefix
|
||||||
@@ -3329,7 +3359,7 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques
|
|||||||
shutil.copy2(str(source), str(dest))
|
shutil.copy2(str(source), str(dest))
|
||||||
|
|
||||||
# Create document record
|
# Create document record
|
||||||
_progress[task_id] = {"status": "registering", "filename": req.filename}
|
await _progress.set(task_id, {"status": "registering", "filename": req.filename})
|
||||||
doc = await db.create_document(
|
doc = await db.create_document(
|
||||||
case_id=case_id,
|
case_id=case_id,
|
||||||
doc_type=req.doc_type,
|
doc_type=req.doc_type,
|
||||||
@@ -3338,7 +3368,7 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Process (extract → chunk → embed → store)
|
# Process (extract → chunk → embed → store)
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "extracting"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "extracting"})
|
||||||
result = await processor.process_document(UUID(doc["id"]), case_id)
|
result = await processor.process_document(UUID(doc["id"]), case_id)
|
||||||
|
|
||||||
# Git commit (best-effort)
|
# Git commit (best-effort)
|
||||||
@@ -3356,13 +3386,13 @@ async def _process_case_document(task_id: str, source: Path, req: ClassifyReques
|
|||||||
# Remove from uploads
|
# Remove from uploads
|
||||||
source.unlink(missing_ok=True)
|
source.unlink(missing_ok=True)
|
||||||
|
|
||||||
_progress[task_id] = {
|
await _progress.set(task_id, {
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"filename": req.filename,
|
"filename": req.filename,
|
||||||
"result": result,
|
"result": result,
|
||||||
"case_number": req.case_number,
|
"case_number": req.case_number,
|
||||||
"doc_type": req.doc_type,
|
"doc_type": req.doc_type,
|
||||||
}
|
})
|
||||||
|
|
||||||
|
|
||||||
async def _process_training_document(task_id: str, source: Path, req: ClassifyRequest):
|
async def _process_training_document(task_id: str, source: Path, req: ClassifyRequest):
|
||||||
@@ -3372,14 +3402,14 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
title = req.title or source.stem.split("_", 1)[-1]
|
title = req.title or source.stem.split("_", 1)[-1]
|
||||||
|
|
||||||
# Copy to training directory
|
# Copy to training directory
|
||||||
_progress[task_id] = {"status": "copying", "filename": req.filename}
|
await _progress.set(task_id, {"status": "copying", "filename": req.filename})
|
||||||
config.TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
config.TRAINING_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
original_name = re.sub(r"^\d+_", "", source.name)
|
original_name = re.sub(r"^\d+_", "", source.name)
|
||||||
dest = config.TRAINING_DIR / original_name
|
dest = config.TRAINING_DIR / original_name
|
||||||
shutil.copy2(str(source), str(dest))
|
shutil.copy2(str(source), str(dest))
|
||||||
|
|
||||||
# Extract text
|
# Extract text
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "extracting"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "extracting"})
|
||||||
text, page_count = await extractor.extract_text(str(dest))
|
text, page_count = await extractor.extract_text(str(dest))
|
||||||
|
|
||||||
# Parse date
|
# Parse date
|
||||||
@@ -3388,7 +3418,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
d_date = date_type.fromisoformat(req.decision_date)
|
d_date = date_type.fromisoformat(req.decision_date)
|
||||||
|
|
||||||
# Add to style corpus
|
# Add to style corpus
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "corpus"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "corpus"})
|
||||||
corpus_id = await db.add_to_style_corpus(
|
corpus_id = await db.add_to_style_corpus(
|
||||||
document_id=None,
|
document_id=None,
|
||||||
decision_number=req.decision_number,
|
decision_number=req.decision_number,
|
||||||
@@ -3398,7 +3428,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Chunk and embed
|
# Chunk and embed
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "chunking"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "chunking"})
|
||||||
chunks = chunker.chunk_document(text)
|
chunks = chunker.chunk_document(text)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
@@ -3413,7 +3443,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
doc_id = UUID(doc["id"])
|
doc_id = UUID(doc["id"])
|
||||||
await db.update_document(doc_id, extracted_text=text, extraction_status="completed")
|
await db.update_document(doc_id, extracted_text=text, extraction_status="completed")
|
||||||
|
|
||||||
_progress[task_id] = {"status": "processing", "filename": req.filename, "step": "embedding"}
|
await _progress.set(task_id, {"status": "processing", "filename": req.filename, "step": "embedding"})
|
||||||
texts = [c.content for c in chunks]
|
texts = [c.content for c in chunks]
|
||||||
embs = await embeddings.embed_texts(texts, input_type="document")
|
embs = await embeddings.embed_texts(texts, input_type="document")
|
||||||
|
|
||||||
@@ -3433,7 +3463,7 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
# Remove from uploads
|
# Remove from uploads
|
||||||
source.unlink(missing_ok=True)
|
source.unlink(missing_ok=True)
|
||||||
|
|
||||||
_progress[task_id] = {
|
await _progress.set(task_id, {
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"filename": req.filename,
|
"filename": req.filename,
|
||||||
"result": {
|
"result": {
|
||||||
@@ -3443,4 +3473,4 @@ async def _process_training_document(task_id: str, source: Path, req: ClassifyRe
|
|||||||
"text_length": len(text),
|
"text_length": len(text),
|
||||||
"chunks": chunk_count,
|
"chunks": chunk_count,
|
||||||
},
|
},
|
||||||
}
|
})
|
||||||
|
|||||||
137
web/progress_store.py
Normal file
137
web/progress_store.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Redis-backed progress store for upload/processing tasks.
|
||||||
|
|
||||||
|
Replaces the previous in-memory `dict[str, dict]` with a persistent store so
|
||||||
|
that progress survives container restarts and is observable across replicas.
|
||||||
|
|
||||||
|
Why Redis:
|
||||||
|
- Each entry has a natural TTL (we don't want to hold task state forever).
|
||||||
|
- SSE consumers in different processes can read the same state.
|
||||||
|
- Atomic GETDEL avoids races when cleaning up after a stream closes.
|
||||||
|
|
||||||
|
Public API mirrors the old dict semantics where possible:
|
||||||
|
await store.set(task_id, {...}) # write + extend TTL
|
||||||
|
await store.get(task_id) # → dict | None
|
||||||
|
await store.pop(task_id) # → dict | None (atomic)
|
||||||
|
await store.active() # → list[(task_id, dict)] for non-terminal entries
|
||||||
|
await store.exists(task_id) # → bool
|
||||||
|
|
||||||
|
All entries auto-expire after ``ttl_seconds`` (default 5 min). Terminal
|
||||||
|
states (``completed``/``failed``) are kept for the same TTL so a late
|
||||||
|
subscriber can still observe the result instead of getting a 404.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis.asyncio as redis_async
|
||||||
|
except ImportError: # pragma: no cover — package always present in prod
|
||||||
|
redis_async = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_KEY_PREFIX = "legal-ai:progress:"
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressStore:
|
||||||
|
def __init__(self, redis_url: str, ttl_seconds: int = 300):
|
||||||
|
if redis_async is None:
|
||||||
|
raise RuntimeError("redis package not installed")
|
||||||
|
self._url = redis_url
|
||||||
|
self._ttl = ttl_seconds
|
||||||
|
self._client: Any = None
|
||||||
|
|
||||||
|
async def _conn(self) -> Any:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = redis_async.from_url(self._url, decode_responses=True)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(task_id: str) -> str:
|
||||||
|
return f"{_KEY_PREFIX}{task_id}"
|
||||||
|
|
||||||
|
async def set(self, task_id: str, data: dict) -> None:
|
||||||
|
"""Best-effort write. Logs on Redis failure but never raises — progress
|
||||||
|
reporting is observability, not the critical path. If Redis is down
|
||||||
|
the upload still succeeds and the client's SSE fallback recovers."""
|
||||||
|
try:
|
||||||
|
c = await self._conn()
|
||||||
|
payload = json.dumps(data, ensure_ascii=False)
|
||||||
|
await c.set(self._key(task_id), payload, ex=self._ttl)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("progress.set failed for %s", task_id, exc_info=True)
|
||||||
|
|
||||||
|
async def get(self, task_id: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
c = await self._conn()
|
||||||
|
raw = await c.get(self._key(task_id))
|
||||||
|
except Exception:
|
||||||
|
logger.warning("progress.get failed for %s", task_id, exc_info=True)
|
||||||
|
return None
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("Corrupted progress payload for %s", task_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def exists(self, task_id: str) -> bool:
|
||||||
|
try:
|
||||||
|
c = await self._conn()
|
||||||
|
return bool(await c.exists(self._key(task_id)))
|
||||||
|
except Exception:
|
||||||
|
logger.warning("progress.exists failed for %s", task_id, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def pop(self, task_id: str) -> dict | None:
|
||||||
|
try:
|
||||||
|
c = await self._conn()
|
||||||
|
raw = await c.getdel(self._key(task_id)) # atomic, Redis ≥ 6.2
|
||||||
|
except Exception:
|
||||||
|
logger.warning("progress.pop failed for %s", task_id, exc_info=True)
|
||||||
|
return None
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def active(self) -> list[tuple[str, dict]]:
|
||||||
|
"""Return non-terminal entries. Terminal states are excluded."""
|
||||||
|
try:
|
||||||
|
c = await self._conn()
|
||||||
|
keys: list[str] = []
|
||||||
|
async for k in c.scan_iter(match=f"{_KEY_PREFIX}*", count=200):
|
||||||
|
keys.append(k)
|
||||||
|
if not keys:
|
||||||
|
return []
|
||||||
|
values = await c.mget(*keys)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("progress.active failed", exc_info=True)
|
||||||
|
return []
|
||||||
|
out: list[tuple[str, dict]] = []
|
||||||
|
for k, v in zip(keys, values):
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
d = json.loads(v)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if d.get("status") in ("completed", "failed"):
|
||||||
|
continue
|
||||||
|
tid = k[len(_KEY_PREFIX):]
|
||||||
|
out.append((tid, d))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._client is not None:
|
||||||
|
try:
|
||||||
|
await self._client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._client = None
|
||||||
Reference in New Issue
Block a user