Files

108 lines
4.3 KiB
Python
Raw Permalink Normal View History

"""Application settings with environment variable support."""
from functools import lru_cache
from pathlib import Path
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration with environment variable support.
All settings can be overridden via environment variables prefixed with AUDIOCRAFT_.
Example: AUDIOCRAFT_API_PORT=8080
"""
model_config = SettingsConfigDict(
env_prefix="AUDIOCRAFT_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Server Configuration
host: str = Field(default="0.0.0.0", description="Server bind host")
gradio_port: int = Field(default=7860, description="Gradio UI port")
api_port: int = Field(default=8000, description="FastAPI port")
root_path: Optional[str] = Field(default=None, description="External URL for reverse proxy (e.g., https://example.com)")
# Paths
data_dir: Path = Field(default=Path("./data"), description="Data directory")
output_dir: Path = Field(default=Path("./outputs"), description="Generated audio output")
cache_dir: Path = Field(default=Path("./cache"), description="Model cache directory")
models_config: Path = Field(
default=Path("./config/models.yaml"), description="Model registry config"
)
# VRAM Management
comfyui_reserve_gb: float = Field(
default=0.0, description="VRAM reserved for ComfyUI (GB). Set via AUDIOCRAFT_COMFYUI_RESERVE_GB if running with ComfyUI."
)
safety_buffer_gb: float = Field(
default=1.0, description="Safety buffer to prevent OOM (GB)"
)
idle_unload_minutes: int = Field(
default=15, description="Unload models after idle time (minutes)"
)
max_cached_models: int = Field(
default=2, description="Maximum number of models to keep loaded"
)
# API Configuration
api_enabled: bool = Field(default=True, description="Enable REST API")
api_key: Optional[str] = Field(default=None, description="API key for authentication")
api_key_required: bool = Field(default=False, description="Require API key for requests")
api_rate_limit: int = Field(default=10, description="API rate limit per minute")
cors_origins: list[str] = Field(
default=["*"], description="Allowed CORS origins"
)
# Audio Output
default_format: str = Field(default="wav", description="Default audio format (wav, mp3, flac)")
sample_rate: int = Field(default=32000, description="Output sample rate")
normalize_audio: bool = Field(default=True, description="Normalize audio output")
# Generation Defaults
default_duration: float = Field(default=10.0, description="Default generation duration")
max_duration: float = Field(default=300.0, description="Maximum generation duration")
default_batch_size: int = Field(default=1, description="Default batch size")
max_batch_size: int = Field(default=8, description="Maximum batch size")
max_queue_size: int = Field(default=100, description="Maximum generation queue size")
# Database
database_url: str = Field(
default="sqlite+aiosqlite:///./data/audiocraft.db",
description="Database connection URL",
)
# Logging
log_level: str = Field(default="INFO", description="Logging level")
debug: bool = Field(default=False, description="Enable debug mode")
# Model Management
max_loaded_models: int = Field(default=2, description="Maximum models loaded simultaneously")
def ensure_directories(self) -> None:
"""Create required directories if they don't exist."""
self.data_dir.mkdir(parents=True, exist_ok=True)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
(self.data_dir / "presets").mkdir(parents=True, exist_ok=True)
@property
def database_path(self) -> Path:
"""Extract database file path from URL."""
if self.database_url.startswith("sqlite"):
# Handle both sqlite:/// and sqlite+aiosqlite:///
path = self.database_url.split("///")[-1]
return Path(path)
raise ValueError("Only SQLite databases are supported")
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()