""" Configuration management for AI Skills API. Supports YAML config file with environment variable overrides. """ import os from dataclasses import dataclass, field from typing import List import yaml @dataclass class RAGConfig: max_skills: int = 3 max_conventions: int = 2 max_snippets: int = 2 min_skill_score: float = 0.3 min_snippet_score: float = 0.25 embedding_model: str = "all-MiniLM-L6-v2" @dataclass class CompressionConfig: enabled: bool = True strategy: str = "extractive" keep_last_n: int = 3 max_tokens: int = 2000 ollama_model: str = "phi3:mini" ollama_url: str = "http://localhost:11434" @dataclass class AuthConfig: enabled: bool = False api_key: str = "change-me-in-production" header_name: str = "X-API-Key" @dataclass class LoggingConfig: level: str = "INFO" format: str = "json" @dataclass class Config: host: str = "0.0.0.0" port: int = 8675 database_url: str = "sqlite+aiosqlite:///./ai.db" rag: RAGConfig = field(default_factory=RAGConfig) compression: CompressionConfig = field(default_factory=CompressionConfig) auth: AuthConfig = field(default_factory=AuthConfig) logging: LoggingConfig = field(default_factory=LoggingConfig) cors_origins: List[str] = field(default_factory=lambda: ["*"]) def _coerce_value(value, cast_type): """Coerce string to type""" if cast_type is bool: return str(value).lower() in ('true', '1', 'yes', 'on') if cast_type is int: return int(value) if cast_type is float: return float(value) if cast_type is list: return str(value).split(',') if value else [] return str(value) if value is not None else None def load_config(config_path: str = "/app/config.yaml") -> Config: """Load configuration from YAML file with environment overrides""" # Find config file locations = [config_path, "/app/config.yaml", "./config.yaml"] file_data = {} for path in locations: if os.path.exists(path): with open(path, 'r') as f: file_data = yaml.safe_load(f) or {} break # Helper to get value: env > file > default def get(prefix, key, default, cast=None): env_key = f"{prefix}_{key}".upper() if prefix else key.upper() env_val = os.getenv(env_key) if env_val is not None: return _coerce_value(env_val, cast if cast else type(default)) if key in file_data: return file_data[key] return default # Build nested configs rag_file = file_data.get('rag', {}) rag = RAGConfig( max_skills=get('RAG', 'MAX_SKILLS', rag_file.get('max_skills', 3), int), max_conventions=get('RAG', 'MAX_CONVENTIONS', rag_file.get('max_conventions', 2), int), max_snippets=get('RAG', 'MAX_SNIPPETS', rag_file.get('max_snippets', 2), int), min_skill_score=get('RAG', 'MIN_SKILL_SCORE', rag_file.get('min_skill_score', 0.3), float), min_snippet_score=get('RAG', 'MIN_SNIPPET_SCORE', rag_file.get('min_snippet_score', 0.25), float), embedding_model=get('RAG', 'EMBEDDING_MODEL', rag_file.get('embedding_model', 'all-MiniLM-L6-v2'), str), ) comp_file = file_data.get('compression', {}) compression = CompressionConfig( enabled=get('COMPRESSION', 'ENABLED', comp_file.get('enabled', True), bool), strategy=get('COMPRESSION', 'STRATEGY', comp_file.get('strategy', 'extractive'), str), keep_last_n=get('COMPRESSION', 'KEEP_LAST_N', comp_file.get('keep_last_n', 3), int), max_tokens=get('COMPRESSION', 'MAX_TOKENS', comp_file.get('max_tokens', 2000), int), ollama_model=get('COMPRESSION', 'OLLAMA_MODEL', comp_file.get('ollama_model', 'phi3:mini'), str), ollama_url=get('COMPRESSION', 'OLLAMA_URL', comp_file.get('ollama_url', 'http://localhost:11434'), str), ) auth_file = file_data.get('auth', {}) auth = AuthConfig( enabled=get('AUTH', 'ENABLED', auth_file.get('enabled', False), bool), api_key=os.getenv('API_KEY', auth_file.get('api_key', 'change-me-in-production')), header_name=get('AUTH', 'HEADER_NAME', auth_file.get('header_name', 'X-API-Key'), str), ) log_file = file_data.get('logging', {}) logging = LoggingConfig( level=get('LOGGING', 'LEVEL', log_file.get('level', 'INFO'), str), format=get('LOGGING', 'FORMAT', log_file.get('format', 'json'), str), ) return Config( host=get('APP', 'HOST', file_data.get('host', '0.0.0.0'), str), port=get('APP', 'PORT', file_data.get('port', 8675), int), database_url=os.getenv('DATABASE_URL', file_data.get('database_url', 'sqlite+aiosqlite:///./ai.db')), rag=rag, compression=compression, auth=auth, logging=logging, cors_origins=file_data.get('cors_origins', ["*"]) )