Major refactor: remove semantic cache, add config, auth, improve RAG performance, fix tags JSON

This commit is contained in:
Lukas Parsons 2026-03-22 22:32:44 -04:00
parent 62c875c9a6
commit b8edf40010
11 changed files with 533 additions and 443 deletions

View file

@ -6,6 +6,8 @@ COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt RUN pip install --no-cache-dir -r requirements.txt
COPY . . COPY . .
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
EXPOSE 8080 EXPOSE 8080

View file

@ -1,10 +1,18 @@
""" """
Prompt compression - summarizes conversation history to reduce tokens. Conversation compression - summarizes old turns to save tokens.
Uses a small local model (no API calls) to compress old turns. Supports multiple strategies: extractive summarization or Ollama LLM.
""" """
from typing import List, Dict from typing import List, Dict
import logging
import tiktoken import tiktoken
import httpx
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lsa import LsaSummarizer
import asyncio
logger = logging.getLogger(__name__)
ENCODING = tiktoken.get_encoding("cl100k_base") ENCODING = tiktoken.get_encoding("cl100k_base")
@ -14,18 +22,85 @@ def count_tokens(text: str) -> int:
return len(ENCODING.encode(text)) return len(ENCODING.encode(text))
def compress_conversation( def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
"""Truncate tool outputs to save tokens"""
tokens = ENCODING.encode(output)
if len(tokens) <= max_tokens:
return output
truncated = ENCODING.decode(tokens[:max_tokens])
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"
def extractive_summarize(text: str, sentences_count: int = 3) -> str:
"""
Simple extractive summarization using LSA algorithm.
Picks the most important sentences from the text.
No external API calls, fast and deterministic.
"""
try:
parser = PlaintextParser.from_string(text, Tokenizer("english"))
summarizer = LsaSummarizer()
summary_sentences = summarizer(parser.document, sentences_count)
return " ".join(str(sentence) for sentence in summary_sentences)
except Exception as e:
# Fallback: truncate to first few sentences
sentences = text.split('. ')[:3]
return '. '.join(sentences) + '.'
async def ollama_summarize(text: str, model: str = "phi3:mini", url: str = "http://localhost:11434") -> str:
"""
Summarize using Ollama API.
Requires Ollama running with the specified model pulled.
"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{url}/api/generate",
json={
"model": model,
"prompt": f"Summarize the following conversation in 2-3 sentences, focusing on key decisions and conclusions:\n\n{text}",
"stream": False,
"options": {
"num_predict": 200
}
}
)
response.raise_for_status()
result = response.json()
return result.get("response", "").strip()
except Exception as e:
# Fallback to extractive on any error
return extractive_summarize(text, sentences_count=3)
async def compress_conversation(
messages: List[Dict], messages: List[Dict],
max_tokens: int = 2000, max_tokens: int = 2000,
keep_last_n: int = 3 keep_last_n: int = 3,
strategy: str = "extractive",
ollama_model: str = "phi3:mini",
ollama_url: str = "http://localhost:11434"
) -> List[Dict]: ) -> List[Dict]:
""" """
Compress conversation history: Compress conversation history:
- Keep last N exchanges in full - Keep last N exchanges in full
- Summarize everything before into a single system message - Summarize everything before using the configured strategy
Args:
messages: Full conversation history
max_tokens: Target token budget
keep_last_n: Number of recent exchanges to keep uncompressed
strategy: "extractive", "ollama", or "none"
ollama_model: Model to use if strategy is "ollama"
ollama_url: Ollama API endpoint
Returns compressed message list. Returns compressed message list.
""" """
if strategy == "none":
return messages
if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs
return messages return messages
@ -41,72 +116,42 @@ def compress_conversation(
recent = convo_messages[-keep_last_n * 2:] recent = convo_messages[-keep_last_n * 2:]
old = convo_messages[:-keep_last_n * 2] old = convo_messages[:-keep_last_n * 2]
# Summarize old conversation # Create text to summarize from old turns
summary = _summarize_turns(old) old_text = "\n".join([f"{m['role']}: {m['content']}" for m in old])
# Summarize using selected strategy
summary = None
if strategy == "ollama":
try:
summary = await ollama_summarize(old_text, ollama_model, ollama_url)
except Exception as e:
logger.warning(f"Ollama summarization failed: {e}, falling back to extractive")
summary = extractive_summarize(old_text, sentences_count=3)
else:
# Extractive is synchronous but fast; run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
summary = await loop.run_in_executor(None, lambda: extractive_summarize(old_text, 3))
# Build compressed messages # Build compressed messages
compressed = [] compressed = []
if system_msg: if system_msg:
compressed.append(system_msg) compressed.append(system_msg)
# Add summary as a user message with context # Add summary as a user message with clear demarcation
compressed.append({ compressed.append({
"role": "user", "role": "user",
"content": f"[PREVIOUS CONVERSATION SUMMARY]\n{summary}\n[/PREVIOUS CONVERSATION SUMMARY]\n\n---\n\nConversation continues below:" "content": f"[CONVERSATION SUMMARY]\n{summary}\n[/CONVERSATION SUMMARY]\n\n---\n\nRecent conversation (most relevant):"
}) })
compressed.extend(recent) compressed.extend(recent)
# Verify we're under limit # Verify we're under limit, if not, drop old more aggressively
total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
if total_tokens > max_tokens: if total_tokens > max_tokens and len(compressed) > 2:
# Aggressive compression - keep only last exchange # Keep only system + last exchange
compressed = compressed[-2:] if system_msg:
compressed = [system_msg, recent[-2]]
else:
compressed = recent[-2:]
return compressed return compressed
def _summarize_turns(messages: List[Dict]) -> str:
"""
Create a brief summary of conversation turns.
In production, call a small local model here.
For now, extract key decisions and topics.
"""
topics = []
decisions = []
for msg in messages:
content = msg.get("content", "")
# Extract topics from user messages
if msg.get("role") == "user":
# Simple keyword extraction (replace with LLM summary)
if "docker" in content.lower():
topics.append("Docker configuration")
if "server" in content.lower():
topics.append("Server setup")
if "config" in content.lower():
topics.append("Configuration")
# Extract decisions from assistant messages
if msg.get("role") == "assistant":
if "we decided" in content.lower() or "I'll use" in content.lower():
decisions.append(content[:200])
summary_parts = []
if topics:
summary_parts.append(f"Topics discussed: {', '.join(set(topics))}")
if decisions:
summary_parts.append(f"Decisions made: {'; '.join(decisions[:3])}")
return "\n".join(summary_parts) if summary_parts else "Previous conversation covered various topics."
def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
"""Truncate tool outputs to save tokens"""
tokens = ENCODING.encode(output)
if len(tokens) <= max_tokens:
return output
truncated = ENCODING.decode(tokens[:max_tokens])
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"

144
config.py Normal file
View file

@ -0,0 +1,144 @@
"""
Configuration management for AI Skills API.
Supports YAML config file with sensible defaults.
Priority: env vars > config file > defaults
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional
import yaml
@dataclass
class RAGConfig:
"""RAG-specific configuration"""
max_skills: int = 3
max_conventions: int = 2
max_snippets: int = 2
min_skill_score: float = 0.3
min_snippet_score: float = 0.25
embedding_model: str = "all-MiniLM-L6-v2"
@dataclass
class CompressionConfig:
"""Compression configuration"""
enabled: bool = True
strategy: str = "extractive" # "extractive", "ollama", or "none"
keep_last_n: int = 3
max_tokens: int = 2000
ollama_model: str = "phi3:mini"
ollama_url: str = "http://localhost:11434"
@dataclass
class AuthConfig:
"""Authentication configuration"""
enabled: bool = False # Set to True to require API keys
api_key: str = "change-me-in-production" # Single shared key for simplicity
header_name: str = "X-API-Key"
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "json" # "json" or "text"
@dataclass
class Config:
"""Main configuration"""
host: str = "0.0.0.0"
port: int = 8675
database_url: str = "sqlite+aiosqlite:///./ai.db"
rag: RAGConfig = field(default_factory=RAGConfig)
compression: CompressionConfig = field(default_factory=CompressionConfig)
auth: AuthConfig = field(default_factory=AuthConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
cors_origins: List[str] = field(default_factory=lambda: ["*"])
def _env_override(prefix: str, key: str, default):
"""Check for environment variable with prefix"""
env_key = f"{prefix}_{key}".upper()
value = os.getenv(env_key)
if value is not None:
# Convert to appropriate type
if isinstance(default, bool):
return value.lower() in ('true', '1', 'yes', 'on')
elif isinstance(default, int):
return int(value)
elif isinstance(default, float):
return float(value)
elif isinstance(default, list):
return value.split(',')
else:
return value
return default
def load_config(config_path: str = "/app/config.yaml") -> Config:
"""
Load configuration from YAML file with environment variable overrides.
Priority: env vars > config file > defaults.
"""
# Start with defaults
config_dict = {}
# Load from config file if exists
locations = [config_path, "/app/config.yaml", "./config.yaml"]
file_data = {}
for path in locations:
if os.path.exists(path):
with open(path, 'r') as f:
file_data = yaml.safe_load(f) or {}
break
# Build config with file values as base
config_dict.update(file_data)
# Override with environment variables
# Top-level settings
config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0'))
config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675))
config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db'))
config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"])
# Nested configs
rag_dict = file_data.get('rag', {})
config_dict['rag'] = {
'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)),
'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)),
'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)),
'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)),
'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)),
'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')),
}
compression_dict = file_data.get('compression', {})
config_dict['compression'] = {
'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)),
'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')),
'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)),
'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)),
'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')),
'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')),
}
auth_dict = file_data.get('auth', {})
config_dict['auth'] = {
'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)),
'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')),
'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')),
}
logging_dict = file_data.get('logging', {})
config_dict['logging'] = {
'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')),
'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')),
}
return Config(**config_dict)

