"""Redis-backed progress store for upload/processing tasks. Replaces the previous in-memory `dict[str, dict]` with a persistent store so that progress survives container restarts and is observable across replicas. Why Redis: - Each entry has a natural TTL (we don't want to hold task state forever). - SSE consumers in different processes can read the same state. - Atomic GETDEL avoids races when cleaning up after a stream closes. Public API mirrors the old dict semantics where possible: await store.set(task_id, {...}) # write + extend TTL await store.get(task_id) # → dict | None await store.pop(task_id) # → dict | None (atomic) await store.active() # → list[(task_id, dict)] for non-terminal entries await store.exists(task_id) # → bool All entries auto-expire after ``ttl_seconds`` (default 5 min). Terminal states (``completed``/``failed``) are kept for the same TTL so a late subscriber can still observe the result instead of getting a 404. """ from __future__ import annotations import json import logging from typing import Any try: import redis.asyncio as redis_async except ImportError: # pragma: no cover — package always present in prod redis_async = None # type: ignore[assignment] logger = logging.getLogger(__name__) _KEY_PREFIX = "legal-ai:progress:" class ProgressStore: def __init__(self, redis_url: str, ttl_seconds: int = 300): if redis_async is None: raise RuntimeError("redis package not installed") self._url = redis_url self._ttl = ttl_seconds self._client: Any = None async def _conn(self) -> Any: if self._client is None: self._client = redis_async.from_url(self._url, decode_responses=True) return self._client @staticmethod def _key(task_id: str) -> str: return f"{_KEY_PREFIX}{task_id}" async def set(self, task_id: str, data: dict) -> None: """Best-effort write. Logs on Redis failure but never raises — progress reporting is observability, not the critical path. If Redis is down the upload still succeeds and the client's SSE fallback recovers.""" try: c = await self._conn() payload = json.dumps(data, ensure_ascii=False) await c.set(self._key(task_id), payload, ex=self._ttl) except Exception: logger.warning("progress.set failed for %s", task_id, exc_info=True) async def get(self, task_id: str) -> dict | None: try: c = await self._conn() raw = await c.get(self._key(task_id)) except Exception: logger.warning("progress.get failed for %s", task_id, exc_info=True) return None if raw is None: return None try: return json.loads(raw) except json.JSONDecodeError: logger.warning("Corrupted progress payload for %s", task_id) return None async def exists(self, task_id: str) -> bool: try: c = await self._conn() return bool(await c.exists(self._key(task_id))) except Exception: logger.warning("progress.exists failed for %s", task_id, exc_info=True) return False async def pop(self, task_id: str) -> dict | None: try: c = await self._conn() raw = await c.getdel(self._key(task_id)) # atomic, Redis ≥ 6.2 except Exception: logger.warning("progress.pop failed for %s", task_id, exc_info=True) return None if raw is None: return None try: return json.loads(raw) except json.JSONDecodeError: return None async def active(self) -> list[tuple[str, dict]]: """Return non-terminal entries. Terminal states are excluded.""" try: c = await self._conn() keys: list[str] = [] async for k in c.scan_iter(match=f"{_KEY_PREFIX}*", count=200): keys.append(k) if not keys: return [] values = await c.mget(*keys) except Exception: logger.warning("progress.active failed", exc_info=True) return [] out: list[tuple[str, dict]] = [] for k, v in zip(keys, values): if v is None: continue try: d = json.loads(v) except json.JSONDecodeError: continue if d.get("status") in ("completed", "failed"): continue tid = k[len(_KEY_PREFIX):] out.append((tid, d)) return out async def close(self) -> None: if self._client is not None: try: await self._client.aclose() except Exception: pass self._client = None