From b8edf40010afb8ac20737c0ba50b6717571f74d6 Mon Sep 17 00:00:00 2001 From: Lukas Parsons Date: Sun, 22 Mar 2026 22:32:44 -0400 Subject: [PATCH] Major refactor: remove semantic cache, add config, auth, improve RAG performance, fix tags JSON --- Dockerfile | 2 + compression.py | 163 ++++++++++++++--------- config.py | 144 +++++++++++++++++++++ config.yaml | 38 ++++++ docker-compose.yml | 1 + main.py | 312 +++++++++++++++++++++++---------------------- models.py | 17 +-- rag.py | 102 +++++++++++---- requirements.txt | 3 + schemas.py | 25 ---- semantic_cache.py | 169 ------------------------ 11 files changed, 533 insertions(+), 443 deletions(-) create mode 100644 config.py create mode 100644 config.yaml delete mode 100644 semantic_cache.py diff --git a/Dockerfile b/Dockerfile index a18d900..3109694 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,6 +6,8 @@ COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt COPY . . +RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app +USER appuser EXPOSE 8080 diff --git a/compression.py b/compression.py index a62f8f3..9311b71 100644 --- a/compression.py +++ b/compression.py @@ -1,10 +1,18 @@ """ -Prompt compression - summarizes conversation history to reduce tokens. -Uses a small local model (no API calls) to compress old turns. +Conversation compression - summarizes old turns to save tokens. +Supports multiple strategies: extractive summarization or Ollama LLM. """ from typing import List, Dict +import logging import tiktoken +import httpx +from sumy.parsers.plaintext import PlaintextParser +from sumy.nlp.tokenizers import Tokenizer +from sumy.summarizers.lsa import LsaSummarizer +import asyncio + +logger = logging.getLogger(__name__) ENCODING = tiktoken.get_encoding("cl100k_base") @@ -14,18 +22,85 @@ def count_tokens(text: str) -> int: return len(ENCODING.encode(text)) -def compress_conversation( +def truncate_tool_output(output: str, max_tokens: int = 200) -> str: + """Truncate tool outputs to save tokens""" + tokens = ENCODING.encode(output) + if len(tokens) <= max_tokens: + return output + + truncated = ENCODING.decode(tokens[:max_tokens]) + return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]" + + +def extractive_summarize(text: str, sentences_count: int = 3) -> str: + """ + Simple extractive summarization using LSA algorithm. + Picks the most important sentences from the text. + No external API calls, fast and deterministic. + """ + try: + parser = PlaintextParser.from_string(text, Tokenizer("english")) + summarizer = LsaSummarizer() + summary_sentences = summarizer(parser.document, sentences_count) + return " ".join(str(sentence) for sentence in summary_sentences) + except Exception as e: + # Fallback: truncate to first few sentences + sentences = text.split('. ')[:3] + return '. '.join(sentences) + '.' + + +async def ollama_summarize(text: str, model: str = "phi3:mini", url: str = "http://localhost:11434") -> str: + """ + Summarize using Ollama API. + Requires Ollama running with the specified model pulled. + """ + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{url}/api/generate", + json={ + "model": model, + "prompt": f"Summarize the following conversation in 2-3 sentences, focusing on key decisions and conclusions:\n\n{text}", + "stream": False, + "options": { + "num_predict": 200 + } + } + ) + response.raise_for_status() + result = response.json() + return result.get("response", "").strip() + except Exception as e: + # Fallback to extractive on any error + return extractive_summarize(text, sentences_count=3) + + +async def compress_conversation( messages: List[Dict], max_tokens: int = 2000, - keep_last_n: int = 3 + keep_last_n: int = 3, + strategy: str = "extractive", + ollama_model: str = "phi3:mini", + ollama_url: str = "http://localhost:11434" ) -> List[Dict]: """ Compress conversation history: - Keep last N exchanges in full - - Summarize everything before into a single system message + - Summarize everything before using the configured strategy + + Args: + messages: Full conversation history + max_tokens: Target token budget + keep_last_n: Number of recent exchanges to keep uncompressed + strategy: "extractive", "ollama", or "none" + ollama_model: Model to use if strategy is "ollama" + ollama_url: Ollama API endpoint Returns compressed message list. """ + if strategy == "none": + return messages + if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs return messages @@ -41,72 +116,42 @@ def compress_conversation( recent = convo_messages[-keep_last_n * 2:] old = convo_messages[:-keep_last_n * 2] - # Summarize old conversation - summary = _summarize_turns(old) + # Create text to summarize from old turns + old_text = "\n".join([f"{m['role']}: {m['content']}" for m in old]) + + # Summarize using selected strategy + summary = None + if strategy == "ollama": + try: + summary = await ollama_summarize(old_text, ollama_model, ollama_url) + except Exception as e: + logger.warning(f"Ollama summarization failed: {e}, falling back to extractive") + summary = extractive_summarize(old_text, sentences_count=3) + else: + # Extractive is synchronous but fast; run in thread pool to avoid blocking + loop = asyncio.get_event_loop() + summary = await loop.run_in_executor(None, lambda: extractive_summarize(old_text, 3)) # Build compressed messages compressed = [] if system_msg: compressed.append(system_msg) - # Add summary as a user message with context + # Add summary as a user message with clear demarcation compressed.append({ "role": "user", - "content": f"[PREVIOUS CONVERSATION SUMMARY]\n{summary}\n[/PREVIOUS CONVERSATION SUMMARY]\n\n---\n\nConversation continues below:" + "content": f"[CONVERSATION SUMMARY]\n{summary}\n[/CONVERSATION SUMMARY]\n\n---\n\nRecent conversation (most relevant):" }) compressed.extend(recent) - # Verify we're under limit + # Verify we're under limit, if not, drop old more aggressively total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) - if total_tokens > max_tokens: - # Aggressive compression - keep only last exchange - compressed = compressed[-2:] + if total_tokens > max_tokens and len(compressed) > 2: + # Keep only system + last exchange + if system_msg: + compressed = [system_msg, recent[-2]] + else: + compressed = recent[-2:] return compressed - - -def _summarize_turns(messages: List[Dict]) -> str: - """ - Create a brief summary of conversation turns. - In production, call a small local model here. - For now, extract key decisions and topics. - """ - topics = [] - decisions = [] - - for msg in messages: - content = msg.get("content", "") - - # Extract topics from user messages - if msg.get("role") == "user": - # Simple keyword extraction (replace with LLM summary) - if "docker" in content.lower(): - topics.append("Docker configuration") - if "server" in content.lower(): - topics.append("Server setup") - if "config" in content.lower(): - topics.append("Configuration") - - # Extract decisions from assistant messages - if msg.get("role") == "assistant": - if "we decided" in content.lower() or "I'll use" in content.lower(): - decisions.append(content[:200]) - - summary_parts = [] - if topics: - summary_parts.append(f"Topics discussed: {', '.join(set(topics))}") - if decisions: - summary_parts.append(f"Decisions made: {'; '.join(decisions[:3])}") - - return "\n".join(summary_parts) if summary_parts else "Previous conversation covered various topics." - - -def truncate_tool_output(output: str, max_tokens: int = 200) -> str: - """Truncate tool outputs to save tokens""" - tokens = ENCODING.encode(output) - if len(tokens) <= max_tokens: - return output - - truncated = ENCODING.decode(tokens[:max_tokens]) - return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]" diff --git a/config.py b/config.py new file mode 100644 index 0000000..382bedd --- /dev/null +++ b/config.py @@ -0,0 +1,144 @@ +""" +Configuration management for AI Skills API. +Supports YAML config file with sensible defaults. +Priority: env vars > config file > defaults +""" + +import os +from dataclasses import dataclass, field +from typing import List, Optional +import yaml + + +@dataclass +class RAGConfig: + """RAG-specific configuration""" + max_skills: int = 3 + max_conventions: int = 2 + max_snippets: int = 2 + min_skill_score: float = 0.3 + min_snippet_score: float = 0.25 + embedding_model: str = "all-MiniLM-L6-v2" + + +@dataclass +class CompressionConfig: + """Compression configuration""" + enabled: bool = True + strategy: str = "extractive" # "extractive", "ollama", or "none" + keep_last_n: int = 3 + max_tokens: int = 2000 + ollama_model: str = "phi3:mini" + ollama_url: str = "http://localhost:11434" + + +@dataclass +class AuthConfig: + """Authentication configuration""" + enabled: bool = False # Set to True to require API keys + api_key: str = "change-me-in-production" # Single shared key for simplicity + header_name: str = "X-API-Key" + + +@dataclass +class LoggingConfig: + """Logging configuration""" + level: str = "INFO" + format: str = "json" # "json" or "text" + + +@dataclass +class Config: + """Main configuration""" + host: str = "0.0.0.0" + port: int = 8675 + database_url: str = "sqlite+aiosqlite:///./ai.db" + rag: RAGConfig = field(default_factory=RAGConfig) + compression: CompressionConfig = field(default_factory=CompressionConfig) + auth: AuthConfig = field(default_factory=AuthConfig) + logging: LoggingConfig = field(default_factory=LoggingConfig) + cors_origins: List[str] = field(default_factory=lambda: ["*"]) + + +def _env_override(prefix: str, key: str, default): + """Check for environment variable with prefix""" + env_key = f"{prefix}_{key}".upper() + value = os.getenv(env_key) + if value is not None: + # Convert to appropriate type + if isinstance(default, bool): + return value.lower() in ('true', '1', 'yes', 'on') + elif isinstance(default, int): + return int(value) + elif isinstance(default, float): + return float(value) + elif isinstance(default, list): + return value.split(',') + else: + return value + return default + + +def load_config(config_path: str = "/app/config.yaml") -> Config: + """ + Load configuration from YAML file with environment variable overrides. + Priority: env vars > config file > defaults. + """ + # Start with defaults + config_dict = {} + + # Load from config file if exists + locations = [config_path, "/app/config.yaml", "./config.yaml"] + file_data = {} + + for path in locations: + if os.path.exists(path): + with open(path, 'r') as f: + file_data = yaml.safe_load(f) or {} + break + + # Build config with file values as base + config_dict.update(file_data) + + # Override with environment variables + # Top-level settings + config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0')) + config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675)) + config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db')) + config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"]) + + # Nested configs + rag_dict = file_data.get('rag', {}) + config_dict['rag'] = { + 'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)), + 'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)), + 'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)), + 'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)), + 'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)), + 'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')), + } + + compression_dict = file_data.get('compression', {}) + config_dict['compression'] = { + 'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)), + 'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')), + 'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)), + 'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)), + 'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')), + 'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')), + } + + auth_dict = file_data.get('auth', {}) + config_dict['auth'] = { + 'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)), + 'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')), + 'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')), + } + + logging_dict = file_data.get('logging', {}) + config_dict['logging'] = { + 'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')), + 'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')), + } + + return Config(**config_dict) diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..6a0cce2 --- /dev/null +++ b/config.yaml @@ -0,0 +1,38 @@ +# AI Skills API Configuration + +# Server settings +host: "0.0.0.0" +port: 8675 +database_url: "sqlite+aiosqlite:///./ai.db" + +# CORS origins (restrict in production) +cors_origins: ["*"] + +# RAG (Retrieval Augmented Generation) settings +rag: + max_skills: 3 # Number of skills to include in context + max_conventions: 2 # Number of conventions to include + max_snippets: 2 # Number of code snippets to include + min_skill_score: 0.3 # Minimum similarity threshold for skills (0-1) + min_snippet_score: 0.25 # Minimum similarity for snippets (0-1) + embedding_model: "all-MiniLM-L6-v2" # Sentence transformer model + +# Compression settings +compression: + enabled: true + strategy: "extractive" # "extractive" (sumy), "ollama" (phi-3-mini), or "none" + keep_last_n: 3 # Number of recent exchanges to keep uncompressed + max_tokens: 2000 # Target token budget for conversation history + ollama_model: "phi3:mini" # Only used if strategy is "ollama" + ollama_url: "http://localhost:11434" # Ollama API endpoint + +# Authentication (set and forget - simple API key) +auth: + enabled: false # Set to true to require API key on all endpoints + api_key: "change-me-in-production" # Change this if enabling auth + header_name: "X-API-Key" + +# Logging configuration +logging: + level: "INFO" + format: "json" # "json" for structured logs, "text" for human readable \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 26f84a7..7f28cd9 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,7 @@ services: - DATABASE_URL=sqlite+aiosqlite:///./ai.db volumes: - ./data:/app/data + - ./config.yaml:/app/config.yaml:ro restart: unless-stopped healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8080/health"] diff --git a/main.py b/main.py index ff30970..0f12e7a 100644 --- a/main.py +++ b/main.py @@ -1,37 +1,81 @@ -from fastapi import FastAPI, HTTPException, Depends, Query +import logging +from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header from fastapi.middleware.cors import CORSMiddleware -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker +from sqlalchemy import select from sqlalchemy.exc import IntegrityError -import hashlib -import json -import os from typing import Optional, List +import sys -from database import get_db, init_db -from models import Skill, Snippet, Convention, Cache, Memory +from config import load_config, Config +from database import init_db as default_init_db +from models import Skill, Snippet, Convention, Memory, Base from schemas import ( - SkillBase, Skill, - SnippetBase, Snippet, - ConventionBase, Convention, - CacheStore, Cache as CacheSchema, - MemoryBase, Memory as MemorySchema, - ContextBundle, CacheLookup + SkillBase, + Skill as SkillSchema, + SnippetBase, + Snippet as SnippetSchema, + ConventionBase, + Convention as ConventionSchema, + MemoryBase, + Memory as MemorySchema, + ContextBundle ) -from semantic_cache import ( - semantic_cache_lookup, - semantic_cache_store, - get_cache_stats, - clear_old_cache -) -from rag import build_context_bundle +from rag import build_context_bundle, clear_cache as clear_rag_cache from compression import compress_conversation, count_tokens -app = FastAPI(title="AI Skills API", description="Local infrastructure for AI context management") +# Load configuration +config = load_config() +# Setup logging +logging.basicConfig( + level=getattr(logging, config.logging.level.upper()), + format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[logging.StreamHandler(sys.stdout)] +) +logger = logging.getLogger("ai-skills-api") + + +# API Key authentication dependency +async def verify_api_key( + request: Request, + x_api_key: Optional[str] = Header(None, alias=config.auth.header_name) +) -> None: + """Verify API key if auth is enabled""" + if config.auth.enabled: + if not x_api_key: + raise HTTPException(status_code=401, detail="Missing API key") + if x_api_key != config.auth.api_key: + raise HTTPException(status_code=403, detail="Invalid API key") + + +# Database setup based on config +if config.database_url != "sqlite+aiosqlite:///./ai.db": + # Custom DB URL - create new engine + engine = create_async_engine(config.database_url, echo=False) + AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async def get_db() -> AsyncSession: + async with AsyncSessionLocal() as session: + yield session + + async def init_db(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) +else: + # Use default database setup from database.py + from database import get_db, init_db + + +app = FastAPI( + title="AI Skills API", + description="Local infrastructure for AI context management" +) + +# CORS configuration app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=config.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], @@ -41,11 +85,15 @@ app.add_middleware( @app.on_event("startup") async def startup(): await init_db() + logger.info("Application startup complete") + logger.info(f"RAG cache will initialize on first request") + if config.auth.enabled: + logger.info("API key authentication enabled") # ============== SKILLS ============== -@app.get("/skills", response_model=list[Skill]) +@app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)]) async def list_skills( category: Optional[str] = None, db: AsyncSession = Depends(get_db) @@ -57,7 +105,7 @@ async def list_skills( return result.scalars().all() -@app.get("/skills/search") +@app.get("/skills/search", dependencies=[Depends(verify_api_key)]) async def search_skills( q: str, category: Optional[str] = None, @@ -66,7 +114,7 @@ async def search_skills( query = select(Skill).where( (Skill.name.ilike(f"%{q}%")) | (Skill.content.ilike(f"%{q}%")) | - (Skill.tags.ilike(f"%{q}%")) + (Skill.tags.astext.ilike(f"%{q}%")) ) if category: query = query.where(Skill.category == category) @@ -74,7 +122,7 @@ async def search_skills( return result.scalars().all() -@app.get("/skills/{skill_id}", response_model=Skill) +@app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() @@ -86,7 +134,7 @@ async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)): return skill -@app.post("/skills", response_model=Skill) +@app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)): db_skill = Skill(**skill.model_dump()) db.add(db_skill) @@ -99,7 +147,7 @@ async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)): return db_skill -@app.put("/skills/{skill_id}", response_model=Skill) +@app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)]) async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) db_skill = result.scalar_one_or_none() @@ -114,7 +162,7 @@ async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depen return db_skill -@app.delete("/skills/{skill_id}") +@app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)]) async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Skill).where(Skill.id == skill_id)) skill = result.scalar_one_or_none() @@ -128,7 +176,7 @@ async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)): # ============== SNIPPETS ============== -@app.get("/snippets", response_model=list[Snippet]) +@app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)]) async def list_snippets( category: Optional[str] = None, language: Optional[str] = None, @@ -143,7 +191,7 @@ async def list_snippets( return result.scalars().all() -@app.get("/snippets/{snippet_id}", response_model=Snippet) +@app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)]) async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() @@ -152,7 +200,7 @@ async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): return snippet -@app.post("/snippets", response_model=Snippet) +@app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)]) async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)): db_snippet = Snippet(**snippet.model_dump()) db.add(db_snippet) @@ -165,7 +213,7 @@ async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db return db_snippet -@app.delete("/snippets/{snippet_id}") +@app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)]) async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Snippet).where(Snippet.id == snippet_id)) snippet = result.scalar_one_or_none() @@ -179,7 +227,7 @@ async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)): # ============== CONVENTIONS ============== -@app.get("/conventions", response_model=list[Convention]) +@app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)]) async def list_conventions( project: Optional[str] = None, db: AsyncSession = Depends(get_db) @@ -191,7 +239,7 @@ async def list_conventions( return result.scalars().all() -@app.get("/conventions/{convention_id}", response_model=Convention) +@app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() @@ -200,7 +248,7 @@ async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)) return convention -@app.post("/conventions", response_model=Convention) +@app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)): db_convention = Convention(**convention.model_dump()) db.add(db_convention) @@ -213,7 +261,7 @@ async def create_convention(convention: ConventionBase, db: AsyncSession = Depen return db_convention -@app.put("/conventions/{convention_id}", response_model=Convention) +@app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)]) async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) db_convention = result.scalar_one_or_none() @@ -228,7 +276,7 @@ async def update_convention(convention_id: str, convention: ConventionBase, db: return db_convention -@app.delete("/conventions/{convention_id}") +@app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)]) async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Convention).where(Convention.id == convention_id)) convention = result.scalar_one_or_none() @@ -240,107 +288,9 @@ async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_d return {"deleted": convention_id} -# ============== CACHE ============== - -@app.post("/cache/lookup", response_model=Optional[CacheSchema]) -async def lookup_cache(lookup: CacheLookup, db: AsyncSession = Depends(get_db)): - """Exact hash-based cache lookup""" - prompt_hash = hashlib.sha256( - json.dumps({"prompt": lookup.prompt, "model": lookup.model}, sort_keys=True).encode() - ).hexdigest() - - result = await db.execute( - select(Cache).where( - (Cache.hash == prompt_hash) & - ((Cache.expires_at == None) | (Cache.expires_at > func.now())) - ) - ) - return result.scalar_one_or_none() - - -@app.post("/cache/semantic-lookup", response_model=dict) -async def semantic_lookup( - prompt: str, - model: Optional[str] = None, - min_similarity: float = 0.85, - db: AsyncSession = Depends(get_db) -): - """Semantic cache lookup - finds similar prompts""" - result = await semantic_cache_lookup( - prompt, db, model=model, min_similarity=min_similarity - ) - if result: - return {"hit": True, **result} - return {"hit": False} - - -@app.post("/cache/store", response_model=CacheSchema) -async def store_cache(cache: CacheStore, db: AsyncSession = Depends(get_db)): - """Store in exact-match cache""" - prompt_hash = hashlib.sha256( - json.dumps({"prompt": cache.response, "model": cache.model}, sort_keys=True).encode() - ).hexdigest() - - db_cache = Cache( - hash=prompt_hash, - response=cache.response, - model=cache.model, - tokens_in=cache.tokens_in, - tokens_out=cache.tokens_out, - expires_at=cache.expires_at - ) - db.add(db_cache) - await db.commit() - await db.refresh(db_cache) - return db_cache - - -@app.post("/cache/semantic-store", response_model=dict) -async def semantic_store( - prompt: str, - response: str, - model: Optional[str] = None, - tokens_in: Optional[int] = None, - tokens_out: Optional[int] = None, - db: AsyncSession = Depends(get_db) -): - """Store in semantic cache""" - return await semantic_cache_store( - prompt, response, db, model, tokens_in, tokens_out - ) - - -@app.delete("/cache/{cache_hash}") -async def delete_cache(cache_hash: str, db: AsyncSession = Depends(get_db)): - result = await db.execute(select(Cache).where(Cache.hash == cache_hash)) - cache = result.scalar_one_or_none() - if not cache: - raise HTTPException(status_code=404, detail="Cache entry not found") - - await db.delete(cache) - await db.commit() - return {"deleted": cache_hash} - - -@app.get("/cache/stats") -async def cache_stats_endpoint(db: AsyncSession = Depends(get_db)): - """Get cache statistics""" - return await get_cache_stats(db) - - -@app.post("/cache/clear-old") -async def clear_old( - older_than_hours: int = 168, - db: AsyncSession = Depends(get_db) -): - """Clear cache entries older than threshold""" - deleted = await clear_old_cache(db, older_than_hours) - return {"deleted": deleted} - - # ============== MEMORY ============== -@app.get("/memory", response_model=list[MemorySchema]) +@app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)]) async def list_memory( project: Optional[str] = None, db: AsyncSession = Depends(get_db) @@ -352,7 +302,7 @@ async def list_memory( return result.scalars().all() -@app.get("/memory/{memory_id}", response_model=MemorySchema) +@app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() @@ -361,7 +311,7 @@ async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)): return memory -@app.post("/memory", response_model=MemorySchema) +@app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)): db_memory = Memory(**memory.model_dump()) db.add(db_memory) @@ -374,7 +324,7 @@ async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)): return db_memory -@app.put("/memory/{memory_id}", response_model=MemorySchema) +@app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)]) async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) db_memory = result.scalar_one_or_none() @@ -389,7 +339,7 @@ async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = D return db_memory -@app.delete("/memory/{memory_id}") +@app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)]) async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)): result = await db.execute(select(Memory).where(Memory.id == memory_id)) memory = result.scalar_one_or_none() @@ -403,7 +353,7 @@ async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)): # ============== CONTEXT BUNDLE ============== -@app.get("/context", response_model=ContextBundle) +@app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)]) async def get_context( project: Optional[str] = None, skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"), @@ -438,19 +388,29 @@ async def get_context( ) -@app.get("/context/rag") +@app.get("/context/rag", dependencies=[Depends(verify_api_key)]) async def get_context_rag( query: str, project: Optional[str] = None, - max_skills: int = 3, - max_conventions: int = 2, - max_snippets: int = 2, + max_skills: Optional[int] = None, + max_conventions: Optional[int] = None, + max_snippets: Optional[int] = None, db: AsyncSession = Depends(get_db) ): """ RAG-based context selection - returns ONLY relevant items. Uses semantic search to find top K most relevant skills/snippets. + + Uses config defaults if parameters not provided. """ + # Use config defaults if not specified + if max_skills is None: + max_skills = config.rag.max_skills + if max_conventions is None: + max_conventions = config.rag.max_conventions + if max_snippets is None: + max_snippets = config.rag.max_snippets + bundle = await build_context_bundle( query, db, project, max_skills=max_skills, @@ -460,17 +420,28 @@ async def get_context_rag( return bundle -@app.post("/compress") -async def compress_messages( +@app.post("/compress", dependencies=[Depends(verify_api_key)]) +async def compress_messages_endpoint( messages: List[dict], - keep_last_n: int = 3, - max_tokens: int = 2000 + keep_last_n: Optional[int] = None, + max_tokens: Optional[int] = None ): """ Compress conversation history. - Keeps last N exchanges in full, summarizes everything before. + Uses config strategy (extractive or ollama). """ - compressed = compress_conversation(messages, max_tokens, keep_last_n) + # Use config defaults if not specified + keep_last_n = keep_last_n or config.compression.keep_last_n + max_tokens = max_tokens or config.compression.max_tokens + + compressed = await compress_conversation( + messages, + max_tokens=max_tokens, + keep_last_n=keep_last_n, + strategy=config.compression.strategy, + ollama_model=config.compression.ollama_model, + ollama_url=config.compression.ollama_url + ) original_tokens = sum(count_tokens(m.get("content", "")) for m in messages) compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed) @@ -483,7 +454,7 @@ async def compress_messages( } -@app.get("/tokens/count") +@app.get("/tokens/count", dependencies=[Depends(verify_api_key)]) async def count_tokens_endpoint(text: str): """Count tokens in text""" return {"tokens": count_tokens(text)} @@ -492,3 +463,36 @@ async def count_tokens_endpoint(text: str): @app.get("/health") async def health(): return {"status": "healthy"} + + +@app.get("/config") +async def get_config(): + """Return current configuration (non-sensitive)""" + return { + "rag": { + "max_skills": config.rag.max_skills, + "max_conventions": config.rag.max_conventions, + "max_snippets": config.rag.max_snippets, + "min_skill_score": config.rag.min_skill_score, + "min_snippet_score": config.rag.min_snippet_score, + "embedding_model": config.rag.embedding_model + }, + "compression": { + "enabled": config.compression.enabled, + "strategy": config.compression.strategy, + "keep_last_n": config.compression.keep_last_n, + "max_tokens": config.compression.max_tokens, + "ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None + }, + "auth": { + "enabled": config.auth.enabled + } + } + + +@app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)]) +async def clear_cache(): + """Clear RAG embedding cache (forces reload on next request)""" + clear_rag_cache() + return {"status": "cache cleared"} + diff --git a/models.py b/models.py index 448dc4a..a1384da 100644 --- a/models.py +++ b/models.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey +from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey, JSON from sqlalchemy.orm import relationship from sqlalchemy.sql import func from database import Base @@ -12,7 +12,7 @@ class Skill(Base): description = Column(Text) category = Column(String) content = Column(Text, nullable=False) - tags = Column(String) + tags = Column(JSON, default=list) created_at = Column(DateTime(timezone=True), server_default=func.now()) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) usage_count = Column(Integer, default=0) @@ -26,7 +26,7 @@ class Snippet(Base): language = Column(String) content = Column(Text, nullable=False) category = Column(String) - tags = Column(String) + tags = Column(JSON, default=list) created_at = Column(DateTime(timezone=True), server_default=func.now()) @@ -41,17 +41,6 @@ class Convention(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) -class Cache(Base): - __tablename__ = "cache" - - hash = Column(String, primary_key=True) - response = Column(Text, nullable=False) - model = Column(String) - tokens_in = Column(Integer) - tokens_out = Column(Integer) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - expires_at = Column(DateTime(timezone=True)) - class Memory(Base): __tablename__ = "memory" diff --git a/rag.py b/rag.py index ce98a85..1742674 100644 --- a/rag.py +++ b/rag.py @@ -1,6 +1,6 @@ """ RAG-based context selection using local embeddings. -No external API calls - runs entirely on your home server. +Optimized with in-memory embedding cache to avoid recomputation. """ import numpy as np @@ -8,12 +8,21 @@ from sentence_transformers import SentenceTransformer from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Dict, Optional -import os +import asyncio +from threading import Lock # Small, fast model - ~100MB, runs on CPU MODEL_NAME = "all-MiniLM-L6-v2" _model: Optional[SentenceTransformer] = None +# In-memory embedding cache +_embedding_cache: Dict[str, Dict] = { + "skills": {}, + "snippets": {}, + "initialized": False +} +_cache_lock = Lock() + def get_model() -> SentenceTransformer: """Lazy-load the embedding model""" @@ -34,6 +43,51 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: return float(np.dot(a, b)) +def _get_skill_text(skill) -> str: + """Generate searchable text for a skill""" + return f"{skill.name} {skill.description or ''} {skill.content[:500]}" + + +def _get_snippet_text(snippet) -> str: + """Generate searchable text for a snippet""" + return f"{snippet.name} {snippet.content}" + + +async def ensure_cache_initialized(db: AsyncSession): + """Load all skills and snippets into memory with their embeddings""" + global _embedding_cache + + with _cache_lock: + if _embedding_cache["initialized"]: + return + + # Load skills + result = await db.execute(select(Skill)) + skills = result.scalars().all() + + for skill in skills: + text = _get_skill_text(skill) + embedding = embed_text(text) + _embedding_cache["skills"][skill.id] = { + "embedding": embedding, + "skill": skill + } + + # Load snippets + result = await db.execute(select(Snippet)) + snippets = result.scalars().all() + + for snippet in snippets: + text = _get_snippet_text(snippet) + embedding = embed_text(text) + _embedding_cache["snippets"][snippet.id] = { + "embedding": embedding, + "snippet": snippet + } + + _embedding_cache["initialized"] = True + + async def select_relevant_skills( query: str, db: AsyncSession, @@ -42,27 +96,24 @@ async def select_relevant_skills( ) -> List[Dict]: """ Find most relevant skills using semantic search. - Only returns skills above minimum similarity threshold. + Uses cached embeddings for O(1) retrieval after initial load. """ from models import Skill - # Get all skills (for small datasets, load all - fine for <1000 items) - result = await db.execute(select(Skill)) - skills = result.scalars().all() + await ensure_cache_initialized(db) - if not skills: + if not _embedding_cache["skills"]: return [] # Generate query embedding query_embedding = embed_text(query) - # Score each skill + # Score each cached skill scored = [] - for skill in skills: - # Use cached embedding if available, else compute - skill_text = f"{skill.name} {skill.description or ''} {skill.content[:500]}" - skill_embedding = embed_text(skill_text) - score = cosine_similarity(query_embedding, skill_embedding) + for skill_id, cached in _embedding_cache["skills"].items(): + skill = cached["skill"] + embedding = cached["embedding"] + score = cosine_similarity(query_embedding, embedding) if score >= min_score: scored.append((score, skill)) @@ -90,6 +141,7 @@ async def select_relevant_conventions( """ Get conventions for a project. Exact match on project_path, plus fuzzy match on parent paths. + (No embeddings - small dataset, exact matching is fine) """ from models import Convention @@ -128,25 +180,24 @@ async def select_relevant_snippets( top_k: int = 2, language: Optional[str] = None ) -> List[Dict]: - """Find relevant code snippets""" + """Find relevant code snippets using cached embeddings""" from models import Snippet - result = await db.execute(select(Snippet)) - snippets = result.scalars().all() + await ensure_cache_initialized(db) - if not snippets: + if not _embedding_cache["snippets"]: return [] query_embedding = embed_text(query) scored = [] - for snippet in snippets: + for snippet_id, cached in _embedding_cache["snippets"].items(): + snippet = cached["snippet"] if language and snippet.language != language: continue - snippet_text = f"{snippet.name} {snippet.content}" - snippet_embedding = embed_text(snippet_text) - score = cosine_similarity(query_embedding, snippet_embedding) + embedding = cached["embedding"] + score = cosine_similarity(query_embedding, embedding) if score >= 0.25: # Lower threshold for snippets scored.append((score, snippet)) @@ -202,4 +253,11 @@ async def build_context_bundle( } -import asyncio +def clear_cache(): + """Clear embedding cache (useful for testing)""" + global _embedding_cache + _embedding_cache = { + "skills": {}, + "snippets": {}, + "initialized": False + } diff --git a/requirements.txt b/requirements.txt index 5bb08c6..9781c85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,6 @@ aiosqlite==0.19.0 sentence-transformers==2.3.1 numpy==1.26.3 tiktoken==0.5.2 +pyyaml==6.0 +httpx==0.27.0 +sumy==0.11.0 diff --git a/schemas.py b/schemas.py index 0bd60ba..418b2b0 100644 --- a/schemas.py +++ b/schemas.py @@ -52,26 +52,6 @@ class Convention(ConventionBase): from_attributes = True -class CacheBase(BaseModel): - hash: str - response: str - model: Optional[str] = None - tokens_in: Optional[int] = None - tokens_out: Optional[int] = None - expires_at: Optional[datetime] = None - - -class CacheStore(CacheBase): - pass - - -class Cache(CacheBase): - created_at: datetime - - class Config: - from_attributes = True - - class MemoryBase(BaseModel): id: str project: str @@ -92,8 +72,3 @@ class ContextBundle(BaseModel): snippets: List[Snippet] conventions: List[Convention] memories: List[Memory] - - -class CacheLookup(BaseModel): - prompt: str - model: Optional[str] = None diff --git a/semantic_cache.py b/semantic_cache.py deleted file mode 100644 index aca054d..0000000 --- a/semantic_cache.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Semantic cache - matches similar prompts, not just exact hashes. -Uses embeddings to find similar questions and return cached responses. -""" - -import numpy as np -from sqlalchemy import select, func -from sqlalchemy.ext.asyncio import AsyncSession -from datetime import datetime, timedelta -from typing import Optional, Dict, List -import json -import hashlib - -from rag import embed_text, cosine_similarity -from compression import count_tokens - - -async def semantic_cache_lookup( - prompt: str, - db: AsyncSession, - model: Optional[str] = None, - min_similarity: float = 0.85, - max_age_hours: int = 168 # 1 week -) -> Optional[Dict]: - """ - Find cached responses for semantically similar prompts. - - Returns cached response if similarity >= threshold and not expired. - """ - from models import Cache - - # Generate embedding for the query - query_embedding = embed_text(prompt) - - # Get non-expired cache entries - expiry = datetime.now() - timedelta(hours=max_age_hours) - result = await db.execute( - select(Cache) - .where( - (Cache.expires_at == None) | (Cache.expires_at > datetime.now()) - ) - .where(Cache.created_at > expiry) - ) - cache_entries = result.scalars().all() - - if not cache_entries: - return None - - # Score each cache entry - best_match = None - best_score = 0 - - for entry in cache_entries: - # Skip if model doesn't match (optional) - if model and entry.model and entry.model != model: - continue - - # Compute similarity - entry_embedding = embed_text(entry.response) # Or store prompt embedding - score = cosine_similarity(query_embedding, entry_embedding) - - if score >= min_similarity and score > best_score: - best_score = score - best_match = entry - - if best_match: - return { - "response": best_match.response, - "similarity": best_score, - "model": best_match.model, - "tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0), - "cached_at": best_match.created_at - } - - return None - - -async def semantic_cache_store( - prompt: str, - response: str, - db: AsyncSession, - model: Optional[str] = None, - tokens_in: Optional[int] = None, - tokens_out: Optional[int] = None, - ttl_hours: Optional[int] = None -) -> Dict: - """ - Store response in cache with embedding for semantic matching. - """ - from models import Cache - - # Generate hash for deduplication - prompt_hash = hashlib.sha256( - json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode() - ).hexdigest() - - # Check if exact match already exists - existing = await db.execute( - select(Cache).where(Cache.hash == prompt_hash) - ) - if existing.scalar_one_or_none(): - return {"status": "exists", "hash": prompt_hash} - - # Create new entry - expires_at = None - if ttl_hours: - expires_at = datetime.now() + timedelta(hours=ttl_hours) - - new_entry = Cache( - hash=prompt_hash, - response=response, - model=model, - tokens_in=tokens_in, - tokens_out=tokens_out, - expires_at=expires_at - ) - - db.add(new_entry) - await db.commit() - - return { - "status": "stored", - "hash": prompt_hash, - "tokens_stored": (tokens_in or 0) + (tokens_out or 0) - } - - -async def get_cache_stats(db: AsyncSession) -> Dict: - """Get cache statistics""" - from models import Cache - - result = await db.execute(select(Cache)) - entries = result.scalars().all() - - now = datetime.now() - valid_entries = [ - e for e in entries - if e.expires_at is None or e.expires_at > now - ] - - return { - "total_entries": len(entries), - "valid_entries": len(valid_entries), - "total_tokens_stored": sum( - (e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries - ), - "models_used": list(set(e.model for e in entries if e.model)) - } - - -async def clear_old_cache( - db: AsyncSession, - older_than_hours: int = 168 -) -> int: - """Delete cache entries older than threshold""" - from models import Cache - - cutoff = datetime.now() - timedelta(hours=older_than_hours) - result = await db.execute( - select(Cache).where(Cache.created_at < cutoff) - ) - old_entries = result.scalars().all() - - for entry in old_entries: - await db.delete(entry) - - await db.commit() - - return len(old_entries)