""" Semantic cache - matches similar prompts, not just exact hashes. Uses embeddings to find similar questions and return cached responses. """ import numpy as np from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from datetime import datetime, timedelta from typing import Optional, Dict, List import json import hashlib from rag import embed_text, cosine_similarity from compression import count_tokens async def semantic_cache_lookup( prompt: str, db: AsyncSession, model: Optional[str] = None, min_similarity: float = 0.85, max_age_hours: int = 168 # 1 week ) -> Optional[Dict]: """ Find cached responses for semantically similar prompts. Returns cached response if similarity >= threshold and not expired. """ from models import Cache # Generate embedding for the query query_embedding = embed_text(prompt) # Get non-expired cache entries expiry = datetime.now() - timedelta(hours=max_age_hours) result = await db.execute( select(Cache) .where( (Cache.expires_at == None) | (Cache.expires_at > datetime.now()) ) .where(Cache.created_at > expiry) ) cache_entries = result.scalars().all() if not cache_entries: return None # Score each cache entry best_match = None best_score = 0 for entry in cache_entries: # Skip if model doesn't match (optional) if model and entry.model and entry.model != model: continue # Compute similarity entry_embedding = embed_text(entry.response) # Or store prompt embedding score = cosine_similarity(query_embedding, entry_embedding) if score >= min_similarity and score > best_score: best_score = score best_match = entry if best_match: return { "response": best_match.response, "similarity": best_score, "model": best_match.model, "tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0), "cached_at": best_match.created_at } return None async def semantic_cache_store( prompt: str, response: str, db: AsyncSession, model: Optional[str] = None, tokens_in: Optional[int] = None, tokens_out: Optional[int] = None, ttl_hours: Optional[int] = None ) -> Dict: """ Store response in cache with embedding for semantic matching. """ from models import Cache # Generate hash for deduplication prompt_hash = hashlib.sha256( json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode() ).hexdigest() # Check if exact match already exists existing = await db.execute( select(Cache).where(Cache.hash == prompt_hash) ) if existing.scalar_one_or_none(): return {"status": "exists", "hash": prompt_hash} # Create new entry expires_at = None if ttl_hours: expires_at = datetime.now() + timedelta(hours=ttl_hours) new_entry = Cache( hash=prompt_hash, response=response, model=model, tokens_in=tokens_in, tokens_out=tokens_out, expires_at=expires_at ) db.add(new_entry) await db.commit() return { "status": "stored", "hash": prompt_hash, "tokens_stored": (tokens_in or 0) + (tokens_out or 0) } async def get_cache_stats(db: AsyncSession) -> Dict: """Get cache statistics""" from models import Cache result = await db.execute(select(Cache)) entries = result.scalars().all() now = datetime.now() valid_entries = [ e for e in entries if e.expires_at is None or e.expires_at > now ] return { "total_entries": len(entries), "valid_entries": len(valid_entries), "total_tokens_stored": sum( (e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries ), "models_used": list(set(e.model for e in entries if e.model)) } async def clear_old_cache( db: AsyncSession, older_than_hours: int = 168 ) -> int: """Delete cache entries older than threshold""" from models import Cache cutoff = datetime.now() - timedelta(hours=older_than_hours) result = await db.execute( select(Cache).where(Cache.created_at < cutoff) ) old_entries = result.scalars().all() for entry in old_entries: await db.delete(entry) await db.commit() return len(old_entries)