Major refactor: remove semantic cache, add config, auth, improve RAG performance, fix tags JSON
This commit is contained in:
parent
62c875c9a6
commit
b8edf40010
11 changed files with 533 additions and 443 deletions
|
|
@ -6,6 +6,8 @@ COPY requirements.txt .
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
|
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
|
||||||
|
USER appuser
|
||||||
|
|
||||||
EXPOSE 8080
|
EXPOSE 8080
|
||||||
|
|
||||||
|
|
|
||||||
163
compression.py
163
compression.py
|
|
@ -1,10 +1,18 @@
|
||||||
"""
|
"""
|
||||||
Prompt compression - summarizes conversation history to reduce tokens.
|
Conversation compression - summarizes old turns to save tokens.
|
||||||
Uses a small local model (no API calls) to compress old turns.
|
Supports multiple strategies: extractive summarization or Ollama LLM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
import logging
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import httpx
|
||||||
|
from sumy.parsers.plaintext import PlaintextParser
|
||||||
|
from sumy.nlp.tokenizers import Tokenizer
|
||||||
|
from sumy.summarizers.lsa import LsaSummarizer
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ENCODING = tiktoken.get_encoding("cl100k_base")
|
ENCODING = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
@ -14,18 +22,85 @@ def count_tokens(text: str) -> int:
|
||||||
return len(ENCODING.encode(text))
|
return len(ENCODING.encode(text))
|
||||||
|
|
||||||
|
|
||||||
def compress_conversation(
|
def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
|
||||||
|
"""Truncate tool outputs to save tokens"""
|
||||||
|
tokens = ENCODING.encode(output)
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
return output
|
||||||
|
|
||||||
|
truncated = ENCODING.decode(tokens[:max_tokens])
|
||||||
|
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"
|
||||||
|
|
||||||
|
|
||||||
|
def extractive_summarize(text: str, sentences_count: int = 3) -> str:
|
||||||
|
"""
|
||||||
|
Simple extractive summarization using LSA algorithm.
|
||||||
|
Picks the most important sentences from the text.
|
||||||
|
No external API calls, fast and deterministic.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parser = PlaintextParser.from_string(text, Tokenizer("english"))
|
||||||
|
summarizer = LsaSummarizer()
|
||||||
|
summary_sentences = summarizer(parser.document, sentences_count)
|
||||||
|
return " ".join(str(sentence) for sentence in summary_sentences)
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback: truncate to first few sentences
|
||||||
|
sentences = text.split('. ')[:3]
|
||||||
|
return '. '.join(sentences) + '.'
|
||||||
|
|
||||||
|
|
||||||
|
async def ollama_summarize(text: str, model: str = "phi3:mini", url: str = "http://localhost:11434") -> str:
|
||||||
|
"""
|
||||||
|
Summarize using Ollama API.
|
||||||
|
Requires Ollama running with the specified model pulled.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{url}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"prompt": f"Summarize the following conversation in 2-3 sentences, focusing on key decisions and conclusions:\n\n{text}",
|
||||||
|
"stream": False,
|
||||||
|
"options": {
|
||||||
|
"num_predict": 200
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
return result.get("response", "").strip()
|
||||||
|
except Exception as e:
|
||||||
|
# Fallback to extractive on any error
|
||||||
|
return extractive_summarize(text, sentences_count=3)
|
||||||
|
|
||||||
|
|
||||||
|
async def compress_conversation(
|
||||||
messages: List[Dict],
|
messages: List[Dict],
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
keep_last_n: int = 3
|
keep_last_n: int = 3,
|
||||||
|
strategy: str = "extractive",
|
||||||
|
ollama_model: str = "phi3:mini",
|
||||||
|
ollama_url: str = "http://localhost:11434"
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Compress conversation history:
|
Compress conversation history:
|
||||||
- Keep last N exchanges in full
|
- Keep last N exchanges in full
|
||||||
- Summarize everything before into a single system message
|
- Summarize everything before using the configured strategy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Full conversation history
|
||||||
|
max_tokens: Target token budget
|
||||||
|
keep_last_n: Number of recent exchanges to keep uncompressed
|
||||||
|
strategy: "extractive", "ollama", or "none"
|
||||||
|
ollama_model: Model to use if strategy is "ollama"
|
||||||
|
ollama_url: Ollama API endpoint
|
||||||
|
|
||||||
Returns compressed message list.
|
Returns compressed message list.
|
||||||
"""
|
"""
|
||||||
|
if strategy == "none":
|
||||||
|
return messages
|
||||||
|
|
||||||
if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs
|
if len(messages) <= keep_last_n * 2: # *2 for user/assistant pairs
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
@ -41,72 +116,42 @@ def compress_conversation(
|
||||||
recent = convo_messages[-keep_last_n * 2:]
|
recent = convo_messages[-keep_last_n * 2:]
|
||||||
old = convo_messages[:-keep_last_n * 2]
|
old = convo_messages[:-keep_last_n * 2]
|
||||||
|
|
||||||
# Summarize old conversation
|
# Create text to summarize from old turns
|
||||||
summary = _summarize_turns(old)
|
old_text = "\n".join([f"{m['role']}: {m['content']}" for m in old])
|
||||||
|
|
||||||
|
# Summarize using selected strategy
|
||||||
|
summary = None
|
||||||
|
if strategy == "ollama":
|
||||||
|
try:
|
||||||
|
summary = await ollama_summarize(old_text, ollama_model, ollama_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Ollama summarization failed: {e}, falling back to extractive")
|
||||||
|
summary = extractive_summarize(old_text, sentences_count=3)
|
||||||
|
else:
|
||||||
|
# Extractive is synchronous but fast; run in thread pool to avoid blocking
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
summary = await loop.run_in_executor(None, lambda: extractive_summarize(old_text, 3))
|
||||||
|
|
||||||
# Build compressed messages
|
# Build compressed messages
|
||||||
compressed = []
|
compressed = []
|
||||||
if system_msg:
|
if system_msg:
|
||||||
compressed.append(system_msg)
|
compressed.append(system_msg)
|
||||||
|
|
||||||
# Add summary as a user message with context
|
# Add summary as a user message with clear demarcation
|
||||||
compressed.append({
|
compressed.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"[PREVIOUS CONVERSATION SUMMARY]\n{summary}\n[/PREVIOUS CONVERSATION SUMMARY]\n\n---\n\nConversation continues below:"
|
"content": f"[CONVERSATION SUMMARY]\n{summary}\n[/CONVERSATION SUMMARY]\n\n---\n\nRecent conversation (most relevant):"
|
||||||
})
|
})
|
||||||
|
|
||||||
compressed.extend(recent)
|
compressed.extend(recent)
|
||||||
|
|
||||||
# Verify we're under limit
|
# Verify we're under limit, if not, drop old more aggressively
|
||||||
total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
total_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
||||||
if total_tokens > max_tokens:
|
if total_tokens > max_tokens and len(compressed) > 2:
|
||||||
# Aggressive compression - keep only last exchange
|
# Keep only system + last exchange
|
||||||
compressed = compressed[-2:]
|
if system_msg:
|
||||||
|
compressed = [system_msg, recent[-2]]
|
||||||
|
else:
|
||||||
|
compressed = recent[-2:]
|
||||||
|
|
||||||
return compressed
|
return compressed
|
||||||
|
|
||||||
|
|
||||||
def _summarize_turns(messages: List[Dict]) -> str:
|
|
||||||
"""
|
|
||||||
Create a brief summary of conversation turns.
|
|
||||||
In production, call a small local model here.
|
|
||||||
For now, extract key decisions and topics.
|
|
||||||
"""
|
|
||||||
topics = []
|
|
||||||
decisions = []
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
content = msg.get("content", "")
|
|
||||||
|
|
||||||
# Extract topics from user messages
|
|
||||||
if msg.get("role") == "user":
|
|
||||||
# Simple keyword extraction (replace with LLM summary)
|
|
||||||
if "docker" in content.lower():
|
|
||||||
topics.append("Docker configuration")
|
|
||||||
if "server" in content.lower():
|
|
||||||
topics.append("Server setup")
|
|
||||||
if "config" in content.lower():
|
|
||||||
topics.append("Configuration")
|
|
||||||
|
|
||||||
# Extract decisions from assistant messages
|
|
||||||
if msg.get("role") == "assistant":
|
|
||||||
if "we decided" in content.lower() or "I'll use" in content.lower():
|
|
||||||
decisions.append(content[:200])
|
|
||||||
|
|
||||||
summary_parts = []
|
|
||||||
if topics:
|
|
||||||
summary_parts.append(f"Topics discussed: {', '.join(set(topics))}")
|
|
||||||
if decisions:
|
|
||||||
summary_parts.append(f"Decisions made: {'; '.join(decisions[:3])}")
|
|
||||||
|
|
||||||
return "\n".join(summary_parts) if summary_parts else "Previous conversation covered various topics."
|
|
||||||
|
|
||||||
|
|
||||||
def truncate_tool_output(output: str, max_tokens: int = 200) -> str:
|
|
||||||
"""Truncate tool outputs to save tokens"""
|
|
||||||
tokens = ENCODING.encode(output)
|
|
||||||
if len(tokens) <= max_tokens:
|
|
||||||
return output
|
|
||||||
|
|
||||||
truncated = ENCODING.decode(tokens[:max_tokens])
|
|
||||||
return f"{truncated}... [truncated, {len(tokens) - max_tokens} tokens omitted]"
|
|
||||||
|
|
|
||||||
144
config.py
Normal file
144
config.py
Normal file
|
|
@ -0,0 +1,144 @@
|
||||||
|
"""
|
||||||
|
Configuration management for AI Skills API.
|
||||||
|
Supports YAML config file with sensible defaults.
|
||||||
|
Priority: env vars > config file > defaults
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RAGConfig:
|
||||||
|
"""RAG-specific configuration"""
|
||||||
|
max_skills: int = 3
|
||||||
|
max_conventions: int = 2
|
||||||
|
max_snippets: int = 2
|
||||||
|
min_skill_score: float = 0.3
|
||||||
|
min_snippet_score: float = 0.25
|
||||||
|
embedding_model: str = "all-MiniLM-L6-v2"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompressionConfig:
|
||||||
|
"""Compression configuration"""
|
||||||
|
enabled: bool = True
|
||||||
|
strategy: str = "extractive" # "extractive", "ollama", or "none"
|
||||||
|
keep_last_n: int = 3
|
||||||
|
max_tokens: int = 2000
|
||||||
|
ollama_model: str = "phi3:mini"
|
||||||
|
ollama_url: str = "http://localhost:11434"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuthConfig:
|
||||||
|
"""Authentication configuration"""
|
||||||
|
enabled: bool = False # Set to True to require API keys
|
||||||
|
api_key: str = "change-me-in-production" # Single shared key for simplicity
|
||||||
|
header_name: str = "X-API-Key"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoggingConfig:
|
||||||
|
"""Logging configuration"""
|
||||||
|
level: str = "INFO"
|
||||||
|
format: str = "json" # "json" or "text"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Config:
|
||||||
|
"""Main configuration"""
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8675
|
||||||
|
database_url: str = "sqlite+aiosqlite:///./ai.db"
|
||||||
|
rag: RAGConfig = field(default_factory=RAGConfig)
|
||||||
|
compression: CompressionConfig = field(default_factory=CompressionConfig)
|
||||||
|
auth: AuthConfig = field(default_factory=AuthConfig)
|
||||||
|
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||||
|
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
||||||
|
|
||||||
|
|
||||||
|
def _env_override(prefix: str, key: str, default):
|
||||||
|
"""Check for environment variable with prefix"""
|
||||||
|
env_key = f"{prefix}_{key}".upper()
|
||||||
|
value = os.getenv(env_key)
|
||||||
|
if value is not None:
|
||||||
|
# Convert to appropriate type
|
||||||
|
if isinstance(default, bool):
|
||||||
|
return value.lower() in ('true', '1', 'yes', 'on')
|
||||||
|
elif isinstance(default, int):
|
||||||
|
return int(value)
|
||||||
|
elif isinstance(default, float):
|
||||||
|
return float(value)
|
||||||
|
elif isinstance(default, list):
|
||||||
|
return value.split(',')
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path: str = "/app/config.yaml") -> Config:
|
||||||
|
"""
|
||||||
|
Load configuration from YAML file with environment variable overrides.
|
||||||
|
Priority: env vars > config file > defaults.
|
||||||
|
"""
|
||||||
|
# Start with defaults
|
||||||
|
config_dict = {}
|
||||||
|
|
||||||
|
# Load from config file if exists
|
||||||
|
locations = [config_path, "/app/config.yaml", "./config.yaml"]
|
||||||
|
file_data = {}
|
||||||
|
|
||||||
|
for path in locations:
|
||||||
|
if os.path.exists(path):
|
||||||
|
with open(path, 'r') as f:
|
||||||
|
file_data = yaml.safe_load(f) or {}
|
||||||
|
break
|
||||||
|
|
||||||
|
# Build config with file values as base
|
||||||
|
config_dict.update(file_data)
|
||||||
|
|
||||||
|
# Override with environment variables
|
||||||
|
# Top-level settings
|
||||||
|
config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0'))
|
||||||
|
config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675))
|
||||||
|
config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db'))
|
||||||
|
config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"])
|
||||||
|
|
||||||
|
# Nested configs
|
||||||
|
rag_dict = file_data.get('rag', {})
|
||||||
|
config_dict['rag'] = {
|
||||||
|
'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)),
|
||||||
|
'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)),
|
||||||
|
'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)),
|
||||||
|
'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)),
|
||||||
|
'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)),
|
||||||
|
'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')),
|
||||||
|
}
|
||||||
|
|
||||||
|
compression_dict = file_data.get('compression', {})
|
||||||
|
config_dict['compression'] = {
|
||||||
|
'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)),
|
||||||
|
'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')),
|
||||||
|
'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)),
|
||||||
|
'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)),
|
||||||
|
'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')),
|
||||||
|
'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')),
|
||||||
|
}
|
||||||
|
|
||||||
|
auth_dict = file_data.get('auth', {})
|
||||||
|
config_dict['auth'] = {
|
||||||
|
'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)),
|
||||||
|
'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')),
|
||||||
|
'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')),
|
||||||
|
}
|
||||||
|
|
||||||
|
logging_dict = file_data.get('logging', {})
|
||||||
|
config_dict['logging'] = {
|
||||||
|
'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')),
|
||||||
|
'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')),
|
||||||
|
}
|
||||||
|
|
||||||
|
return Config(**config_dict)
|
||||||
38
config.yaml
Normal file
38
config.yaml
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
# AI Skills API Configuration
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
host: "0.0.0.0"
|
||||||
|
port: 8675
|
||||||
|
database_url: "sqlite+aiosqlite:///./ai.db"
|
||||||
|
|
||||||
|
# CORS origins (restrict in production)
|
||||||
|
cors_origins: ["*"]
|
||||||
|
|
||||||
|
# RAG (Retrieval Augmented Generation) settings
|
||||||
|
rag:
|
||||||
|
max_skills: 3 # Number of skills to include in context
|
||||||
|
max_conventions: 2 # Number of conventions to include
|
||||||
|
max_snippets: 2 # Number of code snippets to include
|
||||||
|
min_skill_score: 0.3 # Minimum similarity threshold for skills (0-1)
|
||||||
|
min_snippet_score: 0.25 # Minimum similarity for snippets (0-1)
|
||||||
|
embedding_model: "all-MiniLM-L6-v2" # Sentence transformer model
|
||||||
|
|
||||||
|
# Compression settings
|
||||||
|
compression:
|
||||||
|
enabled: true
|
||||||
|
strategy: "extractive" # "extractive" (sumy), "ollama" (phi-3-mini), or "none"
|
||||||
|
keep_last_n: 3 # Number of recent exchanges to keep uncompressed
|
||||||
|
max_tokens: 2000 # Target token budget for conversation history
|
||||||
|
ollama_model: "phi3:mini" # Only used if strategy is "ollama"
|
||||||
|
ollama_url: "http://localhost:11434" # Ollama API endpoint
|
||||||
|
|
||||||
|
# Authentication (set and forget - simple API key)
|
||||||
|
auth:
|
||||||
|
enabled: false # Set to true to require API key on all endpoints
|
||||||
|
api_key: "change-me-in-production" # Change this if enabling auth
|
||||||
|
header_name: "X-API-Key"
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
logging:
|
||||||
|
level: "INFO"
|
||||||
|
format: "json" # "json" for structured logs, "text" for human readable
|
||||||
|
|
@ -7,6 +7,7 @@ services:
|
||||||
- DATABASE_URL=sqlite+aiosqlite:///./ai.db
|
- DATABASE_URL=sqlite+aiosqlite:///./ai.db
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/app/data
|
- ./data:/app/data
|
||||||
|
- ./config.yaml:/app/config.yaml:ro
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
|
||||||
|
|
|
||||||
312
main.py
312
main.py
|
|
@ -1,37 +1,81 @@
|
||||||
from fastapi import FastAPI, HTTPException, Depends, Query
|
import logging
|
||||||
|
from fastapi import FastAPI, HTTPException, Depends, Query, Request, Header
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
import sys
|
||||||
|
|
||||||
from database import get_db, init_db
|
from config import load_config, Config
|
||||||
from models import Skill, Snippet, Convention, Cache, Memory
|
from database import init_db as default_init_db
|
||||||
|
from models import Skill, Snippet, Convention, Memory, Base
|
||||||
from schemas import (
|
from schemas import (
|
||||||
SkillBase, Skill,
|
SkillBase,
|
||||||
SnippetBase, Snippet,
|
Skill as SkillSchema,
|
||||||
ConventionBase, Convention,
|
SnippetBase,
|
||||||
CacheStore, Cache as CacheSchema,
|
Snippet as SnippetSchema,
|
||||||
MemoryBase, Memory as MemorySchema,
|
ConventionBase,
|
||||||
ContextBundle, CacheLookup
|
Convention as ConventionSchema,
|
||||||
|
MemoryBase,
|
||||||
|
Memory as MemorySchema,
|
||||||
|
ContextBundle
|
||||||
)
|
)
|
||||||
from semantic_cache import (
|
from rag import build_context_bundle, clear_cache as clear_rag_cache
|
||||||
semantic_cache_lookup,
|
|
||||||
semantic_cache_store,
|
|
||||||
get_cache_stats,
|
|
||||||
clear_old_cache
|
|
||||||
)
|
|
||||||
from rag import build_context_bundle
|
|
||||||
from compression import compress_conversation, count_tokens
|
from compression import compress_conversation, count_tokens
|
||||||
|
|
||||||
app = FastAPI(title="AI Skills API", description="Local infrastructure for AI context management")
|
# Load configuration
|
||||||
|
config = load_config()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, config.logging.level.upper()),
|
||||||
|
format='%(message)s' if config.logging.format == "json" else '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("ai-skills-api")
|
||||||
|
|
||||||
|
|
||||||
|
# API Key authentication dependency
|
||||||
|
async def verify_api_key(
|
||||||
|
request: Request,
|
||||||
|
x_api_key: Optional[str] = Header(None, alias=config.auth.header_name)
|
||||||
|
) -> None:
|
||||||
|
"""Verify API key if auth is enabled"""
|
||||||
|
if config.auth.enabled:
|
||||||
|
if not x_api_key:
|
||||||
|
raise HTTPException(status_code=401, detail="Missing API key")
|
||||||
|
if x_api_key != config.auth.api_key:
|
||||||
|
raise HTTPException(status_code=403, detail="Invalid API key")
|
||||||
|
|
||||||
|
|
||||||
|
# Database setup based on config
|
||||||
|
if config.database_url != "sqlite+aiosqlite:///./ai.db":
|
||||||
|
# Custom DB URL - create new engine
|
||||||
|
engine = create_async_engine(config.database_url, echo=False)
|
||||||
|
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def get_db() -> AsyncSession:
|
||||||
|
async with AsyncSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
async def init_db():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
else:
|
||||||
|
# Use default database setup from database.py
|
||||||
|
from database import get_db, init_db
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AI Skills API",
|
||||||
|
description="Local infrastructure for AI context management"
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS configuration
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=config.cors_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|
@ -41,11 +85,15 @@ app.add_middleware(
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
await init_db()
|
await init_db()
|
||||||
|
logger.info("Application startup complete")
|
||||||
|
logger.info(f"RAG cache will initialize on first request")
|
||||||
|
if config.auth.enabled:
|
||||||
|
logger.info("API key authentication enabled")
|
||||||
|
|
||||||
|
|
||||||
# ============== SKILLS ==============
|
# ============== SKILLS ==============
|
||||||
|
|
||||||
@app.get("/skills", response_model=list[Skill])
|
@app.get("/skills", response_model=list[SkillSchema], dependencies=[Depends(verify_api_key)])
|
||||||
async def list_skills(
|
async def list_skills(
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
|
|
@ -57,7 +105,7 @@ async def list_skills(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/skills/search")
|
@app.get("/skills/search", dependencies=[Depends(verify_api_key)])
|
||||||
async def search_skills(
|
async def search_skills(
|
||||||
q: str,
|
q: str,
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
|
|
@ -66,7 +114,7 @@ async def search_skills(
|
||||||
query = select(Skill).where(
|
query = select(Skill).where(
|
||||||
(Skill.name.ilike(f"%{q}%")) |
|
(Skill.name.ilike(f"%{q}%")) |
|
||||||
(Skill.content.ilike(f"%{q}%")) |
|
(Skill.content.ilike(f"%{q}%")) |
|
||||||
(Skill.tags.ilike(f"%{q}%"))
|
(Skill.tags.astext.ilike(f"%{q}%"))
|
||||||
)
|
)
|
||||||
if category:
|
if category:
|
||||||
query = query.where(Skill.category == category)
|
query = query.where(Skill.category == category)
|
||||||
|
|
@ -74,7 +122,7 @@ async def search_skills(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/skills/{skill_id}", response_model=Skill)
|
@app.get("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
||||||
skill = result.scalar_one_or_none()
|
skill = result.scalar_one_or_none()
|
||||||
|
|
@ -86,7 +134,7 @@ async def get_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
return skill
|
return skill
|
||||||
|
|
||||||
|
|
||||||
@app.post("/skills", response_model=Skill)
|
@app.post("/skills", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
||||||
db_skill = Skill(**skill.model_dump())
|
db_skill = Skill(**skill.model_dump())
|
||||||
db.add(db_skill)
|
db.add(db_skill)
|
||||||
|
|
@ -99,7 +147,7 @@ async def create_skill(skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
||||||
return db_skill
|
return db_skill
|
||||||
|
|
||||||
|
|
||||||
@app.put("/skills/{skill_id}", response_model=Skill)
|
@app.put("/skills/{skill_id}", response_model=SkillSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
||||||
db_skill = result.scalar_one_or_none()
|
db_skill = result.scalar_one_or_none()
|
||||||
|
|
@ -114,7 +162,7 @@ async def update_skill(skill_id: str, skill: SkillBase, db: AsyncSession = Depen
|
||||||
return db_skill
|
return db_skill
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/skills/{skill_id}")
|
@app.delete("/skills/{skill_id}", dependencies=[Depends(verify_api_key)])
|
||||||
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
result = await db.execute(select(Skill).where(Skill.id == skill_id))
|
||||||
skill = result.scalar_one_or_none()
|
skill = result.scalar_one_or_none()
|
||||||
|
|
@ -128,7 +176,7 @@ async def delete_skill(skill_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
|
||||||
# ============== SNIPPETS ==============
|
# ============== SNIPPETS ==============
|
||||||
|
|
||||||
@app.get("/snippets", response_model=list[Snippet])
|
@app.get("/snippets", response_model=list[SnippetSchema], dependencies=[Depends(verify_api_key)])
|
||||||
async def list_snippets(
|
async def list_snippets(
|
||||||
category: Optional[str] = None,
|
category: Optional[str] = None,
|
||||||
language: Optional[str] = None,
|
language: Optional[str] = None,
|
||||||
|
|
@ -143,7 +191,7 @@ async def list_snippets(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/snippets/{snippet_id}", response_model=Snippet)
|
@app.get("/snippets/{snippet_id}", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
||||||
snippet = result.scalar_one_or_none()
|
snippet = result.scalar_one_or_none()
|
||||||
|
|
@ -152,7 +200,7 @@ async def get_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
return snippet
|
return snippet
|
||||||
|
|
||||||
|
|
||||||
@app.post("/snippets", response_model=Snippet)
|
@app.post("/snippets", response_model=SnippetSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
|
async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db)):
|
||||||
db_snippet = Snippet(**snippet.model_dump())
|
db_snippet = Snippet(**snippet.model_dump())
|
||||||
db.add(db_snippet)
|
db.add(db_snippet)
|
||||||
|
|
@ -165,7 +213,7 @@ async def create_snippet(snippet: SnippetBase, db: AsyncSession = Depends(get_db
|
||||||
return db_snippet
|
return db_snippet
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/snippets/{snippet_id}")
|
@app.delete("/snippets/{snippet_id}", dependencies=[Depends(verify_api_key)])
|
||||||
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
result = await db.execute(select(Snippet).where(Snippet.id == snippet_id))
|
||||||
snippet = result.scalar_one_or_none()
|
snippet = result.scalar_one_or_none()
|
||||||
|
|
@ -179,7 +227,7 @@ async def delete_snippet(snippet_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
|
||||||
# ============== CONVENTIONS ==============
|
# ============== CONVENTIONS ==============
|
||||||
|
|
||||||
@app.get("/conventions", response_model=list[Convention])
|
@app.get("/conventions", response_model=list[ConventionSchema], dependencies=[Depends(verify_api_key)])
|
||||||
async def list_conventions(
|
async def list_conventions(
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
|
|
@ -191,7 +239,7 @@ async def list_conventions(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/conventions/{convention_id}", response_model=Convention)
|
@app.get("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
||||||
convention = result.scalar_one_or_none()
|
convention = result.scalar_one_or_none()
|
||||||
|
|
@ -200,7 +248,7 @@ async def get_convention(convention_id: str, db: AsyncSession = Depends(get_db))
|
||||||
return convention
|
return convention
|
||||||
|
|
||||||
|
|
||||||
@app.post("/conventions", response_model=Convention)
|
@app.post("/conventions", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
async def create_convention(convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
||||||
db_convention = Convention(**convention.model_dump())
|
db_convention = Convention(**convention.model_dump())
|
||||||
db.add(db_convention)
|
db.add(db_convention)
|
||||||
|
|
@ -213,7 +261,7 @@ async def create_convention(convention: ConventionBase, db: AsyncSession = Depen
|
||||||
return db_convention
|
return db_convention
|
||||||
|
|
||||||
|
|
||||||
@app.put("/conventions/{convention_id}", response_model=Convention)
|
@app.put("/conventions/{convention_id}", response_model=ConventionSchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
async def update_convention(convention_id: str, convention: ConventionBase, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
||||||
db_convention = result.scalar_one_or_none()
|
db_convention = result.scalar_one_or_none()
|
||||||
|
|
@ -228,7 +276,7 @@ async def update_convention(convention_id: str, convention: ConventionBase, db:
|
||||||
return db_convention
|
return db_convention
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/conventions/{convention_id}")
|
@app.delete("/conventions/{convention_id}", dependencies=[Depends(verify_api_key)])
|
||||||
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
result = await db.execute(select(Convention).where(Convention.id == convention_id))
|
||||||
convention = result.scalar_one_or_none()
|
convention = result.scalar_one_or_none()
|
||||||
|
|
@ -240,107 +288,9 @@ async def delete_convention(convention_id: str, db: AsyncSession = Depends(get_d
|
||||||
return {"deleted": convention_id}
|
return {"deleted": convention_id}
|
||||||
|
|
||||||
|
|
||||||
# ============== CACHE ==============
|
|
||||||
|
|
||||||
@app.post("/cache/lookup", response_model=Optional[CacheSchema])
|
|
||||||
async def lookup_cache(lookup: CacheLookup, db: AsyncSession = Depends(get_db)):
|
|
||||||
"""Exact hash-based cache lookup"""
|
|
||||||
prompt_hash = hashlib.sha256(
|
|
||||||
json.dumps({"prompt": lookup.prompt, "model": lookup.model}, sort_keys=True).encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(Cache).where(
|
|
||||||
(Cache.hash == prompt_hash) &
|
|
||||||
((Cache.expires_at == None) | (Cache.expires_at > func.now()))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return result.scalar_one_or_none()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/cache/semantic-lookup", response_model=dict)
|
|
||||||
async def semantic_lookup(
|
|
||||||
prompt: str,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
min_similarity: float = 0.85,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Semantic cache lookup - finds similar prompts"""
|
|
||||||
result = await semantic_cache_lookup(
|
|
||||||
prompt, db, model=model, min_similarity=min_similarity
|
|
||||||
)
|
|
||||||
if result:
|
|
||||||
return {"hit": True, **result}
|
|
||||||
return {"hit": False}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/cache/store", response_model=CacheSchema)
|
|
||||||
async def store_cache(cache: CacheStore, db: AsyncSession = Depends(get_db)):
|
|
||||||
"""Store in exact-match cache"""
|
|
||||||
prompt_hash = hashlib.sha256(
|
|
||||||
json.dumps({"prompt": cache.response, "model": cache.model}, sort_keys=True).encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
db_cache = Cache(
|
|
||||||
hash=prompt_hash,
|
|
||||||
response=cache.response,
|
|
||||||
model=cache.model,
|
|
||||||
tokens_in=cache.tokens_in,
|
|
||||||
tokens_out=cache.tokens_out,
|
|
||||||
expires_at=cache.expires_at
|
|
||||||
)
|
|
||||||
db.add(db_cache)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(db_cache)
|
|
||||||
return db_cache
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/cache/semantic-store", response_model=dict)
|
|
||||||
async def semantic_store(
|
|
||||||
prompt: str,
|
|
||||||
response: str,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
tokens_in: Optional[int] = None,
|
|
||||||
tokens_out: Optional[int] = None,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Store in semantic cache"""
|
|
||||||
return await semantic_cache_store(
|
|
||||||
prompt, response, db, model, tokens_in, tokens_out
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/cache/{cache_hash}")
|
|
||||||
async def delete_cache(cache_hash: str, db: AsyncSession = Depends(get_db)):
|
|
||||||
result = await db.execute(select(Cache).where(Cache.hash == cache_hash))
|
|
||||||
cache = result.scalar_one_or_none()
|
|
||||||
if not cache:
|
|
||||||
raise HTTPException(status_code=404, detail="Cache entry not found")
|
|
||||||
|
|
||||||
await db.delete(cache)
|
|
||||||
await db.commit()
|
|
||||||
return {"deleted": cache_hash}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/cache/stats")
|
|
||||||
async def cache_stats_endpoint(db: AsyncSession = Depends(get_db)):
|
|
||||||
"""Get cache statistics"""
|
|
||||||
return await get_cache_stats(db)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/cache/clear-old")
|
|
||||||
async def clear_old(
|
|
||||||
older_than_hours: int = 168,
|
|
||||||
db: AsyncSession = Depends(get_db)
|
|
||||||
):
|
|
||||||
"""Clear cache entries older than threshold"""
|
|
||||||
deleted = await clear_old_cache(db, older_than_hours)
|
|
||||||
return {"deleted": deleted}
|
|
||||||
|
|
||||||
|
|
||||||
# ============== MEMORY ==============
|
# ============== MEMORY ==============
|
||||||
|
|
||||||
@app.get("/memory", response_model=list[MemorySchema])
|
@app.get("/memory", response_model=list[MemorySchema], dependencies=[Depends(verify_api_key)])
|
||||||
async def list_memory(
|
async def list_memory(
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
|
|
@ -352,7 +302,7 @@ async def list_memory(
|
||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/memory/{memory_id}", response_model=MemorySchema)
|
@app.get("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
||||||
memory = result.scalar_one_or_none()
|
memory = result.scalar_one_or_none()
|
||||||
|
|
@ -361,7 +311,7 @@ async def get_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
return memory
|
return memory
|
||||||
|
|
||||||
|
|
||||||
@app.post("/memory", response_model=MemorySchema)
|
@app.post("/memory", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
||||||
db_memory = Memory(**memory.model_dump())
|
db_memory = Memory(**memory.model_dump())
|
||||||
db.add(db_memory)
|
db.add(db_memory)
|
||||||
|
|
@ -374,7 +324,7 @@ async def create_memory(memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
||||||
return db_memory
|
return db_memory
|
||||||
|
|
||||||
|
|
||||||
@app.put("/memory/{memory_id}", response_model=MemorySchema)
|
@app.put("/memory/{memory_id}", response_model=MemorySchema, dependencies=[Depends(verify_api_key)])
|
||||||
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
||||||
db_memory = result.scalar_one_or_none()
|
db_memory = result.scalar_one_or_none()
|
||||||
|
|
@ -389,7 +339,7 @@ async def update_memory(memory_id: str, memory: MemoryBase, db: AsyncSession = D
|
||||||
return db_memory
|
return db_memory
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/memory/{memory_id}")
|
@app.delete("/memory/{memory_id}", dependencies=[Depends(verify_api_key)])
|
||||||
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
result = await db.execute(select(Memory).where(Memory.id == memory_id))
|
||||||
memory = result.scalar_one_or_none()
|
memory = result.scalar_one_or_none()
|
||||||
|
|
@ -403,7 +353,7 @@ async def delete_memory(memory_id: str, db: AsyncSession = Depends(get_db)):
|
||||||
|
|
||||||
# ============== CONTEXT BUNDLE ==============
|
# ============== CONTEXT BUNDLE ==============
|
||||||
|
|
||||||
@app.get("/context", response_model=ContextBundle)
|
@app.get("/context", response_model=ContextBundle, dependencies=[Depends(verify_api_key)])
|
||||||
async def get_context(
|
async def get_context(
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
|
skills: Optional[str] = Query(None, description="Comma-separated skill IDs to include"),
|
||||||
|
|
@ -438,19 +388,29 @@ async def get_context(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/context/rag")
|
@app.get("/context/rag", dependencies=[Depends(verify_api_key)])
|
||||||
async def get_context_rag(
|
async def get_context_rag(
|
||||||
query: str,
|
query: str,
|
||||||
project: Optional[str] = None,
|
project: Optional[str] = None,
|
||||||
max_skills: int = 3,
|
max_skills: Optional[int] = None,
|
||||||
max_conventions: int = 2,
|
max_conventions: Optional[int] = None,
|
||||||
max_snippets: int = 2,
|
max_snippets: Optional[int] = None,
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
RAG-based context selection - returns ONLY relevant items.
|
RAG-based context selection - returns ONLY relevant items.
|
||||||
Uses semantic search to find top K most relevant skills/snippets.
|
Uses semantic search to find top K most relevant skills/snippets.
|
||||||
|
|
||||||
|
Uses config defaults if parameters not provided.
|
||||||
"""
|
"""
|
||||||
|
# Use config defaults if not specified
|
||||||
|
if max_skills is None:
|
||||||
|
max_skills = config.rag.max_skills
|
||||||
|
if max_conventions is None:
|
||||||
|
max_conventions = config.rag.max_conventions
|
||||||
|
if max_snippets is None:
|
||||||
|
max_snippets = config.rag.max_snippets
|
||||||
|
|
||||||
bundle = await build_context_bundle(
|
bundle = await build_context_bundle(
|
||||||
query, db, project,
|
query, db, project,
|
||||||
max_skills=max_skills,
|
max_skills=max_skills,
|
||||||
|
|
@ -460,17 +420,28 @@ async def get_context_rag(
|
||||||
return bundle
|
return bundle
|
||||||
|
|
||||||
|
|
||||||
@app.post("/compress")
|
@app.post("/compress", dependencies=[Depends(verify_api_key)])
|
||||||
async def compress_messages(
|
async def compress_messages_endpoint(
|
||||||
messages: List[dict],
|
messages: List[dict],
|
||||||
keep_last_n: int = 3,
|
keep_last_n: Optional[int] = None,
|
||||||
max_tokens: int = 2000
|
max_tokens: Optional[int] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compress conversation history.
|
Compress conversation history.
|
||||||
Keeps last N exchanges in full, summarizes everything before.
|
Uses config strategy (extractive or ollama).
|
||||||
"""
|
"""
|
||||||
compressed = compress_conversation(messages, max_tokens, keep_last_n)
|
# Use config defaults if not specified
|
||||||
|
keep_last_n = keep_last_n or config.compression.keep_last_n
|
||||||
|
max_tokens = max_tokens or config.compression.max_tokens
|
||||||
|
|
||||||
|
compressed = await compress_conversation(
|
||||||
|
messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
keep_last_n=keep_last_n,
|
||||||
|
strategy=config.compression.strategy,
|
||||||
|
ollama_model=config.compression.ollama_model,
|
||||||
|
ollama_url=config.compression.ollama_url
|
||||||
|
)
|
||||||
original_tokens = sum(count_tokens(m.get("content", "")) for m in messages)
|
original_tokens = sum(count_tokens(m.get("content", "")) for m in messages)
|
||||||
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
compressed_tokens = sum(count_tokens(m.get("content", "")) for m in compressed)
|
||||||
|
|
||||||
|
|
@ -483,7 +454,7 @@ async def compress_messages(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/tokens/count")
|
@app.get("/tokens/count", dependencies=[Depends(verify_api_key)])
|
||||||
async def count_tokens_endpoint(text: str):
|
async def count_tokens_endpoint(text: str):
|
||||||
"""Count tokens in text"""
|
"""Count tokens in text"""
|
||||||
return {"tokens": count_tokens(text)}
|
return {"tokens": count_tokens(text)}
|
||||||
|
|
@ -492,3 +463,36 @@ async def count_tokens_endpoint(text: str):
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "healthy"}
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/config")
|
||||||
|
async def get_config():
|
||||||
|
"""Return current configuration (non-sensitive)"""
|
||||||
|
return {
|
||||||
|
"rag": {
|
||||||
|
"max_skills": config.rag.max_skills,
|
||||||
|
"max_conventions": config.rag.max_conventions,
|
||||||
|
"max_snippets": config.rag.max_snippets,
|
||||||
|
"min_skill_score": config.rag.min_skill_score,
|
||||||
|
"min_snippet_score": config.rag.min_snippet_score,
|
||||||
|
"embedding_model": config.rag.embedding_model
|
||||||
|
},
|
||||||
|
"compression": {
|
||||||
|
"enabled": config.compression.enabled,
|
||||||
|
"strategy": config.compression.strategy,
|
||||||
|
"keep_last_n": config.compression.keep_last_n,
|
||||||
|
"max_tokens": config.compression.max_tokens,
|
||||||
|
"ollama_model": config.compression.ollama_model if config.compression.strategy == "ollama" else None
|
||||||
|
},
|
||||||
|
"auth": {
|
||||||
|
"enabled": config.auth.enabled
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/admin/clear-cache", dependencies=[Depends(verify_api_key)])
|
||||||
|
async def clear_cache():
|
||||||
|
"""Clear RAG embedding cache (forces reload on next request)"""
|
||||||
|
clear_rag_cache()
|
||||||
|
return {"status": "cache cleared"}
|
||||||
|
|
||||||
|
|
|
||||||
17
models.py
17
models.py
|
|
@ -1,4 +1,4 @@
|
||||||
from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey
|
from sqlalchemy import Column, String, Text, DateTime, Boolean, Integer, ForeignKey, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from database import Base
|
from database import Base
|
||||||
|
|
@ -12,7 +12,7 @@ class Skill(Base):
|
||||||
description = Column(Text)
|
description = Column(Text)
|
||||||
category = Column(String)
|
category = Column(String)
|
||||||
content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
tags = Column(String)
|
tags = Column(JSON, default=list)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||||
usage_count = Column(Integer, default=0)
|
usage_count = Column(Integer, default=0)
|
||||||
|
|
@ -26,7 +26,7 @@ class Snippet(Base):
|
||||||
language = Column(String)
|
language = Column(String)
|
||||||
content = Column(Text, nullable=False)
|
content = Column(Text, nullable=False)
|
||||||
category = Column(String)
|
category = Column(String)
|
||||||
tags = Column(String)
|
tags = Column(JSON, default=list)
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -41,17 +41,6 @@ class Convention(Base):
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
|
||||||
class Cache(Base):
|
|
||||||
__tablename__ = "cache"
|
|
||||||
|
|
||||||
hash = Column(String, primary_key=True)
|
|
||||||
response = Column(Text, nullable=False)
|
|
||||||
model = Column(String)
|
|
||||||
tokens_in = Column(Integer)
|
|
||||||
tokens_out = Column(Integer)
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
|
||||||
expires_at = Column(DateTime(timezone=True))
|
|
||||||
|
|
||||||
|
|
||||||
class Memory(Base):
|
class Memory(Base):
|
||||||
__tablename__ = "memory"
|
__tablename__ = "memory"
|
||||||
|
|
|
||||||
102
rag.py
102
rag.py
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
RAG-based context selection using local embeddings.
|
RAG-based context selection using local embeddings.
|
||||||
No external API calls - runs entirely on your home server.
|
Optimized with in-memory embedding cache to avoid recomputation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -8,12 +8,21 @@ from sentence_transformers import SentenceTransformer
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
import os
|
import asyncio
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
# Small, fast model - ~100MB, runs on CPU
|
# Small, fast model - ~100MB, runs on CPU
|
||||||
MODEL_NAME = "all-MiniLM-L6-v2"
|
MODEL_NAME = "all-MiniLM-L6-v2"
|
||||||
_model: Optional[SentenceTransformer] = None
|
_model: Optional[SentenceTransformer] = None
|
||||||
|
|
||||||
|
# In-memory embedding cache
|
||||||
|
_embedding_cache: Dict[str, Dict] = {
|
||||||
|
"skills": {},
|
||||||
|
"snippets": {},
|
||||||
|
"initialized": False
|
||||||
|
}
|
||||||
|
_cache_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_model() -> SentenceTransformer:
|
def get_model() -> SentenceTransformer:
|
||||||
"""Lazy-load the embedding model"""
|
"""Lazy-load the embedding model"""
|
||||||
|
|
@ -34,6 +43,51 @@ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
return float(np.dot(a, b))
|
return float(np.dot(a, b))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_skill_text(skill) -> str:
|
||||||
|
"""Generate searchable text for a skill"""
|
||||||
|
return f"{skill.name} {skill.description or ''} {skill.content[:500]}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_snippet_text(snippet) -> str:
|
||||||
|
"""Generate searchable text for a snippet"""
|
||||||
|
return f"{snippet.name} {snippet.content}"
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_cache_initialized(db: AsyncSession):
|
||||||
|
"""Load all skills and snippets into memory with their embeddings"""
|
||||||
|
global _embedding_cache
|
||||||
|
|
||||||
|
with _cache_lock:
|
||||||
|
if _embedding_cache["initialized"]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Load skills
|
||||||
|
result = await db.execute(select(Skill))
|
||||||
|
skills = result.scalars().all()
|
||||||
|
|
||||||
|
for skill in skills:
|
||||||
|
text = _get_skill_text(skill)
|
||||||
|
embedding = embed_text(text)
|
||||||
|
_embedding_cache["skills"][skill.id] = {
|
||||||
|
"embedding": embedding,
|
||||||
|
"skill": skill
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load snippets
|
||||||
|
result = await db.execute(select(Snippet))
|
||||||
|
snippets = result.scalars().all()
|
||||||
|
|
||||||
|
for snippet in snippets:
|
||||||
|
text = _get_snippet_text(snippet)
|
||||||
|
embedding = embed_text(text)
|
||||||
|
_embedding_cache["snippets"][snippet.id] = {
|
||||||
|
"embedding": embedding,
|
||||||
|
"snippet": snippet
|
||||||
|
}
|
||||||
|
|
||||||
|
_embedding_cache["initialized"] = True
|
||||||
|
|
||||||
|
|
||||||
async def select_relevant_skills(
|
async def select_relevant_skills(
|
||||||
query: str,
|
query: str,
|
||||||
db: AsyncSession,
|
db: AsyncSession,
|
||||||
|
|
@ -42,27 +96,24 @@ async def select_relevant_skills(
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
Find most relevant skills using semantic search.
|
Find most relevant skills using semantic search.
|
||||||
Only returns skills above minimum similarity threshold.
|
Uses cached embeddings for O(1) retrieval after initial load.
|
||||||
"""
|
"""
|
||||||
from models import Skill
|
from models import Skill
|
||||||
|
|
||||||
# Get all skills (for small datasets, load all - fine for <1000 items)
|
await ensure_cache_initialized(db)
|
||||||
result = await db.execute(select(Skill))
|
|
||||||
skills = result.scalars().all()
|
|
||||||
|
|
||||||
if not skills:
|
if not _embedding_cache["skills"]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Generate query embedding
|
# Generate query embedding
|
||||||
query_embedding = embed_text(query)
|
query_embedding = embed_text(query)
|
||||||
|
|
||||||
# Score each skill
|
# Score each cached skill
|
||||||
scored = []
|
scored = []
|
||||||
for skill in skills:
|
for skill_id, cached in _embedding_cache["skills"].items():
|
||||||
# Use cached embedding if available, else compute
|
skill = cached["skill"]
|
||||||
skill_text = f"{skill.name} {skill.description or ''} {skill.content[:500]}"
|
embedding = cached["embedding"]
|
||||||
skill_embedding = embed_text(skill_text)
|
score = cosine_similarity(query_embedding, embedding)
|
||||||
score = cosine_similarity(query_embedding, skill_embedding)
|
|
||||||
|
|
||||||
if score >= min_score:
|
if score >= min_score:
|
||||||
scored.append((score, skill))
|
scored.append((score, skill))
|
||||||
|
|
@ -90,6 +141,7 @@ async def select_relevant_conventions(
|
||||||
"""
|
"""
|
||||||
Get conventions for a project.
|
Get conventions for a project.
|
||||||
Exact match on project_path, plus fuzzy match on parent paths.
|
Exact match on project_path, plus fuzzy match on parent paths.
|
||||||
|
(No embeddings - small dataset, exact matching is fine)
|
||||||
"""
|
"""
|
||||||
from models import Convention
|
from models import Convention
|
||||||
|
|
||||||
|
|
@ -128,25 +180,24 @@ async def select_relevant_snippets(
|
||||||
top_k: int = 2,
|
top_k: int = 2,
|
||||||
language: Optional[str] = None
|
language: Optional[str] = None
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Find relevant code snippets"""
|
"""Find relevant code snippets using cached embeddings"""
|
||||||
from models import Snippet
|
from models import Snippet
|
||||||
|
|
||||||
result = await db.execute(select(Snippet))
|
await ensure_cache_initialized(db)
|
||||||
snippets = result.scalars().all()
|
|
||||||
|
|
||||||
if not snippets:
|
if not _embedding_cache["snippets"]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
query_embedding = embed_text(query)
|
query_embedding = embed_text(query)
|
||||||
|
|
||||||
scored = []
|
scored = []
|
||||||
for snippet in snippets:
|
for snippet_id, cached in _embedding_cache["snippets"].items():
|
||||||
|
snippet = cached["snippet"]
|
||||||
if language and snippet.language != language:
|
if language and snippet.language != language:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
snippet_text = f"{snippet.name} {snippet.content}"
|
embedding = cached["embedding"]
|
||||||
snippet_embedding = embed_text(snippet_text)
|
score = cosine_similarity(query_embedding, embedding)
|
||||||
score = cosine_similarity(query_embedding, snippet_embedding)
|
|
||||||
|
|
||||||
if score >= 0.25: # Lower threshold for snippets
|
if score >= 0.25: # Lower threshold for snippets
|
||||||
scored.append((score, snippet))
|
scored.append((score, snippet))
|
||||||
|
|
@ -202,4 +253,11 @@ async def build_context_bundle(
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
def clear_cache():
|
||||||
|
"""Clear embedding cache (useful for testing)"""
|
||||||
|
global _embedding_cache
|
||||||
|
_embedding_cache = {
|
||||||
|
"skills": {},
|
||||||
|
"snippets": {},
|
||||||
|
"initialized": False
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,6 @@ aiosqlite==0.19.0
|
||||||
sentence-transformers==2.3.1
|
sentence-transformers==2.3.1
|
||||||
numpy==1.26.3
|
numpy==1.26.3
|
||||||
tiktoken==0.5.2
|
tiktoken==0.5.2
|
||||||
|
pyyaml==6.0
|
||||||
|
httpx==0.27.0
|
||||||
|
sumy==0.11.0
|
||||||
|
|
|
||||||
25
schemas.py
25
schemas.py
|
|
@ -52,26 +52,6 @@ class Convention(ConventionBase):
|
||||||
from_attributes = True
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
class CacheBase(BaseModel):
|
|
||||||
hash: str
|
|
||||||
response: str
|
|
||||||
model: Optional[str] = None
|
|
||||||
tokens_in: Optional[int] = None
|
|
||||||
tokens_out: Optional[int] = None
|
|
||||||
expires_at: Optional[datetime] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CacheStore(CacheBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(CacheBase):
|
|
||||||
created_at: datetime
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
from_attributes = True
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryBase(BaseModel):
|
class MemoryBase(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
project: str
|
project: str
|
||||||
|
|
@ -92,8 +72,3 @@ class ContextBundle(BaseModel):
|
||||||
snippets: List[Snippet]
|
snippets: List[Snippet]
|
||||||
conventions: List[Convention]
|
conventions: List[Convention]
|
||||||
memories: List[Memory]
|
memories: List[Memory]
|
||||||
|
|
||||||
|
|
||||||
class CacheLookup(BaseModel):
|
|
||||||
prompt: str
|
|
||||||
model: Optional[str] = None
|
|
||||||
|
|
|
||||||
|
|
@ -1,169 +0,0 @@
|
||||||
"""
|
|
||||||
Semantic cache - matches similar prompts, not just exact hashes.
|
|
||||||
Uses embeddings to find similar questions and return cached responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from sqlalchemy import select, func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Optional, Dict, List
|
|
||||||
import json
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
from rag import embed_text, cosine_similarity
|
|
||||||
from compression import count_tokens
|
|
||||||
|
|
||||||
|
|
||||||
async def semantic_cache_lookup(
|
|
||||||
prompt: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
min_similarity: float = 0.85,
|
|
||||||
max_age_hours: int = 168 # 1 week
|
|
||||||
) -> Optional[Dict]:
|
|
||||||
"""
|
|
||||||
Find cached responses for semantically similar prompts.
|
|
||||||
|
|
||||||
Returns cached response if similarity >= threshold and not expired.
|
|
||||||
"""
|
|
||||||
from models import Cache
|
|
||||||
|
|
||||||
# Generate embedding for the query
|
|
||||||
query_embedding = embed_text(prompt)
|
|
||||||
|
|
||||||
# Get non-expired cache entries
|
|
||||||
expiry = datetime.now() - timedelta(hours=max_age_hours)
|
|
||||||
result = await db.execute(
|
|
||||||
select(Cache)
|
|
||||||
.where(
|
|
||||||
(Cache.expires_at == None) | (Cache.expires_at > datetime.now())
|
|
||||||
)
|
|
||||||
.where(Cache.created_at > expiry)
|
|
||||||
)
|
|
||||||
cache_entries = result.scalars().all()
|
|
||||||
|
|
||||||
if not cache_entries:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Score each cache entry
|
|
||||||
best_match = None
|
|
||||||
best_score = 0
|
|
||||||
|
|
||||||
for entry in cache_entries:
|
|
||||||
# Skip if model doesn't match (optional)
|
|
||||||
if model and entry.model and entry.model != model:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Compute similarity
|
|
||||||
entry_embedding = embed_text(entry.response) # Or store prompt embedding
|
|
||||||
score = cosine_similarity(query_embedding, entry_embedding)
|
|
||||||
|
|
||||||
if score >= min_similarity and score > best_score:
|
|
||||||
best_score = score
|
|
||||||
best_match = entry
|
|
||||||
|
|
||||||
if best_match:
|
|
||||||
return {
|
|
||||||
"response": best_match.response,
|
|
||||||
"similarity": best_score,
|
|
||||||
"model": best_match.model,
|
|
||||||
"tokens_saved": (best_match.tokens_in or 0) + (best_match.tokens_out or 0),
|
|
||||||
"cached_at": best_match.created_at
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def semantic_cache_store(
|
|
||||||
prompt: str,
|
|
||||||
response: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
tokens_in: Optional[int] = None,
|
|
||||||
tokens_out: Optional[int] = None,
|
|
||||||
ttl_hours: Optional[int] = None
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Store response in cache with embedding for semantic matching.
|
|
||||||
"""
|
|
||||||
from models import Cache
|
|
||||||
|
|
||||||
# Generate hash for deduplication
|
|
||||||
prompt_hash = hashlib.sha256(
|
|
||||||
json.dumps({"prompt": prompt, "model": model}, sort_keys=True).encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
# Check if exact match already exists
|
|
||||||
existing = await db.execute(
|
|
||||||
select(Cache).where(Cache.hash == prompt_hash)
|
|
||||||
)
|
|
||||||
if existing.scalar_one_or_none():
|
|
||||||
return {"status": "exists", "hash": prompt_hash}
|
|
||||||
|
|
||||||
# Create new entry
|
|
||||||
expires_at = None
|
|
||||||
if ttl_hours:
|
|
||||||
expires_at = datetime.now() + timedelta(hours=ttl_hours)
|
|
||||||
|
|
||||||
new_entry = Cache(
|
|
||||||
hash=prompt_hash,
|
|
||||||
response=response,
|
|
||||||
model=model,
|
|
||||||
tokens_in=tokens_in,
|
|
||||||
tokens_out=tokens_out,
|
|
||||||
expires_at=expires_at
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add(new_entry)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "stored",
|
|
||||||
"hash": prompt_hash,
|
|
||||||
"tokens_stored": (tokens_in or 0) + (tokens_out or 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def get_cache_stats(db: AsyncSession) -> Dict:
|
|
||||||
"""Get cache statistics"""
|
|
||||||
from models import Cache
|
|
||||||
|
|
||||||
result = await db.execute(select(Cache))
|
|
||||||
entries = result.scalars().all()
|
|
||||||
|
|
||||||
now = datetime.now()
|
|
||||||
valid_entries = [
|
|
||||||
e for e in entries
|
|
||||||
if e.expires_at is None or e.expires_at > now
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_entries": len(entries),
|
|
||||||
"valid_entries": len(valid_entries),
|
|
||||||
"total_tokens_stored": sum(
|
|
||||||
(e.tokens_in or 0) + (e.tokens_out or 0) for e in valid_entries
|
|
||||||
),
|
|
||||||
"models_used": list(set(e.model for e in entries if e.model))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def clear_old_cache(
|
|
||||||
db: AsyncSession,
|
|
||||||
older_than_hours: int = 168
|
|
||||||
) -> int:
|
|
||||||
"""Delete cache entries older than threshold"""
|
|
||||||
from models import Cache
|
|
||||||
|
|
||||||
cutoff = datetime.now() - timedelta(hours=older_than_hours)
|
|
||||||
result = await db.execute(
|
|
||||||
select(Cache).where(Cache.created_at < cutoff)
|
|
||||||
)
|
|
||||||
old_entries = result.scalars().all()
|
|
||||||
|
|
||||||
for entry in old_entries:
|
|
||||||
await db.delete(entry)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return len(old_entries)
|
|
||||||
Loading…
Add table
Reference in a new issue