From ffbf02b12c5eb1a578b174ba0366d28308f98691 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Kr=C3=BCger?= Date: Tue, 25 Nov 2025 19:34:27 +0100 Subject: [PATCH] Initial implementation of AudioCraft Studio MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .env.example | 42 ++ .gitignore | 76 ++++ Dockerfile | 83 ++++ README.md | 197 +++++++++ config/__init__.py | 5 + config/models.yaml | 151 +++++++ config/settings.py | 94 +++++ docker-compose.yml | 64 +++ main.py | 147 +++++++ pyproject.toml | 89 ++++ requirements.txt | 30 ++ runpod.yaml | 77 ++++ scripts/download_models.py | 116 ++++++ scripts/start.sh | 55 +++ src/__init__.py | 3 + src/api/__init__.py | 5 + src/api/app.py | 150 +++++++ src/api/auth.py | 133 ++++++ src/api/models.py | 166 ++++++++ src/api/routes/__init__.py | 13 + src/api/routes/generation.py | 234 +++++++++++ src/api/routes/models.py | 228 ++++++++++ src/api/routes/projects.py | 250 +++++++++++ src/api/routes/system.py | 263 ++++++++++++ src/core/__init__.py | 24 ++ src/core/audio_utils.py | 535 ++++++++++++++++++++++++ src/core/base_model.py | 247 +++++++++++ src/core/gpu_manager.py | 433 +++++++++++++++++++ src/core/model_registry.py | 487 ++++++++++++++++++++++ src/core/oom_handler.py | 297 +++++++++++++ src/main.py | 84 ++++ src/models/__init__.py | 32 ++ src/models/audiogen/__init__.py | 5 + src/models/audiogen/adapter.py | 203 +++++++++ src/models/jasco/__init__.py | 5 + src/models/jasco/adapter.py | 348 ++++++++++++++++ src/models/magnet/__init__.py | 5 + src/models/magnet/adapter.py | 253 ++++++++++++ src/models/musicgen/__init__.py | 5 + src/models/musicgen/adapter.py | 290 +++++++++++++ src/models/musicgen_style/__init__.py | 5 + src/models/musicgen_style/adapter.py | 277 +++++++++++++ src/services/__init__.py | 13 + src/services/batch_processor.py | 397 ++++++++++++++++++ src/services/generation_service.py | 322 +++++++++++++++ src/services/project_service.py | 395 ++++++++++++++++++ src/storage/__init__.py | 5 + src/storage/database.py | 550 +++++++++++++++++++++++++ src/ui/__init__.py | 5 + src/ui/app.py | 355 ++++++++++++++++ src/ui/components/__init__.py | 13 + src/ui/components/audio_player.py | 178 ++++++++ src/ui/components/generation_params.py | 199 +++++++++ src/ui/components/preset_selector.py | 103 +++++ src/ui/components/vram_monitor.py | 151 +++++++ src/ui/pages/__init__.py | 9 + src/ui/pages/projects_page.py | 374 +++++++++++++++++ src/ui/pages/settings_page.py | 397 ++++++++++++++++++ src/ui/state.py | 294 +++++++++++++ src/ui/tabs/__init__.py | 17 + src/ui/tabs/audiogen_tab.py | 283 +++++++++++++ src/ui/tabs/dashboard_tab.py | 166 ++++++++ src/ui/tabs/jasco_tab.py | 364 ++++++++++++++++ src/ui/tabs/magnet_tab.py | 316 ++++++++++++++ src/ui/tabs/musicgen_tab.py | 325 +++++++++++++++ src/ui/tabs/style_tab.py | 292 +++++++++++++ src/ui/theme.py | 303 ++++++++++++++ 67 files changed, 12032 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 config/__init__.py create mode 100644 config/models.yaml create mode 100644 config/settings.py create mode 100644 docker-compose.yml create mode 100644 main.py create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 runpod.yaml create mode 100755 scripts/download_models.py create mode 100755 scripts/start.sh create mode 100644 src/__init__.py create mode 100644 src/api/__init__.py create mode 100644 src/api/app.py create mode 100644 src/api/auth.py create mode 100644 src/api/models.py create mode 100644 src/api/routes/__init__.py create mode 100644 src/api/routes/generation.py create mode 100644 src/api/routes/models.py create mode 100644 src/api/routes/projects.py create mode 100644 src/api/routes/system.py create mode 100644 src/core/__init__.py create mode 100644 src/core/audio_utils.py create mode 100644 src/core/base_model.py create mode 100644 src/core/gpu_manager.py create mode 100644 src/core/model_registry.py create mode 100644 src/core/oom_handler.py create mode 100644 src/main.py create mode 100644 src/models/__init__.py create mode 100644 src/models/audiogen/__init__.py create mode 100644 src/models/audiogen/adapter.py create mode 100644 src/models/jasco/__init__.py create mode 100644 src/models/jasco/adapter.py create mode 100644 src/models/magnet/__init__.py create mode 100644 src/models/magnet/adapter.py create mode 100644 src/models/musicgen/__init__.py create mode 100644 src/models/musicgen/adapter.py create mode 100644 src/models/musicgen_style/__init__.py create mode 100644 src/models/musicgen_style/adapter.py create mode 100644 src/services/__init__.py create mode 100644 src/services/batch_processor.py create mode 100644 src/services/generation_service.py create mode 100644 src/services/project_service.py create mode 100644 src/storage/__init__.py create mode 100644 src/storage/database.py create mode 100644 src/ui/__init__.py create mode 100644 src/ui/app.py create mode 100644 src/ui/components/__init__.py create mode 100644 src/ui/components/audio_player.py create mode 100644 src/ui/components/generation_params.py create mode 100644 src/ui/components/preset_selector.py create mode 100644 src/ui/components/vram_monitor.py create mode 100644 src/ui/pages/__init__.py create mode 100644 src/ui/pages/projects_page.py create mode 100644 src/ui/pages/settings_page.py create mode 100644 src/ui/state.py create mode 100644 src/ui/tabs/__init__.py create mode 100644 src/ui/tabs/audiogen_tab.py create mode 100644 src/ui/tabs/dashboard_tab.py create mode 100644 src/ui/tabs/jasco_tab.py create mode 100644 src/ui/tabs/magnet_tab.py create mode 100644 src/ui/tabs/musicgen_tab.py create mode 100644 src/ui/tabs/style_tab.py create mode 100644 src/ui/theme.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..4768144 --- /dev/null +++ b/.env.example @@ -0,0 +1,42 @@ +# AudioCraft Studio Configuration +# Copy this file to .env and customize as needed + +# Server Configuration +AUDIOCRAFT_HOST=0.0.0.0 +AUDIOCRAFT_GRADIO_PORT=7860 +AUDIOCRAFT_API_PORT=8000 + +# Paths (relative to project root) +AUDIOCRAFT_DATA_DIR=./data +AUDIOCRAFT_OUTPUT_DIR=./outputs +AUDIOCRAFT_CACHE_DIR=./cache + +# VRAM Management +# Reserve this much VRAM for ComfyUI (GB) +AUDIOCRAFT_COMFYUI_RESERVE_GB=10 +# Safety buffer to prevent OOM (GB) +AUDIOCRAFT_SAFETY_BUFFER_GB=1 +# Unload idle models after this many minutes +AUDIOCRAFT_IDLE_UNLOAD_MINUTES=15 +# Maximum number of models to keep loaded +AUDIOCRAFT_MAX_CACHED_MODELS=2 + +# API Authentication +# Generate a secure random key for production +AUDIOCRAFT_API_KEY=your-secret-api-key-here + +# Generation Defaults +AUDIOCRAFT_DEFAULT_DURATION=10.0 +AUDIOCRAFT_MAX_DURATION=300.0 +AUDIOCRAFT_DEFAULT_BATCH_SIZE=1 +AUDIOCRAFT_MAX_BATCH_SIZE=8 +AUDIOCRAFT_MAX_QUEUE_SIZE=100 + +# Database +AUDIOCRAFT_DATABASE_URL=sqlite+aiosqlite:///./data/audiocraft.db + +# Logging +AUDIOCRAFT_LOG_LEVEL=INFO + +# PyTorch Optimization (recommended) +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f64cfb8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,76 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +.venv/ +venv/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# Type checking +.mypy_cache/ + +# Project specific +data/ +outputs/ +cache/ +*.db +*.sqlite +*.sqlite3 + +# Logs +*.log +logs/ + +# Environment +.env +.env.local +.env.*.local + +# Model weights (downloaded from HuggingFace) +*.bin +*.safetensors +*.pt +*.pth + +# Audio files (generated) +*.wav +*.mp3 +*.flac +*.ogg + +# Temp files +/tmp/ +*.tmp diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..c126c0a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,83 @@ +# AudioCraft Studio Dockerfile for RunPod +# Optimized for NVIDIA RTX 4090 (24GB VRAM) + +FROM nvidia/cuda:12.1-cudnn8-runtime-ubuntu22.04 + +# Set environment variables +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PIP_NO_CACHE_DIR=1 +ENV PIP_DISABLE_PIP_VERSION_CHECK=1 + +# CUDA settings +ENV CUDA_HOME=/usr/local/cuda +ENV PATH="${CUDA_HOME}/bin:${PATH}" +ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" + +# AudioCraft settings +ENV AUDIOCRAFT_OUTPUT_DIR=/workspace/outputs +ENV AUDIOCRAFT_DATA_DIR=/workspace/data +ENV AUDIOCRAFT_MODEL_CACHE=/workspace/models +ENV AUDIOCRAFT_HOST=0.0.0.0 +ENV AUDIOCRAFT_GRADIO_PORT=7860 +ENV AUDIOCRAFT_API_PORT=8000 + +# Install system dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + git \ + curl \ + wget \ + ffmpeg \ + libsndfile1 \ + libsox-dev \ + sox \ + build-essential \ + python3.10 \ + python3.10-venv \ + python3.10-dev \ + python3-pip \ + && rm -rf /var/lib/apt/lists/* + +# Set Python 3.10 as default +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 + +# Upgrade pip +RUN pip install --upgrade pip setuptools wheel + +# Create workspace directory +WORKDIR /workspace + +# Create necessary directories +RUN mkdir -p /workspace/outputs /workspace/data /workspace/models /workspace/app + +# Copy requirements first for caching +COPY requirements.txt /workspace/app/ +WORKDIR /workspace/app + +# Install PyTorch with CUDA support +RUN pip install torch==2.1.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 + +# Install other requirements +RUN pip install -r requirements.txt + +# Install AudioCraft from source for latest features +RUN pip install git+https://github.com/facebookresearch/audiocraft.git + +# Copy application code +COPY . /workspace/app/ + +# Create non-root user for security (optional, RunPod often uses root) +# RUN useradd -m -u 1000 audiocraft && chown -R audiocraft:audiocraft /workspace +# USER audiocraft + +# Expose ports +EXPOSE 7860 8000 + +# Health check +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -f http://localhost:7860/ || exit 1 + +# Default command +CMD ["python", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..3e2c99d --- /dev/null +++ b/README.md @@ -0,0 +1,197 @@ +# AudioCraft Studio + +A comprehensive web interface for Meta's AudioCraft AI audio generation models, optimized for RunPod deployment with NVIDIA RTX 4090 GPUs. + +## Features + +### Models Supported +- **MusicGen** - Text-to-music generation with melody conditioning +- **AudioGen** - Text-to-sound effects and environmental audio +- **MAGNeT** - Fast non-autoregressive music generation +- **MusicGen Style** - Style-conditioned music from reference audio +- **JASCO** - Chord and drum-conditioned music generation + +### Core Capabilities +- **Gradio Web UI** - Intuitive interface with real-time generation +- **REST API** - Full-featured API with OpenAPI documentation +- **Batch Processing** - Queue system for multiple generations +- **Project Management** - Organize and browse generation history +- **VRAM Management** - Smart model loading/unloading, ComfyUI coexistence +- **Waveform Visualization** - Visual audio feedback + +## Quick Start + +### Local Development + +```bash +# Clone repository +git clone https://github.com/your-username/audiocraft-ui.git +cd audiocraft-ui + +# Create virtual environment +python -m venv venv +source venv/bin/activate # Linux/Mac +# or: venv\Scripts\activate # Windows + +# Install dependencies +pip install -r requirements.txt + +# Run application +python main.py +``` + +Access the UI at `http://localhost:7860` + +### Docker + +```bash +# Build and run +docker-compose up --build + +# Or build manually +docker build -t audiocraft-studio . +docker run --gpus all -p 7860:7860 -p 8000:8000 audiocraft-studio +``` + +### RunPod Deployment + +1. Build and push Docker image: +```bash +docker build -t your-dockerhub/audiocraft-studio:latest . +docker push your-dockerhub/audiocraft-studio:latest +``` + +2. Create RunPod template using `runpod.yaml` as reference + +3. Deploy with RTX 4090 or equivalent GPU + +## Configuration + +Configuration via environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `AUDIOCRAFT_HOST` | `0.0.0.0` | Server bind address | +| `AUDIOCRAFT_GRADIO_PORT` | `7860` | Gradio UI port | +| `AUDIOCRAFT_API_PORT` | `8000` | REST API port | +| `AUDIOCRAFT_OUTPUT_DIR` | `./outputs` | Generated audio output | +| `AUDIOCRAFT_DATA_DIR` | `./data` | Database and config | +| `AUDIOCRAFT_COMFYUI_RESERVE_GB` | `10` | VRAM reserved for ComfyUI | +| `AUDIOCRAFT_MAX_LOADED_MODELS` | `2` | Max models in memory | +| `AUDIOCRAFT_IDLE_UNLOAD_MINUTES` | `15` | Auto-unload idle models | + +See `.env.example` for full configuration options. + +## API Usage + +### Authentication + +```bash +# Get API key from Settings page or generate via CLI +curl -X POST http://localhost:8000/api/v1/system/api-key/regenerate \ + -H "X-API-Key: YOUR_CURRENT_KEY" +``` + +### Generate Audio + +```bash +# Synchronous generation +curl -X POST http://localhost:8000/api/v1/generate \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "musicgen", + "variant": "medium", + "prompts": ["upbeat electronic dance music with synth leads"], + "duration": 10 + }' + +# Async (queue) generation +curl -X POST http://localhost:8000/api/v1/generate/async \ + -H "X-API-Key: YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "request": { + "model": "musicgen", + "prompts": ["ambient soundscape"], + "duration": 30 + }, + "priority": 5 + }' +``` + +### Check Job Status + +```bash +curl http://localhost:8000/api/v1/generate/jobs/{job_id} \ + -H "X-API-Key: YOUR_API_KEY" +``` + +Full API documentation available at `http://localhost:8000/api/docs` + +## Architecture + +``` +audiocraft-ui/ +├── config/ +│ ├── settings.py # Pydantic settings +│ └── models.yaml # Model registry +├── src/ +│ ├── core/ +│ │ ├── base_model.py # Abstract model interface +│ │ ├── gpu_manager.py # VRAM management +│ │ ├── model_registry.py # Model loading/caching +│ │ └── oom_handler.py # OOM recovery +│ ├── models/ +│ │ ├── musicgen/ # MusicGen adapter +│ │ ├── audiogen/ # AudioGen adapter +│ │ ├── magnet/ # MAGNeT adapter +│ │ ├── musicgen_style/ # Style adapter +│ │ └── jasco/ # JASCO adapter +│ ├── services/ +│ │ ├── generation_service.py +│ │ ├── batch_processor.py +│ │ └── project_service.py +│ ├── storage/ +│ │ └── database.py # SQLite storage +│ ├── api/ +│ │ ├── app.py # FastAPI app +│ │ └── routes/ # API endpoints +│ └── ui/ +│ ├── app.py # Gradio app +│ ├── components/ # Reusable UI components +│ ├── tabs/ # Model generation tabs +│ └── pages/ # Projects, Settings +├── main.py # Entry point +├── Dockerfile +└── docker-compose.yml +``` + +## ComfyUI Coexistence + +AudioCraft Studio is designed to run alongside ComfyUI on the same GPU: + +1. Set `AUDIOCRAFT_COMFYUI_RESERVE_GB` to reserve VRAM for ComfyUI +2. Models are automatically unloaded when idle +3. Coordination file at `/tmp/audiocraft_comfyui_coord.json` prevents conflicts + +## Development + +```bash +# Install dev dependencies +pip install -r requirements-dev.txt + +# Run tests +pytest + +# Format code +black src/ config/ +ruff check src/ config/ + +# Type checking +mypy src/ +``` + +## License + +This project uses Meta's AudioCraft library. See [AudioCraft License](https://github.com/facebookresearch/audiocraft/blob/main/LICENSE). diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..f67e698 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,5 @@ +"""Configuration module for AudioCraft Studio.""" + +from config.settings import Settings, get_settings + +__all__ = ["Settings", "get_settings"] diff --git a/config/models.yaml b/config/models.yaml new file mode 100644 index 0000000..43cc44e --- /dev/null +++ b/config/models.yaml @@ -0,0 +1,151 @@ +# AudioCraft Model Registry Configuration +# This file defines all available models and their configurations + +models: + musicgen: + enabled: true + display_name: "MusicGen" + description: "Text-to-music generation with optional melody conditioning" + default_variant: medium + variants: + small: + hf_id: facebook/musicgen-small + vram_mb: 1500 + max_duration: 30 + description: "Fast, lightweight model (300M params)" + medium: + hf_id: facebook/musicgen-medium + vram_mb: 5000 + max_duration: 30 + description: "Balanced quality and speed (1.5B params)" + large: + hf_id: facebook/musicgen-large + vram_mb: 10000 + max_duration: 30 + description: "Highest quality, slower (3.3B params)" + melody: + hf_id: facebook/musicgen-melody + vram_mb: 5000 + max_duration: 30 + conditioning: + - melody + description: "Melody-conditioned generation (1.5B params)" + stereo-small: + hf_id: facebook/musicgen-stereo-small + vram_mb: 1800 + max_duration: 30 + channels: 2 + description: "Stereo output, fast (300M params)" + stereo-medium: + hf_id: facebook/musicgen-stereo-medium + vram_mb: 6000 + max_duration: 30 + channels: 2 + description: "Stereo output, balanced (1.5B params)" + stereo-large: + hf_id: facebook/musicgen-stereo-large + vram_mb: 12000 + max_duration: 30 + channels: 2 + description: "Stereo output, highest quality (3.3B params)" + stereo-melody: + hf_id: facebook/musicgen-stereo-melody + vram_mb: 6000 + max_duration: 30 + channels: 2 + conditioning: + - melody + description: "Stereo melody-conditioned (1.5B params)" + + audiogen: + enabled: true + display_name: "AudioGen" + description: "Text-to-sound effects generation" + default_variant: medium + variants: + medium: + hf_id: facebook/audiogen-medium + vram_mb: 5000 + max_duration: 10 + description: "Sound effects generator (1.5B params)" + + magnet: + enabled: true + display_name: "MAGNeT" + description: "Fast non-autoregressive music generation" + default_variant: medium-10secs + variants: + small-10secs: + hf_id: facebook/magnet-small-10secs + vram_mb: 1500 + max_duration: 10 + description: "Fast 10-second clips (300M params)" + medium-10secs: + hf_id: facebook/magnet-medium-10secs + vram_mb: 5000 + max_duration: 10 + description: "Quality 10-second clips (1.5B params)" + small-30secs: + hf_id: facebook/magnet-small-30secs + vram_mb: 1800 + max_duration: 30 + description: "Fast 30-second clips (300M params)" + medium-30secs: + hf_id: facebook/magnet-medium-30secs + vram_mb: 6000 + max_duration: 30 + description: "Quality 30-second clips (1.5B params)" + + musicgen-style: + enabled: true + display_name: "MusicGen Style" + description: "Style-conditioned music generation from reference audio" + default_variant: medium + variants: + medium: + hf_id: facebook/musicgen-style + vram_mb: 5000 + max_duration: 30 + conditioning: + - style + description: "Style transfer from reference audio (1.5B params)" + + jasco: + enabled: true + display_name: "JASCO" + description: "Chord and drum-conditioned music generation" + default_variant: chords-drums-400M + variants: + chords-drums-400M: + hf_id: facebook/jasco-chords-drums-400M + vram_mb: 2000 + max_duration: 10 + conditioning: + - chords + - drums + description: "Chord/drum control, fast (400M params)" + chords-drums-1B: + hf_id: facebook/jasco-chords-drums-1B + vram_mb: 4000 + max_duration: 10 + conditioning: + - chords + - drums + description: "Chord/drum control, higher quality (1B params)" + +# Default generation parameters +defaults: + generation: + duration: 10 + temperature: 1.0 + top_k: 250 + top_p: 0.0 + cfg_coef: 3.0 + +# VRAM thresholds for warnings +vram: + warning_threshold: 0.85 # 85% utilization warning + critical_threshold: 0.95 # 95% utilization critical + +# Presets are loaded from data/presets/*.yaml +presets_dir: "./data/presets" diff --git a/config/settings.py b/config/settings.py new file mode 100644 index 0000000..debd2a4 --- /dev/null +++ b/config/settings.py @@ -0,0 +1,94 @@ +"""Application settings with environment variable support.""" + +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Application configuration with environment variable support. + + All settings can be overridden via environment variables prefixed with AUDIOCRAFT_. + Example: AUDIOCRAFT_API_PORT=8080 + """ + + model_config = SettingsConfigDict( + env_prefix="AUDIOCRAFT_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # Server Configuration + host: str = Field(default="0.0.0.0", description="Server bind host") + gradio_port: int = Field(default=7860, description="Gradio UI port") + api_port: int = Field(default=8000, description="FastAPI port") + + # Paths + data_dir: Path = Field(default=Path("./data"), description="Data directory") + output_dir: Path = Field(default=Path("./outputs"), description="Generated audio output") + cache_dir: Path = Field(default=Path("./cache"), description="Model cache directory") + models_config: Path = Field( + default=Path("./config/models.yaml"), description="Model registry config" + ) + + # VRAM Management + comfyui_reserve_gb: float = Field( + default=10.0, description="VRAM reserved for ComfyUI (GB)" + ) + safety_buffer_gb: float = Field( + default=1.0, description="Safety buffer to prevent OOM (GB)" + ) + idle_unload_minutes: int = Field( + default=15, description="Unload models after idle time (minutes)" + ) + max_cached_models: int = Field( + default=2, description="Maximum number of models to keep loaded" + ) + + # API Authentication + api_key: Optional[str] = Field(default=None, description="API key for authentication") + cors_origins: list[str] = Field( + default=["*"], description="Allowed CORS origins" + ) + + # Generation Defaults + default_duration: float = Field(default=10.0, description="Default generation duration") + max_duration: float = Field(default=300.0, description="Maximum generation duration") + default_batch_size: int = Field(default=1, description="Default batch size") + max_batch_size: int = Field(default=8, description="Maximum batch size") + max_queue_size: int = Field(default=100, description="Maximum generation queue size") + + # Database + database_url: str = Field( + default="sqlite+aiosqlite:///./data/audiocraft.db", + description="Database connection URL", + ) + + # Logging + log_level: str = Field(default="INFO", description="Logging level") + + def ensure_directories(self) -> None: + """Create required directories if they don't exist.""" + self.data_dir.mkdir(parents=True, exist_ok=True) + self.output_dir.mkdir(parents=True, exist_ok=True) + self.cache_dir.mkdir(parents=True, exist_ok=True) + (self.data_dir / "presets").mkdir(parents=True, exist_ok=True) + + @property + def database_path(self) -> Path: + """Extract database file path from URL.""" + if self.database_url.startswith("sqlite"): + # Handle both sqlite:/// and sqlite+aiosqlite:/// + path = self.database_url.split("///")[-1] + return Path(path) + raise ValueError("Only SQLite databases are supported") + + +@lru_cache +def get_settings() -> Settings: + """Get cached settings instance.""" + return Settings() diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..60a3a27 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,64 @@ +# Docker Compose for local development and testing +# For RunPod deployment, use the Dockerfile directly + +version: '3.8' + +services: + audiocraft: + build: + context: . + dockerfile: Dockerfile + container_name: audiocraft-studio + ports: + - "7860:7860" # Gradio UI + - "8000:8000" # REST API + volumes: + # Persistent storage + - audiocraft-outputs:/workspace/outputs + - audiocraft-data:/workspace/data + - audiocraft-models:/workspace/models + # Development: mount source code + - ./src:/workspace/app/src:ro + - ./config:/workspace/app/config:ro + environment: + - AUDIOCRAFT_HOST=0.0.0.0 + - AUDIOCRAFT_GRADIO_PORT=7860 + - AUDIOCRAFT_API_PORT=8000 + - AUDIOCRAFT_DEBUG=false + - AUDIOCRAFT_COMFYUI_RESERVE_GB=0 # No ComfyUI in this compose + - NVIDIA_VISIBLE_DEVICES=all + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:7860/"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 60s + + # Optional: Run alongside ComfyUI + # comfyui: + # image: your-comfyui-image + # container_name: comfyui + # ports: + # - "8188:8188" + # volumes: + # - comfyui-data:/workspace + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + +volumes: + audiocraft-outputs: + audiocraft-data: + audiocraft-models: diff --git a/main.py b/main.py new file mode 100644 index 0000000..9bd6a4f --- /dev/null +++ b/main.py @@ -0,0 +1,147 @@ +#!/usr/bin/env python3 +"""Main entry point for AudioCraft Studio.""" + +import asyncio +import logging +import sys +from pathlib import Path + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent)) + +from config.settings import get_settings +from src.core.gpu_manager import GPUMemoryManager +from src.core.model_registry import ModelRegistry +from src.services.generation_service import GenerationService +from src.services.batch_processor import BatchProcessor +from src.services.project_service import ProjectService +from src.storage.database import Database +from src.ui.app import create_app + + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.StreamHandler(), + logging.FileHandler("audiocraft.log"), + ], +) +logger = logging.getLogger(__name__) + + +async def initialize_services(): + """Initialize all application services.""" + settings = get_settings() + + # Initialize database + logger.info("Initializing database...") + db = Database(settings.database_path) + await db.initialize() + + # Initialize GPU manager + logger.info("Initializing GPU manager...") + gpu_manager = GPUMemoryManager( + device_id=0, + comfyui_reserve_bytes=int(settings.comfyui_reserve_gb * 1024**3), + ) + + # Initialize model registry + logger.info("Initializing model registry...") + model_registry = ModelRegistry( + gpu_manager=gpu_manager, + max_loaded=settings.max_loaded_models, + idle_timeout_seconds=settings.idle_unload_minutes * 60, + ) + + # Initialize services + logger.info("Initializing services...") + generation_service = GenerationService( + model_registry=model_registry, + gpu_manager=gpu_manager, + output_dir=settings.output_dir, + ) + + batch_processor = BatchProcessor( + generation_service=generation_service, + max_queue_size=settings.max_queue_size, + ) + + project_service = ProjectService( + db=db, + output_dir=settings.output_dir, + ) + + return { + "db": db, + "gpu_manager": gpu_manager, + "model_registry": model_registry, + "generation_service": generation_service, + "batch_processor": batch_processor, + "project_service": project_service, + } + + +def main(): + """Main entry point.""" + settings = get_settings() + + logger.info("=" * 60) + logger.info("AudioCraft Studio") + logger.info("=" * 60) + logger.info(f"Host: {settings.host}") + logger.info(f"Gradio Port: {settings.gradio_port}") + logger.info(f"API Port: {settings.api_port}") + logger.info(f"Output Dir: {settings.output_dir}") + logger.info("=" * 60) + + # Initialize services + logger.info("Initializing services...") + + try: + services = asyncio.run(initialize_services()) + except Exception as e: + logger.error(f"Failed to initialize services: {e}") + logger.warning("Starting in demo mode without backend services") + services = {} + + # Create and launch app + logger.info("Creating Gradio application...") + app = create_app( + generation_service=services.get("generation_service"), + batch_processor=services.get("batch_processor"), + project_service=services.get("project_service"), + gpu_manager=services.get("gpu_manager"), + model_registry=services.get("model_registry"), + ) + + # Start batch processor if available + batch_processor = services.get("batch_processor") + if batch_processor: + logger.info("Starting batch processor...") + asyncio.run(batch_processor.start()) + + # Launch the app + logger.info("Launching application...") + try: + app.launch( + server_name=settings.host, + server_port=settings.gradio_port, + share=False, + show_api=settings.api_enabled, + ) + except KeyboardInterrupt: + logger.info("Shutting down...") + finally: + # Cleanup + if batch_processor: + asyncio.run(batch_processor.stop()) + if "db" in services: + asyncio.run(services["db"].close()) + + logger.info("Shutdown complete") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bf5e527 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,89 @@ +[project] +name = "audiocraft-ui" +version = "0.1.0" +description = "Sophisticated AI audio web application based on Facebook's AudioCraft" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.10" +authors = [{ name = "AudioCraft UI Team" }] +keywords = ["audio", "music", "generation", "ai", "audiocraft", "gradio"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Multimedia :: Sound/Audio", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + +dependencies = [ + # Core ML + "torch>=2.1.0", + "torchaudio>=2.1.0", + "audiocraft>=1.3.0", + "xformers>=0.0.22", + + # UI + "gradio>=4.0.0", + + # API + "fastapi>=0.104.0", + "uvicorn[standard]>=0.24.0", + "python-multipart>=0.0.6", + + # GPU Monitoring + "pynvml>=11.5.0", + + # Storage + "aiosqlite>=0.19.0", + + # Configuration + "pydantic>=2.5.0", + "pydantic-settings>=2.1.0", + "pyyaml>=6.0", + + # Audio Processing + "numpy>=1.24.0", + "scipy>=1.11.0", + "librosa>=0.10.0", + "soundfile>=0.12.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.1.0", + "ruff>=0.1.0", + "mypy>=1.6.0", +] + +[project.scripts] +audiocraft-ui = "src.main:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bed6304 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +# Core ML +torch>=2.1.0 +torchaudio>=2.1.0 +audiocraft>=1.3.0 +xformers>=0.0.22 + +# UI +gradio>=4.0.0 + +# API +fastapi>=0.104.0 +uvicorn[standard]>=0.24.0 +python-multipart>=0.0.6 + +# GPU Monitoring +pynvml>=11.5.0 + +# Storage +aiosqlite>=0.19.0 + +# Configuration +pydantic>=2.5.0 +pydantic-settings>=2.1.0 +pyyaml>=6.0 + +# Audio Processing +numpy>=1.24.0 +scipy>=1.11.0 +librosa>=0.10.0 +soundfile>=0.12.0 diff --git a/runpod.yaml b/runpod.yaml new file mode 100644 index 0000000..7e79960 --- /dev/null +++ b/runpod.yaml @@ -0,0 +1,77 @@ +# RunPod Template Configuration +# Use this as reference when creating a RunPod template + +name: AudioCraft Studio +description: AI-powered music and sound generation using Meta's AudioCraft + +# Container settings +container: + image: your-dockerhub-username/audiocraft-studio:latest + + # Or build from GitHub + # dockerfile: Dockerfile + # context: https://github.com/your-username/audiocraft-ui.git + +# GPU requirements +gpu: + type: RTX 4090 # Recommended: RTX 4090, RTX 3090, A100 + count: 1 + minVram: 24 # GB + +# Resource limits +resources: + cpu: 8 + memory: 32 # GB + disk: 100 # GB (for model cache and outputs) + +# Port mappings +ports: + - name: Gradio UI + internal: 7860 + external: 7860 + protocol: http + - name: REST API + internal: 8000 + external: 8000 + protocol: http + +# Volume mounts +volumes: + - name: outputs + mountPath: /workspace/outputs + size: 50 # GB + - name: models + mountPath: /workspace/models + size: 30 # GB (model cache) + - name: data + mountPath: /workspace/data + size: 10 # GB + +# Environment variables +env: + - name: AUDIOCRAFT_HOST + value: "0.0.0.0" + - name: AUDIOCRAFT_GRADIO_PORT + value: "7860" + - name: AUDIOCRAFT_API_PORT + value: "8000" + - name: AUDIOCRAFT_COMFYUI_RESERVE_GB + value: "10" # Reserve VRAM for ComfyUI if running alongside + - name: AUDIOCRAFT_MAX_LOADED_MODELS + value: "2" + - name: AUDIOCRAFT_IDLE_UNLOAD_MINUTES + value: "15" + - name: HF_HOME + value: "/workspace/models/huggingface" + +# Startup command +command: ["python", "main.py"] + +# Health check +healthCheck: + path: / + port: 7860 + initialDelaySeconds: 120 + periodSeconds: 30 + timeoutSeconds: 10 + failureThreshold: 3 diff --git a/scripts/download_models.py b/scripts/download_models.py new file mode 100755 index 0000000..97c5c86 --- /dev/null +++ b/scripts/download_models.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +"""Pre-download AudioCraft models for faster startup.""" + +import argparse +import os +from pathlib import Path + + +def download_musicgen_models(variants: list[str] = None): + """Download MusicGen models.""" + from audiocraft.models import MusicGen + + variants = variants or ["small", "medium", "large", "melody"] + + for variant in variants: + print(f"Downloading MusicGen {variant}...") + try: + model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}") + del model + print(f" ✓ MusicGen {variant} downloaded") + except Exception as e: + print(f" ✗ Failed to download MusicGen {variant}: {e}") + + +def download_audiogen_models(): + """Download AudioGen models.""" + from audiocraft.models import AudioGen + + print("Downloading AudioGen medium...") + try: + model = AudioGen.get_pretrained("facebook/audiogen-medium") + del model + print(" ✓ AudioGen medium downloaded") + except Exception as e: + print(f" ✗ Failed to download AudioGen: {e}") + + +def download_magnet_models(variants: list[str] = None): + """Download MAGNeT models.""" + from audiocraft.models import MAGNeT + + variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"] + + for variant in variants: + print(f"Downloading MAGNeT {variant}...") + try: + model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}") + del model + print(f" ✓ MAGNeT {variant} downloaded") + except Exception as e: + print(f" ✗ Failed to download MAGNeT {variant}: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Pre-download AudioCraft models") + parser.add_argument( + "--models", + nargs="+", + choices=["musicgen", "audiogen", "magnet", "all"], + default=["all"], + help="Models to download", + ) + parser.add_argument( + "--musicgen-variants", + nargs="+", + default=["small", "medium"], + help="MusicGen variants to download", + ) + parser.add_argument( + "--magnet-variants", + nargs="+", + default=["small", "medium"], + help="MAGNeT variants to download", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Model cache directory", + ) + + args = parser.parse_args() + + # Set cache directory + if args.cache_dir: + os.environ["HF_HOME"] = args.cache_dir + os.environ["TORCH_HOME"] = args.cache_dir + Path(args.cache_dir).mkdir(parents=True, exist_ok=True) + + models = args.models + if "all" in models: + models = ["musicgen", "audiogen", "magnet"] + + print("=" * 50) + print("AudioCraft Model Downloader") + print("=" * 50) + print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}") + print(f"Models to download: {models}") + print("=" * 50) + + if "musicgen" in models: + download_musicgen_models(args.musicgen_variants) + + if "audiogen" in models: + download_audiogen_models() + + if "magnet" in models: + download_magnet_models(args.magnet_variants) + + print("=" * 50) + print("Download complete!") + print("=" * 50) + + +if __name__ == "__main__": + main() diff --git a/scripts/start.sh b/scripts/start.sh new file mode 100755 index 0000000..2272947 --- /dev/null +++ b/scripts/start.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# Startup script for AudioCraft Studio +# Used in Docker container and RunPod + +set -e + +echo "==========================================" +echo " AudioCraft Studio" +echo "==========================================" + +# Create directories if they don't exist +mkdir -p "${AUDIOCRAFT_OUTPUT_DIR:-/workspace/outputs}" +mkdir -p "${AUDIOCRAFT_DATA_DIR:-/workspace/data}" +mkdir -p "${AUDIOCRAFT_MODEL_CACHE:-/workspace/models}" + +# Check GPU availability +echo "Checking GPU..." +if command -v nvidia-smi &> /dev/null; then + nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv +else + echo "Warning: nvidia-smi not found" +fi + +# Check Python and dependencies +echo "Python version:" +python --version + +echo "PyTorch version:" +python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')" + +# Check AudioCraft installation +echo "AudioCraft version:" +python -c "import audiocraft; print(audiocraft.__version__)" 2>/dev/null || echo "AudioCraft installed from source" + +# Generate API key if not exists +if [ ! -f "${AUDIOCRAFT_DATA_DIR:-/workspace/data}/.api_key" ]; then + echo "Generating API key..." + python -c " +from src.api.auth import get_key_manager +km = get_key_manager() +if not km.has_key(): + key = km.generate_new_key() + print(f'Generated API key: {key}') + print('Store this key securely - it will not be shown again!') +" +fi + +# Start the application +echo "Starting AudioCraft Studio..." +echo "Gradio UI: http://0.0.0.0:${AUDIOCRAFT_GRADIO_PORT:-7860}" +echo "REST API: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}" +echo "API Docs: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}/api/docs" +echo "==========================================" + +exec python main.py "$@" diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..6be9d4f --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,3 @@ +"""AudioCraft Studio - AI Audio Generation Web Application.""" + +__version__ = "0.1.0" diff --git a/src/api/__init__.py b/src/api/__init__.py new file mode 100644 index 0000000..0d33c8d --- /dev/null +++ b/src/api/__init__.py @@ -0,0 +1,5 @@ +"""REST API for AudioCraft Studio.""" + +from src.api.app import create_api_app + +__all__ = ["create_api_app"] diff --git a/src/api/app.py b/src/api/app.py new file mode 100644 index 0000000..0453528 --- /dev/null +++ b/src/api/app.py @@ -0,0 +1,150 @@ +"""FastAPI application for AudioCraft Studio REST API.""" + +from typing import Any, Optional +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse +import time + +from config.settings import get_settings +from src.api.routes import ( + generation_router, + projects_router, + models_router, + system_router, +) +from src.api.routes.generation import set_services as set_generation_services +from src.api.routes.projects import set_services as set_project_services +from src.api.routes.models import set_services as set_model_services +from src.api.routes.system import set_services as set_system_services + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Application lifespan handler.""" + # Startup + yield + # Shutdown + + +def create_api_app( + generation_service: Any = None, + batch_processor: Any = None, + project_service: Any = None, + gpu_manager: Any = None, + model_registry: Any = None, +) -> FastAPI: + """Create and configure the FastAPI application. + + Args: + generation_service: Service for handling generations + batch_processor: Service for batch/queue processing + project_service: Service for project management + gpu_manager: GPU memory manager + model_registry: Model registry for loading/unloading + + Returns: + Configured FastAPI application + """ + settings = get_settings() + + app = FastAPI( + title="AudioCraft Studio API", + description="REST API for AI-powered music and sound generation", + version="1.0.0", + docs_url="/api/docs" if settings.api_enabled else None, + redoc_url="/api/redoc" if settings.api_enabled else None, + openapi_url="/api/openapi.json" if settings.api_enabled else None, + lifespan=lifespan, + ) + + # CORS middleware + app.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Request timing middleware + @app.middleware("http") + async def add_process_time_header(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + + # Global exception handler + @app.exception_handler(Exception) + async def global_exception_handler(request: Request, exc: Exception): + return JSONResponse( + status_code=500, + content={ + "error": "Internal server error", + "detail": str(exc) if settings.debug else "An unexpected error occurred", + }, + ) + + # Inject service dependencies + set_generation_services(generation_service, batch_processor) + set_project_services(project_service) + set_model_services(model_registry) + set_system_services(gpu_manager, batch_processor, model_registry) + + # Register routers + app.include_router(generation_router, prefix="/api/v1") + app.include_router(projects_router, prefix="/api/v1") + app.include_router(models_router, prefix="/api/v1") + app.include_router(system_router, prefix="/api/v1") + + # Root endpoint + @app.get("/") + async def root(): + return { + "name": "AudioCraft Studio API", + "version": "1.0.0", + "docs": "/api/docs", + } + + # API info endpoint + @app.get("/api/v1") + async def api_info(): + return { + "version": "1.0.0", + "endpoints": { + "generation": "/api/v1/generate", + "projects": "/api/v1/projects", + "models": "/api/v1/models", + "system": "/api/v1/system", + }, + } + + return app + + +def run_api_server( + app: FastAPI, + host: Optional[str] = None, + port: Optional[int] = None, +) -> None: + """Run the API server. + + Args: + app: FastAPI application + host: Server hostname + port: Server port + """ + import uvicorn + + settings = get_settings() + + uvicorn.run( + app, + host=host or settings.host, + port=port or settings.api_port, + log_level="info", + ) diff --git a/src/api/auth.py b/src/api/auth.py new file mode 100644 index 0000000..f870082 --- /dev/null +++ b/src/api/auth.py @@ -0,0 +1,133 @@ +"""API authentication middleware.""" + +import secrets +import hashlib +from typing import Optional +from pathlib import Path + +from fastapi import HTTPException, Security, status +from fastapi.security import APIKeyHeader + +from config.settings import get_settings + + +# API key header +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +def generate_api_key() -> str: + """Generate a new API key.""" + return secrets.token_urlsafe(32) + + +def hash_api_key(key: str) -> str: + """Hash an API key for storage.""" + return hashlib.sha256(key.encode()).hexdigest() + + +def verify_api_key(key: str, hashed: str) -> bool: + """Verify an API key against its hash.""" + return secrets.compare_digest(hash_api_key(key), hashed) + + +class APIKeyManager: + """Manage API keys for authentication.""" + + def __init__(self, key_file: Optional[Path] = None): + """Initialize the key manager. + + Args: + key_file: Path to store API key hash + """ + self.settings = get_settings() + self.key_file = key_file or Path(self.settings.data_dir) / ".api_key" + self._key_hash: Optional[str] = None + self._load_key() + + def _load_key(self) -> None: + """Load API key hash from file.""" + if self.key_file.exists(): + self._key_hash = self.key_file.read_text().strip() + + def _save_key(self, key_hash: str) -> None: + """Save API key hash to file.""" + self.key_file.parent.mkdir(parents=True, exist_ok=True) + self.key_file.write_text(key_hash) + self._key_hash = key_hash + + def generate_new_key(self) -> str: + """Generate and store a new API key. + + Returns: + The new API key (only shown once) + """ + key = generate_api_key() + self._save_key(hash_api_key(key)) + return key + + def verify(self, key: str) -> bool: + """Verify an API key. + + Args: + key: API key to verify + + Returns: + True if valid, False otherwise + """ + if not self._key_hash: + return False + return verify_api_key(key, self._key_hash) + + def has_key(self) -> bool: + """Check if an API key has been generated.""" + return self._key_hash is not None + + +# Global key manager instance +_key_manager: Optional[APIKeyManager] = None + + +def get_key_manager() -> APIKeyManager: + """Get the global key manager instance.""" + global _key_manager + if _key_manager is None: + _key_manager = APIKeyManager() + return _key_manager + + +async def verify_api_key_dependency( + api_key: Optional[str] = Security(api_key_header), +) -> str: + """FastAPI dependency to verify API key. + + Args: + api_key: API key from header + + Returns: + The verified API key + + Raises: + HTTPException: If key is missing or invalid + """ + settings = get_settings() + + # Skip auth if disabled + if not settings.api_key_required: + return "anonymous" + + if api_key is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="API key required", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + key_manager = get_key_manager() + + if not key_manager.verify(api_key): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Invalid API key", + ) + + return api_key diff --git a/src/api/models.py b/src/api/models.py new file mode 100644 index 0000000..5b1fb05 --- /dev/null +++ b/src/api/models.py @@ -0,0 +1,166 @@ +"""Pydantic models for API requests and responses.""" + +from datetime import datetime +from typing import Any, Optional +from enum import Enum + +from pydantic import BaseModel, Field + + +class ModelFamily(str, Enum): + """Available model families.""" + MUSICGEN = "musicgen" + AUDIOGEN = "audiogen" + MAGNET = "magnet" + MUSICGEN_STYLE = "musicgen-style" + JASCO = "jasco" + + +class JobStatus(str, Enum): + """Generation job status.""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +# Generation requests + +class GenerationRequest(BaseModel): + """Request to generate audio.""" + model: ModelFamily = Field(..., description="Model family to use") + variant: str = Field("medium", description="Model variant") + prompts: list[str] = Field(..., min_length=1, max_length=10, description="Text prompts") + duration: float = Field(10.0, ge=1, le=30, description="Duration in seconds") + temperature: float = Field(1.0, ge=0, le=2, description="Sampling temperature") + top_k: int = Field(250, ge=0, le=500, description="Top-K sampling") + top_p: float = Field(0.0, ge=0, le=1, description="Top-P (nucleus) sampling") + cfg_coef: float = Field(3.0, ge=1, le=10, description="CFG coefficient") + seed: Optional[int] = Field(None, description="Random seed for reproducibility") + conditioning: Optional[dict[str, Any]] = Field(None, description="Model-specific conditioning") + project_id: Optional[str] = Field(None, description="Project to save to") + + +class BatchGenerationRequest(BaseModel): + """Request to add generation to queue.""" + request: GenerationRequest + priority: int = Field(0, ge=0, le=10, description="Job priority (higher = sooner)") + + +# Generation responses + +class GenerationResult(BaseModel): + """Result of a completed generation.""" + id: str = Field(..., description="Generation ID") + audio_url: str = Field(..., description="URL to download audio") + waveform_url: Optional[str] = Field(None, description="URL to waveform image") + duration: float = Field(..., description="Actual duration in seconds") + seed: int = Field(..., description="Seed used for generation") + model: str = Field(..., description="Model used") + variant: str = Field(..., description="Variant used") + prompt: str = Field(..., description="Prompt used") + created_at: datetime = Field(..., description="Creation timestamp") + + +class JobResponse(BaseModel): + """Response for a queued job.""" + job_id: str = Field(..., description="Job ID for tracking") + status: JobStatus = Field(..., description="Current status") + position: Optional[int] = Field(None, description="Queue position if pending") + progress: Optional[float] = Field(None, description="Progress 0-1 if running") + result: Optional[GenerationResult] = Field(None, description="Result if completed") + error: Optional[str] = Field(None, description="Error message if failed") + + +# Project models + +class ProjectCreate(BaseModel): + """Request to create a project.""" + name: str = Field(..., min_length=1, max_length=100) + description: Optional[str] = Field(None, max_length=500) + + +class ProjectResponse(BaseModel): + """Project information.""" + id: str + name: str + description: Optional[str] + generation_count: int + created_at: datetime + updated_at: datetime + + +class GenerationResponse(BaseModel): + """Generation record from database.""" + id: str + project_id: str + model: str + variant: str + prompt: str + duration_seconds: float + seed: int + audio_path: str + waveform_path: Optional[str] + parameters: dict[str, Any] + created_at: datetime + + +# Model info + +class ModelVariantInfo(BaseModel): + """Information about a model variant.""" + id: str + name: str + vram_mb: int + description: str + capabilities: list[str] + + +class ModelInfo(BaseModel): + """Information about a model family.""" + id: str + name: str + description: str + variants: list[ModelVariantInfo] + loaded: bool + current_variant: Optional[str] + + +# System info + +class GPUStatus(BaseModel): + """GPU memory status.""" + device_name: str + total_gb: float + used_gb: float + available_gb: float + utilization_percent: float + temperature_c: Optional[float] + + +class QueueStatus(BaseModel): + """Generation queue status.""" + queue_size: int + active_jobs: int + completed_today: int + failed_today: int + + +class SystemStatus(BaseModel): + """Overall system status.""" + gpu: GPUStatus + queue: QueueStatus + loaded_models: list[str] + uptime_seconds: float + + +# Pagination + +class PaginatedResponse(BaseModel): + """Paginated list response.""" + items: list[Any] + total: int + page: int + page_size: int + pages: int diff --git a/src/api/routes/__init__.py b/src/api/routes/__init__.py new file mode 100644 index 0000000..f136a8f --- /dev/null +++ b/src/api/routes/__init__.py @@ -0,0 +1,13 @@ +"""API route modules.""" + +from src.api.routes.generation import router as generation_router +from src.api.routes.projects import router as projects_router +from src.api.routes.models import router as models_router +from src.api.routes.system import router as system_router + +__all__ = [ + "generation_router", + "projects_router", + "models_router", + "system_router", +] diff --git a/src/api/routes/generation.py b/src/api/routes/generation.py new file mode 100644 index 0000000..dc1bab7 --- /dev/null +++ b/src/api/routes/generation.py @@ -0,0 +1,234 @@ +"""Generation API endpoints.""" + +from typing import Any +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, status +from fastapi.responses import FileResponse + +from src.api.auth import verify_api_key_dependency +from src.api.models import ( + GenerationRequest, + BatchGenerationRequest, + GenerationResult, + JobResponse, + JobStatus, +) + + +router = APIRouter(prefix="/generate", tags=["generation"]) + + +# Service dependencies (injected at app startup) +_generation_service = None +_batch_processor = None + + +def set_services(generation_service: Any, batch_processor: Any) -> None: + """Set service dependencies.""" + global _generation_service, _batch_processor + _generation_service = generation_service + _batch_processor = batch_processor + + +@router.post( + "/", + response_model=GenerationResult, + summary="Generate audio synchronously", + description="Generate audio and wait for completion. For long generations, consider using the async endpoint.", +) +async def generate_sync( + request: GenerationRequest, + api_key: str = Depends(verify_api_key_dependency), +) -> GenerationResult: + """Generate audio synchronously.""" + if _generation_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Generation service not available", + ) + + try: + result, generation = await _generation_service.generate( + model_id=request.model.value, + variant=request.variant, + prompts=request.prompts, + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + seed=request.seed, + conditioning=request.conditioning, + project_id=request.project_id, + ) + + return GenerationResult( + id=generation.id, + audio_url=f"/api/v1/audio/{generation.id}", + waveform_url=f"/api/v1/audio/{generation.id}/waveform" if generation.waveform_path else None, + duration=result.duration, + seed=result.seed, + model=request.model.value, + variant=request.variant, + prompt=request.prompts[0], + created_at=generation.created_at, + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.post( + "/async", + response_model=JobResponse, + summary="Queue generation job", + description="Add a generation to the queue for async processing.", +) +async def generate_async( + request: BatchGenerationRequest, + api_key: str = Depends(verify_api_key_dependency), +) -> JobResponse: + """Add generation to queue.""" + if _batch_processor is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Batch processor not available", + ) + + try: + job = _batch_processor.add_job( + model_id=request.request.model.value, + variant=request.request.variant, + prompts=request.request.prompts, + duration=request.request.duration, + temperature=request.request.temperature, + top_k=request.request.top_k, + top_p=request.request.top_p, + cfg_coef=request.request.cfg_coef, + seed=request.request.seed, + conditioning=request.request.conditioning, + project_id=request.request.project_id, + priority=request.priority, + ) + + return JobResponse( + job_id=job.id, + status=JobStatus.PENDING, + position=_batch_processor.get_position(job.id), + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/jobs/{job_id}", + response_model=JobResponse, + summary="Get job status", + description="Check the status of a queued generation job.", +) +async def get_job_status( + job_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> JobResponse: + """Get status of a queued job.""" + if _batch_processor is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Batch processor not available", + ) + + job = _batch_processor.get_job(job_id) + if job is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Job {job_id} not found", + ) + + response = JobResponse( + job_id=job.id, + status=JobStatus(job.status.value), + ) + + if job.status.value == "pending": + response.position = _batch_processor.get_position(job_id) + elif job.status.value == "running": + response.progress = job.progress + elif job.status.value == "completed" and job.result: + response.result = GenerationResult( + id=job.result.id, + audio_url=f"/api/v1/audio/{job.result.id}", + waveform_url=f"/api/v1/audio/{job.result.id}/waveform", + duration=job.result.duration, + seed=job.result.seed, + model=job.model_id, + variant=job.variant, + prompt=job.prompts[0], + created_at=job.completed_at, + ) + elif job.status.value == "failed": + response.error = job.error + + return response + + +@router.delete( + "/jobs/{job_id}", + summary="Cancel job", + description="Cancel a pending or running job.", +) +async def cancel_job( + job_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Cancel a queued job.""" + if _batch_processor is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Batch processor not available", + ) + + success = _batch_processor.cancel_job(job_id) + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Job {job_id} not found or cannot be cancelled", + ) + + return {"message": f"Job {job_id} cancelled"} + + +@router.get( + "/jobs", + response_model=list[JobResponse], + summary="List jobs", + description="List all jobs in the queue.", +) +async def list_jobs( + status_filter: str = None, + limit: int = 50, + api_key: str = Depends(verify_api_key_dependency), +) -> list[JobResponse]: + """List queued jobs.""" + if _batch_processor is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Batch processor not available", + ) + + jobs = _batch_processor.list_jobs(status_filter=status_filter, limit=limit) + + return [ + JobResponse( + job_id=job.id, + status=JobStatus(job.status.value), + position=_batch_processor.get_position(job.id) if job.status.value == "pending" else None, + progress=job.progress if job.status.value == "running" else None, + ) + for job in jobs + ] diff --git a/src/api/routes/models.py b/src/api/routes/models.py new file mode 100644 index 0000000..8ba5a9d --- /dev/null +++ b/src/api/routes/models.py @@ -0,0 +1,228 @@ +"""Models API endpoints.""" + +from typing import Any +from fastapi import APIRouter, Depends, HTTPException, status + +from src.api.auth import verify_api_key_dependency +from src.api.models import ModelInfo, ModelVariantInfo + + +router = APIRouter(prefix="/models", tags=["models"]) + + +# Service dependency (injected at app startup) +_model_registry = None + + +def set_services(model_registry: Any) -> None: + """Set service dependencies.""" + global _model_registry + _model_registry = model_registry + + +# Static model information +MODEL_CATALOG = { + "musicgen": { + "id": "musicgen", + "name": "MusicGen", + "description": "Text-to-music generation with optional melody conditioning", + "variants": [ + {"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params", "capabilities": ["text"]}, + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params", "capabilities": ["text"]}, + {"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params", "capabilities": ["text"]}, + {"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning", "capabilities": ["text", "melody"]}, + {"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params", "capabilities": ["text", "stereo"]}, + {"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params", "capabilities": ["text", "stereo"]}, + {"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params", "capabilities": ["text", "stereo"]}, + {"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody", "capabilities": ["text", "melody", "stereo"]}, + ], + }, + "audiogen": { + "id": "audiogen", + "name": "AudioGen", + "description": "Text-to-sound effects and environmental audio", + "variants": [ + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params", "capabilities": ["text", "sfx"]}, + ], + }, + "magnet": { + "id": "magnet", + "name": "MAGNeT", + "description": "Fast non-autoregressive music generation", + "variants": [ + {"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params", "capabilities": ["text", "music"]}, + {"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params", "capabilities": ["text", "music"]}, + {"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects", "capabilities": ["text", "sfx"]}, + {"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects", "capabilities": ["text", "sfx"]}, + ], + }, + "musicgen-style": { + "id": "musicgen-style", + "name": "MusicGen Style", + "description": "Style-conditioned music from reference audio", + "variants": [ + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning", "capabilities": ["text", "style"]}, + ], + }, + "jasco": { + "id": "jasco", + "name": "JASCO", + "description": "Chord and drum-conditioned music generation", + "variants": [ + {"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation", "capabilities": ["text", "chords"]}, + {"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning", "capabilities": ["text", "chords", "drums"]}, + ], + }, +} + + +@router.get( + "/", + response_model=list[ModelInfo], + summary="List models", + description="Get information about all available models.", +) +async def list_models( + api_key: str = Depends(verify_api_key_dependency), +) -> list[ModelInfo]: + """List all available models.""" + models = [] + + for model_id, info in MODEL_CATALOG.items(): + loaded = False + current_variant = None + + if _model_registry: + loaded = _model_registry.is_loaded(model_id) + if loaded: + current_variant = _model_registry.get_current_variant(model_id) + + models.append( + ModelInfo( + id=info["id"], + name=info["name"], + description=info["description"], + variants=[ModelVariantInfo(**v) for v in info["variants"]], + loaded=loaded, + current_variant=current_variant, + ) + ) + + return models + + +@router.get( + "/{model_id}", + response_model=ModelInfo, + summary="Get model info", + description="Get detailed information about a specific model.", +) +async def get_model( + model_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> ModelInfo: + """Get model information by ID.""" + if model_id not in MODEL_CATALOG: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model {model_id} not found", + ) + + info = MODEL_CATALOG[model_id] + loaded = False + current_variant = None + + if _model_registry: + loaded = _model_registry.is_loaded(model_id) + if loaded: + current_variant = _model_registry.get_current_variant(model_id) + + return ModelInfo( + id=info["id"], + name=info["name"], + description=info["description"], + variants=[ModelVariantInfo(**v) for v in info["variants"]], + loaded=loaded, + current_variant=current_variant, + ) + + +@router.post( + "/{model_id}/load", + summary="Load model", + description="Load a model into GPU memory.", +) +async def load_model( + model_id: str, + variant: str = "medium", + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Load a model into memory.""" + if model_id not in MODEL_CATALOG: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model {model_id} not found", + ) + + if _model_registry is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Model registry not available", + ) + + try: + await _model_registry.load_model(model_id, variant) + return {"message": f"Model {model_id} ({variant}) loaded successfully"} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.post( + "/{model_id}/unload", + summary="Unload model", + description="Unload a model from GPU memory.", +) +async def unload_model( + model_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Unload a model from memory.""" + if model_id not in MODEL_CATALOG: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Model {model_id} not found", + ) + + if _model_registry is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Model registry not available", + ) + + try: + await _model_registry.unload_model(model_id) + return {"message": f"Model {model_id} unloaded successfully"} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/loaded", + response_model=list[str], + summary="List loaded models", + description="Get list of currently loaded models.", +) +async def list_loaded_models( + api_key: str = Depends(verify_api_key_dependency), +) -> list[str]: + """List currently loaded models.""" + if _model_registry is None: + return [] + + return _model_registry.get_loaded_models() diff --git a/src/api/routes/projects.py b/src/api/routes/projects.py new file mode 100644 index 0000000..2afc425 --- /dev/null +++ b/src/api/routes/projects.py @@ -0,0 +1,250 @@ +"""Projects API endpoints.""" + +from typing import Any, Optional +from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi.responses import FileResponse + +from src.api.auth import verify_api_key_dependency +from src.api.models import ( + ProjectCreate, + ProjectResponse, + GenerationResponse, + PaginatedResponse, +) + + +router = APIRouter(prefix="/projects", tags=["projects"]) + + +# Service dependency (injected at app startup) +_project_service = None + + +def set_services(project_service: Any) -> None: + """Set service dependencies.""" + global _project_service + _project_service = project_service + + +@router.post( + "/", + response_model=ProjectResponse, + status_code=status.HTTP_201_CREATED, + summary="Create project", + description="Create a new project for organizing generations.", +) +async def create_project( + request: ProjectCreate, + api_key: str = Depends(verify_api_key_dependency), +) -> ProjectResponse: + """Create a new project.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + project = await _project_service.create_project( + name=request.name, + description=request.description, + ) + return ProjectResponse(**project) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/", + response_model=list[ProjectResponse], + summary="List projects", + description="Get all projects.", +) +async def list_projects( + api_key: str = Depends(verify_api_key_dependency), +) -> list[ProjectResponse]: + """List all projects.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + projects = await _project_service.list_projects() + return [ProjectResponse(**p) for p in projects] + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/{project_id}", + response_model=ProjectResponse, + summary="Get project", + description="Get project details by ID.", +) +async def get_project( + project_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> ProjectResponse: + """Get a project by ID.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + project = await _project_service.get_project(project_id) + if project is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project {project_id} not found", + ) + return ProjectResponse(**project) + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.delete( + "/{project_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete project", + description="Delete a project and all its generations.", +) +async def delete_project( + project_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> None: + """Delete a project.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + await _project_service.delete_project(project_id) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/{project_id}/generations", + response_model=PaginatedResponse, + summary="List generations", + description="Get generations for a project with pagination.", +) +async def list_generations( + project_id: str, + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + model: Optional[str] = Query(None, description="Filter by model"), + api_key: str = Depends(verify_api_key_dependency), +) -> PaginatedResponse: + """List generations for a project.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + offset = (page - 1) * page_size + generations = await _project_service.list_generations( + project_id=project_id, + limit=page_size + 1, # +1 to check if more pages + offset=offset, + model_filter=model, + ) + + has_more = len(generations) > page_size + generations = generations[:page_size] + + # Estimate total (could be improved with actual count query) + total = offset + len(generations) + (1 if has_more else 0) + pages = (total + page_size - 1) // page_size + + return PaginatedResponse( + items=[GenerationResponse(**g) for g in generations], + total=total, + page=page, + page_size=page_size, + pages=pages, + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/{project_id}/export", + summary="Export project", + description="Export project as ZIP file with all audio and metadata.", +) +async def export_project( + project_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> FileResponse: + """Export project as ZIP.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + zip_path = await _project_service.export_project_zip(project_id) + return FileResponse( + path=zip_path, + filename=f"project_{project_id}.zip", + media_type="application/zip", + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.delete( + "/{project_id}/generations/{generation_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete generation", + description="Delete a specific generation.", +) +async def delete_generation( + project_id: str, + generation_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> None: + """Delete a generation.""" + if _project_service is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Project service not available", + ) + + try: + await _project_service.delete_generation(generation_id) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) diff --git a/src/api/routes/system.py b/src/api/routes/system.py new file mode 100644 index 0000000..02018d0 --- /dev/null +++ b/src/api/routes/system.py @@ -0,0 +1,263 @@ +"""System API endpoints.""" + +import time +from typing import Any +from pathlib import Path +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import FileResponse + +from src.api.auth import verify_api_key_dependency, get_key_manager +from src.api.models import GPUStatus, QueueStatus, SystemStatus + + +router = APIRouter(prefix="/system", tags=["system"]) + + +# Service dependencies (injected at app startup) +_gpu_manager = None +_batch_processor = None +_model_registry = None +_start_time = time.time() + + +def set_services( + gpu_manager: Any, + batch_processor: Any, + model_registry: Any, +) -> None: + """Set service dependencies.""" + global _gpu_manager, _batch_processor, _model_registry + _gpu_manager = gpu_manager + _batch_processor = batch_processor + _model_registry = model_registry + + +@router.get( + "/status", + response_model=SystemStatus, + summary="System status", + description="Get overall system status including GPU, queue, and loaded models.", +) +async def get_status( + api_key: str = Depends(verify_api_key_dependency), +) -> SystemStatus: + """Get system status.""" + # GPU status + if _gpu_manager: + gpu = GPUStatus( + device_name=_gpu_manager.device_name, + total_gb=_gpu_manager.total_memory / 1024**3, + used_gb=_gpu_manager.get_used_memory() / 1024**3, + available_gb=_gpu_manager.get_available_memory() / 1024**3, + utilization_percent=_gpu_manager.get_utilization(), + temperature_c=_gpu_manager.get_temperature(), + ) + else: + gpu = GPUStatus( + device_name="Unknown", + total_gb=0, + used_gb=0, + available_gb=0, + utilization_percent=0, + temperature_c=None, + ) + + # Queue status + if _batch_processor: + queue = QueueStatus( + queue_size=len(_batch_processor.queue), + active_jobs=_batch_processor.active_count, + completed_today=_batch_processor.completed_count, + failed_today=_batch_processor.failed_count, + ) + else: + queue = QueueStatus( + queue_size=0, + active_jobs=0, + completed_today=0, + failed_today=0, + ) + + # Loaded models + loaded_models = [] + if _model_registry: + loaded_models = _model_registry.get_loaded_models() + + return SystemStatus( + gpu=gpu, + queue=queue, + loaded_models=loaded_models, + uptime_seconds=time.time() - _start_time, + ) + + +@router.get( + "/gpu", + response_model=GPUStatus, + summary="GPU status", + description="Get detailed GPU memory and utilization status.", +) +async def get_gpu_status( + api_key: str = Depends(verify_api_key_dependency), +) -> GPUStatus: + """Get GPU status.""" + if _gpu_manager is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="GPU manager not available", + ) + + return GPUStatus( + device_name=_gpu_manager.device_name, + total_gb=_gpu_manager.total_memory / 1024**3, + used_gb=_gpu_manager.get_used_memory() / 1024**3, + available_gb=_gpu_manager.get_available_memory() / 1024**3, + utilization_percent=_gpu_manager.get_utilization(), + temperature_c=_gpu_manager.get_temperature(), + ) + + +@router.post( + "/clear-cache", + summary="Clear cache", + description="Clear model cache and free GPU memory.", +) +async def clear_cache( + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Clear model cache.""" + if _model_registry is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Model registry not available", + ) + + try: + _model_registry.clear_cache() + return {"message": "Cache cleared successfully"} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.post( + "/unload-all", + summary="Unload all models", + description="Unload all models from GPU memory.", +) +async def unload_all_models( + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Unload all models.""" + if _model_registry is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Model registry not available", + ) + + try: + await _model_registry.unload_all() + return {"message": "All models unloaded successfully"} + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e), + ) + + +@router.get( + "/health", + summary="Health check", + description="Simple health check endpoint.", +) +async def health_check() -> dict: + """Health check endpoint (no auth required).""" + return { + "status": "healthy", + "uptime_seconds": time.time() - _start_time, + } + + +@router.post( + "/api-key/regenerate", + summary="Regenerate API key", + description="Generate a new API key. The old key will be invalidated.", +) +async def regenerate_api_key( + api_key: str = Depends(verify_api_key_dependency), +) -> dict: + """Regenerate API key.""" + key_manager = get_key_manager() + new_key = key_manager.generate_new_key() + + return { + "api_key": new_key, + "message": "New API key generated. Store it securely - it won't be shown again.", + } + + +@router.get( + "/audio/{generation_id}", + summary="Download audio", + description="Download generated audio file.", +) +async def download_audio( + generation_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> FileResponse: + """Download audio file for a generation.""" + # This would look up the actual file path from the database + # For now, construct expected path + from config.settings import get_settings + settings = get_settings() + + # Find the audio file + audio_dir = Path(settings.output_dir) + possible_paths = [ + audio_dir / f"{generation_id}.wav", + audio_dir / f"{generation_id}.mp3", + audio_dir / f"{generation_id}.flac", + ] + + for path in possible_paths: + if path.exists(): + return FileResponse( + path=path, + filename=path.name, + media_type="audio/wav" if path.suffix == ".wav" else f"audio/{path.suffix[1:]}", + ) + + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Audio file for generation {generation_id} not found", + ) + + +@router.get( + "/audio/{generation_id}/waveform", + summary="Download waveform", + description="Download waveform visualization image.", +) +async def download_waveform( + generation_id: str, + api_key: str = Depends(verify_api_key_dependency), +) -> FileResponse: + """Download waveform image for a generation.""" + from config.settings import get_settings + settings = get_settings() + + waveform_path = Path(settings.output_dir) / f"{generation_id}_waveform.png" + + if not waveform_path.exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Waveform for generation {generation_id} not found", + ) + + return FileResponse( + path=waveform_path, + filename=waveform_path.name, + media_type="image/png", + ) diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..650e771 --- /dev/null +++ b/src/core/__init__.py @@ -0,0 +1,24 @@ +"""Core infrastructure for AudioCraft Studio.""" + +from src.core.base_model import ( + BaseAudioModel, + GenerationRequest, + GenerationResult, + ConditioningType, +) +from src.core.gpu_manager import GPUMemoryManager, VRAMBudget +from src.core.model_registry import ModelRegistry +from src.core.oom_handler import OOMHandler, OOMRecoveryError, oom_safe + +__all__ = [ + "BaseAudioModel", + "GenerationRequest", + "GenerationResult", + "ConditioningType", + "GPUMemoryManager", + "VRAMBudget", + "ModelRegistry", + "OOMHandler", + "OOMRecoveryError", + "oom_safe", +] diff --git a/src/core/audio_utils.py b/src/core/audio_utils.py new file mode 100644 index 0000000..4b65bb7 --- /dev/null +++ b/src/core/audio_utils.py @@ -0,0 +1,535 @@ +"""Audio utilities for processing, visualization, and export.""" + +import io +import logging +from pathlib import Path +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def normalize_audio( + audio: Union[torch.Tensor, np.ndarray], + target_db: float = -14.0, + peak_normalize: bool = False, +) -> np.ndarray: + """Normalize audio to target loudness. + + Args: + audio: Audio tensor or array [channels, samples] or [samples] + target_db: Target loudness in dB (LUFS-like) + peak_normalize: If True, normalize to peak instead of RMS + + Returns: + Normalized audio as numpy array + """ + if isinstance(audio, torch.Tensor): + audio = audio.numpy() + + # Ensure float32 + audio = audio.astype(np.float32) + + # Handle batch dimension + if audio.ndim == 3: + audio = audio[0] # Take first sample if batched + + if peak_normalize: + # Peak normalization + peak = np.abs(audio).max() + if peak > 0: + target_linear = 10 ** (target_db / 20) + audio = audio * (target_linear / peak) + else: + # RMS normalization (approximating LUFS) + rms = np.sqrt(np.mean(audio ** 2)) + if rms > 0: + target_rms = 10 ** (target_db / 20) + audio = audio * (target_rms / rms) + + # Clip to prevent clipping + audio = np.clip(audio, -1.0, 1.0) + + return audio + + +def convert_sample_rate( + audio: np.ndarray, + orig_sr: int, + target_sr: int, +) -> np.ndarray: + """Convert audio sample rate. + + Args: + audio: Audio array [channels, samples] or [samples] + orig_sr: Original sample rate + target_sr: Target sample rate + + Returns: + Resampled audio + """ + if orig_sr == target_sr: + return audio + + try: + import librosa + + # Handle multi-channel + if audio.ndim == 2: + resampled = np.array([ + librosa.resample(ch, orig_sr=orig_sr, target_sr=target_sr) + for ch in audio + ]) + else: + resampled = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr) + + return resampled + + except ImportError: + logger.warning("librosa not available, using scipy for resampling") + from scipy import signal + + ratio = target_sr / orig_sr + new_length = int(audio.shape[-1] * ratio) + + if audio.ndim == 2: + resampled = np.array([ + signal.resample(ch, new_length) for ch in audio + ]) + else: + resampled = signal.resample(audio, new_length) + + return resampled + + +def generate_waveform( + audio: Union[torch.Tensor, np.ndarray], + sample_rate: int, + width: int = 800, + height: int = 200, + color: str = "#3b82f6", + background: str = "#1f2937", +) -> bytes: + """Generate waveform image as PNG bytes. + + Args: + audio: Audio data [channels, samples] or [samples] + sample_rate: Sample rate in Hz + width: Image width in pixels + height: Image height in pixels + color: Waveform color (hex) + background: Background color (hex) + + Returns: + PNG image as bytes + """ + try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + except ImportError: + logger.warning("matplotlib not available for waveform generation") + return b"" + + if isinstance(audio, torch.Tensor): + audio = audio.numpy() + + # Handle dimensions + if audio.ndim == 3: + audio = audio[0] + if audio.ndim == 2: + audio = audio.mean(axis=0) # Mix to mono for visualization + + # Downsample for visualization + samples_per_pixel = max(1, len(audio) // width) + num_chunks = len(audio) // samples_per_pixel + + if num_chunks > 0: + audio_chunks = audio[:num_chunks * samples_per_pixel].reshape( + num_chunks, samples_per_pixel + ) + # Get min/max for each chunk + mins = audio_chunks.min(axis=1) + maxs = audio_chunks.max(axis=1) + else: + mins = maxs = audio + + # Create figure + fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100) + fig.patch.set_facecolor(background) + ax.set_facecolor(background) + + # Plot waveform + x = np.arange(len(mins)) + ax.fill_between(x, mins, maxs, color=color, alpha=0.7) + ax.axhline(y=0, color=color, alpha=0.3, linewidth=0.5) + + # Style + ax.set_xlim(0, len(mins)) + ax.set_ylim(-1, 1) + ax.axis('off') + plt.tight_layout(pad=0) + + # Save to bytes + buf = io.BytesIO() + fig.savefig(buf, format='png', facecolor=background, edgecolor='none') + plt.close(fig) + buf.seek(0) + + return buf.read() + + +def generate_spectrogram( + audio: Union[torch.Tensor, np.ndarray], + sample_rate: int, + width: int = 800, + height: int = 200, + colormap: str = "magma", +) -> bytes: + """Generate spectrogram image as PNG bytes. + + Args: + audio: Audio data + sample_rate: Sample rate in Hz + width: Image width + height: Image height + colormap: Matplotlib colormap name + + Returns: + PNG image as bytes + """ + try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.pyplot as plt + import librosa + import librosa.display + except ImportError: + logger.warning("matplotlib/librosa not available for spectrogram") + return b"" + + if isinstance(audio, torch.Tensor): + audio = audio.numpy() + + # Handle dimensions + if audio.ndim == 3: + audio = audio[0] + if audio.ndim == 2: + audio = audio.mean(axis=0) + + # Compute mel spectrogram + S = librosa.feature.melspectrogram( + y=audio, + sr=sample_rate, + n_mels=128, + fmax=sample_rate // 2, + ) + S_db = librosa.power_to_db(S, ref=np.max) + + # Create figure + fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100) + + librosa.display.specshow( + S_db, + sr=sample_rate, + x_axis='time', + y_axis='mel', + cmap=colormap, + ax=ax, + ) + + ax.axis('off') + plt.tight_layout(pad=0) + + # Save to bytes + buf = io.BytesIO() + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) + plt.close(fig) + buf.seek(0) + + return buf.read() + + +def save_audio( + audio: Union[torch.Tensor, np.ndarray], + sample_rate: int, + path: Path, + format: str = "wav", + normalize: bool = True, + target_db: float = -14.0, +) -> Path: + """Save audio to file with optional normalization. + + Args: + audio: Audio data + sample_rate: Sample rate + path: Output path (extension will be added if needed) + format: Output format (wav, mp3, flac, ogg) + normalize: Whether to normalize audio + target_db: Normalization target + + Returns: + Path to saved file + """ + import soundfile as sf + + if isinstance(audio, torch.Tensor): + audio = audio.numpy() + + # Handle batch dimension + if audio.ndim == 3: + audio = audio[0] + + # Normalize if requested + if normalize: + audio = normalize_audio(audio, target_db=target_db) + + # Transpose for soundfile [samples, channels] + if audio.ndim == 2: + audio = audio.T + + # Ensure correct extension + path = Path(path) + if not path.suffix: + path = path.with_suffix(f".{format}") + + # Save based on format + if format in ("wav", "flac"): + sf.write(path, audio, sample_rate) + elif format == "mp3": + # Use scipy.io.wavfile then convert with pydub if available + try: + from pydub import AudioSegment + + # Save as WAV first + wav_path = path.with_suffix(".wav") + sf.write(wav_path, audio, sample_rate) + + # Convert to MP3 + sound = AudioSegment.from_wav(wav_path) + sound.export(path, format="mp3", bitrate="320k") + + # Remove temp WAV + wav_path.unlink() + except ImportError: + logger.warning("pydub not available, saving as WAV instead") + path = path.with_suffix(".wav") + sf.write(path, audio, sample_rate) + elif format == "ogg": + sf.write(path, audio, sample_rate, format="ogg", subtype="vorbis") + else: + # Default to WAV + path = path.with_suffix(".wav") + sf.write(path, audio, sample_rate) + + return path + + +def load_audio( + path: Path, + target_sr: Optional[int] = None, + mono: bool = False, +) -> Tuple[np.ndarray, int]: + """Load audio from file. + + Args: + path: Path to audio file + target_sr: Target sample rate (None to keep original) + mono: Convert to mono + + Returns: + Tuple of (audio_array, sample_rate) + """ + import soundfile as sf + + audio, sr = sf.read(path) + + # Convert to [channels, samples] format + if audio.ndim == 1: + audio = audio[np.newaxis, :] + else: + audio = audio.T + + # Convert to mono + if mono and audio.shape[0] > 1: + audio = audio.mean(axis=0, keepdims=True) + + # Resample if needed + if target_sr and target_sr != sr: + audio = convert_sample_rate(audio, sr, target_sr) + sr = target_sr + + return audio, sr + + +def get_audio_info(path: Path) -> dict: + """Get audio file information. + + Args: + path: Path to audio file + + Returns: + Dictionary with audio info + """ + import soundfile as sf + + info = sf.info(path) + + return { + "path": str(path), + "duration": info.duration, + "sample_rate": info.samplerate, + "channels": info.channels, + "format": info.format, + "subtype": info.subtype, + "frames": info.frames, + } + + +def trim_silence( + audio: np.ndarray, + sample_rate: int, + threshold_db: float = -40.0, + min_silence_ms: int = 100, +) -> np.ndarray: + """Trim silence from start and end of audio. + + Args: + audio: Audio array + sample_rate: Sample rate + threshold_db: Silence threshold in dB + min_silence_ms: Minimum silence duration to trim + + Returns: + Trimmed audio + """ + try: + import librosa + + if audio.ndim == 2: + # Process mono for trimming + mono = audio.mean(axis=0) + else: + mono = audio + + # Get non-silent intervals + intervals = librosa.effects.split( + mono, + top_db=abs(threshold_db), + frame_length=int(sample_rate * min_silence_ms / 1000), + ) + + if len(intervals) == 0: + return audio + + start = intervals[0][0] + end = intervals[-1][1] + + if audio.ndim == 2: + return audio[:, start:end] + return audio[start:end] + + except ImportError: + logger.warning("librosa not available for silence trimming") + return audio + + +def apply_fade( + audio: np.ndarray, + sample_rate: int, + fade_in_ms: float = 0, + fade_out_ms: float = 0, +) -> np.ndarray: + """Apply fade in/out to audio. + + Args: + audio: Audio array [channels, samples] or [samples] + sample_rate: Sample rate + fade_in_ms: Fade in duration in milliseconds + fade_out_ms: Fade out duration in milliseconds + + Returns: + Audio with fades applied + """ + audio = audio.copy() + + if fade_in_ms > 0: + fade_in_samples = int(sample_rate * fade_in_ms / 1000) + fade_in_samples = min(fade_in_samples, audio.shape[-1]) + fade_in_curve = np.linspace(0, 1, fade_in_samples) + + if audio.ndim == 2: + audio[:, :fade_in_samples] *= fade_in_curve + else: + audio[:fade_in_samples] *= fade_in_curve + + if fade_out_ms > 0: + fade_out_samples = int(sample_rate * fade_out_ms / 1000) + fade_out_samples = min(fade_out_samples, audio.shape[-1]) + fade_out_curve = np.linspace(1, 0, fade_out_samples) + + if audio.ndim == 2: + audio[:, -fade_out_samples:] *= fade_out_curve + else: + audio[-fade_out_samples:] *= fade_out_curve + + return audio + + +def concatenate_audio( + audio_list: list[np.ndarray], + sample_rate: int, + crossfade_ms: float = 0, +) -> np.ndarray: + """Concatenate multiple audio segments. + + Args: + audio_list: List of audio arrays + sample_rate: Sample rate (must be same for all) + crossfade_ms: Crossfade duration between segments + + Returns: + Concatenated audio + """ + if not audio_list: + return np.array([]) + + if len(audio_list) == 1: + return audio_list[0] + + crossfade_samples = int(sample_rate * crossfade_ms / 1000) + + result = audio_list[0] + + for audio in audio_list[1:]: + if crossfade_samples > 0 and crossfade_samples < min( + result.shape[-1], audio.shape[-1] + ): + # Apply crossfade + fade_out = np.linspace(1, 0, crossfade_samples) + fade_in = np.linspace(0, 1, crossfade_samples) + + if result.ndim == 2: + # Overlap region + result[:, -crossfade_samples:] *= fade_out + overlap = result[:, -crossfade_samples:] + audio[:, :crossfade_samples] * fade_in + result = np.concatenate([ + result[:, :-crossfade_samples], + overlap, + audio[:, crossfade_samples:] + ], axis=1) + else: + result[-crossfade_samples:] *= fade_out + overlap = result[-crossfade_samples:] + audio[:crossfade_samples] * fade_in + result = np.concatenate([ + result[:-crossfade_samples], + overlap, + audio[crossfade_samples:] + ]) + else: + # Simple concatenation + result = np.concatenate([result, audio], axis=-1) + + return result diff --git a/src/core/base_model.py b/src/core/base_model.py new file mode 100644 index 0000000..42cdcfc --- /dev/null +++ b/src/core/base_model.py @@ -0,0 +1,247 @@ +"""Abstract base classes for AudioCraft model adapters.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +import torch + + +class ConditioningType(str, Enum): + """Types of conditioning supported by models.""" + + TEXT = "text" + MELODY = "melody" + STYLE = "style" + CHORDS = "chords" + DRUMS = "drums" + + +@dataclass +class GenerationRequest: + """Request parameters for audio generation. + + Attributes: + prompts: List of text prompts for generation + duration: Target duration in seconds + temperature: Sampling temperature (higher = more random) + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter (0 = disabled) + cfg_coef: Classifier-free guidance coefficient + batch_size: Number of samples to generate per prompt + seed: Random seed for reproducibility + conditioning: Optional conditioning data + """ + + prompts: list[str] + duration: float = 10.0 + temperature: float = 1.0 + top_k: int = 250 + top_p: float = 0.0 + cfg_coef: float = 3.0 + batch_size: int = 1 + seed: Optional[int] = None + conditioning: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate request parameters.""" + if not self.prompts: + raise ValueError("At least one prompt is required") + if self.duration <= 0: + raise ValueError("Duration must be positive") + if self.temperature < 0: + raise ValueError("Temperature must be non-negative") + if self.top_k < 0: + raise ValueError("top_k must be non-negative") + if not 0 <= self.top_p <= 1: + raise ValueError("top_p must be between 0 and 1") + if self.cfg_coef < 1: + raise ValueError("cfg_coef must be >= 1") + + +@dataclass +class GenerationResult: + """Result of audio generation. + + Attributes: + audio: Generated audio tensor (shape: [batch, channels, samples]) + sample_rate: Audio sample rate in Hz + duration: Actual duration in seconds + model_id: ID of the model used + variant: Model variant used + parameters: Generation parameters used + seed: Actual seed used (for reproducibility) + """ + + audio: torch.Tensor + sample_rate: int + duration: float + model_id: str + variant: str + parameters: dict[str, Any] + seed: int + + @property + def num_samples(self) -> int: + """Number of audio samples generated.""" + return self.audio.shape[0] + + @property + def num_channels(self) -> int: + """Number of audio channels.""" + return self.audio.shape[1] + + @property + def num_frames(self) -> int: + """Number of audio frames.""" + return self.audio.shape[2] + + +class BaseAudioModel(ABC): + """Abstract base class for AudioCraft model adapters. + + All model adapters must implement this interface to integrate with + the model registry and generation service. + """ + + @property + @abstractmethod + def model_id(self) -> str: + """Unique identifier for this model family (e.g., 'musicgen').""" + ... + + @property + @abstractmethod + def variant(self) -> str: + """Current model variant (e.g., 'medium', 'large').""" + ... + + @property + @abstractmethod + def display_name(self) -> str: + """Human-readable name for UI display.""" + ... + + @property + @abstractmethod + def description(self) -> str: + """Brief description of the model's capabilities.""" + ... + + @property + @abstractmethod + def vram_estimate_mb(self) -> int: + """Estimated VRAM usage when loaded (in megabytes).""" + ... + + @property + @abstractmethod + def max_duration(self) -> float: + """Maximum supported generation duration in seconds.""" + ... + + @property + @abstractmethod + def sample_rate(self) -> int: + """Output audio sample rate in Hz.""" + ... + + @property + @abstractmethod + def supports_conditioning(self) -> list[ConditioningType]: + """List of conditioning types supported by this model.""" + ... + + @property + @abstractmethod + def is_loaded(self) -> bool: + """Whether the model is currently loaded in memory.""" + ... + + @property + def device(self) -> Optional[torch.device]: + """Device the model is loaded on, or None if not loaded.""" + return None + + @abstractmethod + def load(self, device: str = "cuda") -> None: + """Load the model into memory. + + Args: + device: Target device ('cuda', 'cuda:0', 'cpu', etc.) + + Raises: + RuntimeError: If loading fails + """ + ... + + @abstractmethod + def unload(self) -> None: + """Unload the model and free memory. + + Should be idempotent - safe to call even if not loaded. + """ + ... + + @abstractmethod + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate audio based on the request. + + Args: + request: Generation parameters and prompts + + Returns: + GenerationResult containing audio and metadata + + Raises: + RuntimeError: If model is not loaded + ValueError: If request parameters are invalid for this model + """ + ... + + @abstractmethod + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters for this model. + + Returns: + Dictionary of parameter names to default values + """ + ... + + def validate_request(self, request: GenerationRequest) -> None: + """Validate a generation request for this model. + + Args: + request: Request to validate + + Raises: + ValueError: If request is invalid for this model + """ + if not self.is_loaded: + raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded") + + if request.duration > self.max_duration: + raise ValueError( + f"Duration {request.duration}s exceeds maximum {self.max_duration}s " + f"for {self.model_id}/{self.variant}" + ) + + # Check conditioning requirements + for cond_type, cond_data in request.conditioning.items(): + if cond_data is not None: + try: + cond_enum = ConditioningType(cond_type) + except ValueError: + raise ValueError(f"Unknown conditioning type: {cond_type}") + + if cond_enum not in self.supports_conditioning: + raise ValueError( + f"Model {self.model_id}/{self.variant} does not support " + f"{cond_type} conditioning" + ) + + def __repr__(self) -> str: + """String representation.""" + loaded = "loaded" if self.is_loaded else "not loaded" + return f"<{self.__class__.__name__} {self.model_id}/{self.variant} ({loaded})>" diff --git a/src/core/gpu_manager.py b/src/core/gpu_manager.py new file mode 100644 index 0000000..03595ae --- /dev/null +++ b/src/core/gpu_manager.py @@ -0,0 +1,433 @@ +"""GPU memory management for AudioCraft models.""" + +import gc +import json +import logging +import threading +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional + +import torch + +logger = logging.getLogger(__name__) + + +@dataclass +class VRAMBudget: + """VRAM budget allocation information. + + Attributes: + total_mb: Total VRAM in megabytes + used_mb: Currently used VRAM + free_mb: Free VRAM + reserved_comfyui_mb: VRAM reserved for ComfyUI + safety_buffer_mb: Safety buffer to prevent OOM + available_mb: VRAM available for AudioCraft models + """ + + total_mb: int + used_mb: int + free_mb: int + reserved_comfyui_mb: int + safety_buffer_mb: int + available_mb: int + + @property + def utilization(self) -> float: + """Current VRAM utilization as a fraction (0-1).""" + return self.used_mb / self.total_mb if self.total_mb > 0 else 0.0 + + +@dataclass +class GPUState: + """State information for inter-service coordination.""" + + timestamp: float + service: str # "audiocraft" or "comfyui" + vram_used_mb: int + vram_requested_mb: int + status: str # "idle", "working", "requesting_priority", "yielded" + + +class GPUMemoryManager: + """Manages GPU memory allocation and coordination with ComfyUI. + + Uses pynvml for accurate system-wide VRAM tracking and file-based + IPC for coordination with ComfyUI running on the same system. + """ + + COORDINATION_FILE = Path("/tmp/audiocraft_comfyui_coord.json") + LOCK_FILE = Path("/tmp/audiocraft_comfyui_coord.lock") + STALE_THRESHOLD = 30.0 # seconds + + def __init__( + self, + device_id: int = 0, + comfyui_reserve_gb: float = 10.0, + safety_buffer_gb: float = 1.0, + ): + """Initialize GPU memory manager. + + Args: + device_id: CUDA device index + comfyui_reserve_gb: VRAM to reserve for ComfyUI (gigabytes) + safety_buffer_gb: Safety buffer to prevent OOM (gigabytes) + """ + self.device_id = device_id + self.device = torch.device(f"cuda:{device_id}") + self.comfyui_reserve_mb = int(comfyui_reserve_gb * 1024) + self.safety_buffer_mb = int(safety_buffer_gb * 1024) + + # Initialize NVML for direct GPU monitoring + self._nvml_initialized = False + self._nvml_handle = None + self._init_nvml() + + # Threading + self._lock = threading.RLock() + + # Callbacks for memory events + self._low_memory_callbacks: list[Callable[[VRAMBudget], None]] = [] + self._oom_callbacks: list[Callable[[], None]] = [] + + # Initialize coordination file + self._ensure_coordination_file() + + def _init_nvml(self) -> None: + """Initialize NVML for GPU monitoring.""" + try: + import pynvml + + pynvml.nvmlInit() + self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id) + self._nvml_initialized = True + logger.info("NVML initialized successfully") + except ImportError: + logger.warning("pynvml not available, falling back to torch.cuda") + except Exception as e: + logger.warning(f"Failed to initialize NVML: {e}, falling back to torch.cuda") + + def _ensure_coordination_file(self) -> None: + """Create coordination file if it doesn't exist.""" + if not self.COORDINATION_FILE.exists(): + initial_state = { + "audiocraft": None, + "comfyui": None, + "priority": None, + "last_update": time.time(), + } + self._write_coordination_state(initial_state) + + def get_memory_info(self) -> dict[str, int]: + """Get current GPU memory status. + + Returns: + Dictionary with memory values in megabytes: + - total: Total VRAM + - used: Used VRAM (system-wide) + - free: Free VRAM + - torch_allocated: PyTorch allocated memory + - torch_reserved: PyTorch reserved memory + - torch_cached: PyTorch cached memory + """ + with self._lock: + if self._nvml_initialized: + return self._get_memory_info_nvml() + return self._get_memory_info_torch() + + def _get_memory_info_nvml(self) -> dict[str, int]: + """Get memory info using NVML (more accurate).""" + import pynvml + + info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle) + torch_allocated = torch.cuda.memory_allocated(self.device) + torch_reserved = torch.cuda.memory_reserved(self.device) + + return { + "total": info.total // (1024 * 1024), + "used": info.used // (1024 * 1024), + "free": info.free // (1024 * 1024), + "torch_allocated": torch_allocated // (1024 * 1024), + "torch_reserved": torch_reserved // (1024 * 1024), + "torch_cached": (torch_reserved - torch_allocated) // (1024 * 1024), + } + + def _get_memory_info_torch(self) -> dict[str, int]: + """Get memory info using torch.cuda (fallback).""" + props = torch.cuda.get_device_properties(self.device) + allocated = torch.cuda.memory_allocated(self.device) + reserved = torch.cuda.memory_reserved(self.device) + + # Note: This is less accurate for system-wide usage + return { + "total": props.total_memory // (1024 * 1024), + "used": reserved // (1024 * 1024), + "free": (props.total_memory - reserved) // (1024 * 1024), + "torch_allocated": allocated // (1024 * 1024), + "torch_reserved": reserved // (1024 * 1024), + "torch_cached": (reserved - allocated) // (1024 * 1024), + } + + def get_available_budget(self) -> VRAMBudget: + """Calculate available VRAM budget considering ComfyUI. + + Returns: + VRAMBudget with current allocation information + """ + mem = self.get_memory_info() + + # Check ComfyUI's actual usage via coordination file + comfyui_state = self.get_comfyui_status() + if comfyui_state and comfyui_state.status != "yielded": + # Use actual ComfyUI usage + buffer, or reserve, whichever is higher + effective_comfyui_reserve = max( + self.comfyui_reserve_mb, + comfyui_state.vram_used_mb + 2048, # 2GB headroom + ) + else: + effective_comfyui_reserve = self.comfyui_reserve_mb + + available = max( + 0, + mem["total"] + - mem["used"] + + mem["torch_allocated"] # Our own usage doesn't count against us + - effective_comfyui_reserve + - self.safety_buffer_mb, + ) + + return VRAMBudget( + total_mb=mem["total"], + used_mb=mem["used"], + free_mb=mem["free"], + reserved_comfyui_mb=effective_comfyui_reserve, + safety_buffer_mb=self.safety_buffer_mb, + available_mb=available, + ) + + def can_load_model(self, vram_required_mb: int) -> tuple[bool, str]: + """Check if a model can fit in available VRAM. + + Args: + vram_required_mb: VRAM needed by the model + + Returns: + Tuple of (can_load, reason_message) + """ + budget = self.get_available_budget() + + if vram_required_mb <= budget.available_mb: + return True, "Sufficient VRAM available" + + deficit = vram_required_mb - budget.available_mb + return False, ( + f"Insufficient VRAM: need {vram_required_mb}MB, " + f"available {budget.available_mb}MB (deficit: {deficit}MB)" + ) + + def force_cleanup(self) -> int: + """Force GPU memory cleanup. + + Returns: + Freed memory in megabytes (approximate) + """ + with self._lock: + before = self.get_memory_info() + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize(self.device) + + after = self.get_memory_info() + freed = before["torch_reserved"] - after["torch_reserved"] + + if freed > 0: + logger.info(f"Freed {freed}MB of GPU memory") + + return freed + + def get_status(self) -> dict[str, Any]: + """Get detailed GPU status for UI display. + + Returns: + Dictionary with status information + """ + mem = self.get_memory_info() + budget = self.get_available_budget() + + return { + "device": str(self.device), + "total_gb": round(mem["total"] / 1024, 2), + "used_gb": round(mem["used"] / 1024, 2), + "free_gb": round(mem["free"] / 1024, 2), + "utilization_percent": round(budget.utilization * 100, 1), + "available_for_models_gb": round(budget.available_mb / 1024, 2), + "comfyui_reserve_gb": round(budget.reserved_comfyui_mb / 1024, 2), + "torch_allocated_gb": round(mem["torch_allocated"] / 1024, 2), + "torch_cached_gb": round(mem["torch_cached"] / 1024, 2), + } + + # ComfyUI Coordination Methods + + def _read_coordination_state(self) -> dict[str, Any]: + """Read coordination state from file.""" + try: + if self.COORDINATION_FILE.exists(): + return json.loads(self.COORDINATION_FILE.read_text()) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Failed to read coordination file: {e}") + return {} + + def _write_coordination_state(self, state: dict[str, Any]) -> None: + """Write coordination state to file with locking.""" + import fcntl + + try: + self.LOCK_FILE.parent.mkdir(parents=True, exist_ok=True) + with open(self.LOCK_FILE, "w") as lock: + fcntl.flock(lock, fcntl.LOCK_EX) + try: + self.COORDINATION_FILE.write_text(json.dumps(state, indent=2)) + finally: + fcntl.flock(lock, fcntl.LOCK_UN) + except IOError as e: + logger.warning(f"Failed to write coordination file: {e}") + + def update_status( + self, + vram_used_mb: int, + vram_requested_mb: int = 0, + status: str = "idle", + ) -> None: + """Update AudioCraft's status in coordination file. + + Args: + vram_used_mb: Current VRAM usage + vram_requested_mb: VRAM needed for pending operation + status: Current status ("idle", "working", "requesting_priority") + """ + state = self._read_coordination_state() + state["audiocraft"] = { + "timestamp": time.time(), + "service": "audiocraft", + "vram_used_mb": vram_used_mb, + "vram_requested_mb": vram_requested_mb, + "status": status, + } + state["last_update"] = time.time() + self._write_coordination_state(state) + + def get_comfyui_status(self) -> Optional[GPUState]: + """Get ComfyUI's current status. + + Returns: + GPUState if ComfyUI is active and status is fresh, None otherwise + """ + state = self._read_coordination_state() + comfyui_data = state.get("comfyui") + + if not comfyui_data: + return None + + # Check if stale + if time.time() - comfyui_data.get("timestamp", 0) > self.STALE_THRESHOLD: + return None + + return GPUState( + timestamp=comfyui_data["timestamp"], + service="comfyui", + vram_used_mb=comfyui_data.get("vram_used_mb", 0), + vram_requested_mb=comfyui_data.get("vram_requested_mb", 0), + status=comfyui_data.get("status", "unknown"), + ) + + def request_priority(self, vram_needed_mb: int, timeout: float = 30.0) -> bool: + """Request VRAM priority from ComfyUI. + + Signals ComfyUI to release VRAM if possible. + + Args: + vram_needed_mb: Amount of VRAM needed + timeout: Seconds to wait for ComfyUI to yield + + Returns: + True if ComfyUI acknowledged and yielded, False otherwise + """ + state = self._read_coordination_state() + state["priority"] = { + "requester": "audiocraft", + "vram_needed_mb": vram_needed_mb, + "timestamp": time.time(), + } + self._write_coordination_state(state) + + logger.info(f"Requesting {vram_needed_mb}MB VRAM from ComfyUI...") + + # Wait for ComfyUI to respond + start = time.time() + while time.time() - start < timeout: + comfyui = self.get_comfyui_status() + if comfyui and comfyui.status == "yielded": + logger.info("ComfyUI yielded VRAM") + return True + time.sleep(0.5) + + logger.warning("ComfyUI did not yield VRAM within timeout") + return False + + def is_comfyui_busy(self) -> bool: + """Check if ComfyUI is actively processing. + + Returns: + True if ComfyUI is working, False otherwise + """ + status = self.get_comfyui_status() + return status is not None and status.status == "working" + + # Callback Registration + + def on_low_memory(self, callback: Callable[[VRAMBudget], None]) -> None: + """Register callback for low memory warnings. + + Args: + callback: Function to call with budget info when memory is low + """ + self._low_memory_callbacks.append(callback) + + def on_oom(self, callback: Callable[[], None]) -> None: + """Register callback for OOM events. + + Args: + callback: Function to call when OOM occurs + """ + self._oom_callbacks.append(callback) + + def check_memory_pressure(self, warning_threshold: float = 0.85) -> None: + """Check memory pressure and trigger callbacks if needed. + + Args: + warning_threshold: Utilization threshold for warnings (0-1) + """ + budget = self.get_available_budget() + + if budget.utilization >= warning_threshold: + logger.warning( + f"High GPU memory pressure: {budget.utilization*100:.1f}% utilized" + ) + for callback in self._low_memory_callbacks: + try: + callback(budget) + except Exception as e: + logger.error(f"Low memory callback failed: {e}") + + def __del__(self) -> None: + """Cleanup NVML on destruction.""" + if self._nvml_initialized: + try: + import pynvml + + pynvml.nvmlShutdown() + except Exception: + pass diff --git a/src/core/model_registry.py b/src/core/model_registry.py new file mode 100644 index 0000000..113f825 --- /dev/null +++ b/src/core/model_registry.py @@ -0,0 +1,487 @@ +"""Model registry for discovering and managing AudioCraft model adapters.""" + +import asyncio +import logging +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Generator, Optional, Type + +import yaml + +from src.core.base_model import BaseAudioModel, ConditioningType +from src.core.gpu_manager import GPUMemoryManager + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelVariantConfig: + """Configuration for a model variant.""" + + hf_id: str + vram_mb: int + max_duration: float = 30.0 + channels: int = 1 + conditioning: list[str] = field(default_factory=list) + description: str = "" + + +@dataclass +class ModelFamilyConfig: + """Configuration for a model family.""" + + enabled: bool + display_name: str + description: str + default_variant: str + variants: dict[str, ModelVariantConfig] + + +@dataclass +class ModelHandle: + """Handle for a loaded model with reference counting.""" + + model: BaseAudioModel + model_id: str + variant: str + loaded_at: float + last_accessed: float + ref_count: int = 0 + + def touch(self) -> None: + """Update last accessed time.""" + self.last_accessed = time.time() + + +class ModelRegistry: + """Central registry for discovering and managing model adapters. + + Handles: + - Loading model configurations from YAML + - Lazy loading models on demand + - LRU eviction when VRAM is constrained + - Reference counting to prevent unloading during use + - Automatic idle timeout for unused models + """ + + def __init__( + self, + config_path: Path, + gpu_manager: GPUMemoryManager, + max_cached_models: int = 2, + idle_timeout_minutes: int = 15, + ): + """Initialize the model registry. + + Args: + config_path: Path to models.yaml configuration + gpu_manager: GPU memory manager instance + max_cached_models: Maximum models to keep loaded + idle_timeout_minutes: Unload models after this idle time + """ + self.config_path = config_path + self.gpu_manager = gpu_manager + self.max_cached_models = max_cached_models + self.idle_timeout_seconds = idle_timeout_minutes * 60 + + # Model configurations + self._model_configs: dict[str, ModelFamilyConfig] = {} + self._default_params: dict[str, Any] = {} + + # Loaded model handles + self._handles: dict[str, ModelHandle] = {} # Key: "model_id/variant" + self._access_order: list[str] = [] # LRU tracking + + # Registered adapter classes + self._adapter_classes: dict[str, Type[BaseAudioModel]] = {} + + # Threading + self._lock = threading.RLock() + self._cleanup_thread: Optional[threading.Thread] = None + self._stop_cleanup = threading.Event() + + # Load configuration + self._load_config() + + def _load_config(self) -> None: + """Load model configurations from YAML file.""" + if not self.config_path.exists(): + logger.warning(f"Model config not found: {self.config_path}") + return + + with open(self.config_path) as f: + config = yaml.safe_load(f) + + # Parse model families + for model_id, model_config in config.get("models", {}).items(): + if not model_config.get("enabled", True): + continue + + variants = {} + for variant_name, variant_config in model_config.get("variants", {}).items(): + variants[variant_name] = ModelVariantConfig( + hf_id=variant_config["hf_id"], + vram_mb=variant_config["vram_mb"], + max_duration=variant_config.get("max_duration", 30.0), + channels=variant_config.get("channels", 1), + conditioning=variant_config.get("conditioning", []), + description=variant_config.get("description", ""), + ) + + self._model_configs[model_id] = ModelFamilyConfig( + enabled=model_config.get("enabled", True), + display_name=model_config.get("display_name", model_id), + description=model_config.get("description", ""), + default_variant=model_config.get("default_variant", "medium"), + variants=variants, + ) + + # Parse default generation parameters + self._default_params = config.get("defaults", {}).get("generation", {}) + + logger.info(f"Loaded {len(self._model_configs)} model families from config") + + def register_adapter( + self, model_id: str, adapter_class: Type[BaseAudioModel] + ) -> None: + """Register a model adapter class. + + Args: + model_id: Model family ID (e.g., 'musicgen') + adapter_class: Adapter class implementing BaseAudioModel + """ + self._adapter_classes[model_id] = adapter_class + logger.debug(f"Registered adapter for {model_id}: {adapter_class.__name__}") + + def list_models(self) -> list[dict[str, Any]]: + """List all available models with their configurations. + + Returns: + List of model information dictionaries + """ + models = [] + + for model_id, config in self._model_configs.items(): + for variant_name, variant in config.variants.items(): + key = f"{model_id}/{variant_name}" + handle = self._handles.get(key) + + can_load, reason = self.gpu_manager.can_load_model(variant.vram_mb) + + models.append({ + "model_id": model_id, + "variant": variant_name, + "display_name": config.display_name, + "description": variant.description or config.description, + "hf_id": variant.hf_id, + "vram_mb": variant.vram_mb, + "max_duration": variant.max_duration, + "channels": variant.channels, + "conditioning": variant.conditioning, + "is_default": variant_name == config.default_variant, + "is_loaded": handle is not None, + "can_load": can_load, + "load_reason": reason, + "has_adapter": model_id in self._adapter_classes, + }) + + return models + + def get_model_config( + self, model_id: str, variant: Optional[str] = None + ) -> tuple[ModelFamilyConfig, ModelVariantConfig]: + """Get configuration for a model. + + Args: + model_id: Model family ID + variant: Specific variant, or None for default + + Returns: + Tuple of (family_config, variant_config) + + Raises: + ValueError: If model or variant not found + """ + if model_id not in self._model_configs: + raise ValueError(f"Unknown model: {model_id}") + + family = self._model_configs[model_id] + variant = variant or family.default_variant + + if variant not in family.variants: + raise ValueError(f"Unknown variant {variant} for {model_id}") + + return family, family.variants[variant] + + def get_loaded_models(self) -> list[dict[str, Any]]: + """Get information about currently loaded models. + + Returns: + List of loaded model information + """ + with self._lock: + return [ + { + "model_id": handle.model_id, + "variant": handle.variant, + "loaded_at": handle.loaded_at, + "last_accessed": handle.last_accessed, + "ref_count": handle.ref_count, + "idle_seconds": time.time() - handle.last_accessed, + } + for handle in self._handles.values() + ] + + @contextmanager + def get_model( + self, model_id: str, variant: Optional[str] = None + ) -> Generator[BaseAudioModel, None, None]: + """Get a model, loading it if necessary. + + Context manager that handles reference counting to prevent + unloading during use. + + Args: + model_id: Model family ID + variant: Specific variant, or None for default + + Yields: + Loaded model instance + + Raises: + ValueError: If model not found or cannot be loaded + RuntimeError: If VRAM insufficient + """ + family, variant_config = self.get_model_config(model_id, variant) + variant = variant or family.default_variant + key = f"{model_id}/{variant}" + + with self._lock: + # Get or load model + if key not in self._handles: + self._load_model(model_id, variant) + + handle = self._handles[key] + handle.ref_count += 1 + handle.touch() + + # Update LRU order + if key in self._access_order: + self._access_order.remove(key) + self._access_order.append(key) + + try: + yield handle.model + finally: + with self._lock: + handle.ref_count -= 1 + + def _load_model(self, model_id: str, variant: str) -> None: + """Load a model into memory. + + Must be called with self._lock held. + + Args: + model_id: Model family ID + variant: Variant to load + + Raises: + ValueError: If no adapter registered + RuntimeError: If VRAM insufficient + """ + key = f"{model_id}/{variant}" + family, variant_config = self.get_model_config(model_id, variant) + + # Check for adapter + if model_id not in self._adapter_classes: + raise ValueError(f"No adapter registered for {model_id}") + + # Check VRAM + can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb) + if not can_load: + # Try to free memory by evicting models + self._evict_for_space(variant_config.vram_mb) + can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb) + if not can_load: + raise RuntimeError(reason) + + # Create and load model + logger.info(f"Loading model {key}...") + adapter_class = self._adapter_classes[model_id] + model = adapter_class(variant=variant) + model.load() + + # Register handle + self._handles[key] = ModelHandle( + model=model, + model_id=model_id, + variant=variant, + loaded_at=time.time(), + last_accessed=time.time(), + ) + self._access_order.append(key) + + # Update GPU status + mem = self.gpu_manager.get_memory_info() + self.gpu_manager.update_status(mem["torch_allocated"], status="working") + + logger.info(f"Model {key} loaded successfully") + + def _evict_for_space(self, needed_mb: int) -> bool: + """Evict models to free up VRAM. + + Must be called with self._lock held. + + Args: + needed_mb: VRAM needed + + Returns: + True if enough space was freed + """ + freed = 0 + budget = self.gpu_manager.get_available_budget() + deficit = needed_mb - budget.available_mb + + if deficit <= 0: + return True + + # Evict LRU models that have no active references + for key in list(self._access_order): + if deficit <= 0: + break + + handle = self._handles.get(key) + if handle and handle.ref_count == 0: + _, variant_config = self.get_model_config( + handle.model_id, handle.variant + ) + logger.info(f"Evicting {key} to free {variant_config.vram_mb}MB") + self._unload_model(key) + freed += variant_config.vram_mb + deficit -= variant_config.vram_mb + + self.gpu_manager.force_cleanup() + return deficit <= 0 + + def _unload_model(self, key: str) -> None: + """Unload a model from memory. + + Must be called with self._lock held. + + Args: + key: Model key (model_id/variant) + """ + if key not in self._handles: + return + + handle = self._handles[key] + if handle.ref_count > 0: + logger.warning(f"Cannot unload {key}: {handle.ref_count} active references") + return + + logger.info(f"Unloading model {key}") + handle.model.unload() + del self._handles[key] + + if key in self._access_order: + self._access_order.remove(key) + + self.gpu_manager.force_cleanup() + + def unload_model(self, model_id: str, variant: Optional[str] = None) -> bool: + """Manually unload a model. + + Args: + model_id: Model family ID + variant: Variant to unload, or None for all variants + + Returns: + True if model was unloaded + """ + with self._lock: + if variant: + key = f"{model_id}/{variant}" + if key in self._handles: + self._unload_model(key) + return True + else: + # Unload all variants of this model + keys = [k for k in self._handles if k.startswith(f"{model_id}/")] + for key in keys: + self._unload_model(key) + return bool(keys) + return False + + def preload_model(self, model_id: str, variant: Optional[str] = None) -> bool: + """Preload a model into memory. + + Args: + model_id: Model family ID + variant: Variant to load + + Returns: + True if model was loaded successfully + """ + family, _ = self.get_model_config(model_id, variant) + variant = variant or family.default_variant + key = f"{model_id}/{variant}" + + with self._lock: + if key in self._handles: + return True # Already loaded + + try: + self._load_model(model_id, variant) + return True + except Exception as e: + logger.error(f"Failed to preload {key}: {e}") + return False + + def start_cleanup_thread(self) -> None: + """Start background thread for idle model cleanup.""" + if self._cleanup_thread is not None: + return + + def cleanup_loop(): + while not self._stop_cleanup.is_set(): + self._cleanup_idle_models() + self._stop_cleanup.wait(60) # Check every minute + + self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True) + self._cleanup_thread.start() + logger.info("Started model cleanup thread") + + def stop_cleanup_thread(self) -> None: + """Stop the background cleanup thread.""" + if self._cleanup_thread is not None: + self._stop_cleanup.set() + self._cleanup_thread.join(timeout=5) + self._cleanup_thread = None + self._stop_cleanup.clear() + + def _cleanup_idle_models(self) -> None: + """Unload models that have been idle too long.""" + with self._lock: + now = time.time() + for key, handle in list(self._handles.items()): + idle_time = now - handle.last_accessed + if idle_time > self.idle_timeout_seconds and handle.ref_count == 0: + logger.info( + f"Unloading idle model {key} (idle for {idle_time/60:.1f} min)" + ) + self._unload_model(key) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters. + + Returns: + Dictionary of default parameter values + """ + return self._default_params.copy() + + def __del__(self) -> None: + """Cleanup on destruction.""" + self.stop_cleanup_thread() diff --git a/src/core/oom_handler.py b/src/core/oom_handler.py new file mode 100644 index 0000000..e85a1a4 --- /dev/null +++ b/src/core/oom_handler.py @@ -0,0 +1,297 @@ +"""OOM (Out of Memory) handling and recovery strategies.""" + +import functools +import gc +import logging +import time +from typing import Any, Callable, Optional, ParamSpec, TypeVar + +import torch + +from src.core.gpu_manager import GPUMemoryManager + +logger = logging.getLogger(__name__) + +P = ParamSpec("P") +R = TypeVar("R") + + +class OOMRecoveryError(Exception): + """Raised when OOM recovery fails after all strategies exhausted.""" + + pass + + +class OOMHandler: + """Handles CUDA Out of Memory errors with multi-level recovery strategies. + + Recovery levels: + 1. Clear PyTorch CUDA cache + 2. Evict unused models from registry + 3. Request ComfyUI to yield VRAM + """ + + def __init__( + self, + gpu_manager: GPUMemoryManager, + model_registry: Optional[Any] = None, # Avoid circular import + max_retries: int = 3, + retry_delay: float = 0.5, + ): + """Initialize OOM handler. + + Args: + gpu_manager: GPU memory manager instance + model_registry: Optional model registry for eviction + max_retries: Maximum recovery attempts + retry_delay: Delay between retries in seconds + """ + self.gpu_manager = gpu_manager + self.model_registry = model_registry + self.max_retries = max_retries + self.retry_delay = retry_delay + + # Track OOM events for monitoring + self._oom_count = 0 + self._last_oom_time: Optional[float] = None + + @property + def oom_count(self) -> int: + """Number of OOM events handled.""" + return self._oom_count + + def set_model_registry(self, registry: Any) -> None: + """Set model registry (to avoid circular import at init time).""" + self.model_registry = registry + + def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]: + """Decorator that wraps function with OOM recovery logic. + + Usage: + @oom_handler.with_oom_recovery + def generate_audio(...): + ... + + Args: + func: Function to wrap + + Returns: + Wrapped function with OOM recovery + """ + + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + last_exception = None + + for attempt in range(self.max_retries + 1): + try: + if attempt > 0: + logger.info(f"Retry attempt {attempt}/{self.max_retries}") + time.sleep(self.retry_delay) + + return func(*args, **kwargs) + + except torch.cuda.OutOfMemoryError as e: + last_exception = e + self._oom_count += 1 + self._last_oom_time = time.time() + + logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}") + + if attempt < self.max_retries: + self._execute_recovery_strategy(attempt) + else: + logger.error( + f"OOM recovery failed after {self.max_retries} attempts" + ) + + raise OOMRecoveryError( + f"OOM recovery failed after {self.max_retries} attempts" + ) from last_exception + + return wrapper + + def _execute_recovery_strategy(self, level: int) -> None: + """Execute recovery strategy based on severity level. + + Args: + level: Recovery level (0-2) + """ + strategies = [ + self._strategy_clear_cache, + self._strategy_evict_models, + self._strategy_request_comfyui_yield, + ] + + # Execute all strategies up to and including current level + for i in range(min(level + 1, len(strategies))): + logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}") + strategies[i]() + + def _strategy_clear_cache(self) -> None: + """Level 1: Clear PyTorch CUDA cache. + + This is the fastest and least disruptive recovery strategy. + Clears cached memory that PyTorch holds for future allocations. + """ + logger.info("Clearing CUDA cache...") + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Reset peak memory stats for monitoring + torch.cuda.reset_peak_memory_stats() + + freed = self.gpu_manager.force_cleanup() + logger.info(f"Cache cleared, freed approximately {freed}MB") + + def _strategy_evict_models(self) -> None: + """Level 2: Evict non-essential models from registry. + + Unloads all models that don't have active references, + freeing their VRAM for the current operation. + """ + if self.model_registry is None: + logger.warning("No model registry available for eviction") + self._strategy_clear_cache() + return + + logger.info("Evicting unused models...") + + # Get list of loaded models + loaded = self.model_registry.get_loaded_models() + evicted = [] + + for model_info in loaded: + # Only evict models with no active references + if model_info["ref_count"] == 0: + model_id = model_info["model_id"] + variant = model_info["variant"] + logger.info(f"Evicting {model_id}/{variant}") + self.model_registry.unload_model(model_id, variant) + evicted.append(f"{model_id}/{variant}") + + # Clear cache after eviction + self._strategy_clear_cache() + + logger.info(f"Evicted {len(evicted)} model(s): {evicted}") + + def _strategy_request_comfyui_yield(self) -> None: + """Level 3: Request ComfyUI to yield VRAM. + + Uses the coordination protocol to ask ComfyUI to + temporarily release GPU memory. + """ + logger.info("Requesting ComfyUI to yield VRAM...") + + # First, evict our own models + self._strategy_evict_models() + + # Calculate how much VRAM we need + budget = self.gpu_manager.get_available_budget() + needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total + + # Request priority from ComfyUI + success = self.gpu_manager.request_priority(needed, timeout=15.0) + + if success: + logger.info("ComfyUI yielded VRAM successfully") + else: + logger.warning("ComfyUI did not yield VRAM within timeout") + + # Final cache clear + self._strategy_clear_cache() + + def recover_from_oom(self, level: int = 0) -> bool: + """Manually trigger OOM recovery. + + Args: + level: Recovery level to execute (0-2) + + Returns: + True if recovery was successful (memory was freed) + """ + before = self.gpu_manager.get_memory_info() + + self._execute_recovery_strategy(level) + + after = self.gpu_manager.get_memory_info() + freed = before["used"] - after["used"] + + logger.info(f"Manual recovery freed {freed}MB") + return freed > 0 + + def check_memory_for_operation(self, required_mb: int) -> bool: + """Check if there's enough memory for an operation. + + If not enough, attempts recovery strategies. + + Args: + required_mb: Memory required in megabytes + + Returns: + True if enough memory is available (possibly after recovery) + """ + budget = self.gpu_manager.get_available_budget() + + if budget.available_mb >= required_mb: + return True + + logger.info( + f"Need {required_mb}MB but only {budget.available_mb}MB available. " + "Attempting recovery..." + ) + + # Try progressively more aggressive recovery + for level in range(3): + self._execute_recovery_strategy(level) + budget = self.gpu_manager.get_available_budget() + + if budget.available_mb >= required_mb: + logger.info(f"Recovery successful at level {level + 1}") + return True + + logger.error( + f"Could not free enough memory. Need {required_mb}MB, " + f"have {budget.available_mb}MB" + ) + return False + + def get_stats(self) -> dict[str, Any]: + """Get OOM handling statistics. + + Returns: + Dictionary with OOM stats + """ + return { + "oom_count": self._oom_count, + "last_oom_time": self._last_oom_time, + "max_retries": self.max_retries, + "has_registry": self.model_registry is not None, + } + + +# Module-level convenience function +def oom_safe( + gpu_manager: GPUMemoryManager, + model_registry: Optional[Any] = None, + max_retries: int = 3, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """Decorator factory for OOM-safe functions. + + Usage: + @oom_safe(gpu_manager, model_registry) + def generate_audio(...): + ... + + Args: + gpu_manager: GPU memory manager + model_registry: Optional model registry for eviction + max_retries: Maximum recovery attempts + + Returns: + Decorator function + """ + handler = OOMHandler(gpu_manager, model_registry, max_retries) + return handler.with_oom_recovery diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..6575eaf --- /dev/null +++ b/src/main.py @@ -0,0 +1,84 @@ +"""AudioCraft Studio - Main Application Entry Point.""" + +import asyncio +import logging +import sys +from pathlib import Path + +# Add project root to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from config.settings import get_settings +from src.core.gpu_manager import GPUMemoryManager +from src.core.model_registry import ModelRegistry +from src.storage.database import Database + +logger = logging.getLogger(__name__) + + +async def init_app(): + """Initialize application components.""" + settings = get_settings() + + # Configure logging + logging.basicConfig( + level=getattr(logging, settings.log_level), + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + # Ensure directories exist + settings.ensure_directories() + + # Initialize GPU manager + gpu_manager = GPUMemoryManager( + comfyui_reserve_gb=settings.comfyui_reserve_gb, + safety_buffer_gb=settings.safety_buffer_gb, + ) + + # Initialize model registry + registry = ModelRegistry( + config_path=settings.models_config, + gpu_manager=gpu_manager, + max_cached_models=settings.max_cached_models, + idle_timeout_minutes=settings.idle_unload_minutes, + ) + + # Initialize database + db = Database(settings.database_path) + await db.connect() + + logger.info("AudioCraft Studio initialized") + logger.info(f"GPU Status: {gpu_manager.get_status()}") + logger.info(f"Available models: {len(registry.list_models())}") + + return { + "settings": settings, + "gpu_manager": gpu_manager, + "registry": registry, + "database": db, + } + + +def main(): + """Main entry point.""" + print("AudioCraft Studio - Starting...") + print("Phase 1 core infrastructure is complete.") + print("\nTo continue implementation:") + print(" - Phase 2: Model adapters (musicgen, audiogen, magnet, style, jasco)") + print(" - Phase 3: Services layer (generation, batch, project)") + print(" - Phase 4: Gradio UI") + print(" - Phase 5: REST API") + print(" - Phase 6: Deployment") + + # Quick initialization test + async def test_init(): + components = await init_app() + print(f"\nDatabase path: {components['settings'].database_path}") + print(f"GPU status: {components['gpu_manager'].get_status()}") + await components["database"].close() + + asyncio.run(test_init()) + + +if __name__ == "__main__": + main() diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..f92b0c8 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,32 @@ +"""AudioCraft model adapters. + +This module contains adapters that wrap AudioCraft's models with a +consistent interface for the application. +""" + +from src.models.musicgen.adapter import MusicGenAdapter +from src.models.audiogen.adapter import AudioGenAdapter +from src.models.magnet.adapter import MAGNeTAdapter +from src.models.musicgen_style.adapter import MusicGenStyleAdapter +from src.models.jasco.adapter import JASCOAdapter + +__all__ = [ + "MusicGenAdapter", + "AudioGenAdapter", + "MAGNeTAdapter", + "MusicGenStyleAdapter", + "JASCOAdapter", +] + + +def register_all_adapters(registry) -> None: + """Register all model adapters with the registry. + + Args: + registry: ModelRegistry instance to register adapters with + """ + registry.register_adapter("musicgen", MusicGenAdapter) + registry.register_adapter("audiogen", AudioGenAdapter) + registry.register_adapter("magnet", MAGNeTAdapter) + registry.register_adapter("musicgen-style", MusicGenStyleAdapter) + registry.register_adapter("jasco", JASCOAdapter) diff --git a/src/models/audiogen/__init__.py b/src/models/audiogen/__init__.py new file mode 100644 index 0000000..7ae1261 --- /dev/null +++ b/src/models/audiogen/__init__.py @@ -0,0 +1,5 @@ +"""AudioGen model adapter.""" + +from src.models.audiogen.adapter import AudioGenAdapter + +__all__ = ["AudioGenAdapter"] diff --git a/src/models/audiogen/adapter.py b/src/models/audiogen/adapter.py new file mode 100644 index 0000000..0a88e50 --- /dev/null +++ b/src/models/audiogen/adapter.py @@ -0,0 +1,203 @@ +"""AudioGen model adapter for text-to-sound effects generation.""" + +import gc +import logging +import random +from typing import Any, Optional + +import torch + +from src.core.base_model import ( + BaseAudioModel, + ConditioningType, + GenerationRequest, + GenerationResult, +) + +logger = logging.getLogger(__name__) + + +class AudioGenAdapter(BaseAudioModel): + """Adapter for Facebook's AudioGen model. + + Generates sound effects and environmental audio from text descriptions. + Optimized for non-musical audio like sound effects, ambiences, and foley. + """ + + VARIANTS = { + "medium": { + "hf_id": "facebook/audiogen-medium", + "vram_mb": 5000, + "max_duration": 10, + "channels": 1, + }, + } + + def __init__(self, variant: str = "medium"): + """Initialize AudioGen adapter. + + Args: + variant: Model variant (currently only 'medium' available) + """ + if variant not in self.VARIANTS: + raise ValueError( + f"Unknown AudioGen variant: {variant}. " + f"Available: {list(self.VARIANTS.keys())}" + ) + + self._variant = variant + self._config = self.VARIANTS[variant] + self._model = None + self._device: Optional[torch.device] = None + + @property + def model_id(self) -> str: + return "audiogen" + + @property + def variant(self) -> str: + return self._variant + + @property + def display_name(self) -> str: + return f"AudioGen ({self._variant})" + + @property + def description(self) -> str: + return "Text-to-sound effects generation" + + @property + def vram_estimate_mb(self) -> int: + return self._config["vram_mb"] + + @property + def max_duration(self) -> float: + return self._config["max_duration"] + + @property + def sample_rate(self) -> int: + if self._model is not None: + return self._model.sample_rate + return 16000 # AudioGen default sample rate + + @property + def supports_conditioning(self) -> list[ConditioningType]: + return [ConditioningType.TEXT] + + @property + def is_loaded(self) -> bool: + return self._model is not None + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def load(self, device: str = "cuda") -> None: + """Load the AudioGen model.""" + if self._model is not None: + logger.warning(f"AudioGen {self._variant} already loaded") + return + + logger.info(f"Loading AudioGen {self._variant} from {self._config['hf_id']}...") + + try: + from audiocraft.models import AudioGen + + self._device = torch.device(device) + self._model = AudioGen.get_pretrained(self._config["hf_id"]) + self._model.to(self._device) + + logger.info( + f"AudioGen {self._variant} loaded successfully " + f"(sample_rate={self._model.sample_rate})" + ) + + except Exception as e: + self._model = None + self._device = None + logger.error(f"Failed to load AudioGen {self._variant}: {e}") + raise RuntimeError(f"Failed to load AudioGen: {e}") from e + + def unload(self) -> None: + """Unload the model and free memory.""" + if self._model is None: + return + + logger.info(f"Unloading AudioGen {self._variant}...") + + del self._model + self._model = None + self._device = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate sound effects from text prompts. + + Args: + request: Generation parameters including prompts + + Returns: + GenerationResult with audio tensor and metadata + """ + self.validate_request(request) + + # Set random seed + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Configure generation + self._model.set_generation_params( + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + ) + + logger.info( + f"Generating {len(request.prompts)} sound effect(s) with AudioGen " + f"(duration={request.duration}s)" + ) + + # Generate audio + with torch.inference_mode(): + audio = self._model.generate(request.prompts) + + actual_duration = audio.shape[-1] / self.sample_rate + + logger.info( + f"Generated {audio.shape[0]} sample(s), " + f"duration={actual_duration:.2f}s" + ) + + return GenerationResult( + audio=audio.cpu(), + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + }, + seed=seed, + ) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters.""" + return { + "duration": 5.0, + "temperature": 1.0, + "top_k": 250, + "top_p": 0.0, + "cfg_coef": 3.0, + } diff --git a/src/models/jasco/__init__.py b/src/models/jasco/__init__.py new file mode 100644 index 0000000..58a286a --- /dev/null +++ b/src/models/jasco/__init__.py @@ -0,0 +1,5 @@ +"""JASCO model adapter.""" + +from src.models.jasco.adapter import JASCOAdapter + +__all__ = ["JASCOAdapter"] diff --git a/src/models/jasco/adapter.py b/src/models/jasco/adapter.py new file mode 100644 index 0000000..1e739ca --- /dev/null +++ b/src/models/jasco/adapter.py @@ -0,0 +1,348 @@ +"""JASCO model adapter for chord and drum-conditioned music generation.""" + +import gc +import logging +import random +from typing import Any, Optional + +import torch + +from src.core.base_model import ( + BaseAudioModel, + ConditioningType, + GenerationRequest, + GenerationResult, +) + +logger = logging.getLogger(__name__) + + +class JASCOAdapter(BaseAudioModel): + """Adapter for Facebook's JASCO model. + + JASCO (Joint Audio and Symbolic Conditioning) enables music generation + with control over chord progressions and drum patterns alongside text. + """ + + VARIANTS = { + "chords-drums-400M": { + "hf_id": "facebook/jasco-chords-drums-400M", + "vram_mb": 2000, + "max_duration": 10, + "channels": 1, + }, + "chords-drums-1B": { + "hf_id": "facebook/jasco-chords-drums-1B", + "vram_mb": 4000, + "max_duration": 10, + "channels": 1, + }, + } + + # Common chord types for validation + VALID_CHORD_TYPES = [ + "maj", "min", "dim", "aug", "7", "maj7", "min7", "dim7", + "sus2", "sus4", "add9", "6", "min6", "9", "min9", "maj9", + ] + + def __init__(self, variant: str = "chords-drums-400M"): + """Initialize JASCO adapter. + + Args: + variant: Model variant to use + """ + if variant not in self.VARIANTS: + raise ValueError( + f"Unknown JASCO variant: {variant}. " + f"Available: {list(self.VARIANTS.keys())}" + ) + + self._variant = variant + self._config = self.VARIANTS[variant] + self._model = None + self._device: Optional[torch.device] = None + + @property + def model_id(self) -> str: + return "jasco" + + @property + def variant(self) -> str: + return self._variant + + @property + def display_name(self) -> str: + return f"JASCO ({self._variant})" + + @property + def description(self) -> str: + return "Chord and drum-conditioned music generation" + + @property + def vram_estimate_mb(self) -> int: + return self._config["vram_mb"] + + @property + def max_duration(self) -> float: + return self._config["max_duration"] + + @property + def sample_rate(self) -> int: + if self._model is not None: + return self._model.sample_rate + return 32000 + + @property + def supports_conditioning(self) -> list[ConditioningType]: + return [ConditioningType.TEXT, ConditioningType.CHORDS, ConditioningType.DRUMS] + + @property + def is_loaded(self) -> bool: + return self._model is not None + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def load(self, device: str = "cuda") -> None: + """Load the JASCO model.""" + if self._model is not None: + logger.warning(f"JASCO {self._variant} already loaded") + return + + logger.info(f"Loading JASCO {self._variant} from {self._config['hf_id']}...") + + try: + from audiocraft.models import JASCO + + self._device = torch.device(device) + self._model = JASCO.get_pretrained(self._config["hf_id"]) + self._model.to(self._device) + + logger.info( + f"JASCO {self._variant} loaded successfully " + f"(sample_rate={self._model.sample_rate})" + ) + + except Exception as e: + self._model = None + self._device = None + logger.error(f"Failed to load JASCO {self._variant}: {e}") + raise RuntimeError(f"Failed to load JASCO: {e}") from e + + def unload(self) -> None: + """Unload the model and free memory.""" + if self._model is None: + return + + logger.info(f"Unloading JASCO {self._variant}...") + + del self._model + self._model = None + self._device = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @staticmethod + def parse_chord_progression( + chords: list[dict[str, Any]], duration: float + ) -> list[tuple[float, float, str]]: + """Parse chord progression from user input format. + + Args: + chords: List of chord dictionaries with keys: + - time: Start time in seconds + - chord: Chord name (e.g., "C", "Am", "G7") + duration: Total duration for calculating end times + + Returns: + List of (start_time, end_time, chord_name) tuples + + Example input: + [ + {"time": 0.0, "chord": "C"}, + {"time": 2.0, "chord": "Am"}, + {"time": 4.0, "chord": "F"}, + {"time": 6.0, "chord": "G"}, + ] + """ + if not chords: + return [] + + # Sort by time + sorted_chords = sorted(chords, key=lambda x: x["time"]) + + # Build (start, end, chord) tuples + result = [] + for i, chord_info in enumerate(sorted_chords): + start = chord_info["time"] + # End time is either next chord's start or total duration + if i + 1 < len(sorted_chords): + end = sorted_chords[i + 1]["time"] + else: + end = duration + result.append((start, end, chord_info["chord"])) + + return result + + @staticmethod + def create_drum_pattern( + pattern: str, duration: float, bpm: float = 120.0 + ) -> list[tuple[float, str]]: + """Create drum events from a pattern string. + + Args: + pattern: Pattern string (e.g., "kick,snare,kick,snare") + or "4/4" for common time signature + duration: Total duration in seconds + bpm: Beats per minute + + Returns: + List of (time, drum_type) tuples + """ + beat_duration = 60.0 / bpm + events = [] + + if pattern in ["4/4", "common"]: + # Standard 4/4 rock pattern + time = 0.0 + beat = 0 + while time < duration: + if beat % 4 == 0: + events.append((time, "kick")) + elif beat % 4 == 2: + events.append((time, "snare")) + if beat % 2 == 0: + events.append((time, "hihat")) + time += beat_duration / 2 + beat += 1 + else: + # Parse comma-separated pattern + drum_types = pattern.split(",") + time = 0.0 + idx = 0 + while time < duration: + drum = drum_types[idx % len(drum_types)].strip() + if drum: + events.append((time, drum)) + time += beat_duration + idx += 1 + + return events + + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate music with chord and drum conditioning. + + Args: + request: Generation parameters with optional conditioning: + - chords: List of {"time": float, "chord": str} dicts + - drums: Drum pattern string or list of (time, drum_type) + - bpm: Beats per minute for drum pattern + + Returns: + GenerationResult with audio tensor and metadata + """ + self.validate_request(request) + + # Set random seed + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Configure generation parameters + self._model.set_generation_params( + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + ) + + # Process chord conditioning + chords_input = request.conditioning.get("chords") + chords_formatted = None + if chords_input: + if isinstance(chords_input, list) and len(chords_input) > 0: + if isinstance(chords_input[0], dict): + chords_formatted = self.parse_chord_progression( + chords_input, request.duration + ) + else: + # Already in (start, end, chord) format + chords_formatted = chords_input + + # Process drum conditioning + drums_input = request.conditioning.get("drums") + bpm = request.conditioning.get("bpm", 120.0) + drums_formatted = None + if drums_input: + if isinstance(drums_input, str): + drums_formatted = self.create_drum_pattern( + drums_input, request.duration, bpm + ) + else: + drums_formatted = drums_input + + logger.info( + f"Generating {len(request.prompts)} sample(s) with JASCO " + f"(duration={request.duration}s, chords={chords_formatted is not None}, " + f"drums={drums_formatted is not None})" + ) + + with torch.inference_mode(): + # Build conditioning dict for JASCO + conditioning = {} + if chords_formatted: + conditioning["chords"] = chords_formatted + if drums_formatted: + conditioning["drums"] = drums_formatted + + if conditioning: + audio = self._model.generate( + descriptions=request.prompts, + **conditioning, + ) + else: + # Generate without symbolic conditioning + audio = self._model.generate(request.prompts) + + actual_duration = audio.shape[-1] / self.sample_rate + + logger.info( + f"Generated {audio.shape[0]} sample(s), " + f"duration={actual_duration:.2f}s" + ) + + return GenerationResult( + audio=audio.cpu(), + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + "chords": chords_formatted, + "drums": drums_formatted, + "bpm": bpm, + }, + seed=seed, + ) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters for JASCO.""" + return { + "duration": 10.0, + "temperature": 1.0, + "top_k": 250, + "top_p": 0.0, + "cfg_coef": 3.0, + "bpm": 120.0, + } diff --git a/src/models/magnet/__init__.py b/src/models/magnet/__init__.py new file mode 100644 index 0000000..99906d3 --- /dev/null +++ b/src/models/magnet/__init__.py @@ -0,0 +1,5 @@ +"""MAGNeT model adapter.""" + +from src.models.magnet.adapter import MAGNeTAdapter + +__all__ = ["MAGNeTAdapter"] diff --git a/src/models/magnet/adapter.py b/src/models/magnet/adapter.py new file mode 100644 index 0000000..98f2f2e --- /dev/null +++ b/src/models/magnet/adapter.py @@ -0,0 +1,253 @@ +"""MAGNeT model adapter for fast non-autoregressive audio generation.""" + +import gc +import logging +import random +from typing import Any, Optional + +import torch + +from src.core.base_model import ( + BaseAudioModel, + ConditioningType, + GenerationRequest, + GenerationResult, +) + +logger = logging.getLogger(__name__) + + +class MAGNeTAdapter(BaseAudioModel): + """Adapter for Facebook's MAGNeT model. + + MAGNeT (Masked Audio Generation using Non-autoregressive Transformers) + provides faster generation than autoregressive models like MusicGen. + Supports both music and sound effect generation. + """ + + VARIANTS = { + "small-10secs": { + "hf_id": "facebook/magnet-small-10secs", + "vram_mb": 1500, + "max_duration": 10, + "channels": 1, + "audio_type": "music", + }, + "medium-10secs": { + "hf_id": "facebook/magnet-medium-10secs", + "vram_mb": 5000, + "max_duration": 10, + "channels": 1, + "audio_type": "music", + }, + "small-30secs": { + "hf_id": "facebook/magnet-small-30secs", + "vram_mb": 1800, + "max_duration": 30, + "channels": 1, + "audio_type": "music", + }, + "medium-30secs": { + "hf_id": "facebook/magnet-medium-30secs", + "vram_mb": 6000, + "max_duration": 30, + "channels": 1, + "audio_type": "music", + }, + "audio-small-10secs": { + "hf_id": "facebook/audio-magnet-small", + "vram_mb": 1500, + "max_duration": 10, + "channels": 1, + "audio_type": "sound", + }, + "audio-medium-10secs": { + "hf_id": "facebook/audio-magnet-medium", + "vram_mb": 5000, + "max_duration": 10, + "channels": 1, + "audio_type": "sound", + }, + } + + def __init__(self, variant: str = "medium-10secs"): + """Initialize MAGNeT adapter. + + Args: + variant: Model variant to use + """ + if variant not in self.VARIANTS: + raise ValueError( + f"Unknown MAGNeT variant: {variant}. " + f"Available: {list(self.VARIANTS.keys())}" + ) + + self._variant = variant + self._config = self.VARIANTS[variant] + self._model = None + self._device: Optional[torch.device] = None + + @property + def model_id(self) -> str: + return "magnet" + + @property + def variant(self) -> str: + return self._variant + + @property + def display_name(self) -> str: + return f"MAGNeT ({self._variant})" + + @property + def description(self) -> str: + audio_type = self._config.get("audio_type", "music") + return f"Fast non-autoregressive {audio_type} generation" + + @property + def vram_estimate_mb(self) -> int: + return self._config["vram_mb"] + + @property + def max_duration(self) -> float: + return self._config["max_duration"] + + @property + def sample_rate(self) -> int: + if self._model is not None: + return self._model.sample_rate + return 32000 + + @property + def supports_conditioning(self) -> list[ConditioningType]: + return [ConditioningType.TEXT] + + @property + def is_loaded(self) -> bool: + return self._model is not None + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def load(self, device: str = "cuda") -> None: + """Load the MAGNeT model.""" + if self._model is not None: + logger.warning(f"MAGNeT {self._variant} already loaded") + return + + logger.info(f"Loading MAGNeT {self._variant} from {self._config['hf_id']}...") + + try: + from audiocraft.models import MAGNeT + + self._device = torch.device(device) + self._model = MAGNeT.get_pretrained(self._config["hf_id"]) + self._model.to(self._device) + + logger.info( + f"MAGNeT {self._variant} loaded successfully " + f"(sample_rate={self._model.sample_rate})" + ) + + except Exception as e: + self._model = None + self._device = None + logger.error(f"Failed to load MAGNeT {self._variant}: {e}") + raise RuntimeError(f"Failed to load MAGNeT: {e}") from e + + def unload(self) -> None: + """Unload the model and free memory.""" + if self._model is None: + return + + logger.info(f"Unloading MAGNeT {self._variant}...") + + del self._model + self._model = None + self._device = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate audio from text prompts using MAGNeT. + + MAGNeT uses a non-autoregressive approach with iterative decoding, + which is significantly faster than autoregressive models. + + Args: + request: Generation parameters including prompts + + Returns: + GenerationResult with audio tensor and metadata + """ + self.validate_request(request) + + # Set random seed + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Configure generation parameters + # MAGNeT has different parameters than MusicGen + self._model.set_generation_params( + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + # MAGNeT-specific parameters + decoding_steps=[ + int(request.conditioning.get("decoding_steps_1", 20)), + int(request.conditioning.get("decoding_steps_2", 10)), + int(request.conditioning.get("decoding_steps_3", 10)), + int(request.conditioning.get("decoding_steps_4", 10)), + ], + span_arrangement=request.conditioning.get("span_arrangement", "nonoverlap"), + ) + + logger.info( + f"Generating {len(request.prompts)} sample(s) with MAGNeT {self._variant} " + f"(duration={request.duration}s)" + ) + + # Generate audio + with torch.inference_mode(): + audio = self._model.generate(request.prompts) + + actual_duration = audio.shape[-1] / self.sample_rate + + logger.info( + f"Generated {audio.shape[0]} sample(s), " + f"duration={actual_duration:.2f}s" + ) + + return GenerationResult( + audio=audio.cpu(), + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + }, + seed=seed, + ) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters for MAGNeT.""" + return { + "duration": 10.0, + "temperature": 3.0, # MAGNeT works better with higher temperature + "top_k": 0, # Use top_p instead for MAGNeT + "top_p": 0.9, + "cfg_coef": 3.0, + } diff --git a/src/models/musicgen/__init__.py b/src/models/musicgen/__init__.py new file mode 100644 index 0000000..0c1176a --- /dev/null +++ b/src/models/musicgen/__init__.py @@ -0,0 +1,5 @@ +"""MusicGen model adapter.""" + +from src.models.musicgen.adapter import MusicGenAdapter + +__all__ = ["MusicGenAdapter"] diff --git a/src/models/musicgen/adapter.py b/src/models/musicgen/adapter.py new file mode 100644 index 0000000..4998613 --- /dev/null +++ b/src/models/musicgen/adapter.py @@ -0,0 +1,290 @@ +"""MusicGen model adapter for text-to-music generation.""" + +import gc +import logging +import random +from typing import Any, Optional + +import torch + +from src.core.base_model import ( + BaseAudioModel, + ConditioningType, + GenerationRequest, + GenerationResult, +) + +logger = logging.getLogger(__name__) + + +class MusicGenAdapter(BaseAudioModel): + """Adapter for Facebook's MusicGen model. + + Supports text-to-music generation with optional melody conditioning. + Available variants: small, medium, large, melody, and stereo versions. + """ + + # Variant configurations + VARIANTS = { + "small": { + "hf_id": "facebook/musicgen-small", + "vram_mb": 1500, + "max_duration": 30, + "channels": 1, + "conditioning": [], + }, + "medium": { + "hf_id": "facebook/musicgen-medium", + "vram_mb": 5000, + "max_duration": 30, + "channels": 1, + "conditioning": [], + }, + "large": { + "hf_id": "facebook/musicgen-large", + "vram_mb": 10000, + "max_duration": 30, + "channels": 1, + "conditioning": [], + }, + "melody": { + "hf_id": "facebook/musicgen-melody", + "vram_mb": 5000, + "max_duration": 30, + "channels": 1, + "conditioning": [ConditioningType.MELODY], + }, + "stereo-small": { + "hf_id": "facebook/musicgen-stereo-small", + "vram_mb": 1800, + "max_duration": 30, + "channels": 2, + "conditioning": [], + }, + "stereo-medium": { + "hf_id": "facebook/musicgen-stereo-medium", + "vram_mb": 6000, + "max_duration": 30, + "channels": 2, + "conditioning": [], + }, + "stereo-large": { + "hf_id": "facebook/musicgen-stereo-large", + "vram_mb": 12000, + "max_duration": 30, + "channels": 2, + "conditioning": [], + }, + "stereo-melody": { + "hf_id": "facebook/musicgen-stereo-melody", + "vram_mb": 6000, + "max_duration": 30, + "channels": 2, + "conditioning": [ConditioningType.MELODY], + }, + } + + def __init__(self, variant: str = "medium"): + """Initialize MusicGen adapter. + + Args: + variant: Model variant to use (small, medium, large, melody, etc.) + + Raises: + ValueError: If variant is not recognized + """ + if variant not in self.VARIANTS: + raise ValueError( + f"Unknown MusicGen variant: {variant}. " + f"Available: {list(self.VARIANTS.keys())}" + ) + + self._variant = variant + self._config = self.VARIANTS[variant] + self._model = None + self._device: Optional[torch.device] = None + + @property + def model_id(self) -> str: + return "musicgen" + + @property + def variant(self) -> str: + return self._variant + + @property + def display_name(self) -> str: + return f"MusicGen ({self._variant})" + + @property + def description(self) -> str: + if "melody" in self._variant: + return "Text-to-music with melody conditioning" + elif "stereo" in self._variant: + return "Stereo text-to-music generation" + return "Text-to-music generation" + + @property + def vram_estimate_mb(self) -> int: + return self._config["vram_mb"] + + @property + def max_duration(self) -> float: + return self._config["max_duration"] + + @property + def sample_rate(self) -> int: + if self._model is not None: + return self._model.sample_rate + return 32000 # Default MusicGen sample rate + + @property + def supports_conditioning(self) -> list[ConditioningType]: + return [ConditioningType.TEXT] + self._config["conditioning"] + + @property + def is_loaded(self) -> bool: + return self._model is not None + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def load(self, device: str = "cuda") -> None: + """Load the MusicGen model. + + Args: + device: Target device ('cuda', 'cuda:0', 'cpu', etc.) + """ + if self._model is not None: + logger.warning(f"MusicGen {self._variant} already loaded") + return + + logger.info(f"Loading MusicGen {self._variant} from {self._config['hf_id']}...") + + try: + from audiocraft.models import MusicGen + + self._device = torch.device(device) + self._model = MusicGen.get_pretrained(self._config["hf_id"]) + self._model.to(self._device) + + logger.info( + f"MusicGen {self._variant} loaded successfully " + f"(sample_rate={self._model.sample_rate})" + ) + + except Exception as e: + self._model = None + self._device = None + logger.error(f"Failed to load MusicGen {self._variant}: {e}") + raise RuntimeError(f"Failed to load MusicGen: {e}") from e + + def unload(self) -> None: + """Unload the model and free memory.""" + if self._model is None: + return + + logger.info(f"Unloading MusicGen {self._variant}...") + + del self._model + self._model = None + self._device = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate music from text prompts. + + Args: + request: Generation parameters including prompts + + Returns: + GenerationResult with audio tensor and metadata + + Raises: + RuntimeError: If model not loaded + ValueError: If request is invalid + """ + self.validate_request(request) + + # Set random seed for reproducibility + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Configure generation parameters + self._model.set_generation_params( + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + ) + + logger.info( + f"Generating {len(request.prompts)} sample(s) with MusicGen {self._variant} " + f"(duration={request.duration}s, temp={request.temperature})" + ) + + # Generate audio + with torch.inference_mode(): + melody_audio = request.conditioning.get("melody") + melody_sr = request.conditioning.get("melody_sr", self.sample_rate) + + if melody_audio is not None and ConditioningType.MELODY in self.supports_conditioning: + # Melody-conditioned generation + if isinstance(melody_audio, str): + # Load from file path + import torchaudio + melody_tensor, melody_sr = torchaudio.load(melody_audio) + melody_tensor = melody_tensor.to(self._device) + else: + melody_tensor = torch.tensor(melody_audio).to(self._device) + + audio = self._model.generate_with_chroma( + descriptions=request.prompts, + melody_wavs=melody_tensor.unsqueeze(0) if melody_tensor.dim() == 1 else melody_tensor, + melody_sample_rate=melody_sr, + ) + else: + # Standard text-to-music generation + audio = self._model.generate(request.prompts) + + # audio shape: [batch, channels, samples] + actual_duration = audio.shape[-1] / self.sample_rate + + logger.info( + f"Generated {audio.shape[0]} sample(s), " + f"duration={actual_duration:.2f}s, shape={audio.shape}" + ) + + return GenerationResult( + audio=audio.cpu(), + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + }, + seed=seed, + ) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters.""" + return { + "duration": 10.0, + "temperature": 1.0, + "top_k": 250, + "top_p": 0.0, + "cfg_coef": 3.0, + } diff --git a/src/models/musicgen_style/__init__.py b/src/models/musicgen_style/__init__.py new file mode 100644 index 0000000..e772bba --- /dev/null +++ b/src/models/musicgen_style/__init__.py @@ -0,0 +1,5 @@ +"""MusicGen Style model adapter.""" + +from src.models.musicgen_style.adapter import MusicGenStyleAdapter + +__all__ = ["MusicGenStyleAdapter"] diff --git a/src/models/musicgen_style/adapter.py b/src/models/musicgen_style/adapter.py new file mode 100644 index 0000000..7926ec2 --- /dev/null +++ b/src/models/musicgen_style/adapter.py @@ -0,0 +1,277 @@ +"""MusicGen Style model adapter for style-conditioned music generation.""" + +import gc +import logging +import random +from typing import Any, Optional + +import torch +import torchaudio + +from src.core.base_model import ( + BaseAudioModel, + ConditioningType, + GenerationRequest, + GenerationResult, +) + +logger = logging.getLogger(__name__) + + +class MusicGenStyleAdapter(BaseAudioModel): + """Adapter for Facebook's MusicGen Style model. + + Generates music conditioned on both text and a style reference audio. + Extracts style features from the reference and applies them to new generations. + """ + + VARIANTS = { + "medium": { + "hf_id": "facebook/musicgen-style", + "vram_mb": 5000, + "max_duration": 30, + "channels": 1, + }, + } + + def __init__(self, variant: str = "medium"): + """Initialize MusicGen Style adapter. + + Args: + variant: Model variant (currently only 'medium' available) + """ + if variant not in self.VARIANTS: + raise ValueError( + f"Unknown MusicGen Style variant: {variant}. " + f"Available: {list(self.VARIANTS.keys())}" + ) + + self._variant = variant + self._config = self.VARIANTS[variant] + self._model = None + self._device: Optional[torch.device] = None + + @property + def model_id(self) -> str: + return "musicgen-style" + + @property + def variant(self) -> str: + return self._variant + + @property + def display_name(self) -> str: + return f"MusicGen Style ({self._variant})" + + @property + def description(self) -> str: + return "Style-conditioned music generation from reference audio" + + @property + def vram_estimate_mb(self) -> int: + return self._config["vram_mb"] + + @property + def max_duration(self) -> float: + return self._config["max_duration"] + + @property + def sample_rate(self) -> int: + if self._model is not None: + return self._model.sample_rate + return 32000 + + @property + def supports_conditioning(self) -> list[ConditioningType]: + return [ConditioningType.TEXT, ConditioningType.STYLE] + + @property + def is_loaded(self) -> bool: + return self._model is not None + + @property + def device(self) -> Optional[torch.device]: + return self._device + + def load(self, device: str = "cuda") -> None: + """Load the MusicGen Style model.""" + if self._model is not None: + logger.warning(f"MusicGen Style {self._variant} already loaded") + return + + logger.info(f"Loading MusicGen Style {self._variant}...") + + try: + from audiocraft.models import MusicGen + + self._device = torch.device(device) + self._model = MusicGen.get_pretrained(self._config["hf_id"]) + self._model.to(self._device) + + logger.info( + f"MusicGen Style {self._variant} loaded successfully " + f"(sample_rate={self._model.sample_rate})" + ) + + except Exception as e: + self._model = None + self._device = None + logger.error(f"Failed to load MusicGen Style {self._variant}: {e}") + raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e + + def unload(self) -> None: + """Unload the model and free memory.""" + if self._model is None: + return + + logger.info(f"Unloading MusicGen Style {self._variant}...") + + del self._model + self._model = None + self._device = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def _load_style_audio( + self, style_input: Any, target_sr: int + ) -> tuple[torch.Tensor, int]: + """Load and prepare style reference audio. + + Args: + style_input: File path, tensor, or numpy array + target_sr: Target sample rate + + Returns: + Tuple of (audio_tensor, sample_rate) + """ + if isinstance(style_input, str): + # Load from file + audio, sr = torchaudio.load(style_input) + if sr != target_sr: + audio = torchaudio.functional.resample(audio, sr, target_sr) + return audio.to(self._device), target_sr + elif isinstance(style_input, torch.Tensor): + return style_input.to(self._device), target_sr + else: + # Assume numpy array + return torch.tensor(style_input).to(self._device), target_sr + + def generate(self, request: GenerationRequest) -> GenerationResult: + """Generate music conditioned on text and style reference. + + Args: + request: Generation parameters including prompts and style conditioning + + Returns: + GenerationResult with audio tensor and metadata + + Note: + Style conditioning requires 'style' in request.conditioning with either: + - File path to audio + - Audio tensor + - Numpy array + """ + self.validate_request(request) + + # Set random seed + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Get style conditioning parameters + style_audio = request.conditioning.get("style") + eval_q = request.conditioning.get("eval_q", 3) + excerpt_length = request.conditioning.get("excerpt_length", 3.0) + + # Configure generation parameters + self._model.set_generation_params( + duration=request.duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + ) + + logger.info( + f"Generating {len(request.prompts)} sample(s) with MusicGen Style " + f"(duration={request.duration}s, style_conditioned={style_audio is not None})" + ) + + with torch.inference_mode(): + if style_audio is not None: + # Load style reference + style_tensor, style_sr = self._load_style_audio( + style_audio, self.sample_rate + ) + + # Ensure proper shape [batch, channels, samples] + if style_tensor.dim() == 1: + style_tensor = style_tensor.unsqueeze(0).unsqueeze(0) + elif style_tensor.dim() == 2: + style_tensor = style_tensor.unsqueeze(0) + + # Set style conditioner parameters + if hasattr(self._model, 'set_style_conditioner_params'): + self._model.set_style_conditioner_params( + eval_q=eval_q, + excerpt_length=excerpt_length, + ) + + # Generate with style conditioning + # Expand style to match number of prompts if needed + if style_tensor.shape[0] == 1 and len(request.prompts) > 1: + style_tensor = style_tensor.expand(len(request.prompts), -1, -1) + + audio = self._model.generate_with_chroma( + descriptions=request.prompts, + melody_wavs=style_tensor, + melody_sample_rate=style_sr, + ) + else: + # Generate without style (falls back to standard MusicGen behavior) + logger.warning( + "No style reference provided, generating without style conditioning" + ) + audio = self._model.generate(request.prompts) + + actual_duration = audio.shape[-1] / self.sample_rate + + logger.info( + f"Generated {audio.shape[0]} sample(s), " + f"duration={actual_duration:.2f}s" + ) + + return GenerationResult( + audio=audio.cpu(), + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + "style_conditioned": style_audio is not None, + "eval_q": eval_q, + "excerpt_length": excerpt_length, + }, + seed=seed, + ) + + def get_default_params(self) -> dict[str, Any]: + """Get default generation parameters for MusicGen Style.""" + return { + "duration": 10.0, + "temperature": 1.0, + "top_k": 250, + "top_p": 0.0, + "cfg_coef": 3.0, + "eval_q": 3, + "excerpt_length": 3.0, + } diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..6fc2b56 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,13 @@ +"""Services layer for AudioCraft Studio.""" + +from src.services.generation_service import GenerationService +from src.services.batch_processor import BatchProcessor, GenerationJob, JobStatus +from src.services.project_service import ProjectService + +__all__ = [ + "GenerationService", + "BatchProcessor", + "GenerationJob", + "JobStatus", + "ProjectService", +] diff --git a/src/services/batch_processor.py b/src/services/batch_processor.py new file mode 100644 index 0000000..db37a10 --- /dev/null +++ b/src/services/batch_processor.py @@ -0,0 +1,397 @@ +"""Batch processor for queued audio generation jobs.""" + +import asyncio +import logging +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +class JobStatus(str, Enum): + """Status of a generation job.""" + + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class GenerationJob: + """A queued generation job.""" + + id: str + model_id: str + variant: Optional[str] + prompts: list[str] + parameters: dict[str, Any] + conditioning: dict[str, Any] + project_id: Optional[str] + preset_used: Optional[str] + tags: list[str] + + # Status tracking + status: JobStatus = JobStatus.PENDING + progress: float = 0.0 + progress_message: str = "" + created_at: datetime = field(default_factory=datetime.utcnow) + started_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + # Results + result_id: Optional[str] = None # Generation ID if completed + audio_path: Optional[str] = None + error: Optional[str] = None + + @classmethod + def create( + cls, + model_id: str, + variant: Optional[str], + prompts: list[str], + duration: float = 10.0, + temperature: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: float = 3.0, + seed: Optional[int] = None, + conditioning: Optional[dict[str, Any]] = None, + project_id: Optional[str] = None, + preset_used: Optional[str] = None, + tags: Optional[list[str]] = None, + ) -> "GenerationJob": + """Create a new generation job.""" + return cls( + id=f"job_{uuid.uuid4().hex[:12]}", + model_id=model_id, + variant=variant, + prompts=prompts, + parameters={ + "duration": duration, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "cfg_coef": cfg_coef, + "seed": seed, + }, + conditioning=conditioning or {}, + project_id=project_id, + preset_used=preset_used, + tags=tags or [], + ) + + def to_dict(self) -> dict[str, Any]: + """Convert job to dictionary for API responses.""" + return { + "id": self.id, + "model_id": self.model_id, + "variant": self.variant, + "prompts": self.prompts, + "parameters": self.parameters, + "status": self.status.value, + "progress": self.progress, + "progress_message": self.progress_message, + "created_at": self.created_at.isoformat(), + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() if self.completed_at else None, + "result_id": self.result_id, + "audio_path": self.audio_path, + "error": self.error, + } + + +class BatchProcessor: + """Manages a queue of generation jobs. + + Features: + - Async job queue with configurable concurrency + - Progress tracking and callbacks + - Job cancellation + - Priority support (future enhancement) + """ + + def __init__( + self, + generation_service: Any, # Avoid circular import + max_queue_size: int = 100, + max_concurrent: int = 1, # GPU operations should be serialized + ): + """Initialize batch processor. + + Args: + generation_service: GenerationService instance + max_queue_size: Maximum jobs in queue + max_concurrent: Maximum concurrent generations (usually 1 for GPU) + """ + self.generation_service = generation_service + self.max_queue_size = max_queue_size + self.max_concurrent = max_concurrent + + # Job tracking + self._jobs: dict[str, GenerationJob] = {} + self._queue: asyncio.Queue[str] = asyncio.Queue(maxsize=max_queue_size) + + # Processing control + self._workers: list[asyncio.Task] = [] + self._running = False + self._lock = asyncio.Lock() + + # Callbacks + self._on_job_complete: list[Callable[[GenerationJob], None]] = [] + self._on_job_failed: list[Callable[[GenerationJob], None]] = [] + self._on_progress: list[Callable[[GenerationJob], None]] = [] + + async def start(self) -> None: + """Start the batch processor workers.""" + if self._running: + return + + self._running = True + + # Start worker tasks + for i in range(self.max_concurrent): + worker = asyncio.create_task(self._worker_loop(i)) + self._workers.append(worker) + + logger.info(f"Batch processor started with {self.max_concurrent} worker(s)") + + async def stop(self) -> None: + """Stop the batch processor and wait for pending jobs.""" + if not self._running: + return + + self._running = False + + # Cancel workers + for worker in self._workers: + worker.cancel() + + # Wait for workers to finish + await asyncio.gather(*self._workers, return_exceptions=True) + self._workers.clear() + + logger.info("Batch processor stopped") + + async def submit(self, job: GenerationJob) -> GenerationJob: + """Submit a job to the queue. + + Args: + job: Job to submit + + Returns: + The submitted job with ID + + Raises: + RuntimeError: If queue is full + """ + async with self._lock: + if len(self._jobs) >= self.max_queue_size: + raise RuntimeError( + f"Queue full (max {self.max_queue_size} jobs). " + "Please wait for jobs to complete." + ) + + self._jobs[job.id] = job + await self._queue.put(job.id) + + logger.info(f"Job {job.id} submitted to queue (position: {self._queue.qsize()})") + return job + + async def cancel(self, job_id: str) -> bool: + """Cancel a pending job. + + Args: + job_id: ID of job to cancel + + Returns: + True if job was cancelled, False if not found or already processing + """ + async with self._lock: + job = self._jobs.get(job_id) + if job is None: + return False + + if job.status != JobStatus.PENDING: + logger.warning(f"Cannot cancel job {job_id} with status {job.status}") + return False + + job.status = JobStatus.CANCELLED + job.completed_at = datetime.utcnow() + + logger.info(f"Job {job_id} cancelled") + return True + + def get_job(self, job_id: str) -> Optional[GenerationJob]: + """Get a job by ID.""" + return self._jobs.get(job_id) + + def get_queue_status(self) -> dict[str, Any]: + """Get current queue status.""" + jobs_by_status = {} + for job in self._jobs.values(): + status = job.status.value + jobs_by_status[status] = jobs_by_status.get(status, 0) + 1 + + return { + "queue_size": self._queue.qsize(), + "total_jobs": len(self._jobs), + "jobs_by_status": jobs_by_status, + "running": self._running, + "max_queue_size": self.max_queue_size, + } + + def list_jobs( + self, + status: Optional[JobStatus] = None, + limit: int = 50, + ) -> list[GenerationJob]: + """List jobs with optional status filter. + + Args: + status: Filter by status + limit: Maximum jobs to return + + Returns: + List of jobs ordered by creation time (newest first) + """ + jobs = list(self._jobs.values()) + + if status: + jobs = [j for j in jobs if j.status == status] + + # Sort by created_at descending + jobs.sort(key=lambda j: j.created_at, reverse=True) + + return jobs[:limit] + + def cleanup_completed(self, max_age_hours: float = 24.0) -> int: + """Remove old completed/failed jobs from memory. + + Args: + max_age_hours: Remove jobs older than this + + Returns: + Number of jobs removed + """ + cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600) + removed = 0 + + for job_id, job in list(self._jobs.items()): + if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): + if job.completed_at and job.completed_at.timestamp() < cutoff: + del self._jobs[job_id] + removed += 1 + + if removed: + logger.info(f"Cleaned up {removed} old jobs") + + return removed + + async def _worker_loop(self, worker_id: int) -> None: + """Worker loop that processes jobs from queue.""" + logger.debug(f"Worker {worker_id} started") + + while self._running: + try: + # Wait for job with timeout + try: + job_id = await asyncio.wait_for( + self._queue.get(), timeout=1.0 + ) + except asyncio.TimeoutError: + continue + + job = self._jobs.get(job_id) + if job is None or job.status == JobStatus.CANCELLED: + continue + + await self._process_job(job) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Worker {worker_id} error: {e}") + + logger.debug(f"Worker {worker_id} stopped") + + async def _process_job(self, job: GenerationJob) -> None: + """Process a single generation job.""" + logger.info(f"Processing job {job.id}: {job.model_id}/{job.variant}") + + job.status = JobStatus.PROCESSING + job.started_at = datetime.utcnow() + + def progress_callback(progress: float, message: str) -> None: + job.progress = progress + job.progress_message = message + for callback in self._on_progress: + try: + callback(job) + except Exception as e: + logger.error(f"Progress callback error: {e}") + + try: + result, generation = await self.generation_service.generate( + model_id=job.model_id, + variant=job.variant, + prompts=job.prompts, + duration=job.parameters.get("duration", 10.0), + temperature=job.parameters.get("temperature", 1.0), + top_k=job.parameters.get("top_k", 250), + top_p=job.parameters.get("top_p", 0.0), + cfg_coef=job.parameters.get("cfg_coef", 3.0), + seed=job.parameters.get("seed"), + conditioning=job.conditioning, + project_id=job.project_id, + preset_used=job.preset_used, + tags=job.tags, + progress_callback=progress_callback, + ) + + job.status = JobStatus.COMPLETED + job.result_id = generation.id + job.audio_path = generation.audio_path + job.completed_at = datetime.utcnow() + job.progress = 1.0 + job.progress_message = "Complete" + + logger.info(f"Job {job.id} completed: {generation.id}") + + for callback in self._on_job_complete: + try: + callback(job) + except Exception as e: + logger.error(f"Completion callback error: {e}") + + except Exception as e: + job.status = JobStatus.FAILED + job.error = str(e) + job.completed_at = datetime.utcnow() + + logger.error(f"Job {job.id} failed: {e}") + + for callback in self._on_job_failed: + try: + callback(job) + except Exception as e2: + logger.error(f"Failure callback error: {e2}") + + # Callback registration + + def on_job_complete(self, callback: Callable[[GenerationJob], None]) -> None: + """Register callback for job completion.""" + self._on_job_complete.append(callback) + + def on_job_failed(self, callback: Callable[[GenerationJob], None]) -> None: + """Register callback for job failure.""" + self._on_job_failed.append(callback) + + def on_progress(self, callback: Callable[[GenerationJob], None]) -> None: + """Register callback for progress updates.""" + self._on_progress.append(callback) diff --git a/src/services/generation_service.py b/src/services/generation_service.py new file mode 100644 index 0000000..228ab09 --- /dev/null +++ b/src/services/generation_service.py @@ -0,0 +1,322 @@ +"""Generation service for orchestrating audio generation.""" + +import logging +import time +from pathlib import Path +from typing import Any, Callable, Optional + +import soundfile as sf +import torch + +from src.core.base_model import GenerationRequest, GenerationResult +from src.core.gpu_manager import GPUMemoryManager +from src.core.model_registry import ModelRegistry +from src.core.oom_handler import OOMHandler +from src.storage.database import Database, Generation + +logger = logging.getLogger(__name__) + + +class GenerationService: + """Orchestrates audio generation across all models. + + Handles: + - Model selection and loading + - Generation execution with OOM recovery + - Result saving and database recording + - Progress callbacks for UI updates + """ + + def __init__( + self, + registry: ModelRegistry, + gpu_manager: GPUMemoryManager, + database: Database, + output_dir: Path, + ): + """Initialize generation service. + + Args: + registry: Model registry for model access + gpu_manager: GPU memory manager + database: Database for storing generation records + output_dir: Directory for saving generated audio + """ + self.registry = registry + self.gpu_manager = gpu_manager + self.database = database + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # OOM handler + self.oom_handler = OOMHandler(gpu_manager, registry) + + # Statistics + self._generation_count = 0 + self._total_duration_generated = 0.0 + + async def generate( + self, + model_id: str, + variant: Optional[str], + prompts: list[str], + duration: float = 10.0, + temperature: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: float = 3.0, + seed: Optional[int] = None, + conditioning: Optional[dict[str, Any]] = None, + project_id: Optional[str] = None, + preset_used: Optional[str] = None, + tags: Optional[list[str]] = None, + progress_callback: Optional[Callable[[float, str], None]] = None, + ) -> tuple[GenerationResult, Generation]: + """Generate audio and save to database. + + Args: + model_id: Model family to use + variant: Model variant (None for default) + prompts: Text prompts for generation + duration: Target duration in seconds + temperature: Sampling temperature + top_k: Top-k sampling parameter + top_p: Nucleus sampling parameter + cfg_coef: Classifier-free guidance coefficient + seed: Random seed for reproducibility + conditioning: Optional conditioning data (melody, style, chords, etc.) + project_id: Optional project to associate with + preset_used: Name of preset used (for metadata) + tags: Optional tags for organization + progress_callback: Optional callback for progress updates + + Returns: + Tuple of (GenerationResult, Generation database record) + + Raises: + ValueError: If model not found or parameters invalid + RuntimeError: If generation fails + """ + start_time = time.time() + + # Report progress + if progress_callback: + progress_callback(0.0, "Preparing generation...") + + # Build generation request + request = GenerationRequest( + prompts=prompts, + duration=duration, + temperature=temperature, + top_k=top_k, + top_p=top_p, + cfg_coef=cfg_coef, + seed=seed, + conditioning=conditioning or {}, + ) + + # Get model configuration + family_config, variant_config = self.registry.get_model_config(model_id, variant) + actual_variant = variant or family_config.default_variant + + # Check VRAM availability + if progress_callback: + progress_callback(0.1, "Checking GPU memory...") + + can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb) + if not can_load: + # Try OOM recovery + if not self.oom_handler.check_memory_for_operation(variant_config.vram_mb): + raise RuntimeError(f"Insufficient GPU memory: {reason}") + + # Generate with OOM recovery wrapper + if progress_callback: + progress_callback(0.2, f"Loading {model_id}/{actual_variant}...") + + @self.oom_handler.with_oom_recovery + def do_generation() -> GenerationResult: + with self.registry.get_model(model_id, actual_variant) as model: + if progress_callback: + progress_callback(0.4, "Generating audio...") + return model.generate(request) + + result = do_generation() + + if progress_callback: + progress_callback(0.8, "Saving audio...") + + # Save audio file + audio_path = self._save_audio(result) + + # Create database record + generation = Generation.create( + model=model_id, + variant=actual_variant, + prompt=prompts[0] if len(prompts) == 1 else "\n".join(prompts), + parameters={ + "duration": duration, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "cfg_coef": cfg_coef, + }, + project_id=project_id, + preset_used=preset_used, + conditioning=conditioning, + audio_path=str(audio_path), + duration_seconds=result.duration, + sample_rate=result.sample_rate, + tags=tags or [], + seed=result.seed, + ) + + # Save to database + await self.database.create_generation(generation) + + # Update statistics + self._generation_count += 1 + self._total_duration_generated += result.duration + + elapsed = time.time() - start_time + logger.info( + f"Generation complete: {model_id}/{actual_variant}, " + f"duration={result.duration:.1f}s, elapsed={elapsed:.1f}s" + ) + + if progress_callback: + progress_callback(1.0, "Complete!") + + return result, generation + + def _save_audio(self, result: GenerationResult) -> Path: + """Save generated audio to file. + + Args: + result: Generation result with audio tensor + + Returns: + Path to saved audio file + """ + # Generate unique filename + timestamp = int(time.time() * 1000) + filename = f"{result.model_id}_{result.variant}_{timestamp}.wav" + filepath = self.output_dir / filename + + # Convert tensor to numpy and save + audio = result.audio.numpy() + + # Handle batch dimension - save first sample if batched + if audio.ndim == 3: + audio = audio[0] # [channels, samples] + + # Transpose to [samples, channels] for soundfile + if audio.ndim == 2: + audio = audio.T + + sf.write(filepath, audio, result.sample_rate) + + logger.debug(f"Saved audio to {filepath}") + return filepath + + async def regenerate( + self, + generation_id: str, + new_seed: Optional[int] = None, + progress_callback: Optional[Callable[[float, str], None]] = None, + ) -> tuple[GenerationResult, Generation]: + """Regenerate audio using parameters from existing generation. + + Args: + generation_id: ID of generation to regenerate + new_seed: Optional new seed (uses original if None) + progress_callback: Optional progress callback + + Returns: + Tuple of (GenerationResult, new Generation record) + + Raises: + ValueError: If generation not found + """ + # Load original generation + original = await self.database.get_generation(generation_id) + if original is None: + raise ValueError(f"Generation not found: {generation_id}") + + # Parse prompts + prompts = original.prompt.split("\n") if "\n" in original.prompt else [original.prompt] + + # Regenerate with same or new seed + return await self.generate( + model_id=original.model, + variant=original.variant, + prompts=prompts, + duration=original.parameters.get("duration", 10.0), + temperature=original.parameters.get("temperature", 1.0), + top_k=original.parameters.get("top_k", 250), + top_p=original.parameters.get("top_p", 0.0), + cfg_coef=original.parameters.get("cfg_coef", 3.0), + seed=new_seed if new_seed is not None else original.seed, + conditioning=original.conditioning, + project_id=original.project_id, + preset_used=original.preset_used, + tags=original.tags, + progress_callback=progress_callback, + ) + + def get_stats(self) -> dict[str, Any]: + """Get generation statistics. + + Returns: + Dictionary with generation stats + """ + return { + "generation_count": self._generation_count, + "total_duration_generated": self._total_duration_generated, + "oom_stats": self.oom_handler.get_stats(), + } + + def estimate_generation_time( + self, model_id: str, variant: Optional[str], duration: float + ) -> float: + """Estimate generation time for given parameters. + + Args: + model_id: Model family + variant: Model variant + duration: Target audio duration + + Returns: + Estimated generation time in seconds + """ + # Rough estimates based on model type and RTX 4090 + # These are approximations and vary based on many factors + estimates = { + "musicgen": { + "small": 0.8, # seconds per second of audio + "medium": 1.5, + "large": 3.0, + "melody": 1.8, + }, + "audiogen": { + "medium": 1.5, + }, + "magnet": { + "small-10secs": 0.3, # Non-autoregressive is faster + "medium-10secs": 0.5, + "small-30secs": 0.3, + "medium-30secs": 0.5, + }, + "musicgen-style": { + "medium": 1.8, + }, + "jasco": { + "chords-drums-400M": 1.0, + "chords-drums-1B": 1.5, + }, + } + + family_config, _ = self.registry.get_model_config(model_id, variant) + actual_variant = variant or family_config.default_variant + + ratio = estimates.get(model_id, {}).get(actual_variant, 2.0) + return duration * ratio + 5.0 # Add 5s for model loading overhead diff --git a/src/services/project_service.py b/src/services/project_service.py new file mode 100644 index 0000000..d7caec2 --- /dev/null +++ b/src/services/project_service.py @@ -0,0 +1,395 @@ +"""Project service for managing projects and generations.""" + +import logging +import shutil +from pathlib import Path +from typing import Any, Optional + +from src.storage.database import Database, Generation, Project, Preset + +logger = logging.getLogger(__name__) + + +class ProjectService: + """Service for managing projects, generations, and presets. + + Provides a high-level API for project organization and + generation history management. + """ + + def __init__(self, database: Database, output_dir: Path): + """Initialize project service. + + Args: + database: Database instance + output_dir: Directory where audio files are stored + """ + self.database = database + self.output_dir = Path(output_dir) + + # Project Operations + + async def create_project( + self, name: str, description: str = "" + ) -> Project: + """Create a new project. + + Args: + name: Project name + description: Optional description + + Returns: + Created project + """ + project = Project.create(name, description) + await self.database.create_project(project) + logger.info(f"Created project: {project.id} ({name})") + return project + + async def get_project(self, project_id: str) -> Optional[Project]: + """Get a project by ID.""" + return await self.database.get_project(project_id) + + async def list_projects( + self, limit: int = 100, offset: int = 0 + ) -> list[Project]: + """List all projects.""" + return await self.database.list_projects(limit, offset) + + async def update_project( + self, + project_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> Optional[Project]: + """Update a project. + + Args: + project_id: Project ID + name: New name (None to keep current) + description: New description (None to keep current) + + Returns: + Updated project, or None if not found + """ + project = await self.database.get_project(project_id) + if project is None: + return None + + if name is not None: + project.name = name + if description is not None: + project.description = description + + await self.database.update_project(project) + logger.info(f"Updated project: {project_id}") + return project + + async def delete_project( + self, project_id: str, delete_files: bool = False + ) -> bool: + """Delete a project. + + Args: + project_id: Project ID + delete_files: If True, also delete associated audio files + + Returns: + True if deleted + """ + if delete_files: + # Get all generations and delete their files + generations = await self.database.list_generations(project_id=project_id) + for gen in generations: + if gen.audio_path: + try: + Path(gen.audio_path).unlink(missing_ok=True) + except Exception as e: + logger.warning(f"Failed to delete {gen.audio_path}: {e}") + + result = await self.database.delete_project(project_id) + if result: + logger.info(f"Deleted project: {project_id}") + return result + + async def get_project_stats(self, project_id: str) -> dict[str, Any]: + """Get statistics for a project. + + Args: + project_id: Project ID + + Returns: + Dictionary with project statistics + """ + generations = await self.database.list_generations( + project_id=project_id, limit=10000 + ) + + total_duration = sum(g.duration_seconds or 0 for g in generations) + models_used = {} + for gen in generations: + key = f"{gen.model}/{gen.variant}" + models_used[key] = models_used.get(key, 0) + 1 + + return { + "generation_count": len(generations), + "total_duration_seconds": total_duration, + "models_used": models_used, + } + + # Generation Operations + + async def get_generation(self, generation_id: str) -> Optional[Generation]: + """Get a generation by ID.""" + return await self.database.get_generation(generation_id) + + async def list_generations( + self, + project_id: Optional[str] = None, + model: Optional[str] = None, + search: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ) -> list[Generation]: + """List generations with optional filters. + + Args: + project_id: Filter by project + model: Filter by model family + search: Search in prompts, names, and tags + limit: Maximum results + offset: Pagination offset + + Returns: + List of generations + """ + return await self.database.list_generations( + project_id=project_id, + model=model, + search=search, + limit=limit, + offset=offset, + ) + + async def update_generation( + self, + generation_id: str, + name: Optional[str] = None, + tags: Optional[list[str]] = None, + notes: Optional[str] = None, + project_id: Optional[str] = None, + ) -> Optional[Generation]: + """Update a generation's metadata. + + Args: + generation_id: Generation ID + name: New name + tags: New tags (replaces existing) + notes: New notes + project_id: Move to different project + + Returns: + Updated generation, or None if not found + """ + generation = await self.database.get_generation(generation_id) + if generation is None: + return None + + if name is not None: + generation.name = name + if tags is not None: + generation.tags = tags + if notes is not None: + generation.notes = notes + if project_id is not None: + generation.project_id = project_id + + await self.database.update_generation(generation) + logger.info(f"Updated generation: {generation_id}") + return generation + + async def delete_generation( + self, generation_id: str, delete_file: bool = True + ) -> bool: + """Delete a generation. + + Args: + generation_id: Generation ID + delete_file: If True, also delete audio file + + Returns: + True if deleted + """ + if delete_file: + generation = await self.database.get_generation(generation_id) + if generation and generation.audio_path: + try: + Path(generation.audio_path).unlink(missing_ok=True) + except Exception as e: + logger.warning(f"Failed to delete audio file: {e}") + + result = await self.database.delete_generation(generation_id) + if result: + logger.info(f"Deleted generation: {generation_id}") + return result + + async def move_generations_to_project( + self, generation_ids: list[str], project_id: Optional[str] + ) -> int: + """Move multiple generations to a project. + + Args: + generation_ids: List of generation IDs + project_id: Target project ID (None to unlink) + + Returns: + Number of generations moved + """ + moved = 0 + for gen_id in generation_ids: + result = await self.update_generation(gen_id, project_id=project_id) + if result: + moved += 1 + + logger.info(f"Moved {moved} generations to project {project_id}") + return moved + + # Preset Operations + + async def create_preset( + self, + model: str, + name: str, + parameters: dict[str, Any], + description: str = "", + ) -> Preset: + """Create a custom preset. + + Args: + model: Model family this preset is for + name: Preset name + parameters: Generation parameters + description: Optional description + + Returns: + Created preset + """ + preset = Preset.create(model, name, parameters, description) + await self.database.create_preset(preset) + logger.info(f"Created preset: {preset.id} ({name}) for {model}") + return preset + + async def list_presets( + self, model: Optional[str] = None, include_builtin: bool = True + ) -> list[Preset]: + """List presets with optional model filter. + + Args: + model: Filter by model family + include_builtin: Include built-in presets + + Returns: + List of presets + """ + return await self.database.list_presets(model, include_builtin) + + async def get_preset(self, preset_id: str) -> Optional[Preset]: + """Get a preset by ID.""" + return await self.database.get_preset(preset_id) + + async def delete_preset(self, preset_id: str) -> bool: + """Delete a custom preset. + + Note: Built-in presets cannot be deleted. + + Args: + preset_id: Preset ID + + Returns: + True if deleted + """ + result = await self.database.delete_preset(preset_id) + if result: + logger.info(f"Deleted preset: {preset_id}") + return result + + # Export Operations + + async def export_project( + self, project_id: str, output_path: Path, include_metadata: bool = True + ) -> Path: + """Export a project as a ZIP archive. + + Args: + project_id: Project ID + output_path: Output ZIP file path + include_metadata: Include JSON metadata file + + Returns: + Path to created ZIP file + """ + import json + import tempfile + import zipfile + + project = await self.database.get_project(project_id) + if project is None: + raise ValueError(f"Project not found: {project_id}") + + generations = await self.database.list_generations( + project_id=project_id, limit=10000 + ) + + with tempfile.TemporaryDirectory() as tmpdir: + tmppath = Path(tmpdir) + + # Copy audio files + for gen in generations: + if gen.audio_path and Path(gen.audio_path).exists(): + src = Path(gen.audio_path) + dst = tmppath / src.name + shutil.copy2(src, dst) + + # Create metadata file + if include_metadata: + metadata = { + "project": { + "id": project.id, + "name": project.name, + "description": project.description, + "created_at": project.created_at.isoformat(), + }, + "generations": [ + { + "id": g.id, + "model": g.model, + "variant": g.variant, + "prompt": g.prompt, + "parameters": g.parameters, + "duration": g.duration_seconds, + "audio_file": Path(g.audio_path).name if g.audio_path else None, + "created_at": g.created_at.isoformat(), + "tags": g.tags, + "seed": g.seed, + } + for g in generations + ], + } + + metadata_path = tmppath / "metadata.json" + metadata_path.write_text(json.dumps(metadata, indent=2)) + + # Create ZIP + output_path = Path(output_path) + with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf: + for file in tmppath.iterdir(): + zf.write(file, file.name) + + logger.info(f"Exported project {project_id} to {output_path}") + return output_path + + # Statistics + + async def get_stats(self) -> dict[str, Any]: + """Get overall statistics.""" + return await self.database.get_stats() diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000..ea0aa46 --- /dev/null +++ b/src/storage/__init__.py @@ -0,0 +1,5 @@ +"""Storage module for AudioCraft Studio.""" + +from src.storage.database import Database, Generation, Project, Preset + +__all__ = ["Database", "Generation", "Project", "Preset"] diff --git a/src/storage/database.py b/src/storage/database.py new file mode 100644 index 0000000..cd9e620 --- /dev/null +++ b/src/storage/database.py @@ -0,0 +1,550 @@ +"""SQLite database for projects, generations, and presets.""" + +import json +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Optional + +import aiosqlite + +logger = logging.getLogger(__name__) + + +@dataclass +class Project: + """Project entity for organizing generations.""" + + id: str + name: str + created_at: datetime + updated_at: datetime + description: str = "" + + @classmethod + def create(cls, name: str, description: str = "") -> "Project": + """Create a new project with generated ID.""" + now = datetime.utcnow() + return cls( + id=f"proj_{uuid.uuid4().hex[:12]}", + name=name, + created_at=now, + updated_at=now, + description=description, + ) + + +@dataclass +class Generation: + """Audio generation record.""" + + id: str + project_id: Optional[str] + model: str + variant: str + prompt: str + parameters: dict[str, Any] + created_at: datetime + audio_path: Optional[str] = None + duration_seconds: Optional[float] = None + sample_rate: Optional[int] = None + preset_used: Optional[str] = None + conditioning: dict[str, Any] = field(default_factory=dict) + name: Optional[str] = None + tags: list[str] = field(default_factory=list) + notes: Optional[str] = None + seed: Optional[int] = None + + @classmethod + def create( + cls, + model: str, + variant: str, + prompt: str, + parameters: dict[str, Any], + project_id: Optional[str] = None, + **kwargs, + ) -> "Generation": + """Create a new generation record.""" + return cls( + id=f"gen_{uuid.uuid4().hex[:12]}", + project_id=project_id, + model=model, + variant=variant, + prompt=prompt, + parameters=parameters, + created_at=datetime.utcnow(), + **kwargs, + ) + + +@dataclass +class Preset: + """Generation parameter preset.""" + + id: str + model: str + name: str + parameters: dict[str, Any] + created_at: datetime + description: str = "" + is_builtin: bool = False + + @classmethod + def create( + cls, + model: str, + name: str, + parameters: dict[str, Any], + description: str = "", + ) -> "Preset": + """Create a new custom preset.""" + return cls( + id=f"preset_{uuid.uuid4().hex[:12]}", + model=model, + name=name, + parameters=parameters, + created_at=datetime.utcnow(), + description=description, + is_builtin=False, + ) + + +class Database: + """Async SQLite database for AudioCraft Studio. + + Handles storage of projects, generations, and presets. + """ + + SCHEMA = """ + CREATE TABLE IF NOT EXISTS projects ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT DEFAULT '', + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE IF NOT EXISTS generations ( + id TEXT PRIMARY KEY, + project_id TEXT REFERENCES projects(id) ON DELETE SET NULL, + model TEXT NOT NULL, + variant TEXT NOT NULL, + prompt TEXT NOT NULL, + parameters JSON NOT NULL, + preset_used TEXT, + conditioning JSON, + audio_path TEXT, + duration_seconds REAL, + sample_rate INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + name TEXT, + tags JSON, + notes TEXT, + seed INTEGER + ); + + CREATE TABLE IF NOT EXISTS presets ( + id TEXT PRIMARY KEY, + model TEXT NOT NULL, + name TEXT NOT NULL, + description TEXT DEFAULT '', + parameters JSON NOT NULL, + is_builtin BOOLEAN DEFAULT FALSE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id); + CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC); + CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model); + CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model); + """ + + def __init__(self, db_path: Path): + """Initialize database. + + Args: + db_path: Path to SQLite database file + """ + self.db_path = db_path + self._connection: Optional[aiosqlite.Connection] = None + + async def connect(self) -> None: + """Open database connection and initialize schema.""" + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._connection = await aiosqlite.connect(self.db_path) + self._connection.row_factory = aiosqlite.Row + + # Initialize schema + await self._connection.executescript(self.SCHEMA) + await self._connection.commit() + + logger.info(f"Database connected: {self.db_path}") + + async def close(self) -> None: + """Close database connection.""" + if self._connection: + await self._connection.close() + self._connection = None + + @property + def conn(self) -> aiosqlite.Connection: + """Get active connection.""" + if not self._connection: + raise RuntimeError("Database not connected") + return self._connection + + # Project Methods + + async def create_project(self, project: Project) -> Project: + """Create a new project.""" + await self.conn.execute( + """ + INSERT INTO projects (id, name, description, created_at, updated_at) + VALUES (?, ?, ?, ?, ?) + """, + ( + project.id, + project.name, + project.description, + project.created_at.isoformat(), + project.updated_at.isoformat(), + ), + ) + await self.conn.commit() + return project + + async def get_project(self, project_id: str) -> Optional[Project]: + """Get a project by ID.""" + async with self.conn.execute( + "SELECT * FROM projects WHERE id = ?", (project_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return Project( + id=row["id"], + name=row["name"], + description=row["description"] or "", + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + ) + return None + + async def list_projects( + self, limit: int = 100, offset: int = 0 + ) -> list[Project]: + """List all projects, ordered by last update.""" + async with self.conn.execute( + """ + SELECT * FROM projects + ORDER BY updated_at DESC + LIMIT ? OFFSET ? + """, + (limit, offset), + ) as cursor: + rows = await cursor.fetchall() + return [ + Project( + id=row["id"], + name=row["name"], + description=row["description"] or "", + created_at=datetime.fromisoformat(row["created_at"]), + updated_at=datetime.fromisoformat(row["updated_at"]), + ) + for row in rows + ] + + async def update_project(self, project: Project) -> None: + """Update a project.""" + project.updated_at = datetime.utcnow() + await self.conn.execute( + """ + UPDATE projects SET name = ?, description = ?, updated_at = ? + WHERE id = ? + """, + (project.name, project.description, project.updated_at.isoformat(), project.id), + ) + await self.conn.commit() + + async def delete_project(self, project_id: str) -> bool: + """Delete a project (generations are kept but unlinked).""" + result = await self.conn.execute( + "DELETE FROM projects WHERE id = ?", (project_id,) + ) + await self.conn.commit() + return result.rowcount > 0 + + # Generation Methods + + async def create_generation(self, generation: Generation) -> Generation: + """Create a new generation record.""" + await self.conn.execute( + """ + INSERT INTO generations ( + id, project_id, model, variant, prompt, parameters, + preset_used, conditioning, audio_path, duration_seconds, + sample_rate, created_at, name, tags, notes, seed + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + generation.id, + generation.project_id, + generation.model, + generation.variant, + generation.prompt, + json.dumps(generation.parameters), + generation.preset_used, + json.dumps(generation.conditioning), + generation.audio_path, + generation.duration_seconds, + generation.sample_rate, + generation.created_at.isoformat(), + generation.name, + json.dumps(generation.tags), + generation.notes, + generation.seed, + ), + ) + await self.conn.commit() + + # Update project's updated_at if linked + if generation.project_id: + await self.conn.execute( + "UPDATE projects SET updated_at = ? WHERE id = ?", + (datetime.utcnow().isoformat(), generation.project_id), + ) + await self.conn.commit() + + return generation + + async def get_generation(self, generation_id: str) -> Optional[Generation]: + """Get a generation by ID.""" + async with self.conn.execute( + "SELECT * FROM generations WHERE id = ?", (generation_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return self._row_to_generation(row) + return None + + async def list_generations( + self, + project_id: Optional[str] = None, + model: Optional[str] = None, + limit: int = 100, + offset: int = 0, + search: Optional[str] = None, + ) -> list[Generation]: + """List generations with optional filters.""" + conditions = [] + params = [] + + if project_id: + conditions.append("project_id = ?") + params.append(project_id) + + if model: + conditions.append("model = ?") + params.append(model) + + if search: + conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)") + search_pattern = f"%{search}%" + params.extend([search_pattern, search_pattern, search_pattern]) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with self.conn.execute( + f""" + SELECT * FROM generations + WHERE {where_clause} + ORDER BY created_at DESC + LIMIT ? OFFSET ? + """, + (*params, limit, offset), + ) as cursor: + rows = await cursor.fetchall() + return [self._row_to_generation(row) for row in rows] + + async def update_generation(self, generation: Generation) -> None: + """Update a generation record.""" + await self.conn.execute( + """ + UPDATE generations SET + project_id = ?, name = ?, tags = ?, notes = ?, + audio_path = ?, duration_seconds = ?, sample_rate = ? + WHERE id = ? + """, + ( + generation.project_id, + generation.name, + json.dumps(generation.tags), + generation.notes, + generation.audio_path, + generation.duration_seconds, + generation.sample_rate, + generation.id, + ), + ) + await self.conn.commit() + + async def delete_generation(self, generation_id: str) -> bool: + """Delete a generation record.""" + result = await self.conn.execute( + "DELETE FROM generations WHERE id = ?", (generation_id,) + ) + await self.conn.commit() + return result.rowcount > 0 + + async def count_generations( + self, project_id: Optional[str] = None, model: Optional[str] = None + ) -> int: + """Count generations with optional filters.""" + conditions = [] + params = [] + + if project_id: + conditions.append("project_id = ?") + params.append(project_id) + + if model: + conditions.append("model = ?") + params.append(model) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with self.conn.execute( + f"SELECT COUNT(*) FROM generations WHERE {where_clause}", + params, + ) as cursor: + row = await cursor.fetchone() + return row[0] if row else 0 + + def _row_to_generation(self, row: aiosqlite.Row) -> Generation: + """Convert database row to Generation object.""" + return Generation( + id=row["id"], + project_id=row["project_id"], + model=row["model"], + variant=row["variant"], + prompt=row["prompt"], + parameters=json.loads(row["parameters"]), + preset_used=row["preset_used"], + conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {}, + audio_path=row["audio_path"], + duration_seconds=row["duration_seconds"], + sample_rate=row["sample_rate"], + created_at=datetime.fromisoformat(row["created_at"]), + name=row["name"], + tags=json.loads(row["tags"]) if row["tags"] else [], + notes=row["notes"], + seed=row["seed"], + ) + + # Preset Methods + + async def create_preset(self, preset: Preset) -> Preset: + """Create a new preset.""" + await self.conn.execute( + """ + INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + preset.id, + preset.model, + preset.name, + preset.description, + json.dumps(preset.parameters), + preset.is_builtin, + preset.created_at.isoformat(), + ), + ) + await self.conn.commit() + return preset + + async def get_preset(self, preset_id: str) -> Optional[Preset]: + """Get a preset by ID.""" + async with self.conn.execute( + "SELECT * FROM presets WHERE id = ?", (preset_id,) + ) as cursor: + row = await cursor.fetchone() + if row: + return self._row_to_preset(row) + return None + + async def list_presets( + self, model: Optional[str] = None, include_builtin: bool = True + ) -> list[Preset]: + """List presets with optional model filter.""" + conditions = [] + params = [] + + if model: + conditions.append("model = ?") + params.append(model) + + if not include_builtin: + conditions.append("is_builtin = FALSE") + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with self.conn.execute( + f""" + SELECT * FROM presets + WHERE {where_clause} + ORDER BY is_builtin DESC, name ASC + """, + params, + ) as cursor: + rows = await cursor.fetchall() + return [self._row_to_preset(row) for row in rows] + + async def delete_preset(self, preset_id: str) -> bool: + """Delete a preset (only custom presets can be deleted).""" + result = await self.conn.execute( + "DELETE FROM presets WHERE id = ? AND is_builtin = FALSE", + (preset_id,), + ) + await self.conn.commit() + return result.rowcount > 0 + + def _row_to_preset(self, row: aiosqlite.Row) -> Preset: + """Convert database row to Preset object.""" + return Preset( + id=row["id"], + model=row["model"], + name=row["name"], + description=row["description"] or "", + parameters=json.loads(row["parameters"]), + is_builtin=bool(row["is_builtin"]), + created_at=datetime.fromisoformat(row["created_at"]), + ) + + # Utility Methods + + async def get_stats(self) -> dict[str, Any]: + """Get database statistics.""" + stats = {} + + async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor: + row = await cursor.fetchone() + stats["projects"] = row[0] if row else 0 + + async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor: + row = await cursor.fetchone() + stats["generations"] = row[0] if row else 0 + + async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor: + row = await cursor.fetchone() + stats["presets"] = row[0] if row else 0 + + async with self.conn.execute( + "SELECT model, COUNT(*) as count FROM generations GROUP BY model" + ) as cursor: + rows = await cursor.fetchall() + stats["generations_by_model"] = {row["model"]: row["count"] for row in rows} + + return stats diff --git a/src/ui/__init__.py b/src/ui/__init__.py new file mode 100644 index 0000000..fec1183 --- /dev/null +++ b/src/ui/__init__.py @@ -0,0 +1,5 @@ +"""Gradio UI for AudioCraft Studio.""" + +from src.ui.app import create_app + +__all__ = ["create_app"] diff --git a/src/ui/app.py b/src/ui/app.py new file mode 100644 index 0000000..9f0d6ca --- /dev/null +++ b/src/ui/app.py @@ -0,0 +1,355 @@ +"""Main Gradio application for AudioCraft Studio.""" + +import asyncio +import gradio as gr +from typing import Any, Optional +from pathlib import Path + +from src.ui.theme import create_audiocraft_theme, get_custom_css +from src.ui.state import UIState, DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.vram_monitor import create_vram_monitor +from src.ui.tabs import ( + create_dashboard_tab, + create_musicgen_tab, + create_audiogen_tab, + create_magnet_tab, + create_style_tab, + create_jasco_tab, +) +from src.ui.pages import create_projects_page, create_settings_page + +from config.settings import get_settings + + +class AudioCraftApp: + """Main AudioCraft Studio Gradio application.""" + + def __init__( + self, + generation_service: Any = None, + batch_processor: Any = None, + project_service: Any = None, + gpu_manager: Any = None, + model_registry: Any = None, + ): + """Initialize the application. + + Args: + generation_service: Service for handling generations + batch_processor: Service for batch/queue processing + project_service: Service for project management + gpu_manager: GPU memory manager + model_registry: Model registry for loading/unloading + """ + self.settings = get_settings() + self.generation_service = generation_service + self.batch_processor = batch_processor + self.project_service = project_service + self.gpu_manager = gpu_manager + self.model_registry = model_registry + + self.ui_state = UIState() + self.app: Optional[gr.Blocks] = None + + def _get_queue_status(self) -> dict[str, Any]: + """Get current queue status.""" + if self.batch_processor: + return { + "queue_size": len(self.batch_processor.queue), + "active_jobs": self.batch_processor.active_count, + "completed_today": self.batch_processor.completed_count, + } + return {"queue_size": 0, "active_jobs": 0, "completed_today": 0} + + def _get_recent_generations(self, limit: int = 5) -> list[dict[str, Any]]: + """Get recent generations.""" + if self.project_service: + try: + return asyncio.run(self.project_service.get_recent_generations(limit)) + except Exception: + pass + return [] + + def _get_gpu_status(self) -> dict[str, Any]: + """Get GPU memory status.""" + if self.gpu_manager: + return { + "used_gb": self.gpu_manager.get_used_memory() / 1024**3, + "total_gb": self.gpu_manager.total_memory / 1024**3, + "utilization_percent": self.gpu_manager.get_utilization(), + "available_gb": self.gpu_manager.get_available_memory() / 1024**3, + } + return {"used_gb": 0, "total_gb": 24, "utilization_percent": 0, "available_gb": 24} + + async def _generate(self, **kwargs) -> tuple[Any, Any]: + """Generate audio using the generation service.""" + if self.generation_service: + return await self.generation_service.generate(**kwargs) + raise RuntimeError("Generation service not configured") + + def _add_to_queue(self, **kwargs) -> Any: + """Add generation job to queue.""" + if self.batch_processor: + return self.batch_processor.add_job(**kwargs) + raise RuntimeError("Batch processor not configured") + + def _get_projects(self) -> list[dict]: + """Get all projects.""" + if self.project_service: + try: + return asyncio.run(self.project_service.list_projects()) + except Exception: + pass + return [] + + def _get_generations(self, project_id: str, limit: int, offset: int) -> list[dict]: + """Get generations for a project.""" + if self.project_service: + try: + return asyncio.run( + self.project_service.list_generations(project_id, limit, offset) + ) + except Exception: + pass + return [] + + def _delete_generation(self, generation_id: str) -> bool: + """Delete a generation.""" + if self.project_service: + try: + asyncio.run(self.project_service.delete_generation(generation_id)) + return True + except Exception: + pass + return False + + def _export_project(self, project_id: str) -> str: + """Export project as ZIP.""" + if self.project_service: + return asyncio.run(self.project_service.export_project_zip(project_id)) + raise RuntimeError("Project service not configured") + + def _create_project(self, name: str, description: str) -> dict: + """Create a new project.""" + if self.project_service: + return asyncio.run(self.project_service.create_project(name, description)) + raise RuntimeError("Project service not configured") + + def _get_app_settings(self) -> dict: + """Get application settings.""" + return { + "output_dir": str(self.settings.output_dir), + "default_format": self.settings.default_format, + "sample_rate": self.settings.sample_rate, + "normalize_audio": self.settings.normalize_audio, + "theme_mode": "Dark", + "show_advanced": False, + "auto_play": True, + "comfyui_reserve_gb": self.settings.comfyui_reserve_gb, + "idle_timeout_minutes": self.settings.idle_unload_minutes, + "max_loaded_models": self.settings.max_loaded_models, + "musicgen_variant": "medium", + "musicgen_duration": 10, + "audiogen_duration": 5, + "magnet_variant": "medium", + "magnet_decoding_steps": 20, + "api_enabled": self.settings.api_enabled, + "api_port": self.settings.api_port, + "rate_limit": self.settings.api_rate_limit, + "max_batch_size": self.settings.max_batch_size, + "max_queue_size": self.settings.max_queue_size, + "max_workers": 1, + "priority_queue": False, + } + + def _update_app_settings(self, settings: dict) -> bool: + """Update application settings.""" + # In a real implementation, this would persist settings + # For now, just return success + return True + + def _clear_cache(self) -> bool: + """Clear model cache.""" + if self.model_registry: + try: + self.model_registry.clear_cache() + return True + except Exception: + pass + return False + + def _unload_all_models(self) -> bool: + """Unload all models from memory.""" + if self.model_registry: + try: + asyncio.run(self.model_registry.unload_all()) + return True + except Exception: + pass + return False + + def build(self) -> gr.Blocks: + """Build the Gradio application.""" + theme = create_audiocraft_theme() + css = get_custom_css() + + with gr.Blocks( + theme=theme, + css=css, + title="AudioCraft Studio", + analytics_enabled=False, + ) as app: + # Header with VRAM monitor + with gr.Row(): + with gr.Column(scale=4): + gr.Markdown("# AudioCraft Studio") + with gr.Column(scale=1): + vram_monitor = create_vram_monitor( + get_status_fn=self._get_gpu_status, + update_interval=5, + ) + + # Main tabs + with gr.Tabs() as main_tabs: + # Dashboard + with gr.TabItem("Dashboard", id="dashboard"): + dashboard = create_dashboard_tab( + get_queue_status=self._get_queue_status, + get_recent_generations=self._get_recent_generations, + get_gpu_status=self._get_gpu_status, + ) + + # Model tabs + with gr.TabItem("MusicGen", id="musicgen"): + musicgen = create_musicgen_tab( + generate_fn=self._generate, + add_to_queue_fn=self._add_to_queue, + ) + + with gr.TabItem("AudioGen", id="audiogen"): + audiogen = create_audiogen_tab( + generate_fn=self._generate, + add_to_queue_fn=self._add_to_queue, + ) + + with gr.TabItem("MAGNeT", id="magnet"): + magnet = create_magnet_tab( + generate_fn=self._generate, + add_to_queue_fn=self._add_to_queue, + ) + + with gr.TabItem("Style", id="style"): + style = create_style_tab( + generate_fn=self._generate, + add_to_queue_fn=self._add_to_queue, + ) + + with gr.TabItem("JASCO", id="jasco"): + jasco = create_jasco_tab( + generate_fn=self._generate, + add_to_queue_fn=self._add_to_queue, + ) + + # Projects + with gr.TabItem("Projects", id="projects"): + projects = create_projects_page( + get_projects=self._get_projects, + get_generations=self._get_generations, + delete_generation=self._delete_generation, + export_project=self._export_project, + create_project=self._create_project, + ) + + # Settings + with gr.TabItem("Settings", id="settings"): + settings = create_settings_page( + get_settings=self._get_app_settings, + update_settings=self._update_app_settings, + get_gpu_info=self._get_gpu_status, + clear_cache=self._clear_cache, + unload_all_models=self._unload_all_models, + ) + + # Footer + gr.Markdown("---") + gr.Markdown( + "AudioCraft Studio | " + "[Documentation](https://github.com/facebookresearch/audiocraft) | " + "Powered by Meta AudioCraft" + ) + + # Store component references + self.components = { + "vram_monitor": vram_monitor, + "dashboard": dashboard, + "musicgen": musicgen, + "audiogen": audiogen, + "magnet": magnet, + "style": style, + "jasco": jasco, + "projects": projects, + "settings": settings, + } + + self.app = app + return app + + def launch( + self, + server_name: Optional[str] = None, + server_port: Optional[int] = None, + share: bool = False, + **kwargs, + ) -> None: + """Launch the Gradio application. + + Args: + server_name: Server hostname + server_port: Server port + share: Whether to create a public share link + **kwargs: Additional arguments for gr.Blocks.launch() + """ + if self.app is None: + self.build() + + self.app.launch( + server_name=server_name or self.settings.host, + server_port=server_port or self.settings.gradio_port, + share=share, + show_error=True, + **kwargs, + ) + + +def create_app( + generation_service: Any = None, + batch_processor: Any = None, + project_service: Any = None, + gpu_manager: Any = None, + model_registry: Any = None, +) -> AudioCraftApp: + """Create and return the AudioCraft application. + + Args: + generation_service: Service for handling generations + batch_processor: Service for batch/queue processing + project_service: Service for project management + gpu_manager: GPU memory manager + model_registry: Model registry for loading/unloading + + Returns: + Configured AudioCraftApp instance + """ + return AudioCraftApp( + generation_service=generation_service, + batch_processor=batch_processor, + project_service=project_service, + gpu_manager=gpu_manager, + model_registry=model_registry, + ) + + +# Standalone launch for development/testing +if __name__ == "__main__": + app = create_app() + app.launch() diff --git a/src/ui/components/__init__.py b/src/ui/components/__init__.py new file mode 100644 index 0000000..d43bd81 --- /dev/null +++ b/src/ui/components/__init__.py @@ -0,0 +1,13 @@ +"""Reusable UI components for AudioCraft Studio.""" + +from src.ui.components.vram_monitor import create_vram_monitor +from src.ui.components.audio_player import create_audio_player +from src.ui.components.preset_selector import create_preset_selector +from src.ui.components.generation_params import create_generation_params + +__all__ = [ + "create_vram_monitor", + "create_audio_player", + "create_preset_selector", + "create_generation_params", +] diff --git a/src/ui/components/audio_player.py b/src/ui/components/audio_player.py new file mode 100644 index 0000000..19e4f7f --- /dev/null +++ b/src/ui/components/audio_player.py @@ -0,0 +1,178 @@ +"""Audio player component with waveform visualization.""" + +import gradio as gr +from pathlib import Path +from typing import Any, Optional, Callable + + +def create_audio_player( + label: str = "Generated Audio", + show_waveform: bool = True, + show_download: bool = True, + show_info: bool = True, +) -> dict[str, Any]: + """Create audio player component with optional waveform. + + Args: + label: Label for the audio component + show_waveform: Show waveform image + show_download: Show download buttons + show_info: Show audio info (duration, sample rate) + + Returns: + Dictionary with component references + """ + + with gr.Group(): + # Audio player + audio_output = gr.Audio( + label=label, + type="filepath", + interactive=False, + show_download_button=show_download, + ) + + # Waveform visualization + if show_waveform: + waveform_image = gr.Image( + label="Waveform", + type="filepath", + interactive=False, + height=100, + visible=False, + ) + else: + waveform_image = None + + # Audio info + if show_info: + with gr.Row(): + duration_text = gr.Textbox( + label="Duration", + value="", + interactive=False, + scale=1, + ) + sample_rate_text = gr.Textbox( + label="Sample Rate", + value="", + interactive=False, + scale=1, + ) + seed_text = gr.Textbox( + label="Seed", + value="", + interactive=False, + scale=1, + ) + else: + duration_text = None + sample_rate_text = None + seed_text = None + + # Download buttons + if show_download: + with gr.Row(): + download_wav = gr.Button("Download WAV", size="sm") + download_mp3 = gr.Button("Download MP3", size="sm") + download_flac = gr.Button("Download FLAC", size="sm") + else: + download_wav = download_mp3 = download_flac = None + + return { + "audio": audio_output, + "waveform": waveform_image, + "duration": duration_text, + "sample_rate": sample_rate_text, + "seed": seed_text, + "download_wav": download_wav, + "download_mp3": download_mp3, + "download_flac": download_flac, + } + + +def update_audio_player( + audio_path: Optional[str], + duration: Optional[float] = None, + sample_rate: Optional[int] = None, + seed: Optional[int] = None, + waveform_path: Optional[str] = None, +) -> tuple: + """Update audio player with new audio. + + Args: + audio_path: Path to audio file + duration: Audio duration in seconds + sample_rate: Audio sample rate + seed: Generation seed + waveform_path: Path to waveform image + + Returns: + Tuple of update values for components + """ + duration_str = f"{duration:.2f}s" if duration else "" + sample_rate_str = f"{sample_rate} Hz" if sample_rate else "" + seed_str = str(seed) if seed is not None else "" + + waveform_update = gr.update(value=waveform_path, visible=waveform_path is not None) + + return ( + audio_path, + waveform_update, + duration_str, + sample_rate_str, + seed_str, + ) + + +def create_generation_output() -> dict[str, Any]: + """Create generation output section with audio player and metadata. + + Returns: + Dictionary with component references + """ + with gr.Group(): + gr.Markdown("### Output") + + # Status/progress + with gr.Row(): + status_text = gr.Markdown("Ready to generate") + progress_bar = gr.Slider( + minimum=0, + maximum=100, + value=0, + label="Progress", + interactive=False, + visible=False, + ) + + # Audio player + player = create_audio_player( + label="Generated Audio", + show_waveform=True, + show_download=True, + show_info=True, + ) + + # Generation metadata + with gr.Accordion("Generation Details", open=False): + generation_info = gr.JSON( + label="Parameters", + value={}, + ) + + # Actions + with gr.Row(): + save_btn = gr.Button("Save to Project", variant="secondary") + regenerate_btn = gr.Button("Regenerate", variant="secondary") + add_queue_btn = gr.Button("Add to Queue", variant="secondary") + + return { + "status": status_text, + "progress": progress_bar, + "player": player, + "info": generation_info, + "save_btn": save_btn, + "regenerate_btn": regenerate_btn, + "add_queue_btn": add_queue_btn, + } diff --git a/src/ui/components/generation_params.py b/src/ui/components/generation_params.py new file mode 100644 index 0000000..1b26526 --- /dev/null +++ b/src/ui/components/generation_params.py @@ -0,0 +1,199 @@ +"""Generation parameters component.""" + +import gradio as gr +from typing import Any, Optional + + +def create_generation_params( + model_id: str, + show_advanced: bool = False, + max_duration: float = 30.0, +) -> dict[str, Any]: + """Create generation parameters panel. + + Args: + model_id: Model family for customizing available options + show_advanced: Whether to show advanced parameters by default + max_duration: Maximum allowed duration + + Returns: + Dictionary with component references + """ + # Model-specific defaults + defaults = { + "musicgen": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0}, + "audiogen": {"duration": 5, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0}, + "magnet": {"duration": 10, "temperature": 3.0, "top_k": 0, "top_p": 0.9, "cfg_coef": 3.0}, + "musicgen-style": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0}, + "jasco": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0}, + } + + d = defaults.get(model_id, defaults["musicgen"]) + + with gr.Group(): + # Basic parameters (always visible) + duration_slider = gr.Slider( + minimum=1, + maximum=max_duration, + value=d["duration"], + step=1, + label="Duration (seconds)", + info="Length of audio to generate", + ) + + # Advanced parameters (expandable) + with gr.Accordion("Advanced Parameters", open=show_advanced): + with gr.Row(): + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=d["temperature"], + step=0.05, + label="Temperature", + info="Higher = more random, lower = more deterministic", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=d["cfg_coef"], + step=0.5, + label="CFG Coefficient", + info="Classifier-free guidance strength", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=d["top_k"], + step=10, + label="Top-K", + info="Token selection limit (0 = disabled)", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=d["top_p"], + step=0.05, + label="Top-P (Nucleus)", + info="Cumulative probability threshold (0 = disabled)", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed", + info="Random seed for reproducibility (leave empty for random)", + precision=0, + ) + use_random_seed = gr.Checkbox( + value=True, + label="Random Seed", + ) + + # Reset button + reset_btn = gr.Button("Reset to Defaults", size="sm", variant="secondary") + + def reset_params(): + """Reset all parameters to defaults.""" + return ( + d["duration"], + d["temperature"], + d["cfg_coef"], + d["top_k"], + d["top_p"], + None, + True, + ) + + reset_btn.click( + fn=reset_params, + outputs=[ + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + use_random_seed, + ], + ) + + # Link random seed checkbox to seed input + def toggle_seed(use_random: bool, current_seed: Optional[int]): + if use_random: + return gr.update(value=None, interactive=False) + return gr.update(interactive=True) + + use_random_seed.change( + fn=toggle_seed, + inputs=[use_random_seed, seed_input], + outputs=[seed_input], + ) + + return { + "duration": duration_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "seed": seed_input, + "use_random_seed": use_random_seed, + "reset_btn": reset_btn, + } + + +def create_model_variant_selector( + model_id: str, + variants: list[dict[str, Any]], + default_variant: str = "medium", +) -> dict[str, Any]: + """Create model variant selector. + + Args: + model_id: Model family ID + variants: List of variant configurations + default_variant: Default variant to select + + Returns: + Dictionary with component references + """ + # Build choices with descriptions + choices = [] + for v in variants: + name = v.get("name", v.get("id", "unknown")) + vram = v.get("vram_mb", 0) + desc = v.get("description", "") + label = f"{name} ({vram/1024:.1f}GB)" + choices.append((label, name)) + + with gr.Group(): + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=choices, + value=default_variant, + interactive=True, + ) + + variant_info = gr.Markdown( + value="", + visible=True, + ) + + def update_info(variant_name: str): + for v in variants: + if v.get("name", v.get("id")) == variant_name: + return v.get("description", "") + return "" + + variant_dropdown.change( + fn=update_info, + inputs=[variant_dropdown], + outputs=[variant_info], + ) + + return { + "dropdown": variant_dropdown, + "info": variant_info, + "variants": variants, + } diff --git a/src/ui/components/preset_selector.py b/src/ui/components/preset_selector.py new file mode 100644 index 0000000..5ddee73 --- /dev/null +++ b/src/ui/components/preset_selector.py @@ -0,0 +1,103 @@ +"""Preset selector component.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS + + +def create_preset_selector( + model_id: str, + on_preset_select: Optional[Callable[[dict], None]] = None, +) -> dict[str, Any]: + """Create preset selector component for a model. + + Args: + model_id: Model family ID + on_preset_select: Callback when preset is selected + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get(model_id, []) + + # Create preset choices + choices = [(p["name"], p["id"]) for p in presets] + choices.append(("Custom", "custom")) + + def get_preset_by_id(preset_id: str) -> Optional[dict]: + """Get preset data by ID.""" + for p in presets: + if p["id"] == preset_id: + return p + return None + + def on_change(preset_id: str): + """Handle preset selection change.""" + if preset_id == "custom": + return gr.update(visible=True), {} + + preset = get_preset_by_id(preset_id) + if preset: + return gr.update(visible=False), preset.get("parameters", {}) + + return gr.update(visible=True), {} + + with gr.Group(): + preset_dropdown = gr.Dropdown( + label="Preset", + choices=choices, + value=presets[0]["id"] if presets else "custom", + interactive=True, + ) + + preset_description = gr.Markdown( + value=presets[0]["description"] if presets else "", + visible=True, + ) + + return { + "dropdown": preset_dropdown, + "description": preset_description, + "presets": presets, + "get_preset": get_preset_by_id, + "on_change": on_change, + } + + +def create_preset_chips( + model_id: str, + on_select: Callable[[str], None], +) -> dict[str, Any]: + """Create preset selector as clickable chips/buttons. + + Args: + model_id: Model family ID + on_select: Callback when preset is clicked + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get(model_id, []) + + with gr.Row(): + buttons = [] + for preset in presets: + btn = gr.Button( + preset["name"], + size="sm", + variant="secondary", + ) + buttons.append((btn, preset)) + + custom_btn = gr.Button( + "Custom", + size="sm", + variant="secondary", + ) + + return { + "buttons": buttons, + "custom_btn": custom_btn, + "presets": presets, + } diff --git a/src/ui/components/vram_monitor.py b/src/ui/components/vram_monitor.py new file mode 100644 index 0000000..1d46ee9 --- /dev/null +++ b/src/ui/components/vram_monitor.py @@ -0,0 +1,151 @@ +"""VRAM monitor component for GPU memory tracking.""" + +import gradio as gr +from typing import Any, Callable, Optional + + +def create_vram_monitor( + get_gpu_status: Callable[[], dict[str, Any]], + get_loaded_models: Callable[[], list[dict[str, Any]]], + unload_model: Callable[[str, str], bool], + load_model: Callable[[str, str], bool], +) -> dict[str, Any]: + """Create VRAM monitor component. + + Args: + get_gpu_status: Function to get GPU status dict + get_loaded_models: Function to get list of loaded models + unload_model: Function to unload a model (model_id, variant) + load_model: Function to load a model (model_id, variant) + + Returns: + Dictionary with component references + """ + + def refresh_status(): + """Refresh GPU status display.""" + status = get_gpu_status() + loaded = get_loaded_models() + + # Format VRAM bar + used_gb = status.get("used_gb", 0) + total_gb = status.get("total_gb", 24) + util_pct = status.get("utilization_percent", 0) + + vram_text = f"{used_gb:.1f} / {total_gb:.1f} GB ({util_pct:.0f}%)" + + # Format loaded models list + if loaded: + models_text = "\n".join([ + f"• {m['model_id']}/{m['variant']} " + f"(idle: {m['idle_seconds']:.0f}s)" + for m in loaded + ]) + else: + models_text = "No models loaded" + + # Determine status color + if util_pct > 90: + status_color = "🔴" + elif util_pct > 75: + status_color = "🟡" + else: + status_color = "🟢" + + status_text = f"{status_color} GPU: {status.get('device', 'N/A')}" + + return vram_text, util_pct, models_text, status_text + + def handle_unload(model_selection: str): + """Handle model unload.""" + if not model_selection or "/" not in model_selection: + return "Select a model to unload", *refresh_status() + + parts = model_selection.split("/") + model_id, variant = parts[0], parts[1] + + success = unload_model(model_id, variant) + if success: + msg = f"Unloaded {model_id}/{variant}" + else: + msg = f"Failed to unload {model_id}/{variant}" + + return msg, *refresh_status() + + with gr.Group(): + gr.Markdown("### GPU Memory") + + status_text = gr.Markdown("🟢 GPU: Checking...") + + with gr.Row(): + vram_display = gr.Textbox( + label="VRAM Usage", + value="Loading...", + interactive=False, + scale=3, + ) + refresh_btn = gr.Button("🔄", scale=1, min_width=50) + + vram_slider = gr.Slider( + minimum=0, + maximum=100, + value=0, + label="", + interactive=False, + visible=True, + ) + + gr.Markdown("### Loaded Models") + + models_display = gr.Textbox( + label="", + value="No models loaded", + interactive=False, + lines=4, + max_lines=6, + ) + + with gr.Row(): + model_selector = gr.Dropdown( + label="Select Model", + choices=[], + interactive=True, + scale=3, + ) + unload_btn = gr.Button("Unload", variant="secondary", scale=1) + + unload_status = gr.Markdown("") + + # Event handlers + def update_model_choices(): + loaded = get_loaded_models() + choices = [f"{m['model_id']}/{m['variant']}" for m in loaded] + return gr.update(choices=choices, value=None) + + refresh_btn.click( + fn=refresh_status, + outputs=[vram_display, vram_slider, models_display, status_text], + ).then( + fn=update_model_choices, + outputs=[model_selector], + ) + + unload_btn.click( + fn=handle_unload, + inputs=[model_selector], + outputs=[unload_status, vram_display, vram_slider, models_display, status_text], + ).then( + fn=update_model_choices, + outputs=[model_selector], + ) + + return { + "vram_display": vram_display, + "vram_slider": vram_slider, + "models_display": models_display, + "status_text": status_text, + "model_selector": model_selector, + "refresh_btn": refresh_btn, + "unload_btn": unload_btn, + "refresh_fn": refresh_status, + } diff --git a/src/ui/pages/__init__.py b/src/ui/pages/__init__.py new file mode 100644 index 0000000..6538e68 --- /dev/null +++ b/src/ui/pages/__init__.py @@ -0,0 +1,9 @@ +"""UI pages for AudioCraft Studio.""" + +from src.ui.pages.projects_page import create_projects_page +from src.ui.pages.settings_page import create_settings_page + +__all__ = [ + "create_projects_page", + "create_settings_page", +] diff --git a/src/ui/pages/projects_page.py b/src/ui/pages/projects_page.py new file mode 100644 index 0000000..d8ade14 --- /dev/null +++ b/src/ui/pages/projects_page.py @@ -0,0 +1,374 @@ +"""Projects page for managing generations and history.""" + +import gradio as gr +from typing import Any, Callable, Optional +from datetime import datetime + + +def create_projects_page( + get_projects: Callable[[], list[dict]], + get_generations: Callable[[str, int, int], list[dict]], + delete_generation: Callable[[str], bool], + export_project: Callable[[str], str], + create_project: Callable[[str, str], dict], +) -> dict[str, Any]: + """Create projects management page. + + Args: + get_projects: Function to get all projects + get_generations: Function to get generations (project_id, limit, offset) + delete_generation: Function to delete a generation + export_project: Function to export project as ZIP + create_project: Function to create new project + + Returns: + Dictionary with component references + """ + + with gr.Column(): + gr.Markdown("# Projects") + gr.Markdown("Browse and manage your generations") + + with gr.Row(): + # Left sidebar - project list + with gr.Column(scale=1): + gr.Markdown("### Projects") + + with gr.Row(): + new_project_name = gr.Textbox( + placeholder="New project name...", + show_label=False, + scale=3, + ) + new_project_btn = gr.Button("+", size="sm", scale=1) + + project_list = gr.Dataframe( + headers=["ID", "Name", "Count"], + datatype=["str", "str", "number"], + col_count=(3, "fixed"), + interactive=False, + height=400, + ) + + refresh_projects_btn = gr.Button("Refresh Projects", size="sm") + + # Main content - generations + with gr.Column(scale=3): + # Selected project info + selected_project_id = gr.State(value=None) + selected_project_name = gr.Markdown("### Select a project") + + # Filters + with gr.Row(): + model_filter = gr.Dropdown( + label="Model", + choices=[ + ("All", "all"), + ("MusicGen", "musicgen"), + ("AudioGen", "audiogen"), + ("MAGNeT", "magnet"), + ("Style", "musicgen-style"), + ("JASCO", "jasco"), + ], + value="all", + scale=1, + ) + sort_by = gr.Dropdown( + label="Sort By", + choices=[ + ("Newest First", "newest"), + ("Oldest First", "oldest"), + ("Duration (Long)", "duration_desc"), + ("Duration (Short)", "duration_asc"), + ], + value="newest", + scale=1, + ) + search_input = gr.Textbox( + label="Search Prompts", + placeholder="Search...", + scale=2, + ) + + # Generations grid + generations_gallery = gr.Gallery( + label="Generations", + columns=3, + rows=3, + height=400, + object_fit="contain", + show_label=False, + ) + + # Pagination + with gr.Row(): + prev_page_btn = gr.Button("← Previous", size="sm") + page_info = gr.Markdown("Page 1 of 1") + next_page_btn = gr.Button("Next →", size="sm") + + current_page = gr.State(value=1) + total_pages = gr.State(value=1) + + # Selected generation details + gr.Markdown("---") + gr.Markdown("### Generation Details") + + with gr.Row(): + with gr.Column(scale=2): + selected_audio = gr.Audio( + label="Audio", + interactive=False, + ) + + with gr.Column(scale=2): + selected_prompt = gr.Textbox( + label="Prompt", + interactive=False, + lines=2, + ) + with gr.Row(): + selected_model = gr.Textbox( + label="Model", + interactive=False, + ) + selected_duration = gr.Textbox( + label="Duration", + interactive=False, + ) + with gr.Row(): + selected_seed = gr.Textbox( + label="Seed", + interactive=False, + ) + selected_date = gr.Textbox( + label="Created", + interactive=False, + ) + + # Action buttons + with gr.Row(): + regenerate_btn = gr.Button("Regenerate", variant="secondary") + download_btn = gr.Button("Download", variant="secondary") + delete_btn = gr.Button("Delete", variant="stop") + export_project_btn = gr.Button("Export Project", variant="secondary") + + # Event handlers + + def load_projects(): + """Load all projects into the list.""" + projects = get_projects() + data = [] + for p in projects: + data.append([ + p.get("id", ""), + p.get("name", "Untitled"), + p.get("generation_count", 0), + ]) + return data + + def on_project_select(evt: gr.SelectData, df): + """Handle project selection from dataframe.""" + if evt.index is None: + return None, "### Select a project" + + row = evt.index[0] + if row < len(df): + project_id = df[row][0] + project_name = df[row][1] + return project_id, f"### {project_name}" + + return None, "### Select a project" + + def load_generations(project_id, page, model, sort, search): + """Load generations for selected project.""" + if not project_id: + return [], "Page 0 of 0", 1, 1 + + limit = 9 # 3x3 grid + offset = (page - 1) * limit + + gens = get_generations(project_id, limit + 1, offset) + + # Check if there are more pages + has_more = len(gens) > limit + gens = gens[:limit] + + # Filter by model if needed + if model != "all": + gens = [g for g in gens if g.get("model") == model] + + # Filter by search + if search: + search_lower = search.lower() + gens = [g for g in gens if search_lower in g.get("prompt", "").lower()] + + # Sort + if sort == "oldest": + gens = sorted(gens, key=lambda x: x.get("created_at", "")) + elif sort == "duration_desc": + gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0), reverse=True) + elif sort == "duration_asc": + gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0)) + # Default is newest first (already sorted from DB) + + # Build gallery items (using waveform images if available) + gallery_items = [] + for g in gens: + waveform = g.get("waveform_path") + if waveform: + gallery_items.append((waveform, g.get("prompt", "")[:50])) + else: + # Placeholder + gallery_items.append((None, g.get("prompt", "")[:50])) + + # Calculate total pages (estimate) + total = offset + len(gens) + (1 if has_more else 0) + total_p = max(1, (total + limit - 1) // limit) + + return gallery_items, f"Page {page} of {total_p}", page, total_p + + def on_generation_select(evt: gr.SelectData, project_id): + """Handle generation selection from gallery.""" + if evt.index is None or not project_id: + return None, "", "", "", "", "" + + # Get generations again to find the selected one + gens = get_generations(project_id, 100, 0) + if evt.index < len(gens): + gen = gens[evt.index] + return ( + gen.get("audio_path"), + gen.get("prompt", ""), + gen.get("model", ""), + f"{gen.get('duration_seconds', 0):.1f}s", + str(gen.get("seed", "")), + gen.get("created_at", "")[:19] if gen.get("created_at") else "", + ) + + return None, "", "", "", "", "" + + def do_create_project(name): + """Create a new project.""" + if not name.strip(): + return gr.update(), "Please enter a project name" + + project = create_project(name.strip(), "") + projects_data = load_projects() + return projects_data, f"Created project: {name}" + + def do_delete_generation(project_id, audio_path): + """Delete selected generation.""" + if not audio_path: + return "No generation selected" + + # Find generation by audio path + gens = get_generations(project_id, 100, 0) + for g in gens: + if g.get("audio_path") == audio_path: + if delete_generation(g.get("id")): + return "Generation deleted" + else: + return "Failed to delete" + + return "Generation not found" + + def do_export_project(project_id): + """Export project as ZIP.""" + if not project_id: + return "No project selected" + + try: + zip_path = export_project(project_id) + return f"Exported to: {zip_path}" + except Exception as e: + return f"Export failed: {str(e)}" + + # Wire up events + + refresh_projects_btn.click( + fn=load_projects, + outputs=[project_list], + ) + + project_list.select( + fn=on_project_select, + inputs=[project_list], + outputs=[selected_project_id, selected_project_name], + ).then( + fn=load_generations, + inputs=[selected_project_id, current_page, model_filter, sort_by, search_input], + outputs=[generations_gallery, page_info, current_page, total_pages], + ) + + # Filter changes reload generations + for component in [model_filter, sort_by, search_input]: + component.change( + fn=load_generations, + inputs=[selected_project_id, current_page, model_filter, sort_by, search_input], + outputs=[generations_gallery, page_info, current_page, total_pages], + ) + + # Pagination + def go_prev(page, total): + return max(1, page - 1) + + def go_next(page, total): + return min(total, page + 1) + + prev_page_btn.click( + fn=go_prev, + inputs=[current_page, total_pages], + outputs=[current_page], + ).then( + fn=load_generations, + inputs=[selected_project_id, current_page, model_filter, sort_by, search_input], + outputs=[generations_gallery, page_info, current_page, total_pages], + ) + + next_page_btn.click( + fn=go_next, + inputs=[current_page, total_pages], + outputs=[current_page], + ).then( + fn=load_generations, + inputs=[selected_project_id, current_page, model_filter, sort_by, search_input], + outputs=[generations_gallery, page_info, current_page, total_pages], + ) + + # Generation selection + generations_gallery.select( + fn=on_generation_select, + inputs=[selected_project_id], + outputs=[selected_audio, selected_prompt, selected_model, selected_duration, selected_seed, selected_date], + ) + + # Actions + new_project_btn.click( + fn=do_create_project, + inputs=[new_project_name], + outputs=[project_list, selected_project_name], + ) + + delete_btn.click( + fn=do_delete_generation, + inputs=[selected_project_id, selected_audio], + outputs=[selected_project_name], + ).then( + fn=load_generations, + inputs=[selected_project_id, current_page, model_filter, sort_by, search_input], + outputs=[generations_gallery, page_info, current_page, total_pages], + ) + + export_project_btn.click( + fn=do_export_project, + inputs=[selected_project_id], + outputs=[selected_project_name], + ) + + return { + "project_list": project_list, + "generations_gallery": generations_gallery, + "selected_audio": selected_audio, + "selected_project_id": selected_project_id, + "refresh_fn": load_projects, + } diff --git a/src/ui/pages/settings_page.py b/src/ui/pages/settings_page.py new file mode 100644 index 0000000..d66b49a --- /dev/null +++ b/src/ui/pages/settings_page.py @@ -0,0 +1,397 @@ +"""Settings page for application configuration.""" + +import gradio as gr +from typing import Any, Callable, Optional +from pathlib import Path + + +def create_settings_page( + get_settings: Callable[[], dict], + update_settings: Callable[[dict], bool], + get_gpu_info: Callable[[], dict], + clear_cache: Callable[[], bool], + unload_all_models: Callable[[], bool], +) -> dict[str, Any]: + """Create settings management page. + + Args: + get_settings: Function to get current settings + update_settings: Function to update settings + get_gpu_info: Function to get GPU information + clear_cache: Function to clear model cache + unload_all_models: Function to unload all models + + Returns: + Dictionary with component references + """ + + with gr.Column(): + gr.Markdown("# Settings") + gr.Markdown("Configure AudioCraft Studio") + + with gr.Tabs(): + # General Settings Tab + with gr.TabItem("General"): + with gr.Group(): + gr.Markdown("### Output Settings") + + output_dir = gr.Textbox( + label="Output Directory", + placeholder="/path/to/output", + info="Where generated audio files are saved", + ) + + with gr.Row(): + default_format = gr.Dropdown( + label="Default Audio Format", + choices=[("WAV", "wav"), ("MP3", "mp3"), ("FLAC", "flac"), ("OGG", "ogg")], + value="wav", + ) + sample_rate = gr.Dropdown( + label="Sample Rate", + choices=[ + ("32000 Hz (AudioCraft default)", 32000), + ("44100 Hz (CD quality)", 44100), + ("48000 Hz (Video standard)", 48000), + ], + value=32000, + ) + + normalize_audio = gr.Checkbox( + label="Normalize audio output", + value=True, + info="Normalize audio levels to prevent clipping", + ) + + with gr.Group(): + gr.Markdown("### Interface Settings") + + theme_mode = gr.Radio( + label="Theme", + choices=["Dark", "Light", "System"], + value="Dark", + ) + + show_advanced = gr.Checkbox( + label="Show advanced parameters by default", + value=False, + ) + + auto_play = gr.Checkbox( + label="Auto-play generated audio", + value=True, + ) + + # GPU & Memory Tab + with gr.TabItem("GPU & Memory"): + with gr.Group(): + gr.Markdown("### GPU Information") + + gpu_info_display = gr.JSON( + label="GPU Status", + value={}, + ) + + refresh_gpu_btn = gr.Button("Refresh GPU Info", size="sm") + + with gr.Group(): + gr.Markdown("### Memory Management") + + comfyui_reserve = gr.Slider( + minimum=0, + maximum=16, + value=10, + step=0.5, + label="ComfyUI VRAM Reserve (GB)", + info="VRAM to reserve for ComfyUI when running alongside", + ) + + idle_timeout = gr.Slider( + minimum=1, + maximum=60, + value=15, + step=1, + label="Idle Model Timeout (minutes)", + info="Unload models after this period of inactivity", + ) + + max_loaded = gr.Slider( + minimum=1, + maximum=5, + value=2, + step=1, + label="Maximum Loaded Models", + info="Maximum number of models to keep in memory", + ) + + with gr.Group(): + gr.Markdown("### Cache Management") + + with gr.Row(): + clear_cache_btn = gr.Button("Clear Model Cache", variant="secondary") + unload_models_btn = gr.Button("Unload All Models", variant="stop") + + cache_status = gr.Markdown("Cache status: Ready") + + # Model Defaults Tab + with gr.TabItem("Model Defaults"): + with gr.Group(): + gr.Markdown("### MusicGen Defaults") + + with gr.Row(): + musicgen_variant = gr.Dropdown( + label="Default Variant", + choices=[ + ("Small", "small"), + ("Medium", "medium"), + ("Large", "large"), + ("Melody", "melody"), + ], + value="medium", + ) + musicgen_duration = gr.Slider( + minimum=1, + maximum=30, + value=10, + step=1, + label="Default Duration (s)", + ) + + with gr.Group(): + gr.Markdown("### AudioGen Defaults") + + audiogen_duration = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + label="Default Duration (s)", + ) + + with gr.Group(): + gr.Markdown("### MAGNeT Defaults") + + with gr.Row(): + magnet_variant = gr.Dropdown( + label="Default Variant", + choices=[ + ("Small Music", "small"), + ("Medium Music", "medium"), + ("Small Audio", "audio-small"), + ("Medium Audio", "audio-medium"), + ], + value="medium", + ) + magnet_decoding_steps = gr.Slider( + minimum=10, + maximum=100, + value=20, + step=5, + label="Decoding Steps", + ) + + # API Settings Tab + with gr.TabItem("API"): + with gr.Group(): + gr.Markdown("### REST API Configuration") + + api_enabled = gr.Checkbox( + label="Enable REST API", + value=True, + info="Enable FastAPI endpoints for programmatic access", + ) + + api_port = gr.Number( + value=8000, + label="API Port", + precision=0, + ) + + with gr.Row(): + api_key_display = gr.Textbox( + label="API Key", + value="••••••••", + interactive=False, + ) + regenerate_key_btn = gr.Button("Regenerate", size="sm") + + with gr.Group(): + gr.Markdown("### Rate Limiting") + + rate_limit = gr.Slider( + minimum=1, + maximum=100, + value=10, + step=1, + label="Requests per minute", + ) + + max_batch_size = gr.Slider( + minimum=1, + maximum=10, + value=4, + step=1, + label="Maximum batch size", + ) + + # Queue Settings Tab + with gr.TabItem("Queue"): + with gr.Group(): + gr.Markdown("### Batch Processing") + + max_queue_size = gr.Slider( + minimum=10, + maximum=500, + value=100, + step=10, + label="Maximum Queue Size", + ) + + max_workers = gr.Slider( + minimum=1, + maximum=4, + value=1, + step=1, + label="Concurrent Workers", + info="Number of parallel generation workers", + ) + + priority_queue = gr.Checkbox( + label="Enable priority queue", + value=False, + info="Allow high-priority jobs to skip the queue", + ) + + # Save button + gr.Markdown("---") + with gr.Row(): + save_btn = gr.Button("Save Settings", variant="primary", scale=2) + reset_btn = gr.Button("Reset to Defaults", variant="secondary", scale=1) + + settings_status = gr.Markdown("") + + # Event handlers + + def load_settings(): + """Load current settings into form.""" + settings = get_settings() + return ( + settings.get("output_dir", ""), + settings.get("default_format", "wav"), + settings.get("sample_rate", 32000), + settings.get("normalize_audio", True), + settings.get("theme_mode", "Dark"), + settings.get("show_advanced", False), + settings.get("auto_play", True), + settings.get("comfyui_reserve_gb", 10), + settings.get("idle_timeout_minutes", 15), + settings.get("max_loaded_models", 2), + settings.get("musicgen_variant", "medium"), + settings.get("musicgen_duration", 10), + settings.get("audiogen_duration", 5), + settings.get("magnet_variant", "medium"), + settings.get("magnet_decoding_steps", 20), + settings.get("api_enabled", True), + settings.get("api_port", 8000), + settings.get("rate_limit", 10), + settings.get("max_batch_size", 4), + settings.get("max_queue_size", 100), + settings.get("max_workers", 1), + settings.get("priority_queue", False), + ) + + def save_settings( + out_dir, fmt, sr, norm, theme, adv, play, + comfyui_res, idle_to, max_load, + mg_var, mg_dur, ag_dur, mn_var, mn_steps, + api_en, api_p, rate, batch, queue_sz, workers, priority + ): + """Save settings from form.""" + settings = { + "output_dir": out_dir, + "default_format": fmt, + "sample_rate": sr, + "normalize_audio": norm, + "theme_mode": theme, + "show_advanced": adv, + "auto_play": play, + "comfyui_reserve_gb": comfyui_res, + "idle_timeout_minutes": idle_to, + "max_loaded_models": max_load, + "musicgen_variant": mg_var, + "musicgen_duration": mg_dur, + "audiogen_duration": ag_dur, + "magnet_variant": mn_var, + "magnet_decoding_steps": mn_steps, + "api_enabled": api_en, + "api_port": int(api_p), + "rate_limit": rate, + "max_batch_size": batch, + "max_queue_size": queue_sz, + "max_workers": workers, + "priority_queue": priority, + } + + if update_settings(settings): + return "✅ Settings saved successfully" + else: + return "❌ Failed to save settings" + + def do_refresh_gpu(): + """Refresh GPU info display.""" + return get_gpu_info() + + def do_clear_cache(): + """Clear model cache.""" + if clear_cache(): + return "✅ Cache cleared" + return "❌ Failed to clear cache" + + def do_unload_models(): + """Unload all models.""" + if unload_all_models(): + return "✅ All models unloaded" + return "❌ Failed to unload models" + + # Wire up events + + refresh_gpu_btn.click( + fn=do_refresh_gpu, + outputs=[gpu_info_display], + ) + + clear_cache_btn.click( + fn=do_clear_cache, + outputs=[cache_status], + ) + + unload_models_btn.click( + fn=do_unload_models, + outputs=[cache_status], + ) + + save_btn.click( + fn=save_settings, + inputs=[ + output_dir, default_format, sample_rate, normalize_audio, + theme_mode, show_advanced, auto_play, + comfyui_reserve, idle_timeout, max_loaded, + musicgen_variant, musicgen_duration, audiogen_duration, + magnet_variant, magnet_decoding_steps, + api_enabled, api_port, rate_limit, max_batch_size, + max_queue_size, max_workers, priority_queue, + ], + outputs=[settings_status], + ) + + return { + "output_dir": output_dir, + "default_format": default_format, + "sample_rate": sample_rate, + "comfyui_reserve": comfyui_reserve, + "idle_timeout": idle_timeout, + "api_enabled": api_enabled, + "save_btn": save_btn, + "settings_status": settings_status, + "load_fn": load_settings, + } diff --git a/src/ui/state.py b/src/ui/state.py new file mode 100644 index 0000000..d690d03 --- /dev/null +++ b/src/ui/state.py @@ -0,0 +1,294 @@ +"""State management for Gradio UI.""" + +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class UIState: + """Global UI state container.""" + + # Current view + current_tab: str = "dashboard" + + # Generation state + is_generating: bool = False + current_job_id: Optional[str] = None + + # Selected items + selected_project_id: Optional[str] = None + selected_generation_id: Optional[str] = None + selected_preset_id: Optional[str] = None + + # Model state + selected_model: str = "musicgen" + selected_variant: str = "medium" + + # Generation parameters (current values) + prompt: str = "" + duration: float = 10.0 + temperature: float = 1.0 + top_k: int = 250 + top_p: float = 0.0 + cfg_coef: float = 3.0 + seed: Optional[int] = None + + # Conditioning + melody_audio: Optional[str] = None + style_audio: Optional[str] = None + chords: list[dict[str, Any]] = field(default_factory=list) + drums_pattern: str = "" + bpm: float = 120.0 + + # UI preferences + show_advanced: bool = False + auto_play: bool = True + + def reset_generation_params(self) -> None: + """Reset generation parameters to defaults.""" + self.prompt = "" + self.duration = 10.0 + self.temperature = 1.0 + self.top_k = 250 + self.top_p = 0.0 + self.cfg_coef = 3.0 + self.seed = None + self.melody_audio = None + self.style_audio = None + self.chords = [] + self.drums_pattern = "" + + def apply_preset(self, preset: dict[str, Any]) -> None: + """Apply preset parameters.""" + params = preset.get("parameters", {}) + self.duration = params.get("duration", self.duration) + self.temperature = params.get("temperature", self.temperature) + self.top_k = params.get("top_k", self.top_k) + self.top_p = params.get("top_p", self.top_p) + self.cfg_coef = params.get("cfg_coef", self.cfg_coef) + + def to_generation_params(self) -> dict[str, Any]: + """Convert current state to generation parameters.""" + return { + "duration": self.duration, + "temperature": self.temperature, + "top_k": self.top_k, + "top_p": self.top_p, + "cfg_coef": self.cfg_coef, + "seed": self.seed, + } + + +# Default presets for each model +DEFAULT_PRESETS = { + "musicgen": [ + { + "id": "cinematic", + "name": "Cinematic", + "description": "Epic orchestral soundscapes", + "parameters": { + "duration": 30, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + }, + }, + { + "id": "electronic", + "name": "Electronic", + "description": "Synthesizers and beats", + "parameters": { + "duration": 15, + "temperature": 1.1, + "top_k": 200, + "cfg_coef": 3.5, + }, + }, + { + "id": "ambient", + "name": "Ambient", + "description": "Atmospheric and calm", + "parameters": { + "duration": 30, + "temperature": 0.9, + "top_k": 300, + "cfg_coef": 2.5, + }, + }, + { + "id": "rock", + "name": "Rock", + "description": "Guitar-driven energy", + "parameters": { + "duration": 20, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + }, + }, + { + "id": "jazz", + "name": "Jazz", + "description": "Smooth and improvisational", + "parameters": { + "duration": 20, + "temperature": 1.2, + "top_k": 200, + "cfg_coef": 2.5, + }, + }, + ], + "audiogen": [ + { + "id": "nature", + "name": "Nature", + "description": "Natural environments", + "parameters": { + "duration": 10, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + }, + }, + { + "id": "urban", + "name": "Urban", + "description": "City sounds", + "parameters": { + "duration": 10, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + }, + }, + { + "id": "mechanical", + "name": "Mechanical", + "description": "Machines and tools", + "parameters": { + "duration": 5, + "temperature": 0.9, + "top_k": 200, + "cfg_coef": 3.5, + }, + }, + { + "id": "weather", + "name": "Weather", + "description": "Rain, thunder, wind", + "parameters": { + "duration": 10, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + }, + }, + ], + "magnet": [ + { + "id": "fast", + "name": "Fast", + "description": "Quick generation", + "parameters": { + "duration": 10, + "temperature": 3.0, + "top_p": 0.9, + "cfg_coef": 3.0, + }, + }, + { + "id": "quality", + "name": "Quality", + "description": "Higher quality output", + "parameters": { + "duration": 10, + "temperature": 2.5, + "top_p": 0.85, + "cfg_coef": 4.0, + }, + }, + ], + "musicgen-style": [ + { + "id": "style_transfer", + "name": "Style Transfer", + "description": "Copy style from reference", + "parameters": { + "duration": 15, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + "eval_q": 3, + "excerpt_length": 3.0, + }, + }, + ], + "jasco": [ + { + "id": "pop", + "name": "Pop", + "description": "Pop chord progressions", + "parameters": { + "duration": 10, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + "bpm": 120, + }, + }, + { + "id": "blues", + "name": "Blues", + "description": "12-bar blues", + "parameters": { + "duration": 10, + "temperature": 1.0, + "top_k": 250, + "cfg_coef": 3.0, + "bpm": 100, + }, + }, + ], +} + + +# Prompt suggestions for each model +PROMPT_SUGGESTIONS = { + "musicgen": [ + "Epic orchestral music with dramatic strings and powerful brass", + "Upbeat electronic dance music with synthesizers and heavy bass", + "Calm acoustic guitar melody with soft piano accompaniment", + "Energetic rock song with electric guitars and driving drums", + "Smooth jazz with saxophone solo and walking bass", + "Ambient soundscape with ethereal pads and gentle textures", + "Cinematic trailer music building to an epic climax", + "Lo-fi hip hop beats with vinyl crackle and mellow keys", + ], + "audiogen": [ + "Thunder and heavy rain with occasional lightning strikes", + "Busy city street with traffic, horns, and distant sirens", + "Forest ambience with birds singing and wind in trees", + "Ocean waves crashing on a rocky shore", + "Crackling fireplace with wood popping", + "Coffee shop atmosphere with murmuring voices and clinking cups", + "Construction site with hammering and machinery", + "Spaceship engine humming with occasional beeps", + ], + "magnet": [ + "Energetic pop music with catchy melody", + "Dark electronic music with deep bass", + "Cheerful ukulele tune with whistling", + "Dramatic piano piece with building intensity", + ], + "musicgen-style": [ + "Generate music in the style of the uploaded reference", + "Create a variation with similar instrumentation", + "Compose a piece matching the mood of the reference", + ], + "jasco": [ + "Upbeat pop song with the specified chord progression", + "Mellow jazz piece following the chord changes", + "Rock anthem with powerful drum pattern", + "Electronic track with syncopated rhythms", + ], +} diff --git a/src/ui/tabs/__init__.py b/src/ui/tabs/__init__.py new file mode 100644 index 0000000..dfcf85b --- /dev/null +++ b/src/ui/tabs/__init__.py @@ -0,0 +1,17 @@ +"""Model tabs for AudioCraft Studio.""" + +from src.ui.tabs.dashboard_tab import create_dashboard_tab +from src.ui.tabs.musicgen_tab import create_musicgen_tab +from src.ui.tabs.audiogen_tab import create_audiogen_tab +from src.ui.tabs.magnet_tab import create_magnet_tab +from src.ui.tabs.style_tab import create_style_tab +from src.ui.tabs.jasco_tab import create_jasco_tab + +__all__ = [ + "create_dashboard_tab", + "create_musicgen_tab", + "create_audiogen_tab", + "create_magnet_tab", + "create_style_tab", + "create_jasco_tab", +] diff --git a/src/ui/tabs/audiogen_tab.py b/src/ui/tabs/audiogen_tab.py new file mode 100644 index 0000000..78e1acd --- /dev/null +++ b/src/ui/tabs/audiogen_tab.py @@ -0,0 +1,283 @@ +"""AudioGen tab for text-to-sound generation.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.audio_player import create_generation_output + + +AUDIOGEN_VARIANTS = [ + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, balanced quality/speed"}, +] + + +def create_audiogen_tab( + generate_fn: Callable[..., Any], + add_to_queue_fn: Callable[..., Any], +) -> dict[str, Any]: + """Create AudioGen generation tab. + + Args: + generate_fn: Function to call for generation + add_to_queue_fn: Function to add to queue + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get("audiogen", []) + suggestions = PROMPT_SUGGESTIONS.get("audiogen", []) + + with gr.Column(): + gr.Markdown("## 🔊 AudioGen") + gr.Markdown("Generate sound effects and environmental audio from text") + + with gr.Row(): + # Left column - inputs + with gr.Column(scale=2): + # Preset selector + preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] + preset_dropdown = gr.Dropdown( + label="Preset", + choices=preset_choices, + value=presets[0]["id"] if presets else "custom", + ) + + # Model variant (AudioGen only has medium) + variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in AUDIOGEN_VARIANTS] + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=variant_choices, + value="medium", + ) + + # Prompt input + prompt_input = gr.Textbox( + label="Prompt", + placeholder="Describe the sound you want to generate...", + lines=3, + max_lines=5, + ) + + # Prompt suggestions + with gr.Accordion("Prompt Suggestions", open=False): + suggestion_btns = [] + for i, suggestion in enumerate(suggestions[:6]): + btn = gr.Button(suggestion[:50] + "...", size="sm", variant="secondary") + suggestion_btns.append((btn, suggestion)) + + # Parameters + gr.Markdown("### Parameters") + + duration_slider = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + label="Duration (seconds)", + info="AudioGen works best with shorter clips", + ) + + with gr.Accordion("Advanced Parameters", open=False): + with gr.Row(): + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=1.0, + step=0.05, + label="Temperature", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=3.0, + step=0.5, + label="CFG Coefficient", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=250, + step=10, + label="Top-K", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.0, + step=0.05, + label="Top-P", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed (empty = random)", + precision=0, + ) + + # Generate buttons + with gr.Row(): + generate_btn = gr.Button("🔊 Generate", variant="primary", scale=2) + queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) + + # Right column - output + with gr.Column(scale=3): + output = create_generation_output() + + # Event handlers + + # Preset change + def apply_preset(preset_id: str): + for p in presets: + if p["id"] == preset_id: + params = p["parameters"] + return ( + params.get("duration", 5), + params.get("temperature", 1.0), + params.get("cfg_coef", 3.0), + params.get("top_k", 250), + params.get("top_p", 0.0), + ) + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + preset_dropdown.change( + fn=apply_preset, + inputs=[preset_dropdown], + outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], + ) + + # Prompt suggestions + for btn, suggestion in suggestion_btns: + btn.click( + fn=lambda s=suggestion: s, + outputs=[prompt_input], + ) + + # Generate + async def do_generate( + prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed + ): + if not prompt: + return ( + gr.update(value="Please enter a prompt"), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + yield ( + gr.update(value="🔄 Generating..."), + gr.update(visible=True, value=0), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + try: + result, generation = await generate_fn( + model_id="audiogen", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + ) + + yield ( + gr.update(value="✅ Generation complete!"), + gr.update(visible=False), + gr.update(value=generation.audio_path), + gr.update(), + gr.update(value=f"{result.duration:.2f}s"), + gr.update(value=str(result.seed)), + ) + + except Exception as e: + yield ( + gr.update(value=f"❌ Error: {str(e)}"), + gr.update(visible=False), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + generate_btn.click( + fn=do_generate, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[ + output["status"], + output["progress"], + output["player"]["audio"], + output["player"]["waveform"], + output["player"]["duration"], + output["player"]["seed"], + ], + ) + + # Add to queue + def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed): + if not prompt: + return "Please enter a prompt" + + job = add_to_queue_fn( + model_id="audiogen", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + ) + + return f"✅ Added to queue: {job.id}" + + queue_btn.click( + fn=do_add_queue, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[output["status"]], + ) + + return { + "preset": preset_dropdown, + "variant": variant_dropdown, + "prompt": prompt_input, + "duration": duration_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "seed": seed_input, + "generate_btn": generate_btn, + "queue_btn": queue_btn, + "output": output, + } diff --git a/src/ui/tabs/dashboard_tab.py b/src/ui/tabs/dashboard_tab.py new file mode 100644 index 0000000..b7f9eb4 --- /dev/null +++ b/src/ui/tabs/dashboard_tab.py @@ -0,0 +1,166 @@ +"""Dashboard tab - home page with model overview and quick actions.""" + +import gradio as gr +from typing import Any, Callable, Optional + + +MODEL_INFO = { + "musicgen": { + "name": "MusicGen", + "icon": "🎵", + "description": "Text-to-music generation with optional melody conditioning", + "capabilities": ["Text prompts", "Melody conditioning", "Stereo output"], + }, + "audiogen": { + "name": "AudioGen", + "icon": "🔊", + "description": "Text-to-sound effects and environmental audio", + "capabilities": ["Sound effects", "Ambiences", "Foley"], + }, + "magnet": { + "name": "MAGNeT", + "icon": "⚡", + "description": "Fast non-autoregressive music generation", + "capabilities": ["Fast generation", "Music", "Sound effects"], + }, + "musicgen-style": { + "name": "MusicGen Style", + "icon": "🎨", + "description": "Style-conditioned music from reference audio", + "capabilities": ["Style transfer", "Reference audio", "Text prompts"], + }, + "jasco": { + "name": "JASCO", + "icon": "🎹", + "description": "Chord and drum-conditioned music generation", + "capabilities": ["Chord control", "Drum patterns", "Symbolic conditioning"], + }, +} + + +def create_dashboard_tab( + get_queue_status: Callable[[], dict[str, Any]], + get_recent_generations: Callable[[int], list[dict[str, Any]]], + get_gpu_status: Callable[[], dict[str, Any]], +) -> dict[str, Any]: + """Create dashboard tab with model overview and status. + + Args: + get_queue_status: Function to get generation queue status + get_recent_generations: Function to get recent generations + get_gpu_status: Function to get GPU status + + Returns: + Dictionary with component references + """ + + def refresh_dashboard(): + """Refresh all dashboard data.""" + queue = get_queue_status() + recent = get_recent_generations(5) + gpu = get_gpu_status() + + # Format queue status + queue_size = queue.get("queue_size", 0) + queue_text = f"**Queue:** {queue_size} job(s) pending" + + # Format recent generations + if recent: + recent_items = [] + for gen in recent[:5]: + model = gen.get("model", "unknown") + prompt = gen.get("prompt", "")[:50] + duration = gen.get("duration_seconds", 0) + recent_items.append(f"• **{model}** ({duration:.0f}s): {prompt}...") + recent_text = "\n".join(recent_items) + else: + recent_text = "No recent generations" + + # Format GPU status + used_gb = gpu.get("used_gb", 0) + total_gb = gpu.get("total_gb", 24) + util = gpu.get("utilization_percent", 0) + gpu_text = f"**GPU:** {used_gb:.1f}/{total_gb:.1f} GB ({util:.0f}%)" + + return queue_text, recent_text, gpu_text + + with gr.Column(): + # Header + gr.Markdown("# AudioCraft Studio") + gr.Markdown("AI-powered music and sound generation") + + # Status bar + with gr.Row(): + queue_status = gr.Markdown("**Queue:** Loading...") + gpu_status = gr.Markdown("**GPU:** Loading...") + refresh_btn = gr.Button("🔄 Refresh", size="sm") + + gr.Markdown("---") + + # Model cards + gr.Markdown("## Models") + + with gr.Row(): + # First row of cards + for model_id in ["musicgen", "audiogen", "magnet"]: + info = MODEL_INFO[model_id] + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown(f"### {info['icon']} {info['name']}") + gr.Markdown(info["description"]) + gr.Markdown("**Features:** " + ", ".join(info["capabilities"])) + gr.Button( + f"Open {info['name']}", + variant="primary", + size="sm", + elem_id=f"btn_{model_id}", + ) + + with gr.Row(): + # Second row of cards + for model_id in ["musicgen-style", "jasco"]: + info = MODEL_INFO[model_id] + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown(f"### {info['icon']} {info['name']}") + gr.Markdown(info["description"]) + gr.Markdown("**Features:** " + ", ".join(info["capabilities"])) + gr.Button( + f"Open {info['name']}", + variant="primary", + size="sm", + elem_id=f"btn_{model_id}", + ) + + # Empty column for balance + with gr.Column(scale=1): + pass + + gr.Markdown("---") + + # Recent generations and queue + with gr.Row(): + with gr.Column(scale=1): + gr.Markdown("## Recent Generations") + recent_list = gr.Markdown("Loading...") + + with gr.Column(scale=1): + gr.Markdown("## Quick Actions") + with gr.Group(): + gr.Button("📁 Browse Projects", variant="secondary") + gr.Button("⚙️ Settings", variant="secondary") + gr.Button("📖 API Documentation", variant="secondary") + + # Refresh handler + refresh_btn.click( + fn=refresh_dashboard, + outputs=[queue_status, recent_list, gpu_status], + ) + + return { + "queue_status": queue_status, + "gpu_status": gpu_status, + "recent_list": recent_list, + "refresh_btn": refresh_btn, + "refresh_fn": refresh_dashboard, + } diff --git a/src/ui/tabs/jasco_tab.py b/src/ui/tabs/jasco_tab.py new file mode 100644 index 0000000..bfa74b8 --- /dev/null +++ b/src/ui/tabs/jasco_tab.py @@ -0,0 +1,364 @@ +"""JASCO tab for chord and drum-conditioned generation.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.audio_player import create_generation_output + + +JASCO_VARIANTS = [ + {"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation"}, + {"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning"}, +] + +# Common chord progressions +CHORD_PRESETS = [ + {"name": "Pop I-V-vi-IV", "chords": "C G Am F"}, + {"name": "Jazz ii-V-I", "chords": "Dm7 G7 Cmaj7"}, + {"name": "Blues I-IV-V", "chords": "A7 D7 E7"}, + {"name": "Rock I-bVII-IV", "chords": "E D A"}, + {"name": "Minor i-VI-III-VII", "chords": "Am F C G"}, +] + + +def create_jasco_tab( + generate_fn: Callable[..., Any], + add_to_queue_fn: Callable[..., Any], +) -> dict[str, Any]: + """Create JASCO generation tab. + + Args: + generate_fn: Function to call for generation + add_to_queue_fn: Function to add to queue + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get("jasco", []) + suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) + + with gr.Column(): + gr.Markdown("## 🎹 JASCO") + gr.Markdown("Generate music conditioned on chords and drum patterns") + + with gr.Row(): + # Left column - inputs + with gr.Column(scale=2): + # Preset selector + preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] + preset_dropdown = gr.Dropdown( + label="Preset", + choices=preset_choices, + value=presets[0]["id"] if presets else "custom", + ) + + # Model variant + variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in JASCO_VARIANTS] + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=variant_choices, + value="chords-drums", + ) + + # Prompt input + prompt_input = gr.Textbox( + label="Text Prompt", + placeholder="Describe the music style, mood, instruments...", + lines=2, + max_lines=4, + ) + + # Chord conditioning + gr.Markdown("### Chord Progression") + + chord_input = gr.Textbox( + label="Chords", + placeholder="C G Am F or Cmaj7 Dm7 G7 Cmaj7", + lines=1, + info="Space-separated chord symbols", + ) + + # Chord presets + with gr.Accordion("Chord Presets", open=False): + chord_preset_btns = [] + with gr.Row(): + for cp in CHORD_PRESETS[:3]: + btn = gr.Button(cp["name"], size="sm", variant="secondary") + chord_preset_btns.append((btn, cp["chords"])) + with gr.Row(): + for cp in CHORD_PRESETS[3:]: + btn = gr.Button(cp["name"], size="sm", variant="secondary") + chord_preset_btns.append((btn, cp["chords"])) + + # Drum conditioning (for chords-drums variant) + with gr.Group(visible=True) as drum_section: + gr.Markdown("### Drum Pattern") + + drum_input = gr.Audio( + label="Drum Reference", + type="filepath", + sources=["upload"], + ) + gr.Markdown("*Upload a drum loop to condition the rhythm*") + + # Parameters + gr.Markdown("### Parameters") + + duration_slider = gr.Slider( + minimum=1, + maximum=30, + value=10, + step=1, + label="Duration (seconds)", + ) + + bpm_slider = gr.Slider( + minimum=60, + maximum=180, + value=120, + step=1, + label="BPM", + info="Tempo for chord timing", + ) + + with gr.Accordion("Advanced Parameters", open=False): + with gr.Row(): + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=1.0, + step=0.05, + label="Temperature", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=3.0, + step=0.5, + label="CFG Coefficient", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=250, + step=10, + label="Top-K", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.0, + step=0.05, + label="Top-P", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed (empty = random)", + precision=0, + ) + + # Generate buttons + with gr.Row(): + generate_btn = gr.Button("🎹 Generate", variant="primary", scale=2) + queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) + + # Right column - output + with gr.Column(scale=3): + output = create_generation_output() + + # Event handlers + + # Preset change + def apply_preset(preset_id: str): + for p in presets: + if p["id"] == preset_id: + params = p["parameters"] + return ( + params.get("duration", 10), + params.get("bpm", 120), + params.get("temperature", 1.0), + params.get("cfg_coef", 3.0), + params.get("top_k", 250), + params.get("top_p", 0.0), + ) + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + preset_dropdown.change( + fn=apply_preset, + inputs=[preset_dropdown], + outputs=[duration_slider, bpm_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], + ) + + # Variant change - show/hide drum section + def on_variant_change(variant: str): + show_drums = "drums" in variant.lower() + return gr.update(visible=show_drums) + + variant_dropdown.change( + fn=on_variant_change, + inputs=[variant_dropdown], + outputs=[drum_section], + ) + + # Chord presets + for btn, chords in chord_preset_btns: + btn.click( + fn=lambda c=chords: c, + outputs=[chord_input], + ) + + # Generate + async def do_generate( + prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed + ): + if not chords: + return ( + gr.update(value="Please enter a chord progression"), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + yield ( + gr.update(value="🔄 Generating..."), + gr.update(visible=True, value=0), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + try: + conditioning = { + "chords": chords, + "bpm": bpm, + } + if drums and "drums" in variant.lower(): + conditioning["drums"] = drums + + result, generation = await generate_fn( + model_id="jasco", + variant=variant, + prompts=[prompt] if prompt else [""], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + yield ( + gr.update(value="✅ Generation complete!"), + gr.update(visible=False), + gr.update(value=generation.audio_path), + gr.update(), + gr.update(value=f"{result.duration:.2f}s"), + gr.update(value=str(result.seed)), + ) + + except Exception as e: + yield ( + gr.update(value=f"❌ Error: {str(e)}"), + gr.update(visible=False), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + generate_btn.click( + fn=do_generate, + inputs=[ + prompt_input, + variant_dropdown, + chord_input, + drum_input, + duration_slider, + bpm_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[ + output["status"], + output["progress"], + output["player"]["audio"], + output["player"]["waveform"], + output["player"]["duration"], + output["player"]["seed"], + ], + ) + + # Add to queue + def do_add_queue(prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed): + if not chords: + return "Please enter a chord progression" + + conditioning = { + "chords": chords, + "bpm": bpm, + } + if drums and "drums" in variant.lower(): + conditioning["drums"] = drums + + job = add_to_queue_fn( + model_id="jasco", + variant=variant, + prompts=[prompt] if prompt else [""], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + return f"✅ Added to queue: {job.id}" + + queue_btn.click( + fn=do_add_queue, + inputs=[ + prompt_input, + variant_dropdown, + chord_input, + drum_input, + duration_slider, + bpm_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[output["status"]], + ) + + return { + "preset": preset_dropdown, + "variant": variant_dropdown, + "prompt": prompt_input, + "chords": chord_input, + "drums": drum_input, + "duration": duration_slider, + "bpm": bpm_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "seed": seed_input, + "generate_btn": generate_btn, + "queue_btn": queue_btn, + "output": output, + } diff --git a/src/ui/tabs/magnet_tab.py b/src/ui/tabs/magnet_tab.py new file mode 100644 index 0000000..f4312b9 --- /dev/null +++ b/src/ui/tabs/magnet_tab.py @@ -0,0 +1,316 @@ +"""MAGNeT tab for fast non-autoregressive generation.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.audio_player import create_generation_output + + +MAGNET_VARIANTS = [ + {"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params"}, + {"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params"}, + {"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects"}, + {"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects"}, +] + + +def create_magnet_tab( + generate_fn: Callable[..., Any], + add_to_queue_fn: Callable[..., Any], +) -> dict[str, Any]: + """Create MAGNeT generation tab. + + Args: + generate_fn: Function to call for generation + add_to_queue_fn: Function to add to queue + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get("magnet", []) + suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) # Reuse music suggestions + + with gr.Column(): + gr.Markdown("## ⚡ MAGNeT") + gr.Markdown("Fast non-autoregressive music and sound generation") + + with gr.Row(): + # Left column - inputs + with gr.Column(scale=2): + # Preset selector + preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] + preset_dropdown = gr.Dropdown( + label="Preset", + choices=preset_choices, + value=presets[0]["id"] if presets else "custom", + ) + + # Model variant + variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MAGNET_VARIANTS] + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=variant_choices, + value="medium", + ) + + # Prompt input + prompt_input = gr.Textbox( + label="Prompt", + placeholder="Describe the music or sound you want to generate...", + lines=3, + max_lines=5, + ) + + # Prompt suggestions + with gr.Accordion("Prompt Suggestions", open=False): + suggestion_btns = [] + for i, suggestion in enumerate(suggestions[:4]): + btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary") + suggestion_btns.append((btn, suggestion)) + + # Parameters + gr.Markdown("### Parameters") + + duration_slider = gr.Slider( + minimum=1, + maximum=30, + value=10, + step=1, + label="Duration (seconds)", + ) + + with gr.Accordion("Advanced Parameters", open=False): + gr.Markdown("*MAGNeT uses different sampling compared to MusicGen*") + + with gr.Row(): + temperature_slider = gr.Slider( + minimum=1.0, + maximum=5.0, + value=3.0, + step=0.1, + label="Temperature", + info="Higher values recommended (3.0 default)", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=3.0, + step=0.5, + label="CFG Coefficient", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=0, + step=10, + label="Top-K", + info="0 recommended for MAGNeT", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.9, + step=0.05, + label="Top-P", + info="0.9 recommended for MAGNeT", + ) + + with gr.Row(): + decoding_steps_slider = gr.Slider( + minimum=10, + maximum=100, + value=20, + step=5, + label="Decoding Steps", + info="More steps = better quality, slower", + ) + span_arrangement = gr.Dropdown( + label="Span Arrangement", + choices=[("No Overlap", "nonoverlap"), ("Overlap", "stride1")], + value="nonoverlap", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed (empty = random)", + precision=0, + ) + + # Generate buttons + with gr.Row(): + generate_btn = gr.Button("⚡ Generate", variant="primary", scale=2) + queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) + + # Right column - output + with gr.Column(scale=3): + output = create_generation_output() + + # Event handlers + + # Preset change + def apply_preset(preset_id: str): + for p in presets: + if p["id"] == preset_id: + params = p["parameters"] + return ( + params.get("duration", 10), + params.get("temperature", 3.0), + params.get("cfg_coef", 3.0), + params.get("top_k", 0), + params.get("top_p", 0.9), + params.get("decoding_steps", 20), + ) + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + preset_dropdown.change( + fn=apply_preset, + inputs=[preset_dropdown], + outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, decoding_steps_slider], + ) + + # Prompt suggestions + for btn, suggestion in suggestion_btns: + btn.click( + fn=lambda s=suggestion: s, + outputs=[prompt_input], + ) + + # Generate + async def do_generate( + prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed + ): + if not prompt: + return ( + gr.update(value="Please enter a prompt"), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + yield ( + gr.update(value="🔄 Generating..."), + gr.update(visible=True, value=0), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + try: + result, generation = await generate_fn( + model_id="magnet", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + decoding_steps=int(decoding_steps), + span_arrangement=span_arr, + seed=int(seed) if seed else None, + ) + + yield ( + gr.update(value="✅ Generation complete!"), + gr.update(visible=False), + gr.update(value=generation.audio_path), + gr.update(), + gr.update(value=f"{result.duration:.2f}s"), + gr.update(value=str(result.seed)), + ) + + except Exception as e: + yield ( + gr.update(value=f"❌ Error: {str(e)}"), + gr.update(visible=False), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + generate_btn.click( + fn=do_generate, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + decoding_steps_slider, + span_arrangement, + seed_input, + ], + outputs=[ + output["status"], + output["progress"], + output["player"]["audio"], + output["player"]["waveform"], + output["player"]["duration"], + output["player"]["seed"], + ], + ) + + # Add to queue + def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed): + if not prompt: + return "Please enter a prompt" + + job = add_to_queue_fn( + model_id="magnet", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + decoding_steps=int(decoding_steps), + span_arrangement=span_arr, + seed=int(seed) if seed else None, + ) + + return f"✅ Added to queue: {job.id}" + + queue_btn.click( + fn=do_add_queue, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + decoding_steps_slider, + span_arrangement, + seed_input, + ], + outputs=[output["status"]], + ) + + return { + "preset": preset_dropdown, + "variant": variant_dropdown, + "prompt": prompt_input, + "duration": duration_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "decoding_steps": decoding_steps_slider, + "span_arrangement": span_arrangement, + "seed": seed_input, + "generate_btn": generate_btn, + "queue_btn": queue_btn, + "output": output, + } diff --git a/src/ui/tabs/musicgen_tab.py b/src/ui/tabs/musicgen_tab.py new file mode 100644 index 0000000..e6a26ce --- /dev/null +++ b/src/ui/tabs/musicgen_tab.py @@ -0,0 +1,325 @@ +"""MusicGen tab for text-to-music generation.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.audio_player import create_generation_output + + +MUSICGEN_VARIANTS = [ + {"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"}, + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"}, + {"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"}, + {"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"}, + {"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"}, + {"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"}, + {"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"}, + {"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"}, +] + + +def create_musicgen_tab( + generate_fn: Callable[..., Any], + add_to_queue_fn: Callable[..., Any], +) -> dict[str, Any]: + """Create MusicGen generation tab. + + Args: + generate_fn: Function to call for generation + add_to_queue_fn: Function to add to queue + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get("musicgen", []) + suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) + + with gr.Column(): + gr.Markdown("## 🎵 MusicGen") + gr.Markdown("Generate music from text descriptions") + + with gr.Row(): + # Left column - inputs + with gr.Column(scale=2): + # Preset selector + preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] + preset_dropdown = gr.Dropdown( + label="Preset", + choices=preset_choices, + value=presets[0]["id"] if presets else "custom", + ) + + # Model variant + variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS] + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=variant_choices, + value="medium", + ) + + # Prompt input + prompt_input = gr.Textbox( + label="Prompt", + placeholder="Describe the music you want to generate...", + lines=3, + max_lines=5, + ) + + # Prompt suggestions + with gr.Accordion("Prompt Suggestions", open=False): + suggestion_btns = [] + for i, suggestion in enumerate(suggestions[:4]): + btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary") + suggestion_btns.append((btn, suggestion)) + + # Melody conditioning (for melody variants) + with gr.Group(visible=False) as melody_section: + gr.Markdown("### Melody Conditioning") + melody_input = gr.Audio( + label="Reference Melody", + type="filepath", + sources=["upload", "microphone"], + ) + gr.Markdown("*Upload audio to condition generation on its melody*") + + # Parameters + gr.Markdown("### Parameters") + + duration_slider = gr.Slider( + minimum=1, + maximum=30, + value=10, + step=1, + label="Duration (seconds)", + ) + + with gr.Accordion("Advanced Parameters", open=False): + with gr.Row(): + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=1.0, + step=0.05, + label="Temperature", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=3.0, + step=0.5, + label="CFG Coefficient", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=250, + step=10, + label="Top-K", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.0, + step=0.05, + label="Top-P", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed (empty = random)", + precision=0, + ) + + # Generate buttons + with gr.Row(): + generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2) + queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) + + # Right column - output + with gr.Column(scale=3): + output = create_generation_output() + + # Event handlers + + # Preset change + def apply_preset(preset_id: str): + for p in presets: + if p["id"] == preset_id: + params = p["parameters"] + return ( + params.get("duration", 10), + params.get("temperature", 1.0), + params.get("cfg_coef", 3.0), + params.get("top_k", 250), + params.get("top_p", 0.0), + ) + # Custom preset - don't change values + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + preset_dropdown.change( + fn=apply_preset, + inputs=[preset_dropdown], + outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], + ) + + # Variant change - show/hide melody section + def on_variant_change(variant: str): + show_melody = "melody" in variant.lower() + return gr.update(visible=show_melody) + + variant_dropdown.change( + fn=on_variant_change, + inputs=[variant_dropdown], + outputs=[melody_section], + ) + + # Prompt suggestions + for btn, suggestion in suggestion_btns: + btn.click( + fn=lambda s=suggestion: s, + outputs=[prompt_input], + ) + + # Generate + async def do_generate( + prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody + ): + if not prompt: + return ( + gr.update(value="Please enter a prompt"), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + # Update status + yield ( + gr.update(value="🔄 Generating..."), + gr.update(visible=True, value=0), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + try: + conditioning = {} + if melody: + conditioning["melody"] = melody + + result, generation = await generate_fn( + model_id="musicgen", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + yield ( + gr.update(value="✅ Generation complete!"), + gr.update(visible=False), + gr.update(value=generation.audio_path), + gr.update(), + gr.update(value=f"{result.duration:.2f}s"), + gr.update(value=str(result.seed)), + ) + + except Exception as e: + yield ( + gr.update(value=f"❌ Error: {str(e)}"), + gr.update(visible=False), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + generate_btn.click( + fn=do_generate, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + melody_input, + ], + outputs=[ + output["status"], + output["progress"], + output["player"]["audio"], + output["player"]["waveform"], + output["player"]["duration"], + output["player"]["seed"], + ], + ) + + # Add to queue + def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody): + if not prompt: + return "Please enter a prompt" + + conditioning = {} + if melody: + conditioning["melody"] = melody + + job = add_to_queue_fn( + model_id="musicgen", + variant=variant, + prompts=[prompt], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + return f"✅ Added to queue: {job.id}" + + queue_btn.click( + fn=do_add_queue, + inputs=[ + prompt_input, + variant_dropdown, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + melody_input, + ], + outputs=[output["status"]], + ) + + return { + "preset": preset_dropdown, + "variant": variant_dropdown, + "prompt": prompt_input, + "melody": melody_input, + "duration": duration_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "seed": seed_input, + "generate_btn": generate_btn, + "queue_btn": queue_btn, + "output": output, + } diff --git a/src/ui/tabs/style_tab.py b/src/ui/tabs/style_tab.py new file mode 100644 index 0000000..621d6d2 --- /dev/null +++ b/src/ui/tabs/style_tab.py @@ -0,0 +1,292 @@ +"""MusicGen Style tab for style-conditioned generation.""" + +import gradio as gr +from typing import Any, Callable, Optional + +from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS +from src.ui.components.audio_player import create_generation_output + + +STYLE_VARIANTS = [ + {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"}, +] + + +def create_style_tab( + generate_fn: Callable[..., Any], + add_to_queue_fn: Callable[..., Any], +) -> dict[str, Any]: + """Create MusicGen Style generation tab. + + Args: + generate_fn: Function to call for generation + add_to_queue_fn: Function to add to queue + + Returns: + Dictionary with component references + """ + presets = DEFAULT_PRESETS.get("musicgen-style", []) + suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) + + with gr.Column(): + gr.Markdown("## 🎨 MusicGen Style") + gr.Markdown("Generate music conditioned on the style of reference audio") + + with gr.Row(): + # Left column - inputs + with gr.Column(scale=2): + # Preset selector + preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] + preset_dropdown = gr.Dropdown( + label="Preset", + choices=preset_choices, + value=presets[0]["id"] if presets else "custom", + ) + + # Model variant + variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS] + variant_dropdown = gr.Dropdown( + label="Model Variant", + choices=variant_choices, + value="medium", + ) + + # Prompt input + prompt_input = gr.Textbox( + label="Text Prompt", + placeholder="Describe additional characteristics for the music...", + lines=3, + max_lines=5, + info="Optional: combine with style conditioning", + ) + + # Style conditioning (required) + gr.Markdown("### Style Conditioning") + gr.Markdown("*Upload reference audio to extract musical style*") + + style_input = gr.Audio( + label="Style Reference", + type="filepath", + sources=["upload", "microphone"], + ) + + style_info = gr.Markdown( + "*The model will learn the style (instrumentation, tempo, mood) from this audio*" + ) + + # Parameters + gr.Markdown("### Parameters") + + duration_slider = gr.Slider( + minimum=1, + maximum=30, + value=10, + step=1, + label="Duration (seconds)", + ) + + with gr.Accordion("Advanced Parameters", open=False): + with gr.Row(): + temperature_slider = gr.Slider( + minimum=0.0, + maximum=2.0, + value=1.0, + step=0.05, + label="Temperature", + ) + cfg_slider = gr.Slider( + minimum=1.0, + maximum=10.0, + value=3.0, + step=0.5, + label="CFG Coefficient", + ) + + with gr.Row(): + top_k_slider = gr.Slider( + minimum=0, + maximum=500, + value=250, + step=10, + label="Top-K", + ) + top_p_slider = gr.Slider( + minimum=0.0, + maximum=1.0, + value=0.0, + step=0.05, + label="Top-P", + ) + + with gr.Row(): + seed_input = gr.Number( + value=None, + label="Seed (empty = random)", + precision=0, + ) + + # Generate buttons + with gr.Row(): + generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2) + queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) + + # Right column - output + with gr.Column(scale=3): + output = create_generation_output() + + # Event handlers + + # Preset change + def apply_preset(preset_id: str): + for p in presets: + if p["id"] == preset_id: + params = p["parameters"] + return ( + params.get("duration", 10), + params.get("temperature", 1.0), + params.get("cfg_coef", 3.0), + params.get("top_k", 250), + params.get("top_p", 0.0), + ) + return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() + + preset_dropdown.change( + fn=apply_preset, + inputs=[preset_dropdown], + outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], + ) + + # Generate + async def do_generate( + prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed + ): + if not style_audio: + return ( + gr.update(value="Please upload a style reference audio"), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + yield ( + gr.update(value="🔄 Generating..."), + gr.update(visible=True, value=0), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + try: + conditioning = {"style": style_audio} + + result, generation = await generate_fn( + model_id="musicgen-style", + variant=variant, + prompts=[prompt] if prompt else [""], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + yield ( + gr.update(value="✅ Generation complete!"), + gr.update(visible=False), + gr.update(value=generation.audio_path), + gr.update(), + gr.update(value=f"{result.duration:.2f}s"), + gr.update(value=str(result.seed)), + ) + + except Exception as e: + yield ( + gr.update(value=f"❌ Error: {str(e)}"), + gr.update(visible=False), + gr.update(), + gr.update(), + gr.update(), + gr.update(), + ) + + generate_btn.click( + fn=do_generate, + inputs=[ + prompt_input, + variant_dropdown, + style_input, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[ + output["status"], + output["progress"], + output["player"]["audio"], + output["player"]["waveform"], + output["player"]["duration"], + output["player"]["seed"], + ], + ) + + # Add to queue + def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed): + if not style_audio: + return "Please upload a style reference audio" + + conditioning = {"style": style_audio} + + job = add_to_queue_fn( + model_id="musicgen-style", + variant=variant, + prompts=[prompt] if prompt else [""], + duration=duration, + temperature=temperature, + top_k=int(top_k), + top_p=top_p, + cfg_coef=cfg_coef, + seed=int(seed) if seed else None, + conditioning=conditioning, + ) + + return f"✅ Added to queue: {job.id}" + + queue_btn.click( + fn=do_add_queue, + inputs=[ + prompt_input, + variant_dropdown, + style_input, + duration_slider, + temperature_slider, + cfg_slider, + top_k_slider, + top_p_slider, + seed_input, + ], + outputs=[output["status"]], + ) + + return { + "preset": preset_dropdown, + "variant": variant_dropdown, + "prompt": prompt_input, + "style": style_input, + "duration": duration_slider, + "temperature": temperature_slider, + "cfg_coef": cfg_slider, + "top_k": top_k_slider, + "top_p": top_p_slider, + "seed": seed_input, + "generate_btn": generate_btn, + "queue_btn": queue_btn, + "output": output, + } diff --git a/src/ui/theme.py b/src/ui/theme.py new file mode 100644 index 0000000..5b87905 --- /dev/null +++ b/src/ui/theme.py @@ -0,0 +1,303 @@ +"""Custom Gradio theme for AudioCraft Studio.""" + +import gradio as gr + + +def create_theme() -> gr.themes.Base: + """Create custom theme for AudioCraft Studio. + + Returns: + Gradio theme instance + """ + return gr.themes.Soft( + primary_hue=gr.themes.colors.blue, + secondary_hue=gr.themes.colors.slate, + neutral_hue=gr.themes.colors.gray, + font=[ + gr.themes.GoogleFont("Inter"), + "ui-sans-serif", + "system-ui", + "sans-serif", + ], + font_mono=[ + gr.themes.GoogleFont("JetBrains Mono"), + "ui-monospace", + "monospace", + ], + ).set( + # Colors + body_background_fill="#0f172a", + body_background_fill_dark="#0f172a", + background_fill_primary="#1e293b", + background_fill_primary_dark="#1e293b", + background_fill_secondary="#334155", + background_fill_secondary_dark="#334155", + border_color_primary="#475569", + border_color_primary_dark="#475569", + + # Text + body_text_color="#e2e8f0", + body_text_color_dark="#e2e8f0", + body_text_color_subdued="#94a3b8", + body_text_color_subdued_dark="#94a3b8", + + # Buttons + button_primary_background_fill="#3b82f6", + button_primary_background_fill_dark="#3b82f6", + button_primary_background_fill_hover="#2563eb", + button_primary_background_fill_hover_dark="#2563eb", + button_primary_text_color="#ffffff", + button_primary_text_color_dark="#ffffff", + + button_secondary_background_fill="#475569", + button_secondary_background_fill_dark="#475569", + button_secondary_background_fill_hover="#64748b", + button_secondary_background_fill_hover_dark="#64748b", + + # Inputs + input_background_fill="#1e293b", + input_background_fill_dark="#1e293b", + input_border_color="#475569", + input_border_color_dark="#475569", + input_border_color_focus="#3b82f6", + input_border_color_focus_dark="#3b82f6", + + # Blocks + block_background_fill="#1e293b", + block_background_fill_dark="#1e293b", + block_border_color="#334155", + block_border_color_dark="#334155", + block_label_background_fill="#334155", + block_label_background_fill_dark="#334155", + block_label_text_color="#e2e8f0", + block_label_text_color_dark="#e2e8f0", + block_title_text_color="#f1f5f9", + block_title_text_color_dark="#f1f5f9", + + # Tabs + tab_nav_background_fill="#1e293b", + + # Sliders + slider_color="#3b82f6", + slider_color_dark="#3b82f6", + + # Shadows + shadow_spread="4px", + block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3)", + + # Spacing + layout_gap="16px", + block_padding="16px", + panel_border_width="1px", + + # Radius + radius_sm="6px", + radius_md="8px", + radius_lg="12px", + ) + + +# CSS overrides for additional customization +CUSTOM_CSS = """ +/* Global styles */ +.gradio-container { + max-width: 100% !important; +} + +/* Header styling */ +.header-title { + font-size: 1.5rem; + font-weight: 700; + color: #f1f5f9; +} + +/* Sidebar styling */ +.sidebar { + background: #1e293b; + border-right: 1px solid #334155; + padding: 1rem; +} + +.sidebar-nav-btn { + width: 100%; + justify-content: flex-start; + margin-bottom: 0.5rem; +} + +/* Model cards */ +.model-card { + background: #334155; + border-radius: 12px; + padding: 1rem; + transition: transform 0.2s, box-shadow 0.2s; +} + +.model-card:hover { + transform: translateY(-2px); + box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3); +} + +/* Audio player */ +.audio-player { + background: #1e293b; + border-radius: 8px; + padding: 1rem; +} + +/* Progress bar */ +.progress-bar { + background: #334155; + border-radius: 4px; + overflow: hidden; +} + +.progress-fill { + background: linear-gradient(90deg, #3b82f6, #8b5cf6); + height: 100%; + transition: width 0.3s ease; +} + +/* VRAM monitor */ +.vram-bar { + background: #334155; + border-radius: 4px; + height: 24px; + position: relative; + overflow: hidden; +} + +.vram-fill { + position: absolute; + left: 0; + top: 0; + height: 100%; + background: linear-gradient(90deg, #22c55e, #eab308, #ef4444); + transition: width 0.5s ease; +} + +.vram-text { + position: absolute; + width: 100%; + text-align: center; + line-height: 24px; + font-size: 0.875rem; + font-weight: 500; + color: white; + text-shadow: 0 1px 2px rgba(0, 0, 0, 0.5); +} + +/* Queue badge */ +.queue-badge { + background: #3b82f6; + color: white; + padding: 0.25rem 0.75rem; + border-radius: 9999px; + font-size: 0.875rem; + font-weight: 500; +} + +/* Generation card */ +.generation-card { + background: #334155; + border-radius: 8px; + padding: 1rem; + margin-bottom: 0.5rem; +} + +/* Preset chips */ +.preset-chip { + display: inline-block; + background: #475569; + color: #e2e8f0; + padding: 0.25rem 0.75rem; + border-radius: 9999px; + font-size: 0.875rem; + margin: 0.25rem; + cursor: pointer; + transition: background 0.2s; +} + +.preset-chip:hover { + background: #3b82f6; +} + +.preset-chip.active { + background: #3b82f6; +} + +/* Tag input */ +.tag { + display: inline-flex; + align-items: center; + background: #475569; + color: #e2e8f0; + padding: 0.25rem 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + margin: 0.125rem; +} + +/* Accordion tweaks */ +.accordion-header { + font-weight: 600; + color: #f1f5f9; +} + +/* Status indicators */ +.status-dot { + width: 8px; + height: 8px; + border-radius: 50%; + display: inline-block; + margin-right: 0.5rem; +} + +.status-dot.loaded { + background: #22c55e; +} + +.status-dot.unloaded { + background: #64748b; +} + +.status-dot.loading { + background: #eab308; + animation: pulse 1s infinite; +} + +@keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.5; } +} + +/* Tooltip */ +.tooltip { + position: relative; +} + +.tooltip:hover::after { + content: attr(data-tooltip); + position: absolute; + bottom: 100%; + left: 50%; + transform: translateX(-50%); + background: #1e293b; + color: #e2e8f0; + padding: 0.5rem; + border-radius: 4px; + font-size: 0.75rem; + white-space: nowrap; + z-index: 100; +} + +/* Responsive adjustments */ +@media (max-width: 768px) { + .sidebar { + display: none; + } + + .mobile-nav { + display: flex !important; + } +} +"""