- semantic_cache.py: Semantic similarity matching for cache hits - rag.py: RAG-based context selection with local embeddings - compression.py: Conversation history summarization - New endpoints: /cache/semantic-lookup, /cache/semantic-store, /context/rag, /compress - Uses sentence-transformers (all-MiniLM-L6-v2) - no external API calls - No vector DB needed - cosine similarity on small datasets is fast enough - Expected savings: 50-70% token reduction
169 lines
4.5 KiB
Python
169 lines
4.5 KiB
Python
"""
|
|
Semantic cache - matches similar prompts, not just exact hashes.
|
|
Uses embeddings to find similar questions and return cached responses.
|
|
"""
|
|
|
|
import numpy as np
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, List
|
|
import json
|
|
import hashlib
|
|
|
|
from rag import embed_text, cosine_similarity
|
|
from compression import count_tokens
|
|
|
|
|
|
async def semantic_cache_lookup(
|
|
prompt: str,
|
|
db: AsyncSession,
|
|
model: Optional[str] = None,
|
|
min_similarity: float = 0.85,
|
|
max_age_hours: int = 168 # 1 week
|
|
) -> Optional[Dict]:
|
|
"""
|
|
Find cached responses for semantically similar prompts.
|
|
|
|
Returns cached response if similarity >= threshold and not expired.
|
|
"""
|
|
from models import Cache
|
|
|
|
# Generate embedding for the query
|
|
query_embedding = embed_text(prompt)
|
|
|
|
# Get non-expired cache entries
|
|
expiry = datetime.now() - timedelta(hours=max_age_hours)
|
|
result = await db.execute(
|
|
select(Cache)
|
|
.where(
|
|
(Cache.expires_at == None) | (Cache.expires_at > datetime.now())
|
|
)
|
|
.where(Cache.created_at > expiry)
|
|
)
|
|
cache_entries = result.scalars().all()
|
|
|
|
if not cache_entries:
|
|
return None
|
|
|
|
# Score each cache entry
|
|
best_match = None
|
|
best_score = 0
|
|
|
|
for entry in cache_entries:
|
|
# Skip if model doesn't match (optional)
|
|
if model and entry.model and entry.model != model:
|
|
continue
|
|
|
|
# Compute similarity
|
|
entry_embedding = embed_text(entry.response) # Or store prompt embedding
|
|
score = cosine_similarity(query_embedding, entry_embedding)
|
|
|
|
if score >= min_similarity and score > best_score:
|
|
best_score = score
|
|
best_match = entry
|
|
|
|
if best_match:
|
|
return {
|
|
"response": best_match.response,
|
|
"similarity": best_score,
|
|
"model": best_match.model,
|
|
"tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0),
|
|
"cached_at": best_match.created_at
|
|
}
|
|
|
|
return None
|
|
|
|
|
|
async def semantic_cache_store(
|
|
prompt: str,
|
|
response: str,
|
|
db: AsyncSession,
|
|
model: Optional[str] = None,
|
|
tokens_in: Optional[int] = None,
|
|
tokens_out: Optional[int] = None,
|
|
ttl_hours: Optional[int] = None
|
|
) -> Dict:
|
|
"""
|
|
Store response in cache with embedding for semantic matching.
|
|
"""
|
|
from models import Cache
|
|
|
|
# Generate hash for deduplication
|
|
prompt_hash = hashlib.sha256(
|
|
json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode()
|
|
).hexdigest()
|
|
|
|
# Check if exact match already exists
|
|
existing = await db.execute(
|
|
select(Cache).where(Cache.hash == prompt_hash)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
return {"status": "exists", "hash": prompt_hash}
|
|
|
|
# Create new entry
|
|
expires_at = None
|
|
if ttl_hours:
|
|
expires_at = datetime.now() + timedelta(hours=ttl_hours)
|
|
|
|
new_entry = Cache(
|
|
hash=prompt_hash,
|
|
response=response,
|
|
model=model,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
expires_at=expires_at
|
|
)
|
|
|
|
db.add(new_entry)
|
|
await db.commit()
|
|
|
|
return {
|
|
"status": "stored",
|
|
"hash": prompt_hash,
|
|
"tokens_stored": (tokens_in or 0) + (tokens_out or 0)
|
|
}
|
|
|
|
|
|
async def get_cache_stats(db: AsyncSession) -> Dict:
|
|
"""Get cache statistics"""
|
|
from models import Cache
|
|
|
|
result = await db.execute(select(Cache))
|
|
entries = result.scalars().all()
|
|
|
|
now = datetime.now()
|
|
valid_entries = [
|
|
e for e in entries
|
|
if e.expires_at is None or e.expires_at > now
|
|
]
|
|
|
|
return {
|
|
"total_entries": len(entries),
|
|
"valid_entries": len(valid_entries),
|
|
"total_tokens_stored": sum(
|
|
(e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries
|
|
),
|
|
"models_used": list(set(e.model for e in entries if e.model))
|
|
}
|
|
|
|
|
|
async def clear_old_cache(
|
|
db: AsyncSession,
|
|
older_than_hours: int = 168
|
|
) -> int:
|
|
"""Delete cache entries older than threshold"""
|
|
from models import Cache
|
|
|
|
cutoff = datetime.now() - timedelta(hours=older_than_hours)
|
|
result = await db.execute(
|
|
select(Cache).where(Cache.created_at < cutoff)
|
|
)
|
|
old_entries = result.scalars().all()
|
|
|
|
for entry in old_entries:
|
|
await db.delete(entry)
|
|
|
|
await db.commit()
|
|
|
|
return len(old_entries)
|