38
config.yaml Normal file
View file

@ -0,0 +1,38 @@
# AI Skills API Configuration
# Server settings
host: "0.0.0.0"
port: 8675
database_url: "sqlite+aiosqlite:///./ai.db"
# CORS origins (restrict in production)
cors_origins: ["*"]
# RAG (Retrieval Augmented Generation) settings
rag:
max_skills: 3 # Number of skills to include in context
max_conventions: 2 # Number of conventions to include
max_snippets: 2 # Number of code snippets to include
min_skill_score: 0.3 # Minimum similarity threshold for skills (0-1)
min_snippet_score: 0.25 # Minimum similarity for snippets (0-1)
embedding_model: "all-MiniLM-L6-v2" # Sentence transformer model
# Compression settings
compression:
enabled: true
strategy: "extractive" # "extractive" (sumy), "ollama" (phi-3-mini), or "none"
keep_last_n: 3 # Number of recent exchanges to keep uncompressed
max_tokens: 2000 # Target token budget for conversation history
ollama_model: "phi3:mini" # Only used if strategy is "ollama"
ollama_url: "http://localhost:11434" # Ollama API endpoint
# Authentication (set and forget - simple API key)
auth:
enabled: false # Set to true to require API key on all endpoints
api_key: "change-me-in-production" # Change this if enabling auth
header_name: "X-API-Key"
# Logging configuration
logging:
level: "INFO"
format: "json" # "json" for structured logs, "text" for human readable

