"""FastAPI ↔ legal-chat-service streaming bridge. The browser hits ``/api/training/chat/conversations/{id}/messages`` on the legal-ai container. The container is sealed off from the host's ``claude`` CLI (intentional — see ``claude_session.py`` docstring), so we forward each request to the pm2-managed ``legal-chat-service`` over loopback (``host.docker.internal:8770``). Responsibilities: - Save the user message to ``chat_messages`` before streaming starts. - Open an HTTP streaming connection to the host service. - Forward each SSE event to the browser as-is, accumulating the assistant text and any ``session_id`` so we can persist them once the stream closes. - Persist the assistant turn + the CLI's session_id at end-of-stream. """ from __future__ import annotations import json import logging import os from typing import AsyncIterator from uuid import UUID import httpx from fastapi import HTTPException from fastapi.responses import StreamingResponse from legal_mcp.services import db from web import chat_system_prompt logger = logging.getLogger(__name__) # legal-chat-service lives on the host (pm2-managed, bound to 0.0.0.0:8770). # From inside the container we reach it via the docker bridge gateway — # 10.0.1.1 is the host on docker0 (the same address Paperclip uses for # its 3100 bridge). Override with CHAT_SERVICE_URL if running outside # Docker (local dev). # # Coolify's `custom_docker_run_options: --add-host=host.docker.internal:host-gateway` # turned out NOT to apply to dockerimage-built apps as of Coolify 4.0.0, # so the explicit IP is the reliable path. The cloud-level firewall # (Oracle security list) keeps port 8770 unreachable from the public # internet, matching the security posture of Paperclip's 3100. CHAT_SERVICE_URL = os.environ.get( "CHAT_SERVICE_URL", "http://10.0.1.1:8770", ) CHAT_SERVICE_TIMEOUT_S = float(os.environ.get("CHAT_SERVICE_TIMEOUT_S", "3600")) _SSE_HEADERS = { "Cache-Control": "no-cache, no-transform", "X-Accel-Buffering": "no", "Connection": "keep-alive", } async def stream_chat_message( conversation_id: UUID, user_message: str, ) -> StreamingResponse: """Open SSE stream, forward events, persist when done. Returns a FastAPI StreamingResponse the route can return directly. """ conv = await db.get_chat_conversation(conversation_id) if not conv: raise HTTPException(404, "conversation not found") # Persist the user turn immediately so a network drop doesn't lose it. await db.add_chat_message( conversation_id, role="user", content=user_message, ) is_first_turn = not conv.get("claude_session_id") system_block: str | None = None if is_first_turn: try: system_block = await chat_system_prompt.build_system_prompt( corpus_id=conv.get("style_corpus_id"), ) except Exception as e: logger.exception("system prompt build failed") raise HTTPException(500, f"system prompt failed: {e}") payload = { "prompt": user_message, "system": system_block, "resume_session_id": conv.get("claude_session_id"), } async def proxy_stream() -> AsyncIterator[bytes]: accumulated_text: list[str] = [] events_log: list[dict] = [] new_session_id: str | None = None try: timeout_cfg = httpx.Timeout( CHAT_SERVICE_TIMEOUT_S, connect=10.0, read=CHAT_SERVICE_TIMEOUT_S, ) async with httpx.AsyncClient(timeout=timeout_cfg) as client: async with client.stream( "POST", f"{CHAT_SERVICE_URL}/chat/start", json=payload, ) as upstream: if upstream.status_code != 200: body = await upstream.aread() msg = body.decode("utf-8", errors="replace")[:300] err = {"type": "error", "message": f"chat-service {upstream.status_code}: {msg}"} yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n".encode("utf-8") return async for line in upstream.aiter_lines(): if not line: yield b"\n" continue # Forward verbatim so the browser sees the same # SSE framing the host emits. out = line + "\n" yield out.encode("utf-8") # Mirror events: capture text + session_id for # persistence. The line starts with "data: " # so we strip the prefix before parsing. if line.startswith("data: "): try: event = json.loads(line[len("data: "):]) except json.JSONDecodeError: continue events_log.append(event) t = event.get("type") if t == "session_id" and event.get("value"): new_session_id = event["value"] elif t == "text_delta" and event.get("text"): accumulated_text.append(event["text"]) elif t == "done" and event.get("text"): if not accumulated_text: accumulated_text.append(event["text"]) except httpx.ConnectError: err = { "type": "error", "message": ( f"לא ניתן להגיע ל-legal-chat-service בכתובת {CHAT_SERVICE_URL}. " "ודא ש-pm2 מריץ אותו: `pm2 status legal-chat-service`." ), } yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n".encode("utf-8") return except Exception as e: logger.exception("chat proxy failed") err = {"type": "error", "message": str(e)} yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n".encode("utf-8") return # End of stream — persist the assistant turn. try: full_text = "".join(accumulated_text).strip() if full_text: await db.add_chat_message( conversation_id, role="assistant", content=full_text, raw_events=events_log, ) if new_session_id: await db.update_chat_conversation_session_id( conversation_id, new_session_id, ) except Exception: logger.exception("failed to persist assistant turn for conv=%s", conversation_id) return StreamingResponse( proxy_stream(), media_type="text/event-stream", headers=_SSE_HEADERS, )