499 lines
17 KiB
Python
499 lines
17 KiB
Python
import logging
|
|
from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from typing import Optional, List
|
|
import sys
|
|
|
|
from config import load_config, Config
|
|
from database import init_db as default_init_db
|
|
from models import Skill, Snippet, Convention, Memory, Base
|
|
from schemas import (
|
|
SkillBase,
|
|
Skill as SkillSchema,
|
|
SnippetBase,
|
|
Snippet as SnippetSchema,
|
|
ConventionBase,
|
|
Convention as ConventionSchema,
|
|
MemoryBase,
|
|
Memory as MemorySchema,
|
|
ContextBundle,
|
|
CompressionRequest
|
|
)
|
|
from rag import build_context_bundle, clear_cache as clear_rag_cache
|
|
from compression import compress_conversation, count_tokens
|
|
|
|
# Load configuration
|
|
config = load_config()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, config.logging.level.upper()),
|
|
format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[logging.StreamHandler(sys.stdout)]
|
|
)
|
|
logger = logging.getLogger("ai-skills-api")
|
|
|
|
|
|
# API Key authentication dependency
|
|
async def verify_api_key(
|
|
request: Request,
|
|
x_api_key: Optional[str] = Header(None, alias=config.auth.header_name)
|
|
) -> None:
|
|
"""Verify API key if auth is enabled"""
|
|
if config.auth.enabled:
|
|
if not x_api_key:
|
|
raise HTTPException(status_code=401, detail="Missing API key")
|
|
if x_api_key != config.auth.api_key:
|
|
raise HTTPException(status_code=403, detail="Invalid API key")
|
|
|
|
|
|
# Database setup based on config
|
|
if config.database_url != "sqlite+aiosqlite:///./ai.db":
|
|
# Custom DB URL - create new engine
|
|
engine = create_async_engine(config.database_url, echo=False)
|
|
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
async def get_db() -> AsyncSession:
|
|
async with AsyncSessionLocal() as session:
|
|
yield session
|
|
|
|
async def init_db():
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
else:
|
|
# Use default database setup from database.py
|
|
from database import get_db, init_db
|
|
|
|
|
|
app = FastAPI(
|
|
title="AI Skills API",
|
|
description="Local infrastructure for AI context management"
|
|
)
|
|
|
|
# CORS configuration
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=config.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
await init_db()
|
|
logger.info("Application startup complete")
|
|
logger.info(f"RAG cache will initialize on first request")
|
|
if config.auth.enabled:
|
|
logger.info("API key authentication enabled")
|
|
|
|
|
|
# ============== SKILLS ==============
|
|
|
|
@app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)])
|
|
async def list_skills(
|
|
category: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(Skill)
|
|
if category:
|
|
query = query.where(Skill.category == category)
|
|
result = await db.execute(query.order_by(Skill.name))
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.get("/skills/search", dependencies=[Depends(verify_api_key)])
|
|
async def search_skills(
|
|
q: str,
|
|
category: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(Skill).where(
|
|
(Skill.name.ilike(f"%{q}%")) |
|
|
(Skill.content.ilike(f"%{q}%")) |
|
|
(Skill.tags.astext.ilike(f"%{q}%"))
|
|
)
|
|
if category:
|
|
query = query.where(Skill.category == category)
|
|
result = await db.execute(query)
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
|
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
|
skill = result.scalar_one_or_none()
|
|
if not skill:
|
|
raise HTTPException(status_code=404, detail="Skill not found")
|
|
|
|
skill.usage_count += 1
|
|
await db.commit()
|
|
return skill
|
|
|
|
|
|
@app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
|
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
|
db_skill = Skill(**skill.model_dump())
|
|
db.add(db_skill)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(db_skill)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=400, detail="Skill with this ID already exists")
|
|
return db_skill
|
|
|
|
|
|
@app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
|
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
|
db_skill = result.scalar_one_or_none()
|
|
if not db_skill:
|
|
raise HTTPException(status_code=404, detail="Skill not found")
|
|
|
|
for key, value in skill.model_dump().items():
|
|
setattr(db_skill, key, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(db_skill)
|
|
return db_skill
|
|
|
|
|
|
@app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)])
|
|
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
|
skill = result.scalar_one_or_none()
|
|
if not skill:
|
|
raise HTTPException(status_code=404, detail="Skill not found")
|
|
|
|
await db.delete(skill)
|
|
await db.commit()
|
|
return {"deleted": skill_id}
|
|
|
|
|
|
# ============== SNIPPETS ==============
|
|
|
|
@app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)])
|
|
async def list_snippets(
|
|
category: Optional[str] = None,
|
|
language: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(Snippet)
|
|
if category:
|
|
query = query.where(Snippet.category == category)
|
|
if language:
|
|
query = query.where(Snippet.language == language)
|
|
result = await db.execute(query.order_by(Snippet.name))
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
|
|
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
|
snippet = result.scalar_one_or_none()
|
|
if not snippet:
|
|
raise HTTPException(status_code=404, detail="Snippet not found")
|
|
return snippet
|
|
|
|
|
|
@app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
|
|
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
|
|
db_snippet = Snippet(**snippet.model_dump())
|
|
db.add(db_snippet)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(db_snippet)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=400, detail="Snippet with this ID already exists")
|
|
return db_snippet
|
|
|
|
|
|
@app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)])
|
|
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
|
snippet = result.scalar_one_or_none()
|
|
if not snippet:
|
|
raise HTTPException(status_code=404, detail="Snippet not found")
|
|
|
|
await db.delete(snippet)
|
|
await db.commit()
|
|
return {"deleted": snippet_id}
|
|
|
|
|
|
# ============== CONVENTIONS ==============
|
|
|
|
@app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)])
|
|
async def list_conventions(
|
|
project: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(Convention)
|
|
if project:
|
|
query = query.where(Convention.project_path == project)
|
|
result = await db.execute(query.order_by(Convention.name))
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
|
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
|
convention = result.scalar_one_or_none()
|
|
if not convention:
|
|
raise HTTPException(status_code=404, detail="Convention not found")
|
|
return convention
|
|
|
|
|
|
@app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
|
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
|
db_convention = Convention(**convention.model_dump())
|
|
db.add(db_convention)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(db_convention)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=400, detail="Convention with this ID already exists")
|
|
return db_convention
|
|
|
|
|
|
@app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
|
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
|
db_convention = result.scalar_one_or_none()
|
|
if not db_convention:
|
|
raise HTTPException(status_code=404, detail="Convention not found")
|
|
|
|
for key, value in convention.model_dump().items():
|
|
setattr(db_convention, key, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(db_convention)
|
|
return db_convention
|
|
|
|
|
|
@app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)])
|
|
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
|
convention = result.scalar_one_or_none()
|
|
if not convention:
|
|
raise HTTPException(status_code=404, detail="Convention not found")
|
|
|
|
await db.delete(convention)
|
|
await db.commit()
|
|
return {"deleted": convention_id}
|
|
|
|
|
|
# ============== MEMORY ==============
|
|
|
|
@app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)])
|
|
async def list_memory(
|
|
project: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
query = select(Memory)
|
|
if project:
|
|
query = query.where(Memory.project == project)
|
|
result = await db.execute(query.order_by(Memory.key))
|
|
return result.scalars().all()
|
|
|
|
|
|
@app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
|
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
|
memory = result.scalar_one_or_none()
|
|
if not memory:
|
|
raise HTTPException(status_code=404, detail="Memory not found")
|
|
return memory
|
|
|
|
|
|
@app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
|
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
|
db_memory = Memory(**memory.model_dump())
|
|
db.add(db_memory)
|
|
try:
|
|
await db.commit()
|
|
await db.refresh(db_memory)
|
|
except IntegrityError:
|
|
await db.rollback()
|
|
raise HTTPException(status_code=400, detail="Memory with this ID already exists")
|
|
return db_memory
|
|
|
|
|
|
@app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
|
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
|
db_memory = result.scalar_one_or_none()
|
|
if not db_memory:
|
|
raise HTTPException(status_code=404, detail="Memory not found")
|
|
|
|
for key, value in memory.model_dump().items():
|
|
setattr(db_memory, key, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(db_memory)
|
|
return db_memory
|
|
|
|
|
|
@app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)])
|
|
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
|
memory = result.scalar_one_or_none()
|
|
if not memory:
|
|
raise HTTPException(status_code=404, detail="Memory not found")
|
|
|
|
await db.delete(memory)
|
|
await db.commit()
|
|
return {"deleted": memory_id}
|
|
|
|
|
|
# ============== CONTEXT BUNDLE ==============
|
|
|
|
@app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)])
|
|
async def get_context(
|
|
project: Optional[str] = None,
|
|
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""Get context bundle - legacy endpoint, returns ALL matching items"""
|
|
skill_list = []
|
|
snippet_list = []
|
|
convention_list = []
|
|
memory_list = []
|
|
|
|
if skills:
|
|
skill_ids = [s.strip() for s in skills.split(",")]
|
|
result = await db.execute(select(Skill).where(Skill.id.in_(skill_ids)))
|
|
skill_list = result.scalars().all()
|
|
|
|
if project:
|
|
result = await db.execute(select(Convention).where(Convention.project_path == project))
|
|
convention_list = result.scalars().all()
|
|
|
|
result = await db.execute(select(Memory).where(Memory.project == project))
|
|
memory_list = result.scalars().all()
|
|
|
|
result = await db.execute(select(Snippet).where(Snippet.category == project.split("/")[-1]))
|
|
snippet_list = result.scalars().all()
|
|
|
|
return ContextBundle(
|
|
skills=skill_list,
|
|
snippets=snippet_list,
|
|
conventions=convention_list,
|
|
memories=memory_list
|
|
)
|
|
|
|
|
|
@app.get("/context/rag", dependencies=[Depends(verify_api_key)])
|
|
async def get_context_rag(
|
|
query: str,
|
|
project: Optional[str] = None,
|
|
max_skills: Optional[int] = None,
|
|
max_conventions: Optional[int] = None,
|
|
max_snippets: Optional[int] = None,
|
|
db: AsyncSession = Depends(get_db)
|
|
):
|
|
"""
|
|
RAG-based context selection - returns ONLY relevant items.
|
|
Uses semantic search to find top K most relevant skills/snippets.
|
|
|
|
Uses config defaults if parameters not provided.
|
|
"""
|
|
# Use config defaults if not specified
|
|
if max_skills is None:
|
|
max_skills = config.rag.max_skills
|
|
if max_conventions is None:
|
|
max_conventions = config.rag.max_conventions
|
|
if max_snippets is None:
|
|
max_snippets = config.rag.max_snippets
|
|
|
|
bundle = await build_context_bundle(
|
|
query, db, project,
|
|
max_skills=max_skills,
|
|
max_conventions=max_conventions,
|
|
max_snippets=max_snippets
|
|
)
|
|
return bundle
|
|
|
|
|
|
@app.post("/compress", dependencies=[Depends(verify_api_key)])
|
|
async def compress_messages_endpoint(
|
|
request: CompressionRequest,
|
|
keep_last_n: Optional[int] = None,
|
|
max_tokens: Optional[int] = None
|
|
):
|
|
"""
|
|
Compress conversation history.
|
|
Uses config strategy (extractive or ollama).
|
|
"""
|
|
# Use config defaults if not specified
|
|
keep_last_n = keep_last_n or config.compression.keep_last_n
|
|
max_tokens = max_tokens or config.compression.max_tokens
|
|
|
|
compressed = await compress_conversation(
|
|
request.messages,
|
|
max_tokens=max_tokens or config.compression.max_tokens,
|
|
keep_last_n=keep_last_n or config.compression.keep_last_n,
|
|
strategy=config.compression.strategy,
|
|
ollama_model=config.compression.ollama_model,
|
|
ollama_url=config.compression.ollama_url
|
|
)
|
|
original_tokens = sum(count_tokens(m.get("content", "")) for m in request.messages)
|
|
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
|
|
|
return {
|
|
"messages": compressed,
|
|
"original_tokens": original_tokens,
|
|
"compressed_tokens": compressed_tokens,
|
|
"tokens_saved": original_tokens - compressed_tokens,
|
|
"reduction_percent": round((1 - compressed_tokens / original_tokens) * 100, 1) if original_tokens > 0 else 0
|
|
}
|
|
|
|
|
|
@app.get("/tokens/count", dependencies=[Depends(verify_api_key)])
|
|
async def count_tokens_endpoint(text: str):
|
|
"""Count tokens in text"""
|
|
return {"tokens": count_tokens(text)}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@app.get("/config")
|
|
async def get_config():
|
|
"""Return current configuration (non-sensitive)"""
|
|
return {
|
|
"rag": {
|
|
"max_skills": config.rag.max_skills,
|
|
"max_conventions": config.rag.max_conventions,
|
|
"max_snippets": config.rag.max_snippets,
|
|
"min_skill_score": config.rag.min_skill_score,
|
|
"min_snippet_score": config.rag.min_snippet_score,
|
|
"embedding_model": config.rag.embedding_model
|
|
},
|
|
"compression": {
|
|
"enabled": config.compression.enabled,
|
|
"strategy": config.compression.strategy,
|
|
"keep_last_n": config.compression.keep_last_n,
|
|
"max_tokens": config.compression.max_tokens,
|
|
"ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None
|
|
},
|
|
"auth": {
|
|
"enabled": config.auth.enabled
|
|
}
|
|
}
|
|
|
|
|
|
@app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)])
|
|
async def clear_cache():
|
|
"""Clear RAG embedding cache (forces reload on next request)"""
|
|
clear_rag_cache()
|
|
return {"status": "cache cleared"}
|
|
|