View file

@ -7,6 +7,7 @@ services:
- DATABASE_URL=sqlite+aiosqlite:///./ai.db - DATABASE_URL=sqlite+aiosqlite:///./ai.db
volumes: volumes:
- ./data:/app/data - ./data:/app/data
- ./config.yaml:/app/config.yaml:ro
restart: unless-stopped restart: unless-stopped
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"] test: ["CMD", "curl", "-f", "http://localhost:8080/health"]

312
main.py
View file

@ -1,37 +1,81 @@
from fastapi import FastAPI, HTTPException, Depends, Query import logging
from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import select, func from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
import hashlib
import json
import os
from typing import Optional, List from typing import Optional, List
import sys
from database import get_db, init_db from config import load_config, Config
from models import Skill, Snippet, Convention, Cache, Memory from database import init_db as default_init_db
from models import Skill, Snippet, Convention, Memory, Base
from schemas import ( from schemas import (
SkillBase, Skill, SkillBase,
SnippetBase, Snippet, Skill as SkillSchema,
ConventionBase, Convention, SnippetBase,
CacheStore, Cache as CacheSchema, Snippet as SnippetSchema,
MemoryBase, Memory as MemorySchema, ConventionBase,
ContextBundle, CacheLookup Convention as ConventionSchema,
MemoryBase,
Memory as MemorySchema,
ContextBundle
) )
from semantic_cache import ( from rag import build_context_bundle, clear_cache as clear_rag_cache
semantic_cache_lookup,
semantic_cache_store,
get_cache_stats,
clear_old_cache
)
from rag import build_context_bundle
from compression import compress_conversation, count_tokens from compression import compress_conversation, count_tokens
app = FastAPI(title="AI Skills API", description="Local infrastructure for AI context management") # Load configuration
config = load_config()
# Setup logging
logging.basicConfig(
level=getattr(logging, config.logging.level.upper()),
format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("ai-skills-api")
# API Key authentication dependency
async def verify_api_key(
request: Request,
x_api_key: Optional[str] = Header(None, alias=config.auth.header_name)
) -> None:
"""Verify API key if auth is enabled"""
if config.auth.enabled:
if not x_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if x_api_key != config.auth.api_key:
raise HTTPException(status_code=403, detail="Invalid API key")
# Database setup based on config
if config.database_url != "sqlite+aiosqlite:///./ai.db":
# Custom DB URL - create new engine
engine = create_async_engine(config.database_url, echo=False)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
yield session
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
else:
# Use default database setup from database.py
from database import get_db, init_db
app = FastAPI(
title="AI Skills API",
description="Local infrastructure for AI context management"
)
# CORS configuration
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=config.cors_origins,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@ -41,11 +85,15 @@ app.add_middleware(
@app.on_event("startup") @app.on_event("startup")
async def startup(): async def startup():
await init_db() await init_db()
logger.info("Application startup complete")
logger.info(f"RAG cache will initialize on first request")
if config.auth.enabled:
logger.info("API key authentication enabled")
# ============== SKILLS ============== # ============== SKILLS ==============
@app.get("/skills", response_model=list[Skill]) @app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)])
async def list_skills( async def list_skills(
category: Optional[str] = None, category: Optional[str] = None,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@ -57,7 +105,7 @@ async def list_skills(
return result.scalars().all() return result.scalars().all()
@app.get("/skills/search") @app.get("/skills/search", dependencies=[Depends(verify_api_key)])
async def search_skills( async def search_skills(
q: str, q: str,
category: Optional[str] = None, category: Optional[str] = None,
@ -66,7 +114,7 @@ async def search_skills(
query = select(Skill).where( query = select(Skill).where(
(Skill.name.ilike(f"%{q}%")) | (Skill.name.ilike(f"%{q}%")) |
(Skill.content.ilike(f"%{q}%")) | (Skill.content.ilike(f"%{q}%")) |
(Skill.tags.ilike(f"%{q}%")) (Skill.tags.astext.ilike(f"%{q}%"))
) )
if category: if category:
query = query.where(Skill.category == category) query = query.where(Skill.category == category)
@ -74,7 +122,7 @@ async def search_skills(
return result.scalars().all() return result.scalars().all()
@app.get("/skills/{skill_id}", response_model=Skill) @app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)): async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id)) result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none() skill = result.scalar_one_or_none()
@ -86,7 +134,7 @@ async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
return skill return skill
@app.post("/skills", response_model=Skill) @app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)): async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
db_skill = Skill(**skill.model_dump()) db_skill = Skill(**skill.model_dump())
db.add(db_skill) db.add(db_skill)
@ -99,7 +147,7 @@ async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
return db_skill return db_skill
@app.put("/skills/{skill_id}", response_model=Skill) @app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)): async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id)) result = await db.execute(select(Skill).where(Skill.id == skill_id))
db_skill = result.scalar_one_or_none() db_skill = result.scalar_one_or_none()
@ -114,7 +162,7 @@ async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depen
return db_skill return db_skill
@app.delete("/skills/{skill_id}") @app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)])
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)): async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id)) result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none() skill = result.scalar_one_or_none()
@ -128,7 +176,7 @@ async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
# ============== SNIPPETS ============== # ============== SNIPPETS ==============
@app.get("/snippets", response_model=list[Snippet]) @app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)])
async def list_snippets( async def list_snippets(
category: Optional[str] = None, category: Optional[str] = None,
language: Optional[str] = None, language: Optional[str] = None,
@ -143,7 +191,7 @@ async def list_snippets(
return result.scalars().all() return result.scalars().all()
@app.get("/snippets/{snippet_id}", response_model=Snippet) @app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none() snippet = result.scalar_one_or_none()
@ -152,7 +200,7 @@ async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
return snippet return snippet
@app.post("/snippets", response_model=Snippet) @app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)): async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
db_snippet = Snippet(**snippet.model_dump()) db_snippet = Snippet(**snippet.model_dump())
db.add(db_snippet) db.add(db_snippet)
@ -165,7 +213,7 @@ async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db
return db_snippet return db_snippet
@app.delete("/snippets/{snippet_id}") @app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)])
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none() snippet = result.scalar_one_or_none()
@ -179,7 +227,7 @@ async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
# ============== CONVENTIONS ============== # ============== CONVENTIONS ==============
@app.get("/conventions", response_model=list[Convention]) @app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)])
async def list_conventions( async def list_conventions(
project: Optional[str] = None, project: Optional[str] = None,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@ -191,7 +239,7 @@ async def list_conventions(
return result.scalars().all() return result.scalars().all()
@app.get("/conventions/{convention_id}", response_model=Convention) @app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)): async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id)) result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none() convention = result.scalar_one_or_none()
@ -200,7 +248,7 @@ async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db))
return convention return convention
@app.post("/conventions", response_model=Convention) @app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)): async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
db_convention = Convention(**convention.model_dump()) db_convention = Convention(**convention.model_dump())
db.add(db_convention) db.add(db_convention)
@ -213,7 +261,7 @@ async def create_convention(convention: ConventionBase, db: AsyncSession = Depen
return db_convention return db_convention
@app.put("/conventions/{convention_id}", response_model=Convention) @app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)): async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id)) result = await db.execute(select(Convention).where(Convention.id == convention_id))
db_convention = result.scalar_one_or_none() db_convention = result.scalar_one_or_none()
@ -228,7 +276,7 @@ async def update_convention(convention_id: str, convention: ConventionBase, db:
return db_convention return db_convention
@app.delete("/conventions/{convention_id}") @app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)])
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)): async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id)) result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none() convention = result.scalar_one_or_none()
@ -240,107 +288,9 @@ async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_d
return {"deleted": convention_id} return {"deleted": convention_id}
# ============== CACHE ==============
@app.post("/cache/lookup", response_model=Optional[CacheSchema])
async def lookup_cache(lookup: CacheLookup, db: AsyncSession = Depends(get_db)):
"""Exact hash-based cache lookup"""
prompt_hash = hashlib.sha256(
json.dumps({"prompt": lookup.prompt, "model": lookup.model}, sort_keys=True).encode()
).hexdigest()
result = await db.execute(
select(Cache).where(
(Cache.hash == prompt_hash) &
((Cache.expires_at == None) | (Cache.expires_at > func.now()))
)
)
return result.scalar_one_or_none()
@app.post("/cache/semantic-lookup", response_model=dict)
async def semantic_lookup(
prompt: str,
model: Optional[str] = None,
min_similarity: float = 0.85,
db: AsyncSession = Depends(get_db)
):
"""Semantic cache lookup - finds similar prompts"""
result = await semantic_cache_lookup(
prompt, db, model=model, min_similarity=min_similarity
)
if result:
return {"hit": True, **result}
return {"hit": False}
@app.post("/cache/store", response_model=CacheSchema)
async def store_cache(cache: CacheStore, db: AsyncSession = Depends(get_db)):
"""Store in exact-match cache"""
prompt_hash = hashlib.sha256(
json.dumps({"prompt": cache.response, "model": cache.model}, sort_keys=True).encode()
).hexdigest()
db_cache = Cache(
hash=prompt_hash,
response=cache.response,
model=cache.model,
tokens_in=cache.tokens_in,
tokens_out=cache.tokens_out,
expires_at=cache.expires_at
)
db.add(db_cache)
await db.commit()
await db.refresh(db_cache)
return db_cache
@app.post("/cache/semantic-store", response_model=dict)
async def semantic_store(
prompt: str,
response: str,
model: Optional[str] = None,
tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None,
db: AsyncSession = Depends(get_db)
):
"""Store in semantic cache"""
return await semantic_cache_store(
prompt, response, db, model, tokens_in, tokens_out
)
@app.delete("/cache/{cache_hash}")
async def delete_cache(cache_hash: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Cache).where(Cache.hash == cache_hash))
cache = result.scalar_one_or_none()
if not cache:
raise HTTPException(status_code=404, detail="Cache entry not found")
await db.delete(cache)
await db.commit()
return {"deleted": cache_hash}
@app.get("/cache/stats")
async def cache_stats_endpoint(db: AsyncSession = Depends(get_db)):
"""Get cache statistics"""
return await get_cache_stats(db)
@app.post("/cache/clear-old")
async def clear_old(
older_than_hours: int = 168,
db: AsyncSession = Depends(get_db)
):
"""Clear cache entries older than threshold"""
deleted = await clear_old_cache(db, older_than_hours)
return {"deleted": deleted}
# ============== MEMORY ============== # ============== MEMORY ==============
@app.get("/memory", response_model=list[MemorySchema]) @app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)])
async def list_memory( async def list_memory(
project: Optional[str] = None, project: Optional[str] = None,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
@ -352,7 +302,7 @@ async def list_memory(
return result.scalars().all() return result.scalars().all()
@app.get("/memory/{memory_id}", response_model=MemorySchema) @app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)): async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id)) result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none() memory = result.scalar_one_or_none()
@ -361,7 +311,7 @@ async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
return memory return memory
@app.post("/memory", response_model=MemorySchema) @app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)): async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
db_memory = Memory(**memory.model_dump()) db_memory = Memory(**memory.model_dump())
db.add(db_memory) db.add(db_memory)
@ -374,7 +324,7 @@ async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
return db_memory return db_memory
@app.put("/memory/{memory_id}", response_model=MemorySchema) @app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)): async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id)) result = await db.execute(select(Memory).where(Memory.id == memory_id))
db_memory = result.scalar_one_or_none() db_memory = result.scalar_one_or_none()
@ -389,7 +339,7 @@ async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = D
return db_memory return db_memory
@app.delete("/memory/{memory_id}") @app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)])
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)): async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id)) result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none() memory = result.scalar_one_or_none()
@ -403,7 +353,7 @@ async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
# ============== CONTEXT BUNDLE ============== # ============== CONTEXT BUNDLE ==============
@app.get("/context", response_model=ContextBundle) @app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)])
async def get_context( async def get_context(
project: Optional[str] = None, project: Optional[str] = None,
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"), skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
@ -438,19 +388,29 @@ async def get_context(
) )
@app.get("/context/rag") @app.get("/context/rag", dependencies=[Depends(verify_api_key)])
async def get_context_rag( async def get_context_rag(
query: str, query: str,
project: Optional[str] = None, project: Optional[str] = None,
max_skills: int = 3, max_skills: Optional[int] = None,
max_conventions: int = 2, max_conventions: Optional[int] = None,
max_snippets: int = 2, max_snippets: Optional[int] = None,
db: AsyncSession = Depends(get_db) db: AsyncSession = Depends(get_db)
): ):
""" """
RAG-based context selection - returns ONLY relevant items. RAG-based context selection - returns ONLY relevant items.
Uses semantic search to find top K most relevant skills/snippets. Uses semantic search to find top K most relevant skills/snippets.
Uses config defaults if parameters not provided.
""" """
# Use config defaults if not specified
if max_skills is None:
max_skills = config.rag.max_skills
if max_conventions is None:
max_conventions = config.rag.max_conventions
if max_snippets is None:
max_snippets = config.rag.max_snippets
bundle = await build_context_bundle( bundle = await build_context_bundle(
query, db, project, query, db, project,
max_skills=max_skills, max_skills=max_skills,
@ -460,17 +420,28 @@ async def get_context_rag(
return bundle return bundle
@app.post("/compress") @app.post("/compress", dependencies=[Depends(verify_api_key)])
async def compress_messages( async def compress_messages_endpoint(
messages: List[dict], messages: List[dict],
keep_last_n: int = 3, keep_last_n: Optional[int] = None,
max_tokens: int = 2000 max_tokens: Optional[int] = None
): ):
""" """
Compress conversation history. Compress conversation history.
Keeps last N exchanges in full, summarizes everything before. Uses config strategy (extractive or ollama).
""" """
compressed = compress_conversation(messages, max_tokens, keep_last_n) # Use config defaults if not specified
keep_last_n = keep_last_n or config.compression.keep_last_n
max_tokens = max_tokens or config.compression.max_tokens
compressed = await compress_conversation(
messages,
max_tokens=max_tokens,
keep_last_n=keep_last_n,
strategy=config.compression.strategy,
ollama_model=config.compression.ollama_model,
ollama_url=config.compression.ollama_url
)
original_tokens = sum(count_tokens(m.get("content", "")) for m in messages) original_tokens = sum(count_tokens(m.get("content", "")) for m in messages)
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
@ -483,7 +454,7 @@ async def compress_messages(
} }
@app.get("/tokens/count") @app.get("/tokens/count", dependencies=[Depends(verify_api_key)])
async def count_tokens_endpoint(text: str): async def count_tokens_endpoint(text: str):
"""Count tokens in text""" """Count tokens in text"""
return {"tokens": count_tokens(text)} return {"tokens": count_tokens(text)}
@ -492,3 +463,36 @@ async def count_tokens_endpoint(text: str):
@app.get("/health") @app.get("/health")
async def health(): async def health():
return {"status": "healthy"} return {"status": "healthy"}
@app.get("/config")
async def get_config():
"""Return current configuration (non-sensitive)"""
return {
"rag": {
"max_skills": config.rag.max_skills,
"max_conventions": config.rag.max_conventions,
"max_snippets": config.rag.max_snippets,
"min_skill_score": config.rag.min_skill_score,
"min_snippet_score": config.rag.min_snippet_score,
"embedding_model": config.rag.embedding_model
},
"compression": {
"enabled": config.compression.enabled,
"strategy": config.compression.strategy,
"keep_last_n": config.compression.keep_last_n,
"max_tokens": config.compression.max_tokens,
"ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None
},
"auth": {
"enabled": config.auth.enabled
}
}
@app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)])
async def clear_cache():
"""Clear RAG embedding cache (forces reload on next request)"""
clear_rag_cache()
return {"status": "cache cleared"}

