ai-skills-api/main.py

498 lines
17 KiB
Python

import logging
from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from typing import Optional, List
import sys
from config import load_config, Config
from database import init_db as default_init_db
from models import Skill, Snippet, Convention, Memory, Base
from schemas import (
SkillBase,
Skill as SkillSchema,
SnippetBase,
Snippet as SnippetSchema,
ConventionBase,
Convention as ConventionSchema,
MemoryBase,
Memory as MemorySchema,
ContextBundle
)
from rag import build_context_bundle, clear_cache as clear_rag_cache
from compression import compress_conversation, count_tokens
# Load configuration
config = load_config()
# Setup logging
logging.basicConfig(
level=getattr(logging, config.logging.level.upper()),
format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("ai-skills-api")
# API Key authentication dependency
async def verify_api_key(
request: Request,
x_api_key: Optional[str] = Header(None, alias=config.auth.header_name)
) -> None:
"""Verify API key if auth is enabled"""
if config.auth.enabled:
if not x_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if x_api_key != config.auth.api_key:
raise HTTPException(status_code=403, detail="Invalid API key")
# Database setup based on config
if config.database_url != "sqlite+aiosqlite:///./ai.db":
# Custom DB URL - create new engine
engine = create_async_engine(config.database_url, echo=False)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
yield session
async def init_db():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
else:
# Use default database setup from database.py
from database import get_db, init_db
app = FastAPI(
title="AI Skills API",
description="Local infrastructure for AI context management"
)
# CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=config.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup():
await init_db()
logger.info("Application startup complete")
logger.info(f"RAG cache will initialize on first request")
if config.auth.enabled:
logger.info("API key authentication enabled")
# ============== SKILLS ==============
@app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)])
async def list_skills(
category: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Skill)
if category:
query = query.where(Skill.category == category)
result = await db.execute(query.order_by(Skill.name))
return result.scalars().all()
@app.get("/skills/search", dependencies=[Depends(verify_api_key)])
async def search_skills(
q: str,
category: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Skill).where(
(Skill.name.ilike(f"%{q}%")) |
(Skill.content.ilike(f"%{q}%")) |
(Skill.tags.astext.ilike(f"%{q}%"))
)
if category:
query = query.where(Skill.category == category)
result = await db.execute(query)
return result.scalars().all()
@app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none()
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
skill.usage_count += 1
await db.commit()
return skill
@app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
db_skill = Skill(**skill.model_dump())
db.add(db_skill)
try:
await db.commit()
await db.refresh(db_skill)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Skill with this ID already exists")
return db_skill
@app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
db_skill = result.scalar_one_or_none()
if not db_skill:
raise HTTPException(status_code=404, detail="Skill not found")
for key, value in skill.model_dump().items():
setattr(db_skill, key, value)
await db.commit()
await db.refresh(db_skill)
return db_skill
@app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)])
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Skill).where(Skill.id == skill_id))
skill = result.scalar_one_or_none()
if not skill:
raise HTTPException(status_code=404, detail="Skill not found")
await db.delete(skill)
await db.commit()
return {"deleted": skill_id}
# ============== SNIPPETS ==============
@app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)])
async def list_snippets(
category: Optional[str] = None,
language: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Snippet)
if category:
query = query.where(Snippet.category == category)
if language:
query = query.where(Snippet.language == language)
result = await db.execute(query.order_by(Snippet.name))
return result.scalars().all()
@app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none()
if not snippet:
raise HTTPException(status_code=404, detail="Snippet not found")
return snippet
@app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
db_snippet = Snippet(**snippet.model_dump())
db.add(db_snippet)
try:
await db.commit()
await db.refresh(db_snippet)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Snippet with this ID already exists")
return db_snippet
@app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)])
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
snippet = result.scalar_one_or_none()
if not snippet:
raise HTTPException(status_code=404, detail="Snippet not found")
await db.delete(snippet)
await db.commit()
return {"deleted": snippet_id}
# ============== CONVENTIONS ==============
@app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)])
async def list_conventions(
project: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Convention)
if project:
query = query.where(Convention.project_path == project)
result = await db.execute(query.order_by(Convention.name))
return result.scalars().all()
@app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none()
if not convention:
raise HTTPException(status_code=404, detail="Convention not found")
return convention
@app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
db_convention = Convention(**convention.model_dump())
db.add(db_convention)
try:
await db.commit()
await db.refresh(db_convention)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Convention with this ID already exists")
return db_convention
@app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
db_convention = result.scalar_one_or_none()
if not db_convention:
raise HTTPException(status_code=404, detail="Convention not found")
for key, value in convention.model_dump().items():
setattr(db_convention, key, value)
await db.commit()
await db.refresh(db_convention)
return db_convention
@app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)])
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Convention).where(Convention.id == convention_id))
convention = result.scalar_one_or_none()
if not convention:
raise HTTPException(status_code=404, detail="Convention not found")
await db.delete(convention)
await db.commit()
return {"deleted": convention_id}
# ============== MEMORY ==============
@app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)])
async def list_memory(
project: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
query = select(Memory)
if project:
query = query.where(Memory.project == project)
result = await db.execute(query.order_by(Memory.key))
return result.scalars().all()
@app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
return memory
@app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
db_memory = Memory(**memory.model_dump())
db.add(db_memory)
try:
await db.commit()
await db.refresh(db_memory)
except IntegrityError:
await db.rollback()
raise HTTPException(status_code=400, detail="Memory with this ID already exists")
return db_memory
@app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
db_memory = result.scalar_one_or_none()
if not db_memory:
raise HTTPException(status_code=404, detail="Memory not found")
for key, value in memory.model_dump().items():
setattr(db_memory, key, value)
await db.commit()
await db.refresh(db_memory)
return db_memory
@app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)])
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(Memory).where(Memory.id == memory_id))
memory = result.scalar_one_or_none()
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
await db.delete(memory)
await db.commit()
return {"deleted": memory_id}
# ============== CONTEXT BUNDLE ==============
@app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)])
async def get_context(
project: Optional[str] = None,
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
db: AsyncSession = Depends(get_db)
):
"""Get context bundle - legacy endpoint, returns ALL matching items"""
skill_list = []
snippet_list = []
convention_list = []
memory_list = []
if skills:
skill_ids = [s.strip() for s in skills.split(",")]
result = await db.execute(select(Skill).where(Skill.id.in_(skill_ids)))
skill_list = result.scalars().all()
if project:
result = await db.execute(select(Convention).where(Convention.project_path == project))
convention_list = result.scalars().all()
result = await db.execute(select(Memory).where(Memory.project == project))
memory_list = result.scalars().all()
result = await db.execute(select(Snippet).where(Snippet.category == project.split("/")[-1]))
snippet_list = result.scalars().all()
return ContextBundle(
skills=skill_list,
snippets=snippet_list,
conventions=convention_list,
memories=memory_list
)
@app.get("/context/rag", dependencies=[Depends(verify_api_key)])
async def get_context_rag(
query: str,
project: Optional[str] = None,
max_skills: Optional[int] = None,
max_conventions: Optional[int] = None,
max_snippets: Optional[int] = None,
db: AsyncSession = Depends(get_db)
):
"""
RAG-based context selection - returns ONLY relevant items.
Uses semantic search to find top K most relevant skills/snippets.
Uses config defaults if parameters not provided.
"""
# Use config defaults if not specified
if max_skills is None:
max_skills = config.rag.max_skills
if max_conventions is None:
max_conventions = config.rag.max_conventions
if max_snippets is None:
max_snippets = config.rag.max_snippets
bundle = await build_context_bundle(
query, db, project,
max_skills=max_skills,
max_conventions=max_conventions,
max_snippets=max_snippets
)
return bundle
@app.post("/compress", dependencies=[Depends(verify_api_key)])
async def compress_messages_endpoint(
request: CompressionRequest,
keep_last_n: Optional[int] = None,
max_tokens: Optional[int] = None
):
"""
Compress conversation history.
Uses config strategy (extractive or ollama).
"""
# Use config defaults if not specified
keep_last_n = keep_last_n or config.compression.keep_last_n
max_tokens = max_tokens or config.compression.max_tokens
compressed = await compress_conversation(
request.messages,
max_tokens=max_tokens or config.compression.max_tokens,
keep_last_n=keep_last_n or config.compression.keep_last_n,
strategy=config.compression.strategy,
ollama_model=config.compression.ollama_model,
ollama_url=config.compression.ollama_url
)
original_tokens = sum(count_tokens(m.get("content", "")) for m in messages)
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
return {
"messages": compressed,
"original_tokens": original_tokens,
"compressed_tokens": compressed_tokens,
"tokens_saved": original_tokens - compressed_tokens,
"reduction_percent": round((1 - compressed_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0
}
@app.get("/tokens/count", dependencies=[Depends(verify_api_key)])
async def count_tokens_endpoint(text: str):
"""Count tokens in text"""
return {"tokens": count_tokens(text)}
@app.get("/health")
async def health():
return {"status": "healthy"}
@app.get("/config")
async def get_config():
"""Return current configuration (non-sensitive)"""
return {
"rag": {
"max_skills": config.rag.max_skills,
"max_conventions": config.rag.max_conventions,
"max_snippets": config.rag.max_snippets,
"min_skill_score": config.rag.min_skill_score,
"min_snippet_score": config.rag.min_snippet_score,
"embedding_model": config.rag.embedding_model
},
"compression": {
"enabled": config.compression.enabled,
"strategy": config.compression.strategy,
"keep_last_n": config.compression.keep_last_n,
"max_tokens": config.compression.max_tokens,
"ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None
},
"auth": {
"enabled": config.auth.enabled
}
}
@app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)])
async def clear_cache():
"""Clear RAG embedding cache (forces reload on next request)"""
clear_rag_cache()
return {"status": "cache cleared"}