From 95805dfc860683b10cf75dadb8bb006012d91481 Mon Sep 17 00:00:00 2001 From: Lukas Parsons Date: Sun, 22 Mar 2026 23:21:31 -0400 Subject: [PATCH] Fix config loading to return proper dataclass objects instead of dicts --- config.py | 145 +++++++++++++++++++++++++----------------------------- 1 file changed, 68 insertions(+), 77 deletions(-) diff --git a/config.py b/config.py index 382bedd..76e731f 100644 --- a/config.py +++ b/config.py @@ -1,18 +1,16 @@ """ Configuration management for AI Skills API. -Supports YAML config file with sensible defaults. -Priority: env vars > config file > defaults +Supports YAML config file with environment variable overrides. """ import os from dataclasses import dataclass, field -from typing import List, Optional +from typing import List import yaml @dataclass class RAGConfig: - """RAG-specific configuration""" max_skills: int = 3 max_conventions: int = 2 max_snippets: int = 2 @@ -23,9 +21,8 @@ class RAGConfig: @dataclass class CompressionConfig: - """Compression configuration""" enabled: bool = True - strategy: str = "extractive" # "extractive", "ollama", or "none" + strategy: str = "extractive" keep_last_n: int = 3 max_tokens: int = 2000 ollama_model: str = "phi3:mini" @@ -34,22 +31,19 @@ class CompressionConfig: @dataclass class AuthConfig: - """Authentication configuration""" - enabled: bool = False # Set to True to require API keys - api_key: str = "change-me-in-production" # Single shared key for simplicity + enabled: bool = False + api_key: str = "change-me-in-production" header_name: str = "X-API-Key" @dataclass class LoggingConfig: - """Logging configuration""" level: str = "INFO" - format: str = "json" # "json" or "text" + format: str = "json" @dataclass class Config: - """Main configuration""" host: str = "0.0.0.0" port: int = 8675 database_url: str = "sqlite+aiosqlite:///./ai.db" @@ -60,34 +54,22 @@ class Config: cors_origins: List[str] = field(default_factory=lambda: ["*"]) -def _env_override(prefix: str, key: str, default): - """Check for environment variable with prefix""" - env_key = f"{prefix}_{key}".upper() - value = os.getenv(env_key) - if value is not None: - # Convert to appropriate type - if isinstance(default, bool): - return value.lower() in ('true', '1', 'yes', 'on') - elif isinstance(default, int): - return int(value) - elif isinstance(default, float): - return float(value) - elif isinstance(default, list): - return value.split(',') - else: - return value - return default +def _coerce_value(value, cast_type): + """Coerce string to type""" + if cast_type is bool: + return str(value).lower() in ('true', '1', 'yes', 'on') + if cast_type is int: + return int(value) + if cast_type is float: + return float(value) + if cast_type is list: + return str(value).split(',') if value else [] + return str(value) if value is not None else None def load_config(config_path: str = "/app/config.yaml") -> Config: - """ - Load configuration from YAML file with environment variable overrides. - Priority: env vars > config file > defaults. - """ - # Start with defaults - config_dict = {} - - # Load from config file if exists + """Load configuration from YAML file with environment overrides""" + # Find config file locations = [config_path, "/app/config.yaml", "./config.yaml"] file_data = {} @@ -97,48 +79,57 @@ def load_config(config_path: str = "/app/config.yaml") -> Config: file_data = yaml.safe_load(f) or {} break - # Build config with file values as base - config_dict.update(file_data) + # Helper to get value: env > file > default + def get(prefix, key, default, cast=None): + env_key = f"{prefix}_{key}".upper() if prefix else key.upper() + env_val = os.getenv(env_key) + if env_val is not None: + return _coerce_value(env_val, cast if cast else type(default)) + if key in file_data: + return file_data[key] + return default - # Override with environment variables - # Top-level settings - config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0')) - config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675)) - config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db')) - config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"]) + # Build nested configs + rag_file = file_data.get('rag', {}) + rag = RAGConfig( + max_skills=get('RAG', 'MAX_SKILLS', rag_file.get('max_skills', 3), int), + max_conventions=get('RAG', 'MAX_CONVENTIONS', rag_file.get('max_conventions', 2), int), + max_snippets=get('RAG', 'MAX_SNIPPETS', rag_file.get('max_snippets', 2), int), + min_skill_score=get('RAG', 'MIN_SKILL_SCORE', rag_file.get('min_skill_score', 0.3), float), + min_snippet_score=get('RAG', 'MIN_SNIPPET_SCORE', rag_file.get('min_snippet_score', 0.25), float), + embedding_model=get('RAG', 'EMBEDDING_MODEL', rag_file.get('embedding_model', 'all-MiniLM-L6-v2'), str), + ) - # Nested configs - rag_dict = file_data.get('rag', {}) - config_dict['rag'] = { - 'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)), - 'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)), - 'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)), - 'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)), - 'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)), - 'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')), - } + comp_file = file_data.get('compression', {}) + compression = CompressionConfig( + enabled=get('COMPRESSION', 'ENABLED', comp_file.get('enabled', True), bool), + strategy=get('COMPRESSION', 'STRATEGY', comp_file.get('strategy', 'extractive'), str), + keep_last_n=get('COMPRESSION', 'KEEP_LAST_N', comp_file.get('keep_last_n', 3), int), + max_tokens=get('COMPRESSION', 'MAX_TOKENS', comp_file.get('max_tokens', 2000), int), + ollama_model=get('COMPRESSION', 'OLLAMA_MODEL', comp_file.get('ollama_model', 'phi3:mini'), str), + ollama_url=get('COMPRESSION', 'OLLAMA_URL', comp_file.get('ollama_url', 'http://localhost:11434'), str), + ) - compression_dict = file_data.get('compression', {}) - config_dict['compression'] = { - 'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)), - 'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')), - 'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)), - 'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)), - 'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')), - 'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')), - } + auth_file = file_data.get('auth', {}) + auth = AuthConfig( + enabled=get('AUTH', 'ENABLED', auth_file.get('enabled', False), bool), + api_key=os.getenv('API_KEY', auth_file.get('api_key', 'change-me-in-production')), + header_name=get('AUTH', 'HEADER_NAME', auth_file.get('header_name', 'X-API-Key'), str), + ) - auth_dict = file_data.get('auth', {}) - config_dict['auth'] = { - 'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)), - 'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')), - 'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')), - } + log_file = file_data.get('logging', {}) + logging = LoggingConfig( + level=get('LOGGING', 'LEVEL', log_file.get('level', 'INFO'), str), + format=get('LOGGING', 'FORMAT', log_file.get('format', 'json'), str), + ) - logging_dict = file_data.get('logging', {}) - config_dict['logging'] = { - 'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')), - 'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')), - } - - return Config(**config_dict) + return Config( + host=get('APP', 'HOST', file_data.get('host', '0.0.0.0'), str), + port=get('APP', 'PORT', file_data.get('port', 8675), int), + database_url=os.getenv('DATABASE_URL', file_data.get('database_url', 'sqlite+aiosqlite:///./ai.db')), + rag=rag, + compression=compression, + auth=auth, + logging=logging, + cors_origins=file_data.get('cors_origins', ["*"]) + )