ai-skills-api/config.py

144 lines
5.5 KiB
Python

"""
Configuration management for AI Skills API.
Supports YAML config file with sensible defaults.
Priority: env vars > config file > defaults
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional
import yaml
@dataclass
class RAGConfig:
"""RAG-specific configuration"""
max_skills: int = 3
max_conventions: int = 2
max_snippets: int = 2
min_skill_score: float = 0.3
min_snippet_score: float = 0.25
embedding_model: str = "all-MiniLM-L6-v2"
@dataclass
class CompressionConfig:
"""Compression configuration"""
enabled: bool = True
strategy: str = "extractive" # "extractive", "ollama", or "none"
keep_last_n: int = 3
max_tokens: int = 2000
ollama_model: str = "phi3:mini"
ollama_url: str = "http://localhost:11434"
@dataclass
class AuthConfig:
"""Authentication configuration"""
enabled: bool = False # Set to True to require API keys
api_key: str = "change-me-in-production" # Single shared key for simplicity
header_name: str = "X-API-Key"
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = "INFO"
format: str = "json" # "json" or "text"
@dataclass
class Config:
"""Main configuration"""
host: str = "0.0.0.0"
port: int = 8675
database_url: str = "sqlite+aiosqlite:///./ai.db"
rag: RAGConfig = field(default_factory=RAGConfig)
compression: CompressionConfig = field(default_factory=CompressionConfig)
auth: AuthConfig = field(default_factory=AuthConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
cors_origins: List[str] = field(default_factory=lambda: ["*"])
def _env_override(prefix: str, key: str, default):
"""Check for environment variable with prefix"""
env_key = f"{prefix}_{key}".upper()
value = os.getenv(env_key)
if value is not None:
# Convert to appropriate type
if isinstance(default, bool):
return value.lower() in ('true', '1', 'yes', 'on')
elif isinstance(default, int):
return int(value)
elif isinstance(default, float):
return float(value)
elif isinstance(default, list):
return value.split(',')
else:
return value
return default
def load_config(config_path: str = "/app/config.yaml") -> Config:
"""
Load configuration from YAML file with environment variable overrides.
Priority: env vars > config file > defaults.
"""
# Start with defaults
config_dict = {}
# Load from config file if exists
locations = [config_path, "/app/config.yaml", "./config.yaml"]
file_data = {}
for path in locations:
if os.path.exists(path):
with open(path, 'r') as f:
file_data = yaml.safe_load(f) or {}
break
# Build config with file values as base
config_dict.update(file_data)
# Override with environment variables
# Top-level settings
config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0'))
config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675))
config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db'))
config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"])
# Nested configs
rag_dict = file_data.get('rag', {})
config_dict['rag'] = {
'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)),
'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)),
'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)),
'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)),
'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)),
'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')),
}
compression_dict = file_data.get('compression', {})
config_dict['compression'] = {
'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)),
'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')),
'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)),
'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)),
'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')),
'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')),
}
auth_dict = file_data.get('auth', {})
config_dict['auth'] = {
'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)),
'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')),
'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')),
}
logging_dict = file_data.get('logging', {})
config_dict['logging'] = {
'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')),
'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')),
}
return Config(**config_dict)