import logging from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header from fastapi.middleware.cors import CORSMiddleware from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy import select from sqlalchemy.exc import IntegrityError from typing import Optional, List import sys from config import load_config, Config from database import init_db as default_init_db from models import Skill, Snippet, Convention, Memory, Base from schemas import ( SkillBase, Skill as SkillSchema, SnippetBase, Snippet as SnippetSchema, ConventionBase, Convention as ConventionSchema, MemoryBase, Memory as MemorySchema, ContextBundle ) from rag import build_context_bundle, clear_cache as clear_rag_cache from compression import compress_conversation, count_tokens # Load configuration config = load_config() # Setup logging logging.basicConfig( level=getattr(logging, config.logging.level.upper()), format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("ai-skills-api") # API Key authentication dependency async def verify_api_key( request: Request, x_api_key: Optional[str] = Header(None, alias=config.auth.header_name) ) -> None: """Verify API key if auth is enabled""" if config.auth.enabled: if not x_api_key: raise HTTPException(status_code=401, detail="Missing API key") if x_api_key != config.auth.api_key: raise HTTPException(status_code=403, detail="Invalid API key") # Database setup based on config if config.database_url != "sqlite+aiosqlite:///./ai.db": # Custom DB URL - create new engine engine = create_async_engine(config.database_url, echo=False) AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) async def get_db() -> AsyncSession: async with AsyncSessionLocal() as session: yield session async def init_db(): async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) else: # Use default database setup from database.py from database import get_db, init_db app = FastAPI( title="AI Skills API", description="Local infrastructure for AI context management" ) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=config.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup(): await init_db() logger.info("Application startup complete") logger.info(f"RAG cache will initialize on first request") if config.auth.enabled: logger.info("API key authentication enabled") # ============== SKILLS ============== @app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)]) async def list_skills( category: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Skill) if category: query = query.where(Skill.category == category) result = await db.execute(query.order_by(Skill.name)) return result.scalars().all() @app.get("/skills/search", dependencies=[Depends(verify_api_key)]) async def search_skills( q: str, category: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Skill).where( (Skill.name.ilike(f"%{q}%")) | (Skill.content.ilike(f"%{q}%")) | (Skill.tags.astext.ilike(f"%{q}%")) ) if category: query = query.where(Skill.category == category) result = await db.execute(query) return result.scalars().all() @app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() if not skill: raise HTTPException(status_code=404, detail="Skill not found") skill.usage_count += 1 await db.commit() return skill @app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)): db_skill = Skill(**skill.model_dump()) db.add(db_skill) try: await db.commit() await db.refresh(db_skill) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Skill with this ID already exists") return db_skill @app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) db_skill = result.scalar_one_or_none() if not db_skill: raise HTTPException(status_code=404, detail="Skill not found") for key, value in skill.model_dump().items(): setattr(db_skill, key, value) await db.commit() await db.refresh(db_skill) return db_skill @app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)]) async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() if not skill: raise HTTPException(status_code=404, detail="Skill not found") await db.delete(skill) await db.commit() return {"deleted": skill_id} # ============== SNIPPETS ============== @app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)]) async def list_snippets( category: Optional[str] = None, language: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Snippet) if category: query = query.where(Snippet.category == category) if language: query = query.where(Snippet.language == language) result = await db.execute(query.order_by(Snippet.name)) return result.scalars().all() @app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)]) async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() if not snippet: raise HTTPException(status_code=404, detail="Snippet not found") return snippet @app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)]) async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)): db_snippet = Snippet(**snippet.model_dump()) db.add(db_snippet) try: await db.commit() await db.refresh(db_snippet) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Snippet with this ID already exists") return db_snippet @app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)]) async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() if not snippet: raise HTTPException(status_code=404, detail="Snippet not found") await db.delete(snippet) await db.commit() return {"deleted": snippet_id} # ============== CONVENTIONS ============== @app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)]) async def list_conventions( project: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Convention) if project: query = query.where(Convention.project_path == project) result = await db.execute(query.order_by(Convention.name)) return result.scalars().all() @app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() if not convention: raise HTTPException(status_code=404, detail="Convention not found") return convention @app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)): db_convention = Convention(**convention.model_dump()) db.add(db_convention) try: await db.commit() await db.refresh(db_convention) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Convention with this ID already exists") return db_convention @app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) db_convention = result.scalar_one_or_none() if not db_convention: raise HTTPException(status_code=404, detail="Convention not found") for key, value in convention.model_dump().items(): setattr(db_convention, key, value) await db.commit() await db.refresh(db_convention) return db_convention @app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)]) async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() if not convention: raise HTTPException(status_code=404, detail="Convention not found") await db.delete(convention) await db.commit() return {"deleted": convention_id} # ============== MEMORY ============== @app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)]) async def list_memory( project: Optional[str] = None, db: AsyncSession = Depends(get_db) ): query = select(Memory) if project: query = query.where(Memory.project == project) result = await db.execute(query.order_by(Memory.key)) return result.scalars().all() @app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)): db_memory = Memory(**memory.model_dump()) db.add(db_memory) try: await db.commit() await db.refresh(db_memory) except IntegrityError: await db.rollback() raise HTTPException(status_code=400, detail="Memory with this ID already exists") return db_memory @app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) db_memory = result.scalar_one_or_none() if not db_memory: raise HTTPException(status_code=404, detail="Memory not found") for key, value in memory.model_dump().items(): setattr(db_memory, key, value) await db.commit() await db.refresh(db_memory) return db_memory @app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)]) async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() if not memory: raise HTTPException(status_code=404, detail="Memory not found") await db.delete(memory) await db.commit() return {"deleted": memory_id} # ============== CONTEXT BUNDLE ============== @app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)]) async def get_context( project: Optional[str] = None, skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"), db: AsyncSession = Depends(get_db) ): """Get context bundle - legacy endpoint, returns ALL matching items""" skill_list = [] snippet_list = [] convention_list = [] memory_list = [] if skills: skill_ids = [s.strip() for s in skills.split(",")] result = await db.execute(select(Skill).where(Skill.id.in_(skill_ids))) skill_list = result.scalars().all() if project: result = await db.execute(select(Convention).where(Convention.project_path == project)) convention_list = result.scalars().all() result = await db.execute(select(Memory).where(Memory.project == project)) memory_list = result.scalars().all() result = await db.execute(select(Snippet).where(Snippet.category == project.split("/")[-1])) snippet_list = result.scalars().all() return ContextBundle( skills=skill_list, snippets=snippet_list, conventions=convention_list, memories=memory_list ) @app.get("/context/rag", dependencies=[Depends(verify_api_key)]) async def get_context_rag( query: str, project: Optional[str] = None, max_skills: Optional[int] = None, max_conventions: Optional[int] = None, max_snippets: Optional[int] = None, db: AsyncSession = Depends(get_db) ): """ RAG-based context selection - returns ONLY relevant items. Uses semantic search to find top K most relevant skills/snippets. Uses config defaults if parameters not provided. """ # Use config defaults if not specified if max_skills is None: max_skills = config.rag.max_skills if max_conventions is None: max_conventions = config.rag.max_conventions if max_snippets is None: max_snippets = config.rag.max_snippets bundle = await build_context_bundle( query, db, project, max_skills=max_skills, max_conventions=max_conventions, max_snippets=max_snippets ) return bundle @app.post("/compress", dependencies=[Depends(verify_api_key)]) async def compress_messages_endpoint( request: CompressionRequest, keep_last_n: Optional[int] = None, max_tokens: Optional[int] = None ): """ Compress conversation history. Uses config strategy (extractive or ollama). """ # Use config defaults if not specified keep_last_n = keep_last_n or config.compression.keep_last_n max_tokens = max_tokens or config.compression.max_tokens compressed = await compress_conversation( request.messages, max_tokens=max_tokens or config.compression.max_tokens, keep_last_n=keep_last_n or config.compression.keep_last_n, strategy=config.compression.strategy, ollama_model=config.compression.ollama_model, ollama_url=config.compression.ollama_url ) original_tokens = sum(count_tokens(m.get("content", "")) for m in messages) compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) return { "messages": compressed, "original_tokens": original_tokens, "compressed_tokens": compressed_tokens, "tokens_saved": original_tokens - compressed_tokens, "reduction_percent": round((1 - compressed_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0 } @app.get("/tokens/count", dependencies=[Depends(verify_api_key)]) async def count_tokens_endpoint(text: str): """Count tokens in text""" return {"tokens": count_tokens(text)} @app.get("/health") async def health(): return {"status": "healthy"} @app.get("/config") async def get_config(): """Return current configuration (non-sensitive)""" return { "rag": { "max_skills": config.rag.max_skills, "max_conventions": config.rag.max_conventions, "max_snippets": config.rag.max_snippets, "min_skill_score": config.rag.min_skill_score, "min_snippet_score": config.rag.min_snippet_score, "embedding_model": config.rag.embedding_model }, "compression": { "enabled": config.compression.enabled, "strategy": config.compression.strategy, "keep_last_n": config.compression.keep_last_n, "max_tokens": config.compression.max_tokens, "ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None }, "auth": { "enabled": config.auth.enabled } } @app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)]) async def clear_cache(): """Clear RAG embedding cache (forces reload on next request)""" clear_rag_cache() return {"status": "cache cleared"}