ai-skills-api/semantic_cache.py
Lukas Parsons 82fd963577 Add token-saving patterns: semantic cache, RAG, compression
- semantic_cache.py: Semantic similarity matching for cache hits
- rag.py: RAG-based context selection with local embeddings
- compression.py: Conversation history summarization
- New endpoints: /cache/semantic-lookup, /cache/semantic-store, /context/rag, /compress
- Uses sentence-transformers (all-MiniLM-L6-v2) - no external API calls
- No vector DB needed - cosine similarity on small datasets is fast enough
- Expected savings: 50-70% token reduction
2026-03-22 21:32:08 -04:00

169 lines
4.5 KiB
Python

"""
Semantic cache - matches similar prompts, not just exact hashes.
Uses embeddings to find similar questions and return cached responses.
"""
import numpy as np
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timedelta
from typing import Optional, Dict, List
import json
import hashlib
from rag import embed_text, cosine_similarity
from compression import count_tokens
async def semantic_cache_lookup(
prompt: str,
db: AsyncSession,
model: Optional[str] = None,
min_similarity: float = 0.85,
max_age_hours: int = 168 # 1 week
) -> Optional[Dict]:
"""
Find cached responses for semantically similar prompts.
Returns cached response if similarity >= threshold and not expired.
"""
from models import Cache
# Generate embedding for the query
query_embedding = embed_text(prompt)
# Get non-expired cache entries
expiry = datetime.now() - timedelta(hours=max_age_hours)
result = await db.execute(
select(Cache)
.where(
(Cache.expires_at == None) | (Cache.expires_at > datetime.now())
)
.where(Cache.created_at > expiry)
)
cache_entries = result.scalars().all()
if not cache_entries:
return None
# Score each cache entry
best_match = None
best_score = 0
for entry in cache_entries:
# Skip if model doesn't match (optional)
if model and entry.model and entry.model != model:
continue
# Compute similarity
entry_embedding = embed_text(entry.response) # Or store prompt embedding
score = cosine_similarity(query_embedding, entry_embedding)
if score >= min_similarity and score > best_score:
best_score = score
best_match = entry
if best_match:
return {
"response": best_match.response,
"similarity": best_score,
"model": best_match.model,
"tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0),
"cached_at": best_match.created_at
}
return None
async def semantic_cache_store(
prompt: str,
response: str,
db: AsyncSession,
model: Optional[str] = None,
tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None,
ttl_hours: Optional[int] = None
) -> Dict:
"""
Store response in cache with embedding for semantic matching.
"""
from models import Cache
# Generate hash for deduplication
prompt_hash = hashlib.sha256(
json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode()
).hexdigest()
# Check if exact match already exists
existing = await db.execute(
select(Cache).where(Cache.hash == prompt_hash)
)
if existing.scalar_one_or_none():
return {"status": "exists", "hash": prompt_hash}
# Create new entry
expires_at = None
if ttl_hours:
expires_at = datetime.now() + timedelta(hours=ttl_hours)
new_entry = Cache(
hash=prompt_hash,
response=response,
model=model,
tokens_in=tokens_in,
tokens_out=tokens_out,
expires_at=expires_at
)
db.add(new_entry)
await db.commit()
return {
"status": "stored",
"hash": prompt_hash,
"tokens_stored": (tokens_in or 0) + (tokens_out or 0)
}
async def get_cache_stats(db: AsyncSession) -> Dict:
"""Get cache statistics"""
from models import Cache
result = await db.execute(select(Cache))
entries = result.scalars().all()
now = datetime.now()
valid_entries = [
e for e in entries
if e.expires_at is None or e.expires_at > now
]
return {
"total_entries": len(entries),
"valid_entries": len(valid_entries),
"total_tokens_stored": sum(
(e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries
),
"models_used": list(set(e.model for e in entries if e.model))
}
async def clear_old_cache(
db: AsyncSession,
older_than_hours: int = 168
) -> int:
"""Delete cache entries older than threshold"""
from models import Cache
cutoff = datetime.now() - timedelta(hours=older_than_hours)
result = await db.execute(
select(Cache).where(Cache.created_at < cutoff)
)
old_entries = result.scalars().all()
for entry in old_entries:
await db.delete(entry)
await db.commit()
return len(old_entries)