from fastapi import FastAPI, HTTPException, Depends, Query from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, func from sqlalchemy.exc import IntegrityError import hashlib import json import os from typing import Optional, List from database import get_db, init_db from models import Skill, Snippet, Convention, Cache, Memory from schemas import ( SkillBase, Skill, SnippetBase, Snippet, ConventionBase, Convention, CacheStore, Cache as CacheSchema, MemoryBase, Memory as MemorySchema, ContextBundle, CacheLookup ) from semantic_cache import ( semantic_cache_lookup, semantic_cache_store, get_cache_stats, clear_old_cache ) from rag import build_context_bundle from compression import compress_conversation, count_tokens app = FastAPI(title="AI Skills API", description="Local infrastructure for AI context management") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup(): await init_db() # ============== SKILLS ============== @app.get("/skills", response_model=list[Skill]) async def list_skills( category: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Skill) if category: query = query.where(Skill.category == category) result = await db.execute(query.order_by(Skill.name)) return result.scalars().all() @app.get("/skills/search") async def search_skills( q: str, category: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Skill).where( (Skill.name.ilike(f"%{q}%")) | (Skill.content.ilike(f"%{q}%")) | (Skill.tags.ilike(f"%{q}%")) ) if category: query = query.where(Skill.category == category) result = await db.execute(query) return result.scalars().all() @app.get("/skills/{skill_id}", response_model=Skill) async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() if not skill: raise HTTPException(status_code=404, detail="Skill not found") skill.usage_count += 1 await db.commit() return skill @app.post("/skills", response_model=Skill) async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)): db_skill = Skill(**skill.model_dump()) db.add(db_skill) try: await db.commit() await db.refresh(db_skill) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Skill with this ID already exists") return db_skill @app.put("/skills/{skill_id}", response_model=Skill) async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) db_skill = result.scalar_one_or_none() if not db_skill: raise HTTPException(status_code=404, detail="Skill not found") for key, value in skill.model_dump().items(): setattr(db_skill, key, value) await db.commit() await db.refresh(db_skill) return db_skill @app.delete("/skills/{skill_id}") async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() if not skill: raise HTTPException(status_code=404, detail="Skill not found") await db.delete(skill) await db.commit() return {"deleted": skill_id} # ============== SNIPPETS ============== @app.get("/snippets", response_model=list[Snippet]) async def list_snippets( category: Optional[str] = None, language: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Snippet) if category: query = query.where(Snippet.category == category) if language: query = query.where(Snippet.language == language) result = await db.execute(query.order_by(Snippet.name)) return result.scalars().all() @app.get("/snippets/{snippet_id}", response_model=Snippet) async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() if not snippet: raise HTTPException(status_code=404, detail="Snippet not found") return snippet @app.post("/snippets", response_model=Snippet) async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)): db_snippet = Snippet(**snippet.model_dump()) db.add(db_snippet) try: await db.commit() await db.refresh(db_snippet) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Snippet with this ID already exists") return db_snippet @app.delete("/snippets/{snippet_id}") async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() if not snippet: raise HTTPException(status_code=404, detail="Snippet not found") await db.delete(snippet) await db.commit() return {"deleted": snippet_id} # ============== CONVENTIONS ============== @app.get("/conventions", response_model=list[Convention]) async def list_conventions( project: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Convention) if project: query = query.where(Convention.project_path == project) result = await db.execute(query.order_by(Convention.name)) return result.scalars().all() @app.get("/conventions/{convention_id}", response_model=Convention) async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() if not convention: raise HTTPException(status_code=404, detail="Convention not found") return convention @app.post("/conventions", response_model=Convention) async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)): db_convention = Convention(**convention.model_dump()) db.add(db_convention) try: await db.commit() await db.refresh(db_convention) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Convention with this ID already exists") return db_convention @app.put("/conventions/{convention_id}", response_model=Convention) async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) db_convention = result.scalar_one_or_none() if not db_convention: raise HTTPException(status_code=404, detail="Convention not found") for key, value in convention.model_dump().items(): setattr(db_convention, key, value) await db.commit() await db.refresh(db_convention) return db_convention @app.delete("/conventions/{convention_id}") async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() if not convention: raise HTTPException(status_code=404, detail="Convention not found") await db.delete(convention) await db.commit() return {"deleted": convention_id} # ============== CACHE ============== @app.post("/cache/lookup", response_model=Optional[CacheSchema]) async def lookup_cache(lookup: CacheLookup, db: AsyncSession = Depends(get_db)): """Exact hash-based cache lookup""" prompt_hash = hashlib.sha256( json.dumps({"prompt": lookup.prompt, "model": lookup.model}, sort_keys=True).encode() ).hexdigest() result = await db.execute( select(Cache).where( (Cache.hash == prompt_hash) & ((Cache.expires_at == None) | (Cache.expires_at > func.now())) ) ) return result.scalar_one_or_none() @app.post("/cache/semantic-lookup", response_model=dict) async def semantic_lookup( prompt: str, model: Optional[str] = None, min_similarity: float = 0.85, db: AsyncSession = Depends(get_db) ): """Semantic cache lookup - finds similar prompts""" result = await semantic_cache_lookup( prompt, db, model=model, min_similarity=min_similarity ) if result: return {"hit": True, **result} return {"hit": False} @app.post("/cache/store", response_model=CacheSchema) async def store_cache(cache: CacheStore, db: AsyncSession = Depends(get_db)): """Store in exact-match cache""" prompt_hash = hashlib.sha256( json.dumps({"prompt": cache.response, "model": cache.model}, sort_keys=True).encode() ).hexdigest() db_cache = Cache( hash=prompt_hash, response=cache.response, model=cache.model, tokens_in=cache.tokens_in, tokens_out=cache.tokens_out, expires_at=cache.expires_at ) db.add(db_cache) await db.commit() await db.refresh(db_cache) return db_cache @app.post("/cache/semantic-store", response_model=dict) async def semantic_store( prompt: str, response: str, model: Optional[str] = None, tokens_in: Optional[int] = None, tokens_out: Optional[int] = None, db: AsyncSession = Depends(get_db) ): """Store in semantic cache""" return await semantic_cache_store( prompt, response, db, model, tokens_in, tokens_out ) @app.delete("/cache/{cache_hash}") async def delete_cache(cache_hash: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Cache).where(Cache.hash == cache_hash)) cache = result.scalar_one_or_none() if not cache: raise HTTPException(status_code=404, detail="Cache entry not found") await db.delete(cache) await db.commit() return {"deleted": cache_hash} @app.get("/cache/stats") async def cache_stats_endpoint(db: AsyncSession = Depends(get_db)): """Get cache statistics""" return await get_cache_stats(db) @app.post("/cache/clear-old") async def clear_old( older_than_hours: int = 168, db: AsyncSession = Depends(get_db) ): """Clear cache entries older than threshold""" deleted = await clear_old_cache(db, older_than_hours) return {"deleted": deleted} # ============== MEMORY ============== @app.get("/memory", response_model=list[MemorySchema]) async def list_memory( project: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Memory) if project: query = query.where(Memory.project == project) result = await db.execute(query.order_by(Memory.key)) return result.scalars().all() @app.get("/memory/{memory_id}", response_model=MemorySchema) async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.post("/memory", response_model=MemorySchema) async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)): db_memory = Memory(**memory.model_dump()) db.add(db_memory) try: await db.commit() await db.refresh(db_memory) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Memory with this ID already exists") return db_memory @app.put("/memory/{memory_id}", response_model=MemorySchema) async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) db_memory = result.scalar_one_or_none() if not db_memory: raise HTTPException(status_code=404, detail="Memory not found") for key, value in memory.model_dump().items(): setattr(db_memory, key, value) await db.commit() await db.refresh(db_memory) return db_memory @app.delete("/memory/{memory_id}") async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() if not memory: raise HTTPException(status_code=404, detail="Memory not found") await db.delete(memory) await db.commit() return {"deleted": memory_id} # ============== CONTEXT BUNDLE ============== @app.get("/context", response_model=ContextBundle) async def get_context( project: Optional[str] = None, skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"), db: AsyncSession = Depends(get_db) ): """Get context bundle - legacy endpoint, returns ALL matching items""" skill_list = [] snippet_list = [] convention_list = [] memory_list = [] if skills: skill_ids = [s.strip() for s in skills.split(",")] result = await db.execute(select(Skill).where(Skill.id.in_(skill_ids))) skill_list = result.scalars().all() if project: result = await db.execute(select(Convention).where(Convention.project_path == project)) convention_list = result.scalars().all() result = await db.execute(select(Memory).where(Memory.project == project)) memory_list = result.scalars().all() result = await db.execute(select(Snippet).where(Snippet.category == project.split("/")[-1])) snippet_list = result.scalars().all() return ContextBundle( skills=skill_list, snippets=snippet_list, conventions=convention_list, memories=memory_list ) @app.get("/context/rag") async def get_context_rag( query: str, project: Optional[str] = None, max_skills: int = 3, max_conventions: int = 2, max_snippets: int = 2, db: AsyncSession = Depends(get_db) ): """ RAG-based context selection - returns ONLY relevant items. Uses semantic search to find top K most relevant skills/snippets. """ bundle = await build_context_bundle( query, db, project, max_skills=max_skills, max_conventions=max_conventions, max_snippets=max_snippets ) return bundle @app.post("/compress") async def compress_messages( messages: List[dict], keep_last_n: int = 3, max_tokens: int = 2000 ): """ Compress conversation history. Keeps last N exchanges in full, summarizes everything before. """ compressed = compress_conversation(messages, max_tokens, keep_last_n) original_tokens = sum(count_tokens(m.get("content", "")) for m in messages) compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) return { "messages": compressed, "original_tokens": original_tokens, "compressed_tokens": compressed_tokens, "tokens_saved": original_tokens - compressed_tokens, "reduction_percent": round((1 - compressed_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0 } @app.get("/tokens/count") async def count_tokens_endpoint(text: str): """Count tokens in text""" return {"tokens": count_tokens(text)} @app.get("/health") async def health(): return {"status": "healthy"}