View file

@ -1,4 +1,4 @@
from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey, JSON
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.sql import func from sqlalchemy.sql import func
from database import Base from database import Base
@ -12,7 +12,7 @@ class Skill(Base):
description = Column(Text) description = Column(Text)
category = Column(String) category = Column(String)
content = Column(Text, nullable=False) content = Column(Text, nullable=False)
tags = Column(String) tags = Column(JSON, default=list)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
usage_count = Column(Integer, default=0) usage_count = Column(Integer, default=0)
@ -26,7 +26,7 @@ class Snippet(Base):
language = Column(String) language = Column(String)
content = Column(Text, nullable=False) content = Column(Text, nullable=False)
category = Column(String) category = Column(String)
tags = Column(String) tags = Column(JSON, default=list)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
@ -41,17 +41,6 @@ class Convention(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
class Cache(Base):
__tablename__ = "cache"
hash = Column(String, primary_key=True)
response = Column(Text, nullable=False)
model = Column(String)
tokens_in = Column(Integer)
tokens_out = Column(Integer)
created_at = Column(DateTime(timezone=True), server_default=func.now())
expires_at = Column(DateTime(timezone=True))
class Memory(Base): class Memory(Base):
__tablename__ = "memory" __tablename__ = "memory"

102
rag.py
View file

@ -1,6 +1,6 @@
""" """
RAG-based context selection using local embeddings. RAG-based context selection using local embeddings.
No external API calls - runs entirely on your home server. Optimized with in-memory embedding cache to avoid recomputation.
""" """
import numpy as np import numpy as np
@ -8,12 +8,21 @@ from sentence_transformers import SentenceTransformer
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Optional from typing import List, Dict, Optional
import os import asyncio
from threading import Lock
# Small, fast model - ~100MB, runs on CPU # Small, fast model - ~100MB, runs on CPU
MODEL_NAME = "all-MiniLM-L6-v2" MODEL_NAME = "all-MiniLM-L6-v2"
_model: Optional[SentenceTransformer] = None _model: Optional[SentenceTransformer] = None
# In-memory embedding cache
_embedding_cache: Dict[str, Dict] = {
"skills": {},
"snippets": {},
"initialized": False
}
_cache_lock = Lock()
def get_model() -> SentenceTransformer: def get_model() -> SentenceTransformer:
"""Lazy-load the embedding model""" """Lazy-load the embedding model"""
@ -34,6 +43,51 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b)) return float(np.dot(a, b))
def _get_skill_text(skill) -> str:
"""Generate searchable text for a skill"""
return f"{skill.name} {skill.description or ''} {skill.content[:500]}"
def _get_snippet_text(snippet) -> str:
"""Generate searchable text for a snippet"""
return f"{snippet.name} {snippet.content}"
async def ensure_cache_initialized(db: AsyncSession):
"""Load all skills and snippets into memory with their embeddings"""
global _embedding_cache
with _cache_lock:
if _embedding_cache["initialized"]:
return
# Load skills
result = await db.execute(select(Skill))
skills = result.scalars().all()
for skill in skills:
text = _get_skill_text(skill)
embedding = embed_text(text)
_embedding_cache["skills"][skill.id] = {
"embedding": embedding,
"skill": skill
}
# Load snippets
result = await db.execute(select(Snippet))
snippets = result.scalars().all()
for snippet in snippets:
text = _get_snippet_text(snippet)
embedding = embed_text(text)
_embedding_cache["snippets"][snippet.id] = {
"embedding": embedding,
"snippet": snippet
}
_embedding_cache["initialized"] = True
async def select_relevant_skills( async def select_relevant_skills(
query: str, query: str,
db: AsyncSession, db: AsyncSession,
@ -42,27 +96,24 @@ async def select_relevant_skills(
) -> List[Dict]: ) -> List[Dict]:
""" """
Find most relevant skills using semantic search. Find most relevant skills using semantic search.
Only returns skills above minimum similarity threshold. Uses cached embeddings for O(1) retrieval after initial load.
""" """
from models import Skill from models import Skill
# Get all skills (for small datasets, load all - fine for <1000 items) await ensure_cache_initialized(db)
result = await db.execute(select(Skill))
skills = result.scalars().all()
if not skills: if not _embedding_cache["skills"]:
return [] return []
# Generate query embedding # Generate query embedding
query_embedding = embed_text(query) query_embedding = embed_text(query)
# Score each skill # Score each cached skill
scored = [] scored = []
for skill in skills: for skill_id, cached in _embedding_cache["skills"].items():
# Use cached embedding if available, else compute skill = cached["skill"]
skill_text = f"{skill.name} {skill.description or ''} {skill.content[:500]}" embedding = cached["embedding"]
skill_embedding = embed_text(skill_text) score = cosine_similarity(query_embedding, embedding)
score = cosine_similarity(query_embedding, skill_embedding)
if score >= min_score: if score >= min_score:
scored.append((score, skill)) scored.append((score, skill))
@ -90,6 +141,7 @@ async def select_relevant_conventions(
""" """
Get conventions for a project. Get conventions for a project.
Exact match on project_path, plus fuzzy match on parent paths. Exact match on project_path, plus fuzzy match on parent paths.
(No embeddings - small dataset, exact matching is fine)
""" """
from models import Convention from models import Convention
@ -128,25 +180,24 @@ async def select_relevant_snippets(
top_k: int = 2, top_k: int = 2,
language: Optional[str] = None language: Optional[str] = None
) -> List[Dict]: ) -> List[Dict]:
"""Find relevant code snippets""" """Find relevant code snippets using cached embeddings"""
from models import Snippet from models import Snippet
result = await db.execute(select(Snippet)) await ensure_cache_initialized(db)
snippets = result.scalars().all()
if not snippets: if not _embedding_cache["snippets"]:
return [] return []
query_embedding = embed_text(query) query_embedding = embed_text(query)
scored = [] scored = []
for snippet in snippets: for snippet_id, cached in _embedding_cache["snippets"].items():
snippet = cached["snippet"]
if language and snippet.language != language: if language and snippet.language != language:
continue continue
snippet_text = f"{snippet.name} {snippet.content}" embedding = cached["embedding"]
snippet_embedding = embed_text(snippet_text) score = cosine_similarity(query_embedding, embedding)
score = cosine_similarity(query_embedding, snippet_embedding)
if score >= 0.25: # Lower threshold for snippets if score >= 0.25: # Lower threshold for snippets
scored.append((score, snippet)) scored.append((score, snippet))
@ -202,4 +253,11 @@ async def build_context_bundle(
} }
import asyncio def clear_cache():
"""Clear embedding cache (useful for testing)"""
global _embedding_cache
_embedding_cache = {
"skills": {},
"snippets": {},
"initialized": False
}

