ai-skills-api/main.py
Lukas Parsons 82fd963577 Add token-saving patterns: semantic cache, RAG, compression
- semantic_cache.py: Semantic similarity matching for cache hits
- rag.py: RAG-based context selection with local embeddings
- compression.py: Conversation history summarization
- New endpoints: /cache/semantic-lookup, /cache/semantic-store, /context/rag, /compress
- Uses sentence-transformers (all-MiniLM-L6-v2) - no external API calls
- No vector DB needed - cosine similarity on small datasets is fast enough
- Expected savings: 50-70% token reduction
2026-03-22 21:32:08 -04:00

494 lines
16 KiB
Python

from fastapi import FastAPI, HTTPException, Depends, Query
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from sqlalchemy.exc import IntegrityError
import hashlib
import json
import os
from typing import Optional, List
from database import get_db, init_db
from models import Skill, Snippet, Convention, Cache, Memory
from schemas import (
SkillBase, Skill,
SnippetBase, Snippet,
ConventionBase, Convention,
CacheStore, Cache as CacheSchema,
MemoryBase, Memory as MemorySchema,
ContextBundle, CacheLookup
)
from semantic_cache import (
semantic_cache_lookup,
semantic_cache_store,
get_cache_stats,
clear_old_cache
)
from rag import build_context_bundle
from compression import compress_conversation, count_tokens
app = FastAPI(title="AI Skills API", description="Local infrastructure for AI context management")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup():
await init_db()
# ============== SKILLS ==============
@app.get("/skills", response_model=list[Skill])
async def list_skills(
category: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Skill)
if category:
query = query.where(Skill.category == category)
result = await db.execute(query.order_by(Skill.name))
return result.scalars().all()
@app.get("/skills/search")
async def search_skills(
q: str,
category: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Skill).where(
(Skill.name.ilike(f"%{q}%")) |
(Skill.content.ilike(f"%{q}%")) |
(Skill.tags.ilike(f"%{q}%"))
)
if category:
query = query.where(Skill.category == category)
result = await db.execute(query)
return result.scalars().all()
@app.get("/skills/{skill_id}", response_model=Skill)
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none()
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
skill.usage_count += 1
await db.commit()
return skill
@app.post("/skills", response_model=Skill)
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
db_skill = Skill(**skill.model_dump())
db.add(db_skill)
try:
await db.commit()
await db.refresh(db_skill)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Skill with this ID already exists")
return db_skill
@app.put("/skills/{skill_id}", response_model=Skill)
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
db_skill = result.scalar_one_or_none()
if not db_skill:
raise HTTPException(status_code=404, detail="Skill not found")
for key, value in skill.model_dump().items():
setattr(db_skill, key, value)
await db.commit()
await db.refresh(db_skill)
return db_skill
@app.delete("/skills/{skill_id}")
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none()
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
await db.delete(skill)
await db.commit()
return {"deleted": skill_id}
# ============== SNIPPETS ==============
@app.get("/snippets", response_model=list[Snippet])
async def list_snippets(
category: Optional[str] = None,
language: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Snippet)
if category:
query = query.where(Snippet.category == category)
if language:
query = query.where(Snippet.language == language)
result = await db.execute(query.order_by(Snippet.name))
return result.scalars().all()
@app.get("/snippets/{snippet_id}", response_model=Snippet)
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none()
if not snippet:
raise HTTPException(status_code=404, detail="Snippet not found")
return snippet
@app.post("/snippets", response_model=Snippet)
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
db_snippet = Snippet(**snippet.model_dump())
db.add(db_snippet)
try:
await db.commit()
await db.refresh(db_snippet)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Snippet with this ID already exists")
return db_snippet
@app.delete("/snippets/{snippet_id}")
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none()
if not snippet:
raise HTTPException(status_code=404, detail="Snippet not found")
await db.delete(snippet)
await db.commit()
return {"deleted": snippet_id}
# ============== CONVENTIONS ==============
@app.get("/conventions", response_model=list[Convention])
async def list_conventions(
project: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Convention)
if project:
query = query.where(Convention.project_path == project)
result = await db.execute(query.order_by(Convention.name))
return result.scalars().all()
@app.get("/conventions/{convention_id}", response_model=Convention)
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none()
if not convention:
raise HTTPException(status_code=404, detail="Convention not found")
return convention
@app.post("/conventions", response_model=Convention)
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
db_convention = Convention(**convention.model_dump())
db.add(db_convention)
try:
await db.commit()
await db.refresh(db_convention)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Convention with this ID already exists")
return db_convention
@app.put("/conventions/{convention_id}", response_model=Convention)
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
db_convention = result.scalar_one_or_none()
if not db_convention:
raise HTTPException(status_code=404, detail="Convention not found")
for key, value in convention.model_dump().items():
setattr(db_convention, key, value)
await db.commit()
await db.refresh(db_convention)
return db_convention
@app.delete("/conventions/{convention_id}")
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none()
if not convention:
raise HTTPException(status_code=404, detail="Convention not found")
await db.delete(convention)
await db.commit()
return {"deleted": convention_id}
# ============== CACHE ==============
@app.post("/cache/lookup", response_model=Optional[CacheSchema])
async def lookup_cache(lookup: CacheLookup, db: AsyncSession = Depends(get_db)):
"""Exact hash-based cache lookup"""
prompt_hash = hashlib.sha256(
json.dumps({"prompt": lookup.prompt, "model": lookup.model}, sort_keys=True).encode()
).hexdigest()
result = await db.execute(
select(Cache).where(
(Cache.hash == prompt_hash) &
((Cache.expires_at == None) | (Cache.expires_at > func.now()))
)
)
return result.scalar_one_or_none()
@app.post("/cache/semantic-lookup", response_model=dict)
async def semantic_lookup(
prompt: str,
model: Optional[str] = None,
min_similarity: float = 0.85,
db: AsyncSession = Depends(get_db)
):
"""Semantic cache lookup - finds similar prompts"""
result = await semantic_cache_lookup(
prompt, db, model=model, min_similarity=min_similarity
)
if result:
return {"hit": True, **result}
return {"hit": False}
@app.post("/cache/store", response_model=CacheSchema)
async def store_cache(cache: CacheStore, db: AsyncSession = Depends(get_db)):
"""Store in exact-match cache"""
prompt_hash = hashlib.sha256(
json.dumps({"prompt": cache.response, "model": cache.model}, sort_keys=True).encode()
).hexdigest()
db_cache = Cache(
hash=prompt_hash,
response=cache.response,
model=cache.model,
tokens_in=cache.tokens_in,
tokens_out=cache.tokens_out,
expires_at=cache.expires_at
)
db.add(db_cache)
await db.commit()
await db.refresh(db_cache)
return db_cache
@app.post("/cache/semantic-store", response_model=dict)
async def semantic_store(
prompt: str,
response: str,
model: Optional[str] = None,
tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None,
db: AsyncSession = Depends(get_db)
):
"""Store in semantic cache"""
return await semantic_cache_store(
prompt, response, db, model, tokens_in, tokens_out
)
@app.delete("/cache/{cache_hash}")
async def delete_cache(cache_hash: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Cache).where(Cache.hash == cache_hash))
cache = result.scalar_one_or_none()
if not cache:
raise HTTPException(status_code=404, detail="Cache entry not found")
await db.delete(cache)
await db.commit()
return {"deleted": cache_hash}
@app.get("/cache/stats")
async def cache_stats_endpoint(db: AsyncSession = Depends(get_db)):
"""Get cache statistics"""
return await get_cache_stats(db)
@app.post("/cache/clear-old")
async def clear_old(
older_than_hours: int = 168,
db: AsyncSession = Depends(get_db)
):
"""Clear cache entries older than threshold"""
deleted = await clear_old_cache(db, older_than_hours)
return {"deleted": deleted}
# ============== MEMORY ==============
@app.get("/memory", response_model=list[MemorySchema])
async def list_memory(
project: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Memory)
if project:
query = query.where(Memory.project == project)
result = await db.execute(query.order_by(Memory.key))
return result.scalars().all()
@app.get("/memory/{memory_id}", response_model=MemorySchema)
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
return memory
@app.post("/memory", response_model=MemorySchema)
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
db_memory = Memory(**memory.model_dump())
db.add(db_memory)
try:
await db.commit()
await db.refresh(db_memory)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Memory with this ID already exists")
return db_memory
@app.put("/memory/{memory_id}", response_model=MemorySchema)
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
db_memory = result.scalar_one_or_none()
if not db_memory:
raise HTTPException(status_code=404, detail="Memory not found")
for key, value in memory.model_dump().items():
setattr(db_memory, key, value)
await db.commit()
await db.refresh(db_memory)
return db_memory
@app.delete("/memory/{memory_id}")
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
await db.delete(memory)
await db.commit()
return {"deleted": memory_id}
# ============== CONTEXT BUNDLE ==============
@app.get("/context", response_model=ContextBundle)
async def get_context(
project: Optional[str] = None,
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
db: AsyncSession = Depends(get_db)
):
"""Get context bundle - legacy endpoint, returns ALL matching items"""
skill_list = []
snippet_list = []
convention_list = []
memory_list = []
if skills:
skill_ids = [s.strip() for s in skills.split(",")]
result = await db.execute(select(Skill).where(Skill.id.in_(skill_ids)))
skill_list = result.scalars().all()
if project:
result = await db.execute(select(Convention).where(Convention.project_path == project))
convention_list = result.scalars().all()
result = await db.execute(select(Memory).where(Memory.project == project))
memory_list = result.scalars().all()
result = await db.execute(select(Snippet).where(Snippet.category == project.split("/")[-1]))
snippet_list = result.scalars().all()
return ContextBundle(
skills=skill_list,
snippets=snippet_list,
conventions=convention_list,
memories=memory_list
)
@app.get("/context/rag")
async def get_context_rag(
query: str,
project: Optional[str] = None,
max_skills: int = 3,
max_conventions: int = 2,
max_snippets: int = 2,
db: AsyncSession = Depends(get_db)
):
"""
RAG-based context selection - returns ONLY relevant items.
Uses semantic search to find top K most relevant skills/snippets.
"""
bundle = await build_context_bundle(
query, db, project,
max_skills=max_skills,
max_conventions=max_conventions,
max_snippets=max_snippets
)
return bundle
@app.post("/compress")
async def compress_messages(
messages: List[dict],
keep_last_n: int = 3,
max_tokens: int = 2000
):
"""
Compress conversation history.
Keeps last N exchanges in full, summarizes everything before.
"""
compressed = compress_conversation(messages, max_tokens, keep_last_n)
original_tokens = sum(count_tokens(m.get("content", "")) for m in messages)
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
return {
"messages": compressed,
"original_tokens": original_tokens,
"compressed_tokens": compressed_tokens,
"tokens_saved": original_tokens - compressed_tokens,
"reduction_percent": round((1 - compressed_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0
}
@app.get("/tokens/count")
async def count_tokens_endpoint(text: str):
"""Count tokens in text"""
return {"tokens": count_tokens(text)}
@app.get("/health")
async def health():
return {"status": "healthy"}