"""Embedding service using Voyage AI API.""" from __future__ import annotations import logging from typing import TYPE_CHECKING from legal_mcp import config if TYPE_CHECKING: import voyageai from PIL import Image as PILImage logger = logging.getLogger(__name__) # voyageai is imported lazily inside _get_client to keep MCP server startup # fast — loading voyageai eagerly costs ~450ms and Claude Code's first tool # call can hit a "No such tool available" race if the server isn't ready yet. _client: "voyageai.Client | None" = None # Per-call cap for multimodal_embed. POC ran 89 pages (~312K tokens) # in a single call comfortably; 50 leaves safe headroom for densely- # OCR'd legal pages where tokens/page can exceed 4K. _MULTIMODAL_BATCH_SIZE = 50 def _get_client() -> "voyageai.Client": global _client if _client is None: import voyageai _client = voyageai.Client(api_key=config.VOYAGE_API_KEY) return _client async def embed_texts(texts: list[str], input_type: str = "document") -> list[list[float]]: """Embed a batch of texts using Voyage AI. Args: texts: List of texts to embed (max 128 per call). input_type: "document" for indexing, "query" for search queries. Returns: List of embedding vectors (1024 dimensions each). """ if not texts: return [] client = _get_client() all_embeddings = [] # Voyage AI supports up to 128 texts per batch for i in range(0, len(texts), 128): batch = texts[i : i + 128] result = client.embed( batch, model=config.VOYAGE_MODEL, input_type=input_type, ) all_embeddings.extend(result.embeddings) return all_embeddings async def embed_query(query: str) -> list[float]: """Embed a single search query.""" results = await embed_texts([query], input_type="query") return results[0] async def embed_images( images: "list[PILImage.Image]", input_type: str = "document", ) -> list[list[float]]: """Embed page images via voyage-multimodal-3. Each input is a single PIL.Image (one page = one embedding). Returns a list of 1024-dim vectors, one per input image, in order. Batches at ``_MULTIMODAL_BATCH_SIZE`` to stay within Voyage's per-request limits on dense legal pages. """ if not images: return [] client = _get_client() out: list[list[float]] = [] for i in range(0, len(images), _MULTIMODAL_BATCH_SIZE): batch = images[i : i + _MULTIMODAL_BATCH_SIZE] result = client.multimodal_embed( inputs=[[img] for img in batch], model=config.MULTIMODAL_MODEL, input_type=input_type, truncation=True, ) out.extend(result.embeddings) return out async def embed_query_for_multimodal(query: str) -> list[float]: """Embed a text query in the multimodal vector space, so it can be cosine-compared against page-image embeddings.""" client = _get_client() result = client.multimodal_embed( inputs=[[query]], model=config.MULTIMODAL_MODEL, input_type="query", ) return result.embeddings[0] async def voyage_rerank( query: str, documents: list[str], top_k: int | None = None, ) -> list[tuple[int, float]]: """Cross-encoder rerank via Voyage. Returns [(orig_index, score), ...] sorted by relevance. Each tuple's index refers to the position in the *input* documents list (not a DB row id) — caller maps it back. Used as a second stage after bi-encoder retrieval: fetch top-N candidates with cosine, then rerank to get top-K with cross-encoder attention over (query, doc). """ if not documents: return [] client = _get_client() result = client.rerank( query=query, documents=documents, model=config.VOYAGE_RERANK_MODEL, top_k=top_k, ) return [(r.index, float(r.relevance_score)) for r in result.results]