ai-skills-api/config.py

135 lines
4.8 KiB
Python

"""
Configuration management for AI Skills API.
Supports YAML config file with environment variable overrides.
"""
import os
from dataclasses import dataclass, field
from typing import List
import yaml
@dataclass
class RAGConfig:
max_skills: int = 3
max_conventions: int = 2
max_snippets: int = 2
min_skill_score: float = 0.3
min_snippet_score: float = 0.25
embedding_model: str = "all-MiniLM-L6-v2"
@dataclass
class CompressionConfig:
enabled: bool = True
strategy: str = "extractive"
keep_last_n: int = 3
max_tokens: int = 2000
ollama_model: str = "phi3:mini"
ollama_url: str = "http://localhost:11434"
@dataclass
class AuthConfig:
enabled: bool = False
api_key: str = "change-me-in-production"
header_name: str = "X-API-Key"
@dataclass
class LoggingConfig:
level: str = "INFO"
format: str = "json"
@dataclass
class Config:
host: str = "0.0.0.0"
port: int = 8675
database_url: str = "sqlite+aiosqlite:///./ai.db"
rag: RAGConfig = field(default_factory=RAGConfig)
compression: CompressionConfig = field(default_factory=CompressionConfig)
auth: AuthConfig = field(default_factory=AuthConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
cors_origins: List[str] = field(default_factory=lambda: ["*"])
def _coerce_value(value, cast_type):
"""Coerce string to type"""
if cast_type is bool:
return str(value).lower() in ('true', '1', 'yes', 'on')
if cast_type is int:
return int(value)
if cast_type is float:
return float(value)
if cast_type is list:
return str(value).split(',') if value else []
return str(value) if value is not None else None
def load_config(config_path: str = "/app/config.yaml") -> Config:
"""Load configuration from YAML file with environment overrides"""
# Find config file
locations = [config_path, "/app/config.yaml", "./config.yaml"]
file_data = {}
for path in locations:
if os.path.exists(path):
with open(path, 'r') as f:
file_data = yaml.safe_load(f) or {}
break
# Helper to get value: env > file > default
def get(prefix, key, default, cast=None):
env_key = f"{prefix}_{key}".upper() if prefix else key.upper()
env_val = os.getenv(env_key)
if env_val is not None:
return _coerce_value(env_val, cast if cast else type(default))
if key in file_data:
return file_data[key]
return default
# Build nested configs
rag_file = file_data.get('rag', {})
rag = RAGConfig(
max_skills=get('RAG', 'MAX_SKILLS', rag_file.get('max_skills', 3), int),
max_conventions=get('RAG', 'MAX_CONVENTIONS', rag_file.get('max_conventions', 2), int),
max_snippets=get('RAG', 'MAX_SNIPPETS', rag_file.get('max_snippets', 2), int),
min_skill_score=get('RAG', 'MIN_SKILL_SCORE', rag_file.get('min_skill_score', 0.3), float),
min_snippet_score=get('RAG', 'MIN_SNIPPET_SCORE', rag_file.get('min_snippet_score', 0.25), float),
embedding_model=get('RAG', 'EMBEDDING_MODEL', rag_file.get('embedding_model', 'all-MiniLM-L6-v2'), str),
)
comp_file = file_data.get('compression', {})
compression = CompressionConfig(
enabled=get('COMPRESSION', 'ENABLED', comp_file.get('enabled', True), bool),
strategy=get('COMPRESSION', 'STRATEGY', comp_file.get('strategy', 'extractive'), str),
keep_last_n=get('COMPRESSION', 'KEEP_LAST_N', comp_file.get('keep_last_n', 3), int),
max_tokens=get('COMPRESSION', 'MAX_TOKENS', comp_file.get('max_tokens', 2000), int),
ollama_model=get('COMPRESSION', 'OLLAMA_MODEL', comp_file.get('ollama_model', 'phi3:mini'), str),
ollama_url=get('COMPRESSION', 'OLLAMA_URL', comp_file.get('ollama_url', 'http://localhost:11434'), str),
)
auth_file = file_data.get('auth', {})
auth = AuthConfig(
enabled=get('AUTH', 'ENABLED', auth_file.get('enabled', False), bool),
api_key=os.getenv('API_KEY', auth_file.get('api_key', 'change-me-in-production')),
header_name=get('AUTH', 'HEADER_NAME', auth_file.get('header_name', 'X-API-Key'), str),
)
log_file = file_data.get('logging', {})
logging = LoggingConfig(
level=get('LOGGING', 'LEVEL', log_file.get('level', 'INFO'), str),
format=get('LOGGING', 'FORMAT', log_file.get('format', 'json'), str),
)
return Config(
host=get('APP', 'HOST', file_data.get('host', '0.0.0.0'), str),
port=get('APP', 'PORT', file_data.get('port', 8675), int),
database_url=os.getenv('DATABASE_URL', file_data.get('database_url', 'sqlite+aiosqlite:///./ai.db')),
rag=rag,
compression=compression,
auth=auth,
logging=logging,
cors_origins=file_data.get('cors_origins', ["*"])
)