Fix config loading to return proper dataclass objects instead of dicts
This commit is contained in:
parent
62637acb6f
commit
95805dfc86
1 changed files with 68 additions and 77 deletions
141
config.py
141
config.py
|
|
@ -1,18 +1,16 @@
|
||||||
"""
|
"""
|
||||||
Configuration management for AI Skills API.
|
Configuration management for AI Skills API.
|
||||||
Supports YAML config file with sensible defaults.
|
Supports YAML config file with environment variable overrides.
|
||||||
Priority: env vars > config file > defaults
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
from typing import List
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RAGConfig:
|
class RAGConfig:
|
||||||
"""RAG-specific configuration"""
|
|
||||||
max_skills: int = 3
|
max_skills: int = 3
|
||||||
max_conventions: int = 2
|
max_conventions: int = 2
|
||||||
max_snippets: int = 2
|
max_snippets: int = 2
|
||||||
|
|
@ -23,9 +21,8 @@ class RAGConfig:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CompressionConfig:
|
class CompressionConfig:
|
||||||
"""Compression configuration"""
|
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
strategy: str = "extractive" # "extractive", "ollama", or "none"
|
strategy: str = "extractive"
|
||||||
keep_last_n: int = 3
|
keep_last_n: int = 3
|
||||||
max_tokens: int = 2000
|
max_tokens: int = 2000
|
||||||
ollama_model: str = "phi3:mini"
|
ollama_model: str = "phi3:mini"
|
||||||
|
|
@ -34,22 +31,19 @@ class CompressionConfig:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AuthConfig:
|
class AuthConfig:
|
||||||
"""Authentication configuration"""
|
enabled: bool = False
|
||||||
enabled: bool = False # Set to True to require API keys
|
api_key: str = "change-me-in-production"
|
||||||
api_key: str = "change-me-in-production" # Single shared key for simplicity
|
|
||||||
header_name: str = "X-API-Key"
|
header_name: str = "X-API-Key"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoggingConfig:
|
class LoggingConfig:
|
||||||
"""Logging configuration"""
|
|
||||||
level: str = "INFO"
|
level: str = "INFO"
|
||||||
format: str = "json" # "json" or "text"
|
format: str = "json"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
"""Main configuration"""
|
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 8675
|
port: int = 8675
|
||||||
database_url: str = "sqlite+aiosqlite:///./ai.db"
|
database_url: str = "sqlite+aiosqlite:///./ai.db"
|
||||||
|
|
@ -60,34 +54,22 @@ class Config:
|
||||||
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
cors_origins: List[str] = field(default_factory=lambda: ["*"])
|
||||||
|
|
||||||
|
|
||||||
def _env_override(prefix: str, key: str, default):
|
def _coerce_value(value, cast_type):
|
||||||
"""Check for environment variable with prefix"""
|
"""Coerce string to type"""
|
||||||
env_key = f"{prefix}_{key}".upper()
|
if cast_type is bool:
|
||||||
value = os.getenv(env_key)
|
return str(value).lower() in ('true', '1', 'yes', 'on')
|
||||||
if value is not None:
|
if cast_type is int:
|
||||||
# Convert to appropriate type
|
|
||||||
if isinstance(default, bool):
|
|
||||||
return value.lower() in ('true', '1', 'yes', 'on')
|
|
||||||
elif isinstance(default, int):
|
|
||||||
return int(value)
|
return int(value)
|
||||||
elif isinstance(default, float):
|
if cast_type is float:
|
||||||
return float(value)
|
return float(value)
|
||||||
elif isinstance(default, list):
|
if cast_type is list:
|
||||||
return value.split(',')
|
return str(value).split(',') if value else []
|
||||||
else:
|
return str(value) if value is not None else None
|
||||||
return value
|
|
||||||
return default
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(config_path: str = "/app/config.yaml") -> Config:
|
def load_config(config_path: str = "/app/config.yaml") -> Config:
|
||||||
"""
|
"""Load configuration from YAML file with environment overrides"""
|
||||||
Load configuration from YAML file with environment variable overrides.
|
# Find config file
|
||||||
Priority: env vars > config file > defaults.
|
|
||||||
"""
|
|
||||||
# Start with defaults
|
|
||||||
config_dict = {}
|
|
||||||
|
|
||||||
# Load from config file if exists
|
|
||||||
locations = [config_path, "/app/config.yaml", "./config.yaml"]
|
locations = [config_path, "/app/config.yaml", "./config.yaml"]
|
||||||
file_data = {}
|
file_data = {}
|
||||||
|
|
||||||
|
|
@ -97,48 +79,57 @@ def load_config(config_path: str = "/app/config.yaml") -> Config:
|
||||||
file_data = yaml.safe_load(f) or {}
|
file_data = yaml.safe_load(f) or {}
|
||||||
break
|
break
|
||||||
|
|
||||||
# Build config with file values as base
|
# Helper to get value: env > file > default
|
||||||
config_dict.update(file_data)
|
def get(prefix, key, default, cast=None):
|
||||||
|
env_key = f"{prefix}_{key}".upper() if prefix else key.upper()
|
||||||
|
env_val = os.getenv(env_key)
|
||||||
|
if env_val is not None:
|
||||||
|
return _coerce_value(env_val, cast if cast else type(default))
|
||||||
|
if key in file_data:
|
||||||
|
return file_data[key]
|
||||||
|
return default
|
||||||
|
|
||||||
# Override with environment variables
|
# Build nested configs
|
||||||
# Top-level settings
|
rag_file = file_data.get('rag', {})
|
||||||
config_dict['host'] = _env_override('APP', 'HOST', config_dict.get('host', '0.0.0.0'))
|
rag = RAGConfig(
|
||||||
config_dict['port'] = _env_override('APP', 'PORT', config_dict.get('port', 8675))
|
max_skills=get('RAG', 'MAX_SKILLS', rag_file.get('max_skills', 3), int),
|
||||||
config_dict['database_url'] = os.getenv('DATABASE_URL', config_dict.get('database_url', 'sqlite+aiosqlite:///./ai.db'))
|
max_conventions=get('RAG', 'MAX_CONVENTIONS', rag_file.get('max_conventions', 2), int),
|
||||||
config_dict['cors_origins'] = os.getenv('CORS_ORIGINS', config_dict.get('cors_origins', ["*"])).split(',') if isinstance(os.getenv('CORS_ORIGINS'), str) else config_dict.get('cors_origins', ["*"])
|
max_snippets=get('RAG', 'MAX_SNIPPETS', rag_file.get('max_snippets', 2), int),
|
||||||
|
min_skill_score=get('RAG', 'MIN_SKILL_SCORE', rag_file.get('min_skill_score', 0.3), float),
|
||||||
|
min_snippet_score=get('RAG', 'MIN_SNIPPET_SCORE', rag_file.get('min_snippet_score', 0.25), float),
|
||||||
|
embedding_model=get('RAG', 'EMBEDDING_MODEL', rag_file.get('embedding_model', 'all-MiniLM-L6-v2'), str),
|
||||||
|
)
|
||||||
|
|
||||||
# Nested configs
|
comp_file = file_data.get('compression', {})
|
||||||
rag_dict = file_data.get('rag', {})
|
compression = CompressionConfig(
|
||||||
config_dict['rag'] = {
|
enabled=get('COMPRESSION', 'ENABLED', comp_file.get('enabled', True), bool),
|
||||||
'max_skills': _env_override('RAG', 'MAX_SKILLS', rag_dict.get('max_skills', 3)),
|
strategy=get('COMPRESSION', 'STRATEGY', comp_file.get('strategy', 'extractive'), str),
|
||||||
'max_conventions': _env_override('RAG', 'MAX_CONVENTIONS', rag_dict.get('max_conventions', 2)),
|
keep_last_n=get('COMPRESSION', 'KEEP_LAST_N', comp_file.get('keep_last_n', 3), int),
|
||||||
'max_snippets': _env_override('RAG', 'MAX_SNIPPETS', rag_dict.get('max_snippets', 2)),
|
max_tokens=get('COMPRESSION', 'MAX_TOKENS', comp_file.get('max_tokens', 2000), int),
|
||||||
'min_skill_score': _env_override('RAG', 'MIN_SKILL_SCORE', rag_dict.get('min_skill_score', 0.3)),
|
ollama_model=get('COMPRESSION', 'OLLAMA_MODEL', comp_file.get('ollama_model', 'phi3:mini'), str),
|
||||||
'min_snippet_score': _env_override('RAG', 'MIN_SNIPPET_SCORE', rag_dict.get('min_snippet_score', 0.25)),
|
ollama_url=get('COMPRESSION', 'OLLAMA_URL', comp_file.get('ollama_url', 'http://localhost:11434'), str),
|
||||||
'embedding_model': _env_override('RAG', 'EMBEDDING_MODEL', rag_dict.get('embedding_model', 'all-MiniLM-L6-v2')),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
compression_dict = file_data.get('compression', {})
|
auth_file = file_data.get('auth', {})
|
||||||
config_dict['compression'] = {
|
auth = AuthConfig(
|
||||||
'enabled': _env_override('COMPRESSION', 'ENABLED', compression_dict.get('enabled', True)),
|
enabled=get('AUTH', 'ENABLED', auth_file.get('enabled', False), bool),
|
||||||
'strategy': _env_override('COMPRESSION', 'STRATEGY', compression_dict.get('strategy', 'extractive')),
|
api_key=os.getenv('API_KEY', auth_file.get('api_key', 'change-me-in-production')),
|
||||||
'keep_last_n': _env_override('COMPRESSION', 'KEEP_LAST_N', compression_dict.get('keep_last_n', 3)),
|
header_name=get('AUTH', 'HEADER_NAME', auth_file.get('header_name', 'X-API-Key'), str),
|
||||||
'max_tokens': _env_override('COMPRESSION', 'MAX_TOKENS', compression_dict.get('max_tokens', 2000)),
|
)
|
||||||
'ollama_model': _env_override('COMPRESSION', 'OLLAMA_MODEL', compression_dict.get('ollama_model', 'phi3:mini')),
|
|
||||||
'ollama_url': _env_override('COMPRESSION', 'OLLAMA_URL', compression_dict.get('ollama_url', 'http://localhost:11434')),
|
|
||||||
}
|
|
||||||
|
|
||||||
auth_dict = file_data.get('auth', {})
|
log_file = file_data.get('logging', {})
|
||||||
config_dict['auth'] = {
|
logging = LoggingConfig(
|
||||||
'enabled': _env_override('AUTH', 'ENABLED', auth_dict.get('enabled', False)),
|
level=get('LOGGING', 'LEVEL', log_file.get('level', 'INFO'), str),
|
||||||
'api_key': os.getenv('API_KEY', auth_dict.get('api_key', 'change-me-in-production')),
|
format=get('LOGGING', 'FORMAT', log_file.get('format', 'json'), str),
|
||||||
'header_name': _env_override('AUTH', 'HEADER_NAME', auth_dict.get('header_name', 'X-API-Key')),
|
)
|
||||||
}
|
|
||||||
|
|
||||||
logging_dict = file_data.get('logging', {})
|
return Config(
|
||||||
config_dict['logging'] = {
|
host=get('APP', 'HOST', file_data.get('host', '0.0.0.0'), str),
|
||||||
'level': _env_override('LOGGING', 'LEVEL', logging_dict.get('level', 'INFO')),
|
port=get('APP', 'PORT', file_data.get('port', 8675), int),
|
||||||
'format': _env_override('LOGGING', 'FORMAT', logging_dict.get('format', 'json')),
|
database_url=os.getenv('DATABASE_URL', file_data.get('database_url', 'sqlite+aiosqlite:///./ai.db')),
|
||||||
}
|
rag=rag,
|
||||||
|
compression=compression,
|
||||||
return Config(**config_dict)
|
auth=auth,
|
||||||
|
logging=logging,
|
||||||
|
cors_origins=file_data.get('cors_origins', ["*"])
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue