"""RAG retrieval telemetry — closed-loop feedback (TaskMaster #50). Logs every semantic search call so we can compute nDCG@10 over time, spot retrieval drift, and feed the rerank training set. Design notes ------------ - **All writes are fire-and-forget**: callers wrap us in ``try/except`` but we also swallow our own DB errors so a telemetry hiccup can never fail a search. The log itself is also written via a detached task — the search returns to the caller immediately and the row lands in the DB on the side. - **search_decisions / search_case_documents** return document chunks from active cases, not ``case_law`` rows. Their telemetry rows leave ``top_case_law_ids`` empty; nDCG aggregation ignores them. - **Auto-inferred feedback**: once a final decision is exported, we scan its ``decision_paragraphs.citations`` JSONB, pull the ``case_law_id`` values, and mark them as ``relevance_score=3`` on any search_log for the same case where the precedent appeared in the top-K. This gives us a "cited == relevant" ground truth signal without asking the chair to label results by hand. """ from __future__ import annotations import asyncio import logging from typing import Any, Iterable from uuid import UUID from legal_mcp.services import db logger = logging.getLogger(__name__) _VALID_SOURCES = {"cited_in_decision", "chair_marked", "auto_inferred"} def _coerce_case_law_ids(results: Iterable[Any], limit: int = 10) -> list[UUID]: """Pull up to ``limit`` ``case_law_id`` UUIDs from search results. Tolerates rows missing the field, non-UUID strings, and ``None`` values. Preserves order (= ranking). """ out: list[UUID] = [] seen: set[str] = set() for r in results: if len(out) >= limit: break if not isinstance(r, dict): continue raw = r.get("case_law_id") if raw is None: continue s = str(raw) if s in seen: continue try: out.append(UUID(s)) seen.add(s) except (ValueError, AttributeError): continue return out async def _insert_log( *, search_type: str, query: str, practice_area: str | None, case_id: UUID | None, user_agent: str | None, result_count: int, top_case_law_ids: list[UUID], duration_ms: int | None, ) -> UUID | None: try: pool = await db.get_pool() async with pool.acquire() as conn: row = await conn.fetchrow( """ INSERT INTO search_logs ( search_type, query, practice_area, case_id, user_agent, result_count, top_case_law_ids, duration_ms ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id """, search_type, query[:2000], # guard against pathologically long queries practice_area or None, case_id, user_agent or None, int(result_count), top_case_law_ids or None, duration_ms, ) return row["id"] if row else None except Exception: logger.exception("telemetry.log_search: insert failed (swallowed)") return None async def log_search( *, search_type: str, query: str, results: Iterable[dict], duration_ms: int | None = None, practice_area: str | None = None, case_id: UUID | str | None = None, user_agent: str | None = None, ) -> UUID | None: """Record a search call. Never raises. Args: search_type: one of 'precedent_library', 'internal_decisions', 'decisions', 'case_documents', 'similar_cases'. query: the raw user query. results: iterable of result dicts. We pull ``case_law_id`` from the first 10 to populate ``top_case_law_ids``. duration_ms: search latency in milliseconds. practice_area: optional filter applied to the search. case_id: optional case context (when the search was scoped to or triggered from a specific case). user_agent: 'writer' / 'researcher' / 'analyst' / 'manual'. Returns: The ``search_logs.id`` UUID if the row was written, else None. Most callers ignore this; auto-inference uses it later via ``infer_relevance_from_citations``. """ # Snapshot results immediately — callers may keep iterating. snapshot = list(results) if not isinstance(results, list) else results top_ids = _coerce_case_law_ids(snapshot, limit=10) case_uuid: UUID | None if case_id is None: case_uuid = None elif isinstance(case_id, UUID): case_uuid = case_id else: try: case_uuid = UUID(str(case_id)) except (ValueError, AttributeError): case_uuid = None return await _insert_log( search_type=search_type, query=query, practice_area=practice_area, case_id=case_uuid, user_agent=user_agent, result_count=len(snapshot), top_case_law_ids=top_ids, duration_ms=duration_ms, ) def log_search_bg( *, search_type: str, query: str, results: Iterable[dict], duration_ms: int | None = None, practice_area: str | None = None, case_id: UUID | str | None = None, user_agent: str | None = None, ) -> None: """Fire-and-forget variant. Schedules the insert as a detached task. Use this from hot search paths so the caller returns to the user immediately. Errors are logged inside ``log_search``. """ # Snapshot eagerly so the caller can mutate/iterate results freely. snapshot = list(results) if not isinstance(results, list) else list(results) try: loop = asyncio.get_running_loop() except RuntimeError: # No running loop — caller is sync. Best-effort: skip telemetry. return loop.create_task( log_search( search_type=search_type, query=query, results=snapshot, duration_ms=duration_ms, practice_area=practice_area, case_id=case_id, user_agent=user_agent, ) ) # ────────────────────────────────────────────────────────────────────── # Auto-inferred relevance feedback # ────────────────────────────────────────────────────────────────────── def _extract_citations_from_jsonb(citations: Any) -> list[UUID]: """Parse ``decision_paragraphs.citations`` JSONB into UUID list. Stored shape: ``[{"case_law_id": "...", "text": "...", "type": ...}]``. Tolerates string form (asyncpg returns it as JSON string when the column registration didn't auto-decode). """ import json as _json if not citations: return [] if isinstance(citations, (bytes, bytearray)): try: citations = _json.loads(citations.decode("utf-8")) except (ValueError, UnicodeDecodeError): return [] elif isinstance(citations, str): try: citations = _json.loads(citations) except ValueError: return [] if not isinstance(citations, list): return [] out: list[UUID] = [] seen: set[str] = set() for item in citations: if not isinstance(item, dict): continue raw = item.get("case_law_id") if not raw: continue s = str(raw) if s in seen: continue try: out.append(UUID(s)) seen.add(s) except (ValueError, AttributeError): continue return out async def _gather_cited_case_law_ids(case_id: UUID) -> list[UUID]: """Pull every distinct ``case_law_id`` cited anywhere in the case's decision paragraphs. """ pool = await db.get_pool() async with pool.acquire() as conn: rows = await conn.fetch( """ SELECT dp.citations FROM decision_paragraphs dp JOIN decision_blocks db ON db.id = dp.block_id JOIN decisions d ON d.id = db.decision_id WHERE d.case_id = $1 AND dp.citations IS NOT NULL AND jsonb_array_length(dp.citations) > 0 """, case_id, ) seen: set[str] = set() out: list[UUID] = [] for r in rows: for clid in _extract_citations_from_jsonb(r["citations"]): s = str(clid) if s not in seen: seen.add(s) out.append(clid) return out async def infer_relevance_from_citations( case_id: UUID | str, *, relevance_score: int = 3, feedback_source: str = "cited_in_decision", ) -> dict: """For each precedent cited in the case's draft, write a relevance row against every search_log where that precedent appeared in the top-K for the same case. Idempotent: the ``UNIQUE(search_log_id, case_law_id, feedback_source)`` constraint on ``search_relevance_feedback`` prevents duplicates. Returns: ``{"cited_precedents": int, "feedback_rows_inserted": int, "searches_matched": int}``. """ if relevance_score not in (0, 1, 2, 3): raise ValueError("relevance_score must be in 0..3") if feedback_source not in _VALID_SOURCES: raise ValueError(f"feedback_source must be one of {_VALID_SOURCES!r}") case_uuid = case_id if isinstance(case_id, UUID) else UUID(str(case_id)) cited = await _gather_cited_case_law_ids(case_uuid) if not cited: return { "cited_precedents": 0, "feedback_rows_inserted": 0, "searches_matched": 0, } pool = await db.get_pool() inserted = 0 matched_searches: set[str] = set() async with pool.acquire() as conn: # For each cited precedent, find all logs where it appeared in # top_case_law_ids for this case, and record its rank. for clid in cited: rows = await conn.fetch( """ SELECT id, top_case_law_ids FROM search_logs WHERE case_id = $1 AND top_case_law_ids IS NOT NULL AND $2 = ANY(top_case_law_ids) """, case_uuid, clid, ) for row in rows: top_ids = row["top_case_law_ids"] or [] # asyncpg returns uuid[] as list[UUID] try: rank = top_ids.index(clid) + 1 except ValueError: continue result = await conn.execute( """ INSERT INTO search_relevance_feedback ( search_log_id, case_law_id, rank, relevance_score, feedback_source ) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (search_log_id, case_law_id, feedback_source) DO NOTHING """, row["id"], clid, rank, relevance_score, feedback_source, ) # ``execute`` returns 'INSERT 0 1' or 'INSERT 0 0' for # the no-op path; count only the writes. if result.endswith(" 1"): inserted += 1 matched_searches.add(str(row["id"])) return { "cited_precedents": len(cited), "feedback_rows_inserted": inserted, "searches_matched": len(matched_searches), } async def infer_relevance_for_all_finalized_cases(limit: int | None = None) -> dict: """Bulk-run auto-inference for every case whose draft is final/exported. Useful for back-filling after V18 schema lands and a few decisions have already been written. Skips cases with no cited precedents silently (they contribute zero to the totals). """ pool = await db.get_pool() sql = """ SELECT DISTINCT c.id FROM cases c JOIN decisions d ON d.case_id = c.id WHERE c.status IN ('final', 'exported') """ if limit is not None and limit > 0: sql += " LIMIT $1" async with pool.acquire() as conn: rows = await conn.fetch(sql, *([limit] if limit else [])) totals = { "cases_processed": 0, "cited_precedents": 0, "feedback_rows_inserted": 0, "searches_matched": 0, } for r in rows: stats = await infer_relevance_from_citations(r["id"]) totals["cases_processed"] += 1 totals["cited_precedents"] += stats["cited_precedents"] totals["feedback_rows_inserted"] += stats["feedback_rows_inserted"] totals["searches_matched"] += stats["searches_matched"] return totals