View file

@ -7,3 +7,6 @@ aiosqlite==0.19.0
sentence-transformers==2.3.1 sentence-transformers==2.3.1
numpy==1.26.3 numpy==1.26.3
tiktoken==0.5.2 tiktoken==0.5.2
pyyaml==6.0
httpx==0.27.0
sumy==0.11.0

View file

@ -52,26 +52,6 @@ class Convention(ConventionBase):
from_attributes = True from_attributes = True
class CacheBase(BaseModel):
hash: str
response: str
model: Optional[str] = None
tokens_in: Optional[int] = None
tokens_out: Optional[int] = None
expires_at: Optional[datetime] = None
class CacheStore(CacheBase):
pass
class Cache(CacheBase):
created_at: datetime
class Config:
from_attributes = True
class MemoryBase(BaseModel): class MemoryBase(BaseModel):
id: str id: str
project: str project: str
@ -92,8 +72,3 @@ class ContextBundle(BaseModel):
snippets: List[Snippet] snippets: List[Snippet]
conventions: List[Convention] conventions: List[Convention]
memories: List[Memory] memories: List[Memory]
class CacheLookup(BaseModel):
prompt: str
model: Optional[str] = None

View file

@ -1,169 +0,0 @@
"""
Semantic cache - matches similar prompts, not just exact hashes.
Uses embeddings to find similar questions and return cached responses.
"""
import numpy as np
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timedelta
from typing import Optional, Dict, List
import json
import hashlib
from rag import embed_text, cosine_similarity
from compression import count_tokens
async def semantic_cache_lookup(
prompt: str,
db: AsyncSession,
model: Optional[str] = None,
min_similarity: float = 0.85,
max_age_hours: int = 168 # 1 week
) -> Optional[Dict]:
"""
Find cached responses for semantically similar prompts.
Returns cached response if similarity >= threshold and not expired.
"""
from models import Cache
# Generate embedding for the query
query_embedding = embed_text(prompt)
# Get non-expired cache entries
expiry = datetime.now() - timedelta(hours=max_age_hours)
result = await db.execute(
select(Cache)
.where(
(Cache.expires_at == None) | (Cache.expires_at > datetime.now())
)
.where(Cache.created_at > expiry)
)
cache_entries = result.scalars().all()
if not cache_entries:
return None
# Score each cache entry
best_match = None
best_score = 0
for entry in cache_entries:
# Skip if model doesn't match (optional)
if model and entry.model and entry.model != model:
continue
# Compute similarity
entry_embedding = embed_text(entry.response) # Or store prompt embedding
score = cosine_similarity(query_embedding, entry_embedding)
if score >= min_similarity and score > best_score:
best_score = score
best_match = entry
if best_match:
return {
"response": best_match.response,
"similarity": best_score,
"model": best_match.model,
"tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0),
"cached_at": best_match.created_at
}
return None
async def semantic_cache_store(
prompt: str,
response: str,
db: AsyncSession,
model: Optional[str] = None,
tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None,
ttl_hours: Optional[int] = None
) -> Dict:
"""
Store response in cache with embedding for semantic matching.
"""
from models import Cache
# Generate hash for deduplication
prompt_hash = hashlib.sha256(
json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode()
).hexdigest()
# Check if exact match already exists
existing = await db.execute(
select(Cache).where(Cache.hash == prompt_hash)
)
if existing.scalar_one_or_none():
return {"status": "exists", "hash": prompt_hash}
# Create new entry
expires_at = None
if ttl_hours:
expires_at = datetime.now() + timedelta(hours=ttl_hours)
new_entry = Cache(
hash=prompt_hash,
response=response,
model=model,
tokens_in=tokens_in,
tokens_out=tokens_out,
expires_at=expires_at
)
db.add(new_entry)
await db.commit()
return {
"status": "stored",
"hash": prompt_hash,
"tokens_stored": (tokens_in or 0) + (tokens_out or 0)
}
async def get_cache_stats(db: AsyncSession) -> Dict:
"""Get cache statistics"""
from models import Cache
result = await db.execute(select(Cache))
entries = result.scalars().all()
now = datetime.now()
valid_entries = [
e for e in entries
if e.expires_at is None or e.expires_at > now
]
return {
"total_entries": len(entries),
"valid_entries": len(valid_entries),
"total_tokens_stored": sum(
(e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries
),
"models_used": list(set(e.model for e in entries if e.model))
}
async def clear_old_cache(
db: AsyncSession,
older_than_hours: int = 168
) -> int:
"""Delete cache entries older than threshold"""
from models import Cache
cutoff = datetime.now() - timedelta(hours=older_than_hours)
result = await db.execute(
select(Cache).where(Cache.created_at < cutoff)
)
old_entries = result.scalars().all()
for entry in old_entries:
await db.delete(entry)
await db.commit()
return len(old_entries)