Initial implementation of AudioCraft Studio

Complete web interface for Meta's AudioCraft AI audio generation:

- Gradio UI with tabs for all 5 model families (MusicGen, AudioGen,
  MAGNeT, MusicGen Style, JASCO)
- REST API with FastAPI, OpenAPI docs, and API key auth
- VRAM management with ComfyUI coexistence support
- SQLite database for project/generation history
- Batch processing queue for async generation
- Docker deployment optimized for RunPod with RTX 4090

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-25 19:34:27 +01:00
commit ffbf02b12c
67 changed files with 12032 additions and 0 deletions

42
.env.example Normal file
View File

@@ -0,0 +1,42 @@
# AudioCraft Studio Configuration
# Copy this file to .env and customize as needed
# Server Configuration
AUDIOCRAFT_HOST=0.0.0.0
AUDIOCRAFT_GRADIO_PORT=7860
AUDIOCRAFT_API_PORT=8000
# Paths (relative to project root)
AUDIOCRAFT_DATA_DIR=./data
AUDIOCRAFT_OUTPUT_DIR=./outputs
AUDIOCRAFT_CACHE_DIR=./cache
# VRAM Management
# Reserve this much VRAM for ComfyUI (GB)
AUDIOCRAFT_COMFYUI_RESERVE_GB=10
# Safety buffer to prevent OOM (GB)
AUDIOCRAFT_SAFETY_BUFFER_GB=1
# Unload idle models after this many minutes
AUDIOCRAFT_IDLE_UNLOAD_MINUTES=15
# Maximum number of models to keep loaded
AUDIOCRAFT_MAX_CACHED_MODELS=2
# API Authentication
# Generate a secure random key for production
AUDIOCRAFT_API_KEY=your-secret-api-key-here
# Generation Defaults
AUDIOCRAFT_DEFAULT_DURATION=10.0
AUDIOCRAFT_MAX_DURATION=300.0
AUDIOCRAFT_DEFAULT_BATCH_SIZE=1
AUDIOCRAFT_MAX_BATCH_SIZE=8
AUDIOCRAFT_MAX_QUEUE_SIZE=100
# Database
AUDIOCRAFT_DATABASE_URL=sqlite+aiosqlite:///./data/audiocraft.db
# Logging
AUDIOCRAFT_LOG_LEVEL=INFO
# PyTorch Optimization (recommended)
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

76
.gitignore vendored Normal file
View File

@@ -0,0 +1,76 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
.venv/
venv/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
# Type checking
.mypy_cache/
# Project specific
data/
outputs/
cache/
*.db
*.sqlite
*.sqlite3
# Logs
*.log
logs/
# Environment
.env
.env.local
.env.*.local
# Model weights (downloaded from HuggingFace)
*.bin
*.safetensors
*.pt
*.pth
# Audio files (generated)
*.wav
*.mp3
*.flac
*.ogg
# Temp files
/tmp/
*.tmp

83
Dockerfile Normal file
View File

@@ -0,0 +1,83 @@
# AudioCraft Studio Dockerfile for RunPod
# Optimized for NVIDIA RTX 4090 (24GB VRAM)
FROM nvidia/cuda:12.1-cudnn8-runtime-ubuntu22.04
# Set environment variables
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
ENV PIP_NO_CACHE_DIR=1
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
# CUDA settings
ENV CUDA_HOME=/usr/local/cuda
ENV PATH="${CUDA_HOME}/bin:${PATH}"
ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}"
# AudioCraft settings
ENV AUDIOCRAFT_OUTPUT_DIR=/workspace/outputs
ENV AUDIOCRAFT_DATA_DIR=/workspace/data
ENV AUDIOCRAFT_MODEL_CACHE=/workspace/models
ENV AUDIOCRAFT_HOST=0.0.0.0
ENV AUDIOCRAFT_GRADIO_PORT=7860
ENV AUDIOCRAFT_API_PORT=8000
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
git \
curl \
wget \
ffmpeg \
libsndfile1 \
libsox-dev \
sox \
build-essential \
python3.10 \
python3.10-venv \
python3.10-dev \
python3-pip \
&& rm -rf /var/lib/apt/lists/*
# Set Python 3.10 as default
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
# Upgrade pip
RUN pip install --upgrade pip setuptools wheel
# Create workspace directory
WORKDIR /workspace
# Create necessary directories
RUN mkdir -p /workspace/outputs /workspace/data /workspace/models /workspace/app
# Copy requirements first for caching
COPY requirements.txt /workspace/app/
WORKDIR /workspace/app
# Install PyTorch with CUDA support
RUN pip install torch==2.1.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
# Install other requirements
RUN pip install -r requirements.txt
# Install AudioCraft from source for latest features
RUN pip install git+https://github.com/facebookresearch/audiocraft.git
# Copy application code
COPY . /workspace/app/
# Create non-root user for security (optional, RunPod often uses root)
# RUN useradd -m -u 1000 audiocraft && chown -R audiocraft:audiocraft /workspace
# USER audiocraft
# Expose ports
EXPOSE 7860 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
CMD curl -f http://localhost:7860/ || exit 1
# Default command
CMD ["python", "main.py"]

197
README.md Normal file
View File

@@ -0,0 +1,197 @@
# AudioCraft Studio
A comprehensive web interface for Meta's AudioCraft AI audio generation models, optimized for RunPod deployment with NVIDIA RTX 4090 GPUs.
## Features
### Models Supported
- **MusicGen** - Text-to-music generation with melody conditioning
- **AudioGen** - Text-to-sound effects and environmental audio
- **MAGNeT** - Fast non-autoregressive music generation
- **MusicGen Style** - Style-conditioned music from reference audio
- **JASCO** - Chord and drum-conditioned music generation
### Core Capabilities
- **Gradio Web UI** - Intuitive interface with real-time generation
- **REST API** - Full-featured API with OpenAPI documentation
- **Batch Processing** - Queue system for multiple generations
- **Project Management** - Organize and browse generation history
- **VRAM Management** - Smart model loading/unloading, ComfyUI coexistence
- **Waveform Visualization** - Visual audio feedback
## Quick Start
### Local Development
```bash
# Clone repository
git clone https://github.com/your-username/audiocraft-ui.git
cd audiocraft-ui
# Create virtual environment
python -m venv venv
source venv/bin/activate # Linux/Mac
# or: venv\Scripts\activate # Windows
# Install dependencies
pip install -r requirements.txt
# Run application
python main.py
```
Access the UI at `http://localhost:7860`
### Docker
```bash
# Build and run
docker-compose up --build
# Or build manually
docker build -t audiocraft-studio .
docker run --gpus all -p 7860:7860 -p 8000:8000 audiocraft-studio
```
### RunPod Deployment
1. Build and push Docker image:
```bash
docker build -t your-dockerhub/audiocraft-studio:latest .
docker push your-dockerhub/audiocraft-studio:latest
```
2. Create RunPod template using `runpod.yaml` as reference
3. Deploy with RTX 4090 or equivalent GPU
## Configuration
Configuration via environment variables:
| Variable | Default | Description |
|----------|---------|-------------|
| `AUDIOCRAFT_HOST` | `0.0.0.0` | Server bind address |
| `AUDIOCRAFT_GRADIO_PORT` | `7860` | Gradio UI port |
| `AUDIOCRAFT_API_PORT` | `8000` | REST API port |
| `AUDIOCRAFT_OUTPUT_DIR` | `./outputs` | Generated audio output |
| `AUDIOCRAFT_DATA_DIR` | `./data` | Database and config |
| `AUDIOCRAFT_COMFYUI_RESERVE_GB` | `10` | VRAM reserved for ComfyUI |
| `AUDIOCRAFT_MAX_LOADED_MODELS` | `2` | Max models in memory |
| `AUDIOCRAFT_IDLE_UNLOAD_MINUTES` | `15` | Auto-unload idle models |
See `.env.example` for full configuration options.
## API Usage
### Authentication
```bash
# Get API key from Settings page or generate via CLI
curl -X POST http://localhost:8000/api/v1/system/api-key/regenerate \
-H "X-API-Key: YOUR_CURRENT_KEY"
```
### Generate Audio
```bash
# Synchronous generation
curl -X POST http://localhost:8000/api/v1/generate \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"model": "musicgen",
"variant": "medium",
"prompts": ["upbeat electronic dance music with synth leads"],
"duration": 10
}'
# Async (queue) generation
curl -X POST http://localhost:8000/api/v1/generate/async \
-H "X-API-Key: YOUR_API_KEY" \
-H "Content-Type: application/json" \
-d '{
"request": {
"model": "musicgen",
"prompts": ["ambient soundscape"],
"duration": 30
},
"priority": 5
}'
```
### Check Job Status
```bash
curl http://localhost:8000/api/v1/generate/jobs/{job_id} \
-H "X-API-Key: YOUR_API_KEY"
```
Full API documentation available at `http://localhost:8000/api/docs`
## Architecture
```
audiocraft-ui/
├── config/
│ ├── settings.py # Pydantic settings
│ └── models.yaml # Model registry
├── src/
│ ├── core/
│ │ ├── base_model.py # Abstract model interface
│ │ ├── gpu_manager.py # VRAM management
│ │ ├── model_registry.py # Model loading/caching
│ │ └── oom_handler.py # OOM recovery
│ ├── models/
│ │ ├── musicgen/ # MusicGen adapter
│ │ ├── audiogen/ # AudioGen adapter
│ │ ├── magnet/ # MAGNeT adapter
│ │ ├── musicgen_style/ # Style adapter
│ │ └── jasco/ # JASCO adapter
│ ├── services/
│ │ ├── generation_service.py
│ │ ├── batch_processor.py
│ │ └── project_service.py
│ ├── storage/
│ │ └── database.py # SQLite storage
│ ├── api/
│ │ ├── app.py # FastAPI app
│ │ └── routes/ # API endpoints
│ └── ui/
│ ├── app.py # Gradio app
│ ├── components/ # Reusable UI components
│ ├── tabs/ # Model generation tabs
│ └── pages/ # Projects, Settings
├── main.py # Entry point
├── Dockerfile
└── docker-compose.yml
```
## ComfyUI Coexistence
AudioCraft Studio is designed to run alongside ComfyUI on the same GPU:
1. Set `AUDIOCRAFT_COMFYUI_RESERVE_GB` to reserve VRAM for ComfyUI
2. Models are automatically unloaded when idle
3. Coordination file at `/tmp/audiocraft_comfyui_coord.json` prevents conflicts
## Development
```bash
# Install dev dependencies
pip install -r requirements-dev.txt
# Run tests
pytest
# Format code
black src/ config/
ruff check src/ config/
# Type checking
mypy src/
```
## License
This project uses Meta's AudioCraft library. See [AudioCraft License](https://github.com/facebookresearch/audiocraft/blob/main/LICENSE).

5
config/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Configuration module for AudioCraft Studio."""
from config.settings import Settings, get_settings
__all__ = ["Settings", "get_settings"]

151
config/models.yaml Normal file
View File

@@ -0,0 +1,151 @@
# AudioCraft Model Registry Configuration
# This file defines all available models and their configurations
models:
musicgen:
enabled: true
display_name: "MusicGen"
description: "Text-to-music generation with optional melody conditioning"
default_variant: medium
variants:
small:
hf_id: facebook/musicgen-small
vram_mb: 1500
max_duration: 30
description: "Fast, lightweight model (300M params)"
medium:
hf_id: facebook/musicgen-medium
vram_mb: 5000
max_duration: 30
description: "Balanced quality and speed (1.5B params)"
large:
hf_id: facebook/musicgen-large
vram_mb: 10000
max_duration: 30
description: "Highest quality, slower (3.3B params)"
melody:
hf_id: facebook/musicgen-melody
vram_mb: 5000
max_duration: 30
conditioning:
- melody
description: "Melody-conditioned generation (1.5B params)"
stereo-small:
hf_id: facebook/musicgen-stereo-small
vram_mb: 1800
max_duration: 30
channels: 2
description: "Stereo output, fast (300M params)"
stereo-medium:
hf_id: facebook/musicgen-stereo-medium
vram_mb: 6000
max_duration: 30
channels: 2
description: "Stereo output, balanced (1.5B params)"
stereo-large:
hf_id: facebook/musicgen-stereo-large
vram_mb: 12000
max_duration: 30
channels: 2
description: "Stereo output, highest quality (3.3B params)"
stereo-melody:
hf_id: facebook/musicgen-stereo-melody
vram_mb: 6000
max_duration: 30
channels: 2
conditioning:
- melody
description: "Stereo melody-conditioned (1.5B params)"
audiogen:
enabled: true
display_name: "AudioGen"
description: "Text-to-sound effects generation"
default_variant: medium
variants:
medium:
hf_id: facebook/audiogen-medium
vram_mb: 5000
max_duration: 10
description: "Sound effects generator (1.5B params)"
magnet:
enabled: true
display_name: "MAGNeT"
description: "Fast non-autoregressive music generation"
default_variant: medium-10secs
variants:
small-10secs:
hf_id: facebook/magnet-small-10secs
vram_mb: 1500
max_duration: 10
description: "Fast 10-second clips (300M params)"
medium-10secs:
hf_id: facebook/magnet-medium-10secs
vram_mb: 5000
max_duration: 10
description: "Quality 10-second clips (1.5B params)"
small-30secs:
hf_id: facebook/magnet-small-30secs
vram_mb: 1800
max_duration: 30
description: "Fast 30-second clips (300M params)"
medium-30secs:
hf_id: facebook/magnet-medium-30secs
vram_mb: 6000
max_duration: 30
description: "Quality 30-second clips (1.5B params)"
musicgen-style:
enabled: true
display_name: "MusicGen Style"
description: "Style-conditioned music generation from reference audio"
default_variant: medium
variants:
medium:
hf_id: facebook/musicgen-style
vram_mb: 5000
max_duration: 30
conditioning:
- style
description: "Style transfer from reference audio (1.5B params)"
jasco:
enabled: true
display_name: "JASCO"
description: "Chord and drum-conditioned music generation"
default_variant: chords-drums-400M
variants:
chords-drums-400M:
hf_id: facebook/jasco-chords-drums-400M
vram_mb: 2000
max_duration: 10
conditioning:
- chords
- drums
description: "Chord/drum control, fast (400M params)"
chords-drums-1B:
hf_id: facebook/jasco-chords-drums-1B
vram_mb: 4000
max_duration: 10
conditioning:
- chords
- drums
description: "Chord/drum control, higher quality (1B params)"
# Default generation parameters
defaults:
generation:
duration: 10
temperature: 1.0
top_k: 250
top_p: 0.0
cfg_coef: 3.0
# VRAM thresholds for warnings
vram:
warning_threshold: 0.85 # 85% utilization warning
critical_threshold: 0.95 # 95% utilization critical
# Presets are loaded from data/presets/*.yaml
presets_dir: "./data/presets"

94
config/settings.py Normal file
View File

@@ -0,0 +1,94 @@
"""Application settings with environment variable support."""
from functools import lru_cache
from pathlib import Path
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application configuration with environment variable support.
All settings can be overridden via environment variables prefixed with AUDIOCRAFT_.
Example: AUDIOCRAFT_API_PORT=8080
"""
model_config = SettingsConfigDict(
env_prefix="AUDIOCRAFT_",
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# Server Configuration
host: str = Field(default="0.0.0.0", description="Server bind host")
gradio_port: int = Field(default=7860, description="Gradio UI port")
api_port: int = Field(default=8000, description="FastAPI port")
# Paths
data_dir: Path = Field(default=Path("./data"), description="Data directory")
output_dir: Path = Field(default=Path("./outputs"), description="Generated audio output")
cache_dir: Path = Field(default=Path("./cache"), description="Model cache directory")
models_config: Path = Field(
default=Path("./config/models.yaml"), description="Model registry config"
)
# VRAM Management
comfyui_reserve_gb: float = Field(
default=10.0, description="VRAM reserved for ComfyUI (GB)"
)
safety_buffer_gb: float = Field(
default=1.0, description="Safety buffer to prevent OOM (GB)"
)
idle_unload_minutes: int = Field(
default=15, description="Unload models after idle time (minutes)"
)
max_cached_models: int = Field(
default=2, description="Maximum number of models to keep loaded"
)
# API Authentication
api_key: Optional[str] = Field(default=None, description="API key for authentication")
cors_origins: list[str] = Field(
default=["*"], description="Allowed CORS origins"
)
# Generation Defaults
default_duration: float = Field(default=10.0, description="Default generation duration")
max_duration: float = Field(default=300.0, description="Maximum generation duration")
default_batch_size: int = Field(default=1, description="Default batch size")
max_batch_size: int = Field(default=8, description="Maximum batch size")
max_queue_size: int = Field(default=100, description="Maximum generation queue size")
# Database
database_url: str = Field(
default="sqlite+aiosqlite:///./data/audiocraft.db",
description="Database connection URL",
)
# Logging
log_level: str = Field(default="INFO", description="Logging level")
def ensure_directories(self) -> None:
"""Create required directories if they don't exist."""
self.data_dir.mkdir(parents=True, exist_ok=True)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir.mkdir(parents=True, exist_ok=True)
(self.data_dir / "presets").mkdir(parents=True, exist_ok=True)
@property
def database_path(self) -> Path:
"""Extract database file path from URL."""
if self.database_url.startswith("sqlite"):
# Handle both sqlite:/// and sqlite+aiosqlite:///
path = self.database_url.split("///")[-1]
return Path(path)
raise ValueError("Only SQLite databases are supported")
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()

64
docker-compose.yml Normal file
View File

@@ -0,0 +1,64 @@
# Docker Compose for local development and testing
# For RunPod deployment, use the Dockerfile directly
version: '3.8'
services:
audiocraft:
build:
context: .
dockerfile: Dockerfile
container_name: audiocraft-studio
ports:
- "7860:7860" # Gradio UI
- "8000:8000" # REST API
volumes:
# Persistent storage
- audiocraft-outputs:/workspace/outputs
- audiocraft-data:/workspace/data
- audiocraft-models:/workspace/models
# Development: mount source code
- ./src:/workspace/app/src:ro
- ./config:/workspace/app/config:ro
environment:
- AUDIOCRAFT_HOST=0.0.0.0
- AUDIOCRAFT_GRADIO_PORT=7860
- AUDIOCRAFT_API_PORT=8000
- AUDIOCRAFT_DEBUG=false
- AUDIOCRAFT_COMFYUI_RESERVE_GB=0 # No ComfyUI in this compose
- NVIDIA_VISIBLE_DEVICES=all
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:7860/"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
# Optional: Run alongside ComfyUI
# comfyui:
# image: your-comfyui-image
# container_name: comfyui
# ports:
# - "8188:8188"
# volumes:
# - comfyui-data:/workspace
# deploy:
# resources:
# reservations:
# devices:
# - driver: nvidia
# count: 1
# capabilities: [gpu]
volumes:
audiocraft-outputs:
audiocraft-data:
audiocraft-models:

147
main.py Normal file
View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
"""Main entry point for AudioCraft Studio."""
import asyncio
import logging
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
from config.settings import get_settings
from src.core.gpu_manager import GPUMemoryManager
from src.core.model_registry import ModelRegistry
from src.services.generation_service import GenerationService
from src.services.batch_processor import BatchProcessor
from src.services.project_service import ProjectService
from src.storage.database import Database
from src.ui.app import create_app
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler("audiocraft.log"),
],
)
logger = logging.getLogger(__name__)
async def initialize_services():
"""Initialize all application services."""
settings = get_settings()
# Initialize database
logger.info("Initializing database...")
db = Database(settings.database_path)
await db.initialize()
# Initialize GPU manager
logger.info("Initializing GPU manager...")
gpu_manager = GPUMemoryManager(
device_id=0,
comfyui_reserve_bytes=int(settings.comfyui_reserve_gb * 1024**3),
)
# Initialize model registry
logger.info("Initializing model registry...")
model_registry = ModelRegistry(
gpu_manager=gpu_manager,
max_loaded=settings.max_loaded_models,
idle_timeout_seconds=settings.idle_unload_minutes * 60,
)
# Initialize services
logger.info("Initializing services...")
generation_service = GenerationService(
model_registry=model_registry,
gpu_manager=gpu_manager,
output_dir=settings.output_dir,
)
batch_processor = BatchProcessor(
generation_service=generation_service,
max_queue_size=settings.max_queue_size,
)
project_service = ProjectService(
db=db,
output_dir=settings.output_dir,
)
return {
"db": db,
"gpu_manager": gpu_manager,
"model_registry": model_registry,
"generation_service": generation_service,
"batch_processor": batch_processor,
"project_service": project_service,
}
def main():
"""Main entry point."""
settings = get_settings()
logger.info("=" * 60)
logger.info("AudioCraft Studio")
logger.info("=" * 60)
logger.info(f"Host: {settings.host}")
logger.info(f"Gradio Port: {settings.gradio_port}")
logger.info(f"API Port: {settings.api_port}")
logger.info(f"Output Dir: {settings.output_dir}")
logger.info("=" * 60)
# Initialize services
logger.info("Initializing services...")
try:
services = asyncio.run(initialize_services())
except Exception as e:
logger.error(f"Failed to initialize services: {e}")
logger.warning("Starting in demo mode without backend services")
services = {}
# Create and launch app
logger.info("Creating Gradio application...")
app = create_app(
generation_service=services.get("generation_service"),
batch_processor=services.get("batch_processor"),
project_service=services.get("project_service"),
gpu_manager=services.get("gpu_manager"),
model_registry=services.get("model_registry"),
)
# Start batch processor if available
batch_processor = services.get("batch_processor")
if batch_processor:
logger.info("Starting batch processor...")
asyncio.run(batch_processor.start())
# Launch the app
logger.info("Launching application...")
try:
app.launch(
server_name=settings.host,
server_port=settings.gradio_port,
share=False,
show_api=settings.api_enabled,
)
except KeyboardInterrupt:
logger.info("Shutting down...")
finally:
# Cleanup
if batch_processor:
asyncio.run(batch_processor.stop())
if "db" in services:
asyncio.run(services["db"].close())
logger.info("Shutdown complete")
if __name__ == "__main__":
main()

89
pyproject.toml Normal file
View File

@@ -0,0 +1,89 @@
[project]
name = "audiocraft-ui"
version = "0.1.0"
description = "Sophisticated AI audio web application based on Facebook's AudioCraft"
readme = "README.md"
license = { text = "MIT" }
requires-python = ">=3.10"
authors = [{ name = "AudioCraft UI Team" }]
keywords = ["audio", "music", "generation", "ai", "audiocraft", "gradio"]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Topic :: Multimedia :: Sound/Audio",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
# Core ML
"torch>=2.1.0",
"torchaudio>=2.1.0",
"audiocraft>=1.3.0",
"xformers>=0.0.22",
# UI
"gradio>=4.0.0",
# API
"fastapi>=0.104.0",
"uvicorn[standard]>=0.24.0",
"python-multipart>=0.0.6",
# GPU Monitoring
"pynvml>=11.5.0",
# Storage
"aiosqlite>=0.19.0",
# Configuration
"pydantic>=2.5.0",
"pydantic-settings>=2.1.0",
"pyyaml>=6.0",
# Audio Processing
"numpy>=1.24.0",
"scipy>=1.11.0",
"librosa>=0.10.0",
"soundfile>=0.12.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.4.0",
"pytest-asyncio>=0.21.0",
"pytest-cov>=4.1.0",
"ruff>=0.1.0",
"mypy>=1.6.0",
]
[project.scripts]
audiocraft-ui = "src.main:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src"]
[tool.ruff]
line-length = 100
target-version = "py310"
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP"]
ignore = ["E501"]
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

30
requirements.txt Normal file
View File

@@ -0,0 +1,30 @@
# Core ML
torch>=2.1.0
torchaudio>=2.1.0
audiocraft>=1.3.0
xformers>=0.0.22
# UI
gradio>=4.0.0
# API
fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
# GPU Monitoring
pynvml>=11.5.0
# Storage
aiosqlite>=0.19.0
# Configuration
pydantic>=2.5.0
pydantic-settings>=2.1.0
pyyaml>=6.0
# Audio Processing
numpy>=1.24.0
scipy>=1.11.0
librosa>=0.10.0
soundfile>=0.12.0

77
runpod.yaml Normal file
View File

@@ -0,0 +1,77 @@
# RunPod Template Configuration
# Use this as reference when creating a RunPod template
name: AudioCraft Studio
description: AI-powered music and sound generation using Meta's AudioCraft
# Container settings
container:
image: your-dockerhub-username/audiocraft-studio:latest
# Or build from GitHub
# dockerfile: Dockerfile
# context: https://github.com/your-username/audiocraft-ui.git
# GPU requirements
gpu:
type: RTX 4090 # Recommended: RTX 4090, RTX 3090, A100
count: 1
minVram: 24 # GB
# Resource limits
resources:
cpu: 8
memory: 32 # GB
disk: 100 # GB (for model cache and outputs)
# Port mappings
ports:
- name: Gradio UI
internal: 7860
external: 7860
protocol: http
- name: REST API
internal: 8000
external: 8000
protocol: http
# Volume mounts
volumes:
- name: outputs
mountPath: /workspace/outputs
size: 50 # GB
- name: models
mountPath: /workspace/models
size: 30 # GB (model cache)
- name: data
mountPath: /workspace/data
size: 10 # GB
# Environment variables
env:
- name: AUDIOCRAFT_HOST
value: "0.0.0.0"
- name: AUDIOCRAFT_GRADIO_PORT
value: "7860"
- name: AUDIOCRAFT_API_PORT
value: "8000"
- name: AUDIOCRAFT_COMFYUI_RESERVE_GB
value: "10" # Reserve VRAM for ComfyUI if running alongside
- name: AUDIOCRAFT_MAX_LOADED_MODELS
value: "2"
- name: AUDIOCRAFT_IDLE_UNLOAD_MINUTES
value: "15"
- name: HF_HOME
value: "/workspace/models/huggingface"
# Startup command
command: ["python", "main.py"]
# Health check
healthCheck:
path: /
port: 7860
initialDelaySeconds: 120
periodSeconds: 30
timeoutSeconds: 10
failureThreshold: 3

116
scripts/download_models.py Executable file
View File

@@ -0,0 +1,116 @@
#!/usr/bin/env python3
"""Pre-download AudioCraft models for faster startup."""
import argparse
import os
from pathlib import Path
def download_musicgen_models(variants: list[str] = None):
"""Download MusicGen models."""
from audiocraft.models import MusicGen
variants = variants or ["small", "medium", "large", "melody"]
for variant in variants:
print(f"Downloading MusicGen {variant}...")
try:
model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}")
del model
print(f" ✓ MusicGen {variant} downloaded")
except Exception as e:
print(f" ✗ Failed to download MusicGen {variant}: {e}")
def download_audiogen_models():
"""Download AudioGen models."""
from audiocraft.models import AudioGen
print("Downloading AudioGen medium...")
try:
model = AudioGen.get_pretrained("facebook/audiogen-medium")
del model
print(" ✓ AudioGen medium downloaded")
except Exception as e:
print(f" ✗ Failed to download AudioGen: {e}")
def download_magnet_models(variants: list[str] = None):
"""Download MAGNeT models."""
from audiocraft.models import MAGNeT
variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"]
for variant in variants:
print(f"Downloading MAGNeT {variant}...")
try:
model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}")
del model
print(f" ✓ MAGNeT {variant} downloaded")
except Exception as e:
print(f" ✗ Failed to download MAGNeT {variant}: {e}")
def main():
parser = argparse.ArgumentParser(description="Pre-download AudioCraft models")
parser.add_argument(
"--models",
nargs="+",
choices=["musicgen", "audiogen", "magnet", "all"],
default=["all"],
help="Models to download",
)
parser.add_argument(
"--musicgen-variants",
nargs="+",
default=["small", "medium"],
help="MusicGen variants to download",
)
parser.add_argument(
"--magnet-variants",
nargs="+",
default=["small", "medium"],
help="MAGNeT variants to download",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Model cache directory",
)
args = parser.parse_args()
# Set cache directory
if args.cache_dir:
os.environ["HF_HOME"] = args.cache_dir
os.environ["TORCH_HOME"] = args.cache_dir
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
models = args.models
if "all" in models:
models = ["musicgen", "audiogen", "magnet"]
print("=" * 50)
print("AudioCraft Model Downloader")
print("=" * 50)
print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
print(f"Models to download: {models}")
print("=" * 50)
if "musicgen" in models:
download_musicgen_models(args.musicgen_variants)
if "audiogen" in models:
download_audiogen_models()
if "magnet" in models:
download_magnet_models(args.magnet_variants)
print("=" * 50)
print("Download complete!")
print("=" * 50)
if __name__ == "__main__":
main()

55
scripts/start.sh Executable file
View File

@@ -0,0 +1,55 @@
#!/bin/bash
# Startup script for AudioCraft Studio
# Used in Docker container and RunPod
set -e
echo "=========================================="
echo " AudioCraft Studio"
echo "=========================================="
# Create directories if they don't exist
mkdir -p "${AUDIOCRAFT_OUTPUT_DIR:-/workspace/outputs}"
mkdir -p "${AUDIOCRAFT_DATA_DIR:-/workspace/data}"
mkdir -p "${AUDIOCRAFT_MODEL_CACHE:-/workspace/models}"
# Check GPU availability
echo "Checking GPU..."
if command -v nvidia-smi &> /dev/null; then
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
else
echo "Warning: nvidia-smi not found"
fi
# Check Python and dependencies
echo "Python version:"
python --version
echo "PyTorch version:"
python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
# Check AudioCraft installation
echo "AudioCraft version:"
python -c "import audiocraft; print(audiocraft.__version__)" 2>/dev/null || echo "AudioCraft installed from source"
# Generate API key if not exists
if [ ! -f "${AUDIOCRAFT_DATA_DIR:-/workspace/data}/.api_key" ]; then
echo "Generating API key..."
python -c "
from src.api.auth import get_key_manager
km = get_key_manager()
if not km.has_key():
key = km.generate_new_key()
print(f'Generated API key: {key}')
print('Store this key securely - it will not be shown again!')
"
fi
# Start the application
echo "Starting AudioCraft Studio..."
echo "Gradio UI: http://0.0.0.0:${AUDIOCRAFT_GRADIO_PORT:-7860}"
echo "REST API: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}"
echo "API Docs: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}/api/docs"
echo "=========================================="
exec python main.py "$@"

3
src/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""AudioCraft Studio - AI Audio Generation Web Application."""
__version__ = "0.1.0"

5
src/api/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""REST API for AudioCraft Studio."""
from src.api.app import create_api_app
__all__ = ["create_api_app"]

150
src/api/app.py Normal file
View File

@@ -0,0 +1,150 @@
"""FastAPI application for AudioCraft Studio REST API."""
from typing import Any, Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import time
from config.settings import get_settings
from src.api.routes import (
generation_router,
projects_router,
models_router,
system_router,
)
from src.api.routes.generation import set_services as set_generation_services
from src.api.routes.projects import set_services as set_project_services
from src.api.routes.models import set_services as set_model_services
from src.api.routes.system import set_services as set_system_services
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler."""
# Startup
yield
# Shutdown
def create_api_app(
generation_service: Any = None,
batch_processor: Any = None,
project_service: Any = None,
gpu_manager: Any = None,
model_registry: Any = None,
) -> FastAPI:
"""Create and configure the FastAPI application.
Args:
generation_service: Service for handling generations
batch_processor: Service for batch/queue processing
project_service: Service for project management
gpu_manager: GPU memory manager
model_registry: Model registry for loading/unloading
Returns:
Configured FastAPI application
"""
settings = get_settings()
app = FastAPI(
title="AudioCraft Studio API",
description="REST API for AI-powered music and sound generation",
version="1.0.0",
docs_url="/api/docs" if settings.api_enabled else None,
redoc_url="/api/redoc" if settings.api_enabled else None,
openapi_url="/api/openapi.json" if settings.api_enabled else None,
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request timing middleware
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
# Global exception handler
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
return JSONResponse(
status_code=500,
content={
"error": "Internal server error",
"detail": str(exc) if settings.debug else "An unexpected error occurred",
},
)
# Inject service dependencies
set_generation_services(generation_service, batch_processor)
set_project_services(project_service)
set_model_services(model_registry)
set_system_services(gpu_manager, batch_processor, model_registry)
# Register routers
app.include_router(generation_router, prefix="/api/v1")
app.include_router(projects_router, prefix="/api/v1")
app.include_router(models_router, prefix="/api/v1")
app.include_router(system_router, prefix="/api/v1")
# Root endpoint
@app.get("/")
async def root():
return {
"name": "AudioCraft Studio API",
"version": "1.0.0",
"docs": "/api/docs",
}
# API info endpoint
@app.get("/api/v1")
async def api_info():
return {
"version": "1.0.0",
"endpoints": {
"generation": "/api/v1/generate",
"projects": "/api/v1/projects",
"models": "/api/v1/models",
"system": "/api/v1/system",
},
}
return app
def run_api_server(
app: FastAPI,
host: Optional[str] = None,
port: Optional[int] = None,
) -> None:
"""Run the API server.
Args:
app: FastAPI application
host: Server hostname
port: Server port
"""
import uvicorn
settings = get_settings()
uvicorn.run(
app,
host=host or settings.host,
port=port or settings.api_port,
log_level="info",
)

133
src/api/auth.py Normal file
View File

@@ -0,0 +1,133 @@
"""API authentication middleware."""
import secrets
import hashlib
from typing import Optional
from pathlib import Path
from fastapi import HTTPException, Security, status
from fastapi.security import APIKeyHeader
from config.settings import get_settings
# API key header
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def generate_api_key() -> str:
"""Generate a new API key."""
return secrets.token_urlsafe(32)
def hash_api_key(key: str) -> str:
"""Hash an API key for storage."""
return hashlib.sha256(key.encode()).hexdigest()
def verify_api_key(key: str, hashed: str) -> bool:
"""Verify an API key against its hash."""
return secrets.compare_digest(hash_api_key(key), hashed)
class APIKeyManager:
"""Manage API keys for authentication."""
def __init__(self, key_file: Optional[Path] = None):
"""Initialize the key manager.
Args:
key_file: Path to store API key hash
"""
self.settings = get_settings()
self.key_file = key_file or Path(self.settings.data_dir) / ".api_key"
self._key_hash: Optional[str] = None
self._load_key()
def _load_key(self) -> None:
"""Load API key hash from file."""
if self.key_file.exists():
self._key_hash = self.key_file.read_text().strip()
def _save_key(self, key_hash: str) -> None:
"""Save API key hash to file."""
self.key_file.parent.mkdir(parents=True, exist_ok=True)
self.key_file.write_text(key_hash)
self._key_hash = key_hash
def generate_new_key(self) -> str:
"""Generate and store a new API key.
Returns:
The new API key (only shown once)
"""
key = generate_api_key()
self._save_key(hash_api_key(key))
return key
def verify(self, key: str) -> bool:
"""Verify an API key.
Args:
key: API key to verify
Returns:
True if valid, False otherwise
"""
if not self._key_hash:
return False
return verify_api_key(key, self._key_hash)
def has_key(self) -> bool:
"""Check if an API key has been generated."""
return self._key_hash is not None
# Global key manager instance
_key_manager: Optional[APIKeyManager] = None
def get_key_manager() -> APIKeyManager:
"""Get the global key manager instance."""
global _key_manager
if _key_manager is None:
_key_manager = APIKeyManager()
return _key_manager
async def verify_api_key_dependency(
api_key: Optional[str] = Security(api_key_header),
) -> str:
"""FastAPI dependency to verify API key.
Args:
api_key: API key from header
Returns:
The verified API key
Raises:
HTTPException: If key is missing or invalid
"""
settings = get_settings()
# Skip auth if disabled
if not settings.api_key_required:
return "anonymous"
if api_key is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key required",
headers={"WWW-Authenticate": "ApiKey"},
)
key_manager = get_key_manager()
if not key_manager.verify(api_key):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API key",
)
return api_key

166
src/api/models.py Normal file
View File

@@ -0,0 +1,166 @@
"""Pydantic models for API requests and responses."""
from datetime import datetime
from typing import Any, Optional
from enum import Enum
from pydantic import BaseModel, Field
class ModelFamily(str, Enum):
"""Available model families."""
MUSICGEN = "musicgen"
AUDIOGEN = "audiogen"
MAGNET = "magnet"
MUSICGEN_STYLE = "musicgen-style"
JASCO = "jasco"
class JobStatus(str, Enum):
"""Generation job status."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
# Generation requests
class GenerationRequest(BaseModel):
"""Request to generate audio."""
model: ModelFamily = Field(..., description="Model family to use")
variant: str = Field("medium", description="Model variant")
prompts: list[str] = Field(..., min_length=1, max_length=10, description="Text prompts")
duration: float = Field(10.0, ge=1, le=30, description="Duration in seconds")
temperature: float = Field(1.0, ge=0, le=2, description="Sampling temperature")
top_k: int = Field(250, ge=0, le=500, description="Top-K sampling")
top_p: float = Field(0.0, ge=0, le=1, description="Top-P (nucleus) sampling")
cfg_coef: float = Field(3.0, ge=1, le=10, description="CFG coefficient")
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
conditioning: Optional[dict[str, Any]] = Field(None, description="Model-specific conditioning")
project_id: Optional[str] = Field(None, description="Project to save to")
class BatchGenerationRequest(BaseModel):
"""Request to add generation to queue."""
request: GenerationRequest
priority: int = Field(0, ge=0, le=10, description="Job priority (higher = sooner)")
# Generation responses
class GenerationResult(BaseModel):
"""Result of a completed generation."""
id: str = Field(..., description="Generation ID")
audio_url: str = Field(..., description="URL to download audio")
waveform_url: Optional[str] = Field(None, description="URL to waveform image")
duration: float = Field(..., description="Actual duration in seconds")
seed: int = Field(..., description="Seed used for generation")
model: str = Field(..., description="Model used")
variant: str = Field(..., description="Variant used")
prompt: str = Field(..., description="Prompt used")
created_at: datetime = Field(..., description="Creation timestamp")
class JobResponse(BaseModel):
"""Response for a queued job."""
job_id: str = Field(..., description="Job ID for tracking")
status: JobStatus = Field(..., description="Current status")
position: Optional[int] = Field(None, description="Queue position if pending")
progress: Optional[float] = Field(None, description="Progress 0-1 if running")
result: Optional[GenerationResult] = Field(None, description="Result if completed")
error: Optional[str] = Field(None, description="Error message if failed")
# Project models
class ProjectCreate(BaseModel):
"""Request to create a project."""
name: str = Field(..., min_length=1, max_length=100)
description: Optional[str] = Field(None, max_length=500)
class ProjectResponse(BaseModel):
"""Project information."""
id: str
name: str
description: Optional[str]
generation_count: int
created_at: datetime
updated_at: datetime
class GenerationResponse(BaseModel):
"""Generation record from database."""
id: str
project_id: str
model: str
variant: str
prompt: str
duration_seconds: float
seed: int
audio_path: str
waveform_path: Optional[str]
parameters: dict[str, Any]
created_at: datetime
# Model info
class ModelVariantInfo(BaseModel):
"""Information about a model variant."""
id: str
name: str
vram_mb: int
description: str
capabilities: list[str]
class ModelInfo(BaseModel):
"""Information about a model family."""
id: str
name: str
description: str
variants: list[ModelVariantInfo]
loaded: bool
current_variant: Optional[str]
# System info
class GPUStatus(BaseModel):
"""GPU memory status."""
device_name: str
total_gb: float
used_gb: float
available_gb: float
utilization_percent: float
temperature_c: Optional[float]
class QueueStatus(BaseModel):
"""Generation queue status."""
queue_size: int
active_jobs: int
completed_today: int
failed_today: int
class SystemStatus(BaseModel):
"""Overall system status."""
gpu: GPUStatus
queue: QueueStatus
loaded_models: list[str]
uptime_seconds: float
# Pagination
class PaginatedResponse(BaseModel):
"""Paginated list response."""
items: list[Any]
total: int
page: int
page_size: int
pages: int

View File

@@ -0,0 +1,13 @@
"""API route modules."""
from src.api.routes.generation import router as generation_router
from src.api.routes.projects import router as projects_router
from src.api.routes.models import router as models_router
from src.api.routes.system import router as system_router
__all__ = [
"generation_router",
"projects_router",
"models_router",
"system_router",
]

View File

@@ -0,0 +1,234 @@
"""Generation API endpoints."""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, status
from fastapi.responses import FileResponse
from src.api.auth import verify_api_key_dependency
from src.api.models import (
GenerationRequest,
BatchGenerationRequest,
GenerationResult,
JobResponse,
JobStatus,
)
router = APIRouter(prefix="/generate", tags=["generation"])
# Service dependencies (injected at app startup)
_generation_service = None
_batch_processor = None
def set_services(generation_service: Any, batch_processor: Any) -> None:
"""Set service dependencies."""
global _generation_service, _batch_processor
_generation_service = generation_service
_batch_processor = batch_processor
@router.post(
"/",
response_model=GenerationResult,
summary="Generate audio synchronously",
description="Generate audio and wait for completion. For long generations, consider using the async endpoint.",
)
async def generate_sync(
request: GenerationRequest,
api_key: str = Depends(verify_api_key_dependency),
) -> GenerationResult:
"""Generate audio synchronously."""
if _generation_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Generation service not available",
)
try:
result, generation = await _generation_service.generate(
model_id=request.model.value,
variant=request.variant,
prompts=request.prompts,
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
seed=request.seed,
conditioning=request.conditioning,
project_id=request.project_id,
)
return GenerationResult(
id=generation.id,
audio_url=f"/api/v1/audio/{generation.id}",
waveform_url=f"/api/v1/audio/{generation.id}/waveform" if generation.waveform_path else None,
duration=result.duration,
seed=result.seed,
model=request.model.value,
variant=request.variant,
prompt=request.prompts[0],
created_at=generation.created_at,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.post(
"/async",
response_model=JobResponse,
summary="Queue generation job",
description="Add a generation to the queue for async processing.",
)
async def generate_async(
request: BatchGenerationRequest,
api_key: str = Depends(verify_api_key_dependency),
) -> JobResponse:
"""Add generation to queue."""
if _batch_processor is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Batch processor not available",
)
try:
job = _batch_processor.add_job(
model_id=request.request.model.value,
variant=request.request.variant,
prompts=request.request.prompts,
duration=request.request.duration,
temperature=request.request.temperature,
top_k=request.request.top_k,
top_p=request.request.top_p,
cfg_coef=request.request.cfg_coef,
seed=request.request.seed,
conditioning=request.request.conditioning,
project_id=request.request.project_id,
priority=request.priority,
)
return JobResponse(
job_id=job.id,
status=JobStatus.PENDING,
position=_batch_processor.get_position(job.id),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/jobs/{job_id}",
response_model=JobResponse,
summary="Get job status",
description="Check the status of a queued generation job.",
)
async def get_job_status(
job_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> JobResponse:
"""Get status of a queued job."""
if _batch_processor is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Batch processor not available",
)
job = _batch_processor.get_job(job_id)
if job is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job {job_id} not found",
)
response = JobResponse(
job_id=job.id,
status=JobStatus(job.status.value),
)
if job.status.value == "pending":
response.position = _batch_processor.get_position(job_id)
elif job.status.value == "running":
response.progress = job.progress
elif job.status.value == "completed" and job.result:
response.result = GenerationResult(
id=job.result.id,
audio_url=f"/api/v1/audio/{job.result.id}",
waveform_url=f"/api/v1/audio/{job.result.id}/waveform",
duration=job.result.duration,
seed=job.result.seed,
model=job.model_id,
variant=job.variant,
prompt=job.prompts[0],
created_at=job.completed_at,
)
elif job.status.value == "failed":
response.error = job.error
return response
@router.delete(
"/jobs/{job_id}",
summary="Cancel job",
description="Cancel a pending or running job.",
)
async def cancel_job(
job_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Cancel a queued job."""
if _batch_processor is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Batch processor not available",
)
success = _batch_processor.cancel_job(job_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Job {job_id} not found or cannot be cancelled",
)
return {"message": f"Job {job_id} cancelled"}
@router.get(
"/jobs",
response_model=list[JobResponse],
summary="List jobs",
description="List all jobs in the queue.",
)
async def list_jobs(
status_filter: str = None,
limit: int = 50,
api_key: str = Depends(verify_api_key_dependency),
) -> list[JobResponse]:
"""List queued jobs."""
if _batch_processor is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Batch processor not available",
)
jobs = _batch_processor.list_jobs(status_filter=status_filter, limit=limit)
return [
JobResponse(
job_id=job.id,
status=JobStatus(job.status.value),
position=_batch_processor.get_position(job.id) if job.status.value == "pending" else None,
progress=job.progress if job.status.value == "running" else None,
)
for job in jobs
]

228
src/api/routes/models.py Normal file
View File

@@ -0,0 +1,228 @@
"""Models API endpoints."""
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from src.api.auth import verify_api_key_dependency
from src.api.models import ModelInfo, ModelVariantInfo
router = APIRouter(prefix="/models", tags=["models"])
# Service dependency (injected at app startup)
_model_registry = None
def set_services(model_registry: Any) -> None:
"""Set service dependencies."""
global _model_registry
_model_registry = model_registry
# Static model information
MODEL_CATALOG = {
"musicgen": {
"id": "musicgen",
"name": "MusicGen",
"description": "Text-to-music generation with optional melody conditioning",
"variants": [
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params", "capabilities": ["text"]},
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params", "capabilities": ["text"]},
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params", "capabilities": ["text"]},
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning", "capabilities": ["text", "melody"]},
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params", "capabilities": ["text", "stereo"]},
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params", "capabilities": ["text", "stereo"]},
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params", "capabilities": ["text", "stereo"]},
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody", "capabilities": ["text", "melody", "stereo"]},
],
},
"audiogen": {
"id": "audiogen",
"name": "AudioGen",
"description": "Text-to-sound effects and environmental audio",
"variants": [
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params", "capabilities": ["text", "sfx"]},
],
},
"magnet": {
"id": "magnet",
"name": "MAGNeT",
"description": "Fast non-autoregressive music generation",
"variants": [
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params", "capabilities": ["text", "music"]},
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params", "capabilities": ["text", "music"]},
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects", "capabilities": ["text", "sfx"]},
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects", "capabilities": ["text", "sfx"]},
],
},
"musicgen-style": {
"id": "musicgen-style",
"name": "MusicGen Style",
"description": "Style-conditioned music from reference audio",
"variants": [
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning", "capabilities": ["text", "style"]},
],
},
"jasco": {
"id": "jasco",
"name": "JASCO",
"description": "Chord and drum-conditioned music generation",
"variants": [
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation", "capabilities": ["text", "chords"]},
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning", "capabilities": ["text", "chords", "drums"]},
],
},
}
@router.get(
"/",
response_model=list[ModelInfo],
summary="List models",
description="Get information about all available models.",
)
async def list_models(
api_key: str = Depends(verify_api_key_dependency),
) -> list[ModelInfo]:
"""List all available models."""
models = []
for model_id, info in MODEL_CATALOG.items():
loaded = False
current_variant = None
if _model_registry:
loaded = _model_registry.is_loaded(model_id)
if loaded:
current_variant = _model_registry.get_current_variant(model_id)
models.append(
ModelInfo(
id=info["id"],
name=info["name"],
description=info["description"],
variants=[ModelVariantInfo(**v) for v in info["variants"]],
loaded=loaded,
current_variant=current_variant,
)
)
return models
@router.get(
"/{model_id}",
response_model=ModelInfo,
summary="Get model info",
description="Get detailed information about a specific model.",
)
async def get_model(
model_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> ModelInfo:
"""Get model information by ID."""
if model_id not in MODEL_CATALOG:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found",
)
info = MODEL_CATALOG[model_id]
loaded = False
current_variant = None
if _model_registry:
loaded = _model_registry.is_loaded(model_id)
if loaded:
current_variant = _model_registry.get_current_variant(model_id)
return ModelInfo(
id=info["id"],
name=info["name"],
description=info["description"],
variants=[ModelVariantInfo(**v) for v in info["variants"]],
loaded=loaded,
current_variant=current_variant,
)
@router.post(
"/{model_id}/load",
summary="Load model",
description="Load a model into GPU memory.",
)
async def load_model(
model_id: str,
variant: str = "medium",
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Load a model into memory."""
if model_id not in MODEL_CATALOG:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found",
)
if _model_registry is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model registry not available",
)
try:
await _model_registry.load_model(model_id, variant)
return {"message": f"Model {model_id} ({variant}) loaded successfully"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.post(
"/{model_id}/unload",
summary="Unload model",
description="Unload a model from GPU memory.",
)
async def unload_model(
model_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Unload a model from memory."""
if model_id not in MODEL_CATALOG:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model_id} not found",
)
if _model_registry is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model registry not available",
)
try:
await _model_registry.unload_model(model_id)
return {"message": f"Model {model_id} unloaded successfully"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/loaded",
response_model=list[str],
summary="List loaded models",
description="Get list of currently loaded models.",
)
async def list_loaded_models(
api_key: str = Depends(verify_api_key_dependency),
) -> list[str]:
"""List currently loaded models."""
if _model_registry is None:
return []
return _model_registry.get_loaded_models()

250
src/api/routes/projects.py Normal file
View File

@@ -0,0 +1,250 @@
"""Projects API endpoints."""
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from fastapi.responses import FileResponse
from src.api.auth import verify_api_key_dependency
from src.api.models import (
ProjectCreate,
ProjectResponse,
GenerationResponse,
PaginatedResponse,
)
router = APIRouter(prefix="/projects", tags=["projects"])
# Service dependency (injected at app startup)
_project_service = None
def set_services(project_service: Any) -> None:
"""Set service dependencies."""
global _project_service
_project_service = project_service
@router.post(
"/",
response_model=ProjectResponse,
status_code=status.HTTP_201_CREATED,
summary="Create project",
description="Create a new project for organizing generations.",
)
async def create_project(
request: ProjectCreate,
api_key: str = Depends(verify_api_key_dependency),
) -> ProjectResponse:
"""Create a new project."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
project = await _project_service.create_project(
name=request.name,
description=request.description,
)
return ProjectResponse(**project)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/",
response_model=list[ProjectResponse],
summary="List projects",
description="Get all projects.",
)
async def list_projects(
api_key: str = Depends(verify_api_key_dependency),
) -> list[ProjectResponse]:
"""List all projects."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
projects = await _project_service.list_projects()
return [ProjectResponse(**p) for p in projects]
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/{project_id}",
response_model=ProjectResponse,
summary="Get project",
description="Get project details by ID.",
)
async def get_project(
project_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> ProjectResponse:
"""Get a project by ID."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
project = await _project_service.get_project(project_id)
if project is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project {project_id} not found",
)
return ProjectResponse(**project)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.delete(
"/{project_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete project",
description="Delete a project and all its generations.",
)
async def delete_project(
project_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> None:
"""Delete a project."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
await _project_service.delete_project(project_id)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/{project_id}/generations",
response_model=PaginatedResponse,
summary="List generations",
description="Get generations for a project with pagination.",
)
async def list_generations(
project_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
model: Optional[str] = Query(None, description="Filter by model"),
api_key: str = Depends(verify_api_key_dependency),
) -> PaginatedResponse:
"""List generations for a project."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
offset = (page - 1) * page_size
generations = await _project_service.list_generations(
project_id=project_id,
limit=page_size + 1, # +1 to check if more pages
offset=offset,
model_filter=model,
)
has_more = len(generations) > page_size
generations = generations[:page_size]
# Estimate total (could be improved with actual count query)
total = offset + len(generations) + (1 if has_more else 0)
pages = (total + page_size - 1) // page_size
return PaginatedResponse(
items=[GenerationResponse(**g) for g in generations],
total=total,
page=page,
page_size=page_size,
pages=pages,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/{project_id}/export",
summary="Export project",
description="Export project as ZIP file with all audio and metadata.",
)
async def export_project(
project_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> FileResponse:
"""Export project as ZIP."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
zip_path = await _project_service.export_project_zip(project_id)
return FileResponse(
path=zip_path,
filename=f"project_{project_id}.zip",
media_type="application/zip",
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.delete(
"/{project_id}/generations/{generation_id}",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete generation",
description="Delete a specific generation.",
)
async def delete_generation(
project_id: str,
generation_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> None:
"""Delete a generation."""
if _project_service is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Project service not available",
)
try:
await _project_service.delete_generation(generation_id)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)

263
src/api/routes/system.py Normal file
View File

@@ -0,0 +1,263 @@
"""System API endpoints."""
import time
from typing import Any
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from src.api.auth import verify_api_key_dependency, get_key_manager
from src.api.models import GPUStatus, QueueStatus, SystemStatus
router = APIRouter(prefix="/system", tags=["system"])
# Service dependencies (injected at app startup)
_gpu_manager = None
_batch_processor = None
_model_registry = None
_start_time = time.time()
def set_services(
gpu_manager: Any,
batch_processor: Any,
model_registry: Any,
) -> None:
"""Set service dependencies."""
global _gpu_manager, _batch_processor, _model_registry
_gpu_manager = gpu_manager
_batch_processor = batch_processor
_model_registry = model_registry
@router.get(
"/status",
response_model=SystemStatus,
summary="System status",
description="Get overall system status including GPU, queue, and loaded models.",
)
async def get_status(
api_key: str = Depends(verify_api_key_dependency),
) -> SystemStatus:
"""Get system status."""
# GPU status
if _gpu_manager:
gpu = GPUStatus(
device_name=_gpu_manager.device_name,
total_gb=_gpu_manager.total_memory / 1024**3,
used_gb=_gpu_manager.get_used_memory() / 1024**3,
available_gb=_gpu_manager.get_available_memory() / 1024**3,
utilization_percent=_gpu_manager.get_utilization(),
temperature_c=_gpu_manager.get_temperature(),
)
else:
gpu = GPUStatus(
device_name="Unknown",
total_gb=0,
used_gb=0,
available_gb=0,
utilization_percent=0,
temperature_c=None,
)
# Queue status
if _batch_processor:
queue = QueueStatus(
queue_size=len(_batch_processor.queue),
active_jobs=_batch_processor.active_count,
completed_today=_batch_processor.completed_count,
failed_today=_batch_processor.failed_count,
)
else:
queue = QueueStatus(
queue_size=0,
active_jobs=0,
completed_today=0,
failed_today=0,
)
# Loaded models
loaded_models = []
if _model_registry:
loaded_models = _model_registry.get_loaded_models()
return SystemStatus(
gpu=gpu,
queue=queue,
loaded_models=loaded_models,
uptime_seconds=time.time() - _start_time,
)
@router.get(
"/gpu",
response_model=GPUStatus,
summary="GPU status",
description="Get detailed GPU memory and utilization status.",
)
async def get_gpu_status(
api_key: str = Depends(verify_api_key_dependency),
) -> GPUStatus:
"""Get GPU status."""
if _gpu_manager is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="GPU manager not available",
)
return GPUStatus(
device_name=_gpu_manager.device_name,
total_gb=_gpu_manager.total_memory / 1024**3,
used_gb=_gpu_manager.get_used_memory() / 1024**3,
available_gb=_gpu_manager.get_available_memory() / 1024**3,
utilization_percent=_gpu_manager.get_utilization(),
temperature_c=_gpu_manager.get_temperature(),
)
@router.post(
"/clear-cache",
summary="Clear cache",
description="Clear model cache and free GPU memory.",
)
async def clear_cache(
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Clear model cache."""
if _model_registry is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model registry not available",
)
try:
_model_registry.clear_cache()
return {"message": "Cache cleared successfully"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.post(
"/unload-all",
summary="Unload all models",
description="Unload all models from GPU memory.",
)
async def unload_all_models(
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Unload all models."""
if _model_registry is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Model registry not available",
)
try:
await _model_registry.unload_all()
return {"message": "All models unloaded successfully"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
@router.get(
"/health",
summary="Health check",
description="Simple health check endpoint.",
)
async def health_check() -> dict:
"""Health check endpoint (no auth required)."""
return {
"status": "healthy",
"uptime_seconds": time.time() - _start_time,
}
@router.post(
"/api-key/regenerate",
summary="Regenerate API key",
description="Generate a new API key. The old key will be invalidated.",
)
async def regenerate_api_key(
api_key: str = Depends(verify_api_key_dependency),
) -> dict:
"""Regenerate API key."""
key_manager = get_key_manager()
new_key = key_manager.generate_new_key()
return {
"api_key": new_key,
"message": "New API key generated. Store it securely - it won't be shown again.",
}
@router.get(
"/audio/{generation_id}",
summary="Download audio",
description="Download generated audio file.",
)
async def download_audio(
generation_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> FileResponse:
"""Download audio file for a generation."""
# This would look up the actual file path from the database
# For now, construct expected path
from config.settings import get_settings
settings = get_settings()
# Find the audio file
audio_dir = Path(settings.output_dir)
possible_paths = [
audio_dir / f"{generation_id}.wav",
audio_dir / f"{generation_id}.mp3",
audio_dir / f"{generation_id}.flac",
]
for path in possible_paths:
if path.exists():
return FileResponse(
path=path,
filename=path.name,
media_type="audio/wav" if path.suffix == ".wav" else f"audio/{path.suffix[1:]}",
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Audio file for generation {generation_id} not found",
)
@router.get(
"/audio/{generation_id}/waveform",
summary="Download waveform",
description="Download waveform visualization image.",
)
async def download_waveform(
generation_id: str,
api_key: str = Depends(verify_api_key_dependency),
) -> FileResponse:
"""Download waveform image for a generation."""
from config.settings import get_settings
settings = get_settings()
waveform_path = Path(settings.output_dir) / f"{generation_id}_waveform.png"
if not waveform_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Waveform for generation {generation_id} not found",
)
return FileResponse(
path=waveform_path,
filename=waveform_path.name,
media_type="image/png",
)

24
src/core/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""Core infrastructure for AudioCraft Studio."""
from src.core.base_model import (
BaseAudioModel,
GenerationRequest,
GenerationResult,
ConditioningType,
)
from src.core.gpu_manager import GPUMemoryManager, VRAMBudget
from src.core.model_registry import ModelRegistry
from src.core.oom_handler import OOMHandler, OOMRecoveryError, oom_safe
__all__ = [
"BaseAudioModel",
"GenerationRequest",
"GenerationResult",
"ConditioningType",
"GPUMemoryManager",
"VRAMBudget",
"ModelRegistry",
"OOMHandler",
"OOMRecoveryError",
"oom_safe",
]

535
src/core/audio_utils.py Normal file
View File

@@ -0,0 +1,535 @@
"""Audio utilities for processing, visualization, and export."""
import io
import logging
from pathlib import Path
from typing import Optional, Tuple, Union
import numpy as np
import torch
logger = logging.getLogger(__name__)
def normalize_audio(
audio: Union[torch.Tensor, np.ndarray],
target_db: float = -14.0,
peak_normalize: bool = False,
) -> np.ndarray:
"""Normalize audio to target loudness.
Args:
audio: Audio tensor or array [channels, samples] or [samples]
target_db: Target loudness in dB (LUFS-like)
peak_normalize: If True, normalize to peak instead of RMS
Returns:
Normalized audio as numpy array
"""
if isinstance(audio, torch.Tensor):
audio = audio.numpy()
# Ensure float32
audio = audio.astype(np.float32)
# Handle batch dimension
if audio.ndim == 3:
audio = audio[0] # Take first sample if batched
if peak_normalize:
# Peak normalization
peak = np.abs(audio).max()
if peak > 0:
target_linear = 10 ** (target_db / 20)
audio = audio * (target_linear / peak)
else:
# RMS normalization (approximating LUFS)
rms = np.sqrt(np.mean(audio ** 2))
if rms > 0:
target_rms = 10 ** (target_db / 20)
audio = audio * (target_rms / rms)
# Clip to prevent clipping
audio = np.clip(audio, -1.0, 1.0)
return audio
def convert_sample_rate(
audio: np.ndarray,
orig_sr: int,
target_sr: int,
) -> np.ndarray:
"""Convert audio sample rate.
Args:
audio: Audio array [channels, samples] or [samples]
orig_sr: Original sample rate
target_sr: Target sample rate
Returns:
Resampled audio
"""
if orig_sr == target_sr:
return audio
try:
import librosa
# Handle multi-channel
if audio.ndim == 2:
resampled = np.array([
librosa.resample(ch, orig_sr=orig_sr, target_sr=target_sr)
for ch in audio
])
else:
resampled = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
return resampled
except ImportError:
logger.warning("librosa not available, using scipy for resampling")
from scipy import signal
ratio = target_sr / orig_sr
new_length = int(audio.shape[-1] * ratio)
if audio.ndim == 2:
resampled = np.array([
signal.resample(ch, new_length) for ch in audio
])
else:
resampled = signal.resample(audio, new_length)
return resampled
def generate_waveform(
audio: Union[torch.Tensor, np.ndarray],
sample_rate: int,
width: int = 800,
height: int = 200,
color: str = "#3b82f6",
background: str = "#1f2937",
) -> bytes:
"""Generate waveform image as PNG bytes.
Args:
audio: Audio data [channels, samples] or [samples]
sample_rate: Sample rate in Hz
width: Image width in pixels
height: Image height in pixels
color: Waveform color (hex)
background: Background color (hex)
Returns:
PNG image as bytes
"""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
except ImportError:
logger.warning("matplotlib not available for waveform generation")
return b""
if isinstance(audio, torch.Tensor):
audio = audio.numpy()
# Handle dimensions
if audio.ndim == 3:
audio = audio[0]
if audio.ndim == 2:
audio = audio.mean(axis=0) # Mix to mono for visualization
# Downsample for visualization
samples_per_pixel = max(1, len(audio) // width)
num_chunks = len(audio) // samples_per_pixel
if num_chunks > 0:
audio_chunks = audio[:num_chunks * samples_per_pixel].reshape(
num_chunks, samples_per_pixel
)
# Get min/max for each chunk
mins = audio_chunks.min(axis=1)
maxs = audio_chunks.max(axis=1)
else:
mins = maxs = audio
# Create figure
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
fig.patch.set_facecolor(background)
ax.set_facecolor(background)
# Plot waveform
x = np.arange(len(mins))
ax.fill_between(x, mins, maxs, color=color, alpha=0.7)
ax.axhline(y=0, color=color, alpha=0.3, linewidth=0.5)
# Style
ax.set_xlim(0, len(mins))
ax.set_ylim(-1, 1)
ax.axis('off')
plt.tight_layout(pad=0)
# Save to bytes
buf = io.BytesIO()
fig.savefig(buf, format='png', facecolor=background, edgecolor='none')
plt.close(fig)
buf.seek(0)
return buf.read()
def generate_spectrogram(
audio: Union[torch.Tensor, np.ndarray],
sample_rate: int,
width: int = 800,
height: int = 200,
colormap: str = "magma",
) -> bytes:
"""Generate spectrogram image as PNG bytes.
Args:
audio: Audio data
sample_rate: Sample rate in Hz
width: Image width
height: Image height
colormap: Matplotlib colormap name
Returns:
PNG image as bytes
"""
try:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import librosa
import librosa.display
except ImportError:
logger.warning("matplotlib/librosa not available for spectrogram")
return b""
if isinstance(audio, torch.Tensor):
audio = audio.numpy()
# Handle dimensions
if audio.ndim == 3:
audio = audio[0]
if audio.ndim == 2:
audio = audio.mean(axis=0)
# Compute mel spectrogram
S = librosa.feature.melspectrogram(
y=audio,
sr=sample_rate,
n_mels=128,
fmax=sample_rate // 2,
)
S_db = librosa.power_to_db(S, ref=np.max)
# Create figure
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
librosa.display.specshow(
S_db,
sr=sample_rate,
x_axis='time',
y_axis='mel',
cmap=colormap,
ax=ax,
)
ax.axis('off')
plt.tight_layout(pad=0)
# Save to bytes
buf = io.BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
return buf.read()
def save_audio(
audio: Union[torch.Tensor, np.ndarray],
sample_rate: int,
path: Path,
format: str = "wav",
normalize: bool = True,
target_db: float = -14.0,
) -> Path:
"""Save audio to file with optional normalization.
Args:
audio: Audio data
sample_rate: Sample rate
path: Output path (extension will be added if needed)
format: Output format (wav, mp3, flac, ogg)
normalize: Whether to normalize audio
target_db: Normalization target
Returns:
Path to saved file
"""
import soundfile as sf
if isinstance(audio, torch.Tensor):
audio = audio.numpy()
# Handle batch dimension
if audio.ndim == 3:
audio = audio[0]
# Normalize if requested
if normalize:
audio = normalize_audio(audio, target_db=target_db)
# Transpose for soundfile [samples, channels]
if audio.ndim == 2:
audio = audio.T
# Ensure correct extension
path = Path(path)
if not path.suffix:
path = path.with_suffix(f".{format}")
# Save based on format
if format in ("wav", "flac"):
sf.write(path, audio, sample_rate)
elif format == "mp3":
# Use scipy.io.wavfile then convert with pydub if available
try:
from pydub import AudioSegment
# Save as WAV first
wav_path = path.with_suffix(".wav")
sf.write(wav_path, audio, sample_rate)
# Convert to MP3
sound = AudioSegment.from_wav(wav_path)
sound.export(path, format="mp3", bitrate="320k")
# Remove temp WAV
wav_path.unlink()
except ImportError:
logger.warning("pydub not available, saving as WAV instead")
path = path.with_suffix(".wav")
sf.write(path, audio, sample_rate)
elif format == "ogg":
sf.write(path, audio, sample_rate, format="ogg", subtype="vorbis")
else:
# Default to WAV
path = path.with_suffix(".wav")
sf.write(path, audio, sample_rate)
return path
def load_audio(
path: Path,
target_sr: Optional[int] = None,
mono: bool = False,
) -> Tuple[np.ndarray, int]:
"""Load audio from file.
Args:
path: Path to audio file
target_sr: Target sample rate (None to keep original)
mono: Convert to mono
Returns:
Tuple of (audio_array, sample_rate)
"""
import soundfile as sf
audio, sr = sf.read(path)
# Convert to [channels, samples] format
if audio.ndim == 1:
audio = audio[np.newaxis, :]
else:
audio = audio.T
# Convert to mono
if mono and audio.shape[0] > 1:
audio = audio.mean(axis=0, keepdims=True)
# Resample if needed
if target_sr and target_sr != sr:
audio = convert_sample_rate(audio, sr, target_sr)
sr = target_sr
return audio, sr
def get_audio_info(path: Path) -> dict:
"""Get audio file information.
Args:
path: Path to audio file
Returns:
Dictionary with audio info
"""
import soundfile as sf
info = sf.info(path)
return {
"path": str(path),
"duration": info.duration,
"sample_rate": info.samplerate,
"channels": info.channels,
"format": info.format,
"subtype": info.subtype,
"frames": info.frames,
}
def trim_silence(
audio: np.ndarray,
sample_rate: int,
threshold_db: float = -40.0,
min_silence_ms: int = 100,
) -> np.ndarray:
"""Trim silence from start and end of audio.
Args:
audio: Audio array
sample_rate: Sample rate
threshold_db: Silence threshold in dB
min_silence_ms: Minimum silence duration to trim
Returns:
Trimmed audio
"""
try:
import librosa
if audio.ndim == 2:
# Process mono for trimming
mono = audio.mean(axis=0)
else:
mono = audio
# Get non-silent intervals
intervals = librosa.effects.split(
mono,
top_db=abs(threshold_db),
frame_length=int(sample_rate * min_silence_ms / 1000),
)
if len(intervals) == 0:
return audio
start = intervals[0][0]
end = intervals[-1][1]
if audio.ndim == 2:
return audio[:, start:end]
return audio[start:end]
except ImportError:
logger.warning("librosa not available for silence trimming")
return audio
def apply_fade(
audio: np.ndarray,
sample_rate: int,
fade_in_ms: float = 0,
fade_out_ms: float = 0,
) -> np.ndarray:
"""Apply fade in/out to audio.
Args:
audio: Audio array [channels, samples] or [samples]
sample_rate: Sample rate
fade_in_ms: Fade in duration in milliseconds
fade_out_ms: Fade out duration in milliseconds
Returns:
Audio with fades applied
"""
audio = audio.copy()
if fade_in_ms > 0:
fade_in_samples = int(sample_rate * fade_in_ms / 1000)
fade_in_samples = min(fade_in_samples, audio.shape[-1])
fade_in_curve = np.linspace(0, 1, fade_in_samples)
if audio.ndim == 2:
audio[:, :fade_in_samples] *= fade_in_curve
else:
audio[:fade_in_samples] *= fade_in_curve
if fade_out_ms > 0:
fade_out_samples = int(sample_rate * fade_out_ms / 1000)
fade_out_samples = min(fade_out_samples, audio.shape[-1])
fade_out_curve = np.linspace(1, 0, fade_out_samples)
if audio.ndim == 2:
audio[:, -fade_out_samples:] *= fade_out_curve
else:
audio[-fade_out_samples:] *= fade_out_curve
return audio
def concatenate_audio(
audio_list: list[np.ndarray],
sample_rate: int,
crossfade_ms: float = 0,
) -> np.ndarray:
"""Concatenate multiple audio segments.
Args:
audio_list: List of audio arrays
sample_rate: Sample rate (must be same for all)
crossfade_ms: Crossfade duration between segments
Returns:
Concatenated audio
"""
if not audio_list:
return np.array([])
if len(audio_list) == 1:
return audio_list[0]
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
result = audio_list[0]
for audio in audio_list[1:]:
if crossfade_samples > 0 and crossfade_samples < min(
result.shape[-1], audio.shape[-1]
):
# Apply crossfade
fade_out = np.linspace(1, 0, crossfade_samples)
fade_in = np.linspace(0, 1, crossfade_samples)
if result.ndim == 2:
# Overlap region
result[:, -crossfade_samples:] *= fade_out
overlap = result[:, -crossfade_samples:] + audio[:, :crossfade_samples] * fade_in
result = np.concatenate([
result[:, :-crossfade_samples],
overlap,
audio[:, crossfade_samples:]
], axis=1)
else:
result[-crossfade_samples:] *= fade_out
overlap = result[-crossfade_samples:] + audio[:crossfade_samples] * fade_in
result = np.concatenate([
result[:-crossfade_samples],
overlap,
audio[crossfade_samples:]
])
else:
# Simple concatenation
result = np.concatenate([result, audio], axis=-1)
return result

247
src/core/base_model.py Normal file
View File

@@ -0,0 +1,247 @@
"""Abstract base classes for AudioCraft model adapters."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Optional
import torch
class ConditioningType(str, Enum):
"""Types of conditioning supported by models."""
TEXT = "text"
MELODY = "melody"
STYLE = "style"
CHORDS = "chords"
DRUMS = "drums"
@dataclass
class GenerationRequest:
"""Request parameters for audio generation.
Attributes:
prompts: List of text prompts for generation
duration: Target duration in seconds
temperature: Sampling temperature (higher = more random)
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter (0 = disabled)
cfg_coef: Classifier-free guidance coefficient
batch_size: Number of samples to generate per prompt
seed: Random seed for reproducibility
conditioning: Optional conditioning data
"""
prompts: list[str]
duration: float = 10.0
temperature: float = 1.0
top_k: int = 250
top_p: float = 0.0
cfg_coef: float = 3.0
batch_size: int = 1
seed: Optional[int] = None
conditioning: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate request parameters."""
if not self.prompts:
raise ValueError("At least one prompt is required")
if self.duration <= 0:
raise ValueError("Duration must be positive")
if self.temperature < 0:
raise ValueError("Temperature must be non-negative")
if self.top_k < 0:
raise ValueError("top_k must be non-negative")
if not 0 <= self.top_p <= 1:
raise ValueError("top_p must be between 0 and 1")
if self.cfg_coef < 1:
raise ValueError("cfg_coef must be >= 1")
@dataclass
class GenerationResult:
"""Result of audio generation.
Attributes:
audio: Generated audio tensor (shape: [batch, channels, samples])
sample_rate: Audio sample rate in Hz
duration: Actual duration in seconds
model_id: ID of the model used
variant: Model variant used
parameters: Generation parameters used
seed: Actual seed used (for reproducibility)
"""
audio: torch.Tensor
sample_rate: int
duration: float
model_id: str
variant: str
parameters: dict[str, Any]
seed: int
@property
def num_samples(self) -> int:
"""Number of audio samples generated."""
return self.audio.shape[0]
@property
def num_channels(self) -> int:
"""Number of audio channels."""
return self.audio.shape[1]
@property
def num_frames(self) -> int:
"""Number of audio frames."""
return self.audio.shape[2]
class BaseAudioModel(ABC):
"""Abstract base class for AudioCraft model adapters.
All model adapters must implement this interface to integrate with
the model registry and generation service.
"""
@property
@abstractmethod
def model_id(self) -> str:
"""Unique identifier for this model family (e.g., 'musicgen')."""
...
@property
@abstractmethod
def variant(self) -> str:
"""Current model variant (e.g., 'medium', 'large')."""
...
@property
@abstractmethod
def display_name(self) -> str:
"""Human-readable name for UI display."""
...
@property
@abstractmethod
def description(self) -> str:
"""Brief description of the model's capabilities."""
...
@property
@abstractmethod
def vram_estimate_mb(self) -> int:
"""Estimated VRAM usage when loaded (in megabytes)."""
...
@property
@abstractmethod
def max_duration(self) -> float:
"""Maximum supported generation duration in seconds."""
...
@property
@abstractmethod
def sample_rate(self) -> int:
"""Output audio sample rate in Hz."""
...
@property
@abstractmethod
def supports_conditioning(self) -> list[ConditioningType]:
"""List of conditioning types supported by this model."""
...
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Whether the model is currently loaded in memory."""
...
@property
def device(self) -> Optional[torch.device]:
"""Device the model is loaded on, or None if not loaded."""
return None
@abstractmethod
def load(self, device: str = "cuda") -> None:
"""Load the model into memory.
Args:
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
Raises:
RuntimeError: If loading fails
"""
...
@abstractmethod
def unload(self) -> None:
"""Unload the model and free memory.
Should be idempotent - safe to call even if not loaded.
"""
...
@abstractmethod
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate audio based on the request.
Args:
request: Generation parameters and prompts
Returns:
GenerationResult containing audio and metadata
Raises:
RuntimeError: If model is not loaded
ValueError: If request parameters are invalid for this model
"""
...
@abstractmethod
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters for this model.
Returns:
Dictionary of parameter names to default values
"""
...
def validate_request(self, request: GenerationRequest) -> None:
"""Validate a generation request for this model.
Args:
request: Request to validate
Raises:
ValueError: If request is invalid for this model
"""
if not self.is_loaded:
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
if request.duration > self.max_duration:
raise ValueError(
f"Duration {request.duration}s exceeds maximum {self.max_duration}s "
f"for {self.model_id}/{self.variant}"
)
# Check conditioning requirements
for cond_type, cond_data in request.conditioning.items():
if cond_data is not None:
try:
cond_enum = ConditioningType(cond_type)
except ValueError:
raise ValueError(f"Unknown conditioning type: {cond_type}")
if cond_enum not in self.supports_conditioning:
raise ValueError(
f"Model {self.model_id}/{self.variant} does not support "
f"{cond_type} conditioning"
)
def __repr__(self) -> str:
"""String representation."""
loaded = "loaded" if self.is_loaded else "not loaded"
return f"<{self.__class__.__name__} {self.model_id}/{self.variant} ({loaded})>"

433
src/core/gpu_manager.py Normal file
View File

@@ -0,0 +1,433 @@
"""GPU memory management for AudioCraft models."""
import gc
import json
import logging
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Optional
import torch
logger = logging.getLogger(__name__)
@dataclass
class VRAMBudget:
"""VRAM budget allocation information.
Attributes:
total_mb: Total VRAM in megabytes
used_mb: Currently used VRAM
free_mb: Free VRAM
reserved_comfyui_mb: VRAM reserved for ComfyUI
safety_buffer_mb: Safety buffer to prevent OOM
available_mb: VRAM available for AudioCraft models
"""
total_mb: int
used_mb: int
free_mb: int
reserved_comfyui_mb: int
safety_buffer_mb: int
available_mb: int
@property
def utilization(self) -> float:
"""Current VRAM utilization as a fraction (0-1)."""
return self.used_mb / self.total_mb if self.total_mb > 0 else 0.0
@dataclass
class GPUState:
"""State information for inter-service coordination."""
timestamp: float
service: str # "audiocraft" or "comfyui"
vram_used_mb: int
vram_requested_mb: int
status: str # "idle", "working", "requesting_priority", "yielded"
class GPUMemoryManager:
"""Manages GPU memory allocation and coordination with ComfyUI.
Uses pynvml for accurate system-wide VRAM tracking and file-based
IPC for coordination with ComfyUI running on the same system.
"""
COORDINATION_FILE = Path("/tmp/audiocraft_comfyui_coord.json")
LOCK_FILE = Path("/tmp/audiocraft_comfyui_coord.lock")
STALE_THRESHOLD = 30.0 # seconds
def __init__(
self,
device_id: int = 0,
comfyui_reserve_gb: float = 10.0,
safety_buffer_gb: float = 1.0,
):
"""Initialize GPU memory manager.
Args:
device_id: CUDA device index
comfyui_reserve_gb: VRAM to reserve for ComfyUI (gigabytes)
safety_buffer_gb: Safety buffer to prevent OOM (gigabytes)
"""
self.device_id = device_id
self.device = torch.device(f"cuda:{device_id}")
self.comfyui_reserve_mb = int(comfyui_reserve_gb * 1024)
self.safety_buffer_mb = int(safety_buffer_gb * 1024)
# Initialize NVML for direct GPU monitoring
self._nvml_initialized = False
self._nvml_handle = None
self._init_nvml()
# Threading
self._lock = threading.RLock()
# Callbacks for memory events
self._low_memory_callbacks: list[Callable[[VRAMBudget], None]] = []
self._oom_callbacks: list[Callable[[], None]] = []
# Initialize coordination file
self._ensure_coordination_file()
def _init_nvml(self) -> None:
"""Initialize NVML for GPU monitoring."""
try:
import pynvml
pynvml.nvmlInit()
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
self._nvml_initialized = True
logger.info("NVML initialized successfully")
except ImportError:
logger.warning("pynvml not available, falling back to torch.cuda")
except Exception as e:
logger.warning(f"Failed to initialize NVML: {e}, falling back to torch.cuda")
def _ensure_coordination_file(self) -> None:
"""Create coordination file if it doesn't exist."""
if not self.COORDINATION_FILE.exists():
initial_state = {
"audiocraft": None,
"comfyui": None,
"priority": None,
"last_update": time.time(),
}
self._write_coordination_state(initial_state)
def get_memory_info(self) -> dict[str, int]:
"""Get current GPU memory status.
Returns:
Dictionary with memory values in megabytes:
- total: Total VRAM
- used: Used VRAM (system-wide)
- free: Free VRAM
- torch_allocated: PyTorch allocated memory
- torch_reserved: PyTorch reserved memory
- torch_cached: PyTorch cached memory
"""
with self._lock:
if self._nvml_initialized:
return self._get_memory_info_nvml()
return self._get_memory_info_torch()
def _get_memory_info_nvml(self) -> dict[str, int]:
"""Get memory info using NVML (more accurate)."""
import pynvml
info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle)
torch_allocated = torch.cuda.memory_allocated(self.device)
torch_reserved = torch.cuda.memory_reserved(self.device)
return {
"total": info.total // (1024 * 1024),
"used": info.used // (1024 * 1024),
"free": info.free // (1024 * 1024),
"torch_allocated": torch_allocated // (1024 * 1024),
"torch_reserved": torch_reserved // (1024 * 1024),
"torch_cached": (torch_reserved - torch_allocated) // (1024 * 1024),
}
def _get_memory_info_torch(self) -> dict[str, int]:
"""Get memory info using torch.cuda (fallback)."""
props = torch.cuda.get_device_properties(self.device)
allocated = torch.cuda.memory_allocated(self.device)
reserved = torch.cuda.memory_reserved(self.device)
# Note: This is less accurate for system-wide usage
return {
"total": props.total_memory // (1024 * 1024),
"used": reserved // (1024 * 1024),
"free": (props.total_memory - reserved) // (1024 * 1024),
"torch_allocated": allocated // (1024 * 1024),
"torch_reserved": reserved // (1024 * 1024),
"torch_cached": (reserved - allocated) // (1024 * 1024),
}
def get_available_budget(self) -> VRAMBudget:
"""Calculate available VRAM budget considering ComfyUI.
Returns:
VRAMBudget with current allocation information
"""
mem = self.get_memory_info()
# Check ComfyUI's actual usage via coordination file
comfyui_state = self.get_comfyui_status()
if comfyui_state and comfyui_state.status != "yielded":
# Use actual ComfyUI usage + buffer, or reserve, whichever is higher
effective_comfyui_reserve = max(
self.comfyui_reserve_mb,
comfyui_state.vram_used_mb + 2048, # 2GB headroom
)
else:
effective_comfyui_reserve = self.comfyui_reserve_mb
available = max(
0,
mem["total"]
- mem["used"]
+ mem["torch_allocated"] # Our own usage doesn't count against us
- effective_comfyui_reserve
- self.safety_buffer_mb,
)
return VRAMBudget(
total_mb=mem["total"],
used_mb=mem["used"],
free_mb=mem["free"],
reserved_comfyui_mb=effective_comfyui_reserve,
safety_buffer_mb=self.safety_buffer_mb,
available_mb=available,
)
def can_load_model(self, vram_required_mb: int) -> tuple[bool, str]:
"""Check if a model can fit in available VRAM.
Args:
vram_required_mb: VRAM needed by the model
Returns:
Tuple of (can_load, reason_message)
"""
budget = self.get_available_budget()
if vram_required_mb <= budget.available_mb:
return True, "Sufficient VRAM available"
deficit = vram_required_mb - budget.available_mb
return False, (
f"Insufficient VRAM: need {vram_required_mb}MB, "
f"available {budget.available_mb}MB (deficit: {deficit}MB)"
)
def force_cleanup(self) -> int:
"""Force GPU memory cleanup.
Returns:
Freed memory in megabytes (approximate)
"""
with self._lock:
before = self.get_memory_info()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize(self.device)
after = self.get_memory_info()
freed = before["torch_reserved"] - after["torch_reserved"]
if freed > 0:
logger.info(f"Freed {freed}MB of GPU memory")
return freed
def get_status(self) -> dict[str, Any]:
"""Get detailed GPU status for UI display.
Returns:
Dictionary with status information
"""
mem = self.get_memory_info()
budget = self.get_available_budget()
return {
"device": str(self.device),
"total_gb": round(mem["total"] / 1024, 2),
"used_gb": round(mem["used"] / 1024, 2),
"free_gb": round(mem["free"] / 1024, 2),
"utilization_percent": round(budget.utilization * 100, 1),
"available_for_models_gb": round(budget.available_mb / 1024, 2),
"comfyui_reserve_gb": round(budget.reserved_comfyui_mb / 1024, 2),
"torch_allocated_gb": round(mem["torch_allocated"] / 1024, 2),
"torch_cached_gb": round(mem["torch_cached"] / 1024, 2),
}
# ComfyUI Coordination Methods
def _read_coordination_state(self) -> dict[str, Any]:
"""Read coordination state from file."""
try:
if self.COORDINATION_FILE.exists():
return json.loads(self.COORDINATION_FILE.read_text())
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to read coordination file: {e}")
return {}
def _write_coordination_state(self, state: dict[str, Any]) -> None:
"""Write coordination state to file with locking."""
import fcntl
try:
self.LOCK_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(self.LOCK_FILE, "w") as lock:
fcntl.flock(lock, fcntl.LOCK_EX)
try:
self.COORDINATION_FILE.write_text(json.dumps(state, indent=2))
finally:
fcntl.flock(lock, fcntl.LOCK_UN)
except IOError as e:
logger.warning(f"Failed to write coordination file: {e}")
def update_status(
self,
vram_used_mb: int,
vram_requested_mb: int = 0,
status: str = "idle",
) -> None:
"""Update AudioCraft's status in coordination file.
Args:
vram_used_mb: Current VRAM usage
vram_requested_mb: VRAM needed for pending operation
status: Current status ("idle", "working", "requesting_priority")
"""
state = self._read_coordination_state()
state["audiocraft"] = {
"timestamp": time.time(),
"service": "audiocraft",
"vram_used_mb": vram_used_mb,
"vram_requested_mb": vram_requested_mb,
"status": status,
}
state["last_update"] = time.time()
self._write_coordination_state(state)
def get_comfyui_status(self) -> Optional[GPUState]:
"""Get ComfyUI's current status.
Returns:
GPUState if ComfyUI is active and status is fresh, None otherwise
"""
state = self._read_coordination_state()
comfyui_data = state.get("comfyui")
if not comfyui_data:
return None
# Check if stale
if time.time() - comfyui_data.get("timestamp", 0) > self.STALE_THRESHOLD:
return None
return GPUState(
timestamp=comfyui_data["timestamp"],
service="comfyui",
vram_used_mb=comfyui_data.get("vram_used_mb", 0),
vram_requested_mb=comfyui_data.get("vram_requested_mb", 0),
status=comfyui_data.get("status", "unknown"),
)
def request_priority(self, vram_needed_mb: int, timeout: float = 30.0) -> bool:
"""Request VRAM priority from ComfyUI.
Signals ComfyUI to release VRAM if possible.
Args:
vram_needed_mb: Amount of VRAM needed
timeout: Seconds to wait for ComfyUI to yield
Returns:
True if ComfyUI acknowledged and yielded, False otherwise
"""
state = self._read_coordination_state()
state["priority"] = {
"requester": "audiocraft",
"vram_needed_mb": vram_needed_mb,
"timestamp": time.time(),
}
self._write_coordination_state(state)
logger.info(f"Requesting {vram_needed_mb}MB VRAM from ComfyUI...")
# Wait for ComfyUI to respond
start = time.time()
while time.time() - start < timeout:
comfyui = self.get_comfyui_status()
if comfyui and comfyui.status == "yielded":
logger.info("ComfyUI yielded VRAM")
return True
time.sleep(0.5)
logger.warning("ComfyUI did not yield VRAM within timeout")
return False
def is_comfyui_busy(self) -> bool:
"""Check if ComfyUI is actively processing.
Returns:
True if ComfyUI is working, False otherwise
"""
status = self.get_comfyui_status()
return status is not None and status.status == "working"
# Callback Registration
def on_low_memory(self, callback: Callable[[VRAMBudget], None]) -> None:
"""Register callback for low memory warnings.
Args:
callback: Function to call with budget info when memory is low
"""
self._low_memory_callbacks.append(callback)
def on_oom(self, callback: Callable[[], None]) -> None:
"""Register callback for OOM events.
Args:
callback: Function to call when OOM occurs
"""
self._oom_callbacks.append(callback)
def check_memory_pressure(self, warning_threshold: float = 0.85) -> None:
"""Check memory pressure and trigger callbacks if needed.
Args:
warning_threshold: Utilization threshold for warnings (0-1)
"""
budget = self.get_available_budget()
if budget.utilization >= warning_threshold:
logger.warning(
f"High GPU memory pressure: {budget.utilization*100:.1f}% utilized"
)
for callback in self._low_memory_callbacks:
try:
callback(budget)
except Exception as e:
logger.error(f"Low memory callback failed: {e}")
def __del__(self) -> None:
"""Cleanup NVML on destruction."""
if self._nvml_initialized:
try:
import pynvml
pynvml.nvmlShutdown()
except Exception:
pass

487
src/core/model_registry.py Normal file
View File

@@ -0,0 +1,487 @@
"""Model registry for discovering and managing AudioCraft model adapters."""
import asyncio
import logging
import threading
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Generator, Optional, Type
import yaml
from src.core.base_model import BaseAudioModel, ConditioningType
from src.core.gpu_manager import GPUMemoryManager
logger = logging.getLogger(__name__)
@dataclass
class ModelVariantConfig:
"""Configuration for a model variant."""
hf_id: str
vram_mb: int
max_duration: float = 30.0
channels: int = 1
conditioning: list[str] = field(default_factory=list)
description: str = ""
@dataclass
class ModelFamilyConfig:
"""Configuration for a model family."""
enabled: bool
display_name: str
description: str
default_variant: str
variants: dict[str, ModelVariantConfig]
@dataclass
class ModelHandle:
"""Handle for a loaded model with reference counting."""
model: BaseAudioModel
model_id: str
variant: str
loaded_at: float
last_accessed: float
ref_count: int = 0
def touch(self) -> None:
"""Update last accessed time."""
self.last_accessed = time.time()
class ModelRegistry:
"""Central registry for discovering and managing model adapters.
Handles:
- Loading model configurations from YAML
- Lazy loading models on demand
- LRU eviction when VRAM is constrained
- Reference counting to prevent unloading during use
- Automatic idle timeout for unused models
"""
def __init__(
self,
config_path: Path,
gpu_manager: GPUMemoryManager,
max_cached_models: int = 2,
idle_timeout_minutes: int = 15,
):
"""Initialize the model registry.
Args:
config_path: Path to models.yaml configuration
gpu_manager: GPU memory manager instance
max_cached_models: Maximum models to keep loaded
idle_timeout_minutes: Unload models after this idle time
"""
self.config_path = config_path
self.gpu_manager = gpu_manager
self.max_cached_models = max_cached_models
self.idle_timeout_seconds = idle_timeout_minutes * 60
# Model configurations
self._model_configs: dict[str, ModelFamilyConfig] = {}
self._default_params: dict[str, Any] = {}
# Loaded model handles
self._handles: dict[str, ModelHandle] = {} # Key: "model_id/variant"
self._access_order: list[str] = [] # LRU tracking
# Registered adapter classes
self._adapter_classes: dict[str, Type[BaseAudioModel]] = {}
# Threading
self._lock = threading.RLock()
self._cleanup_thread: Optional[threading.Thread] = None
self._stop_cleanup = threading.Event()
# Load configuration
self._load_config()
def _load_config(self) -> None:
"""Load model configurations from YAML file."""
if not self.config_path.exists():
logger.warning(f"Model config not found: {self.config_path}")
return
with open(self.config_path) as f:
config = yaml.safe_load(f)
# Parse model families
for model_id, model_config in config.get("models", {}).items():
if not model_config.get("enabled", True):
continue
variants = {}
for variant_name, variant_config in model_config.get("variants", {}).items():
variants[variant_name] = ModelVariantConfig(
hf_id=variant_config["hf_id"],
vram_mb=variant_config["vram_mb"],
max_duration=variant_config.get("max_duration", 30.0),
channels=variant_config.get("channels", 1),
conditioning=variant_config.get("conditioning", []),
description=variant_config.get("description", ""),
)
self._model_configs[model_id] = ModelFamilyConfig(
enabled=model_config.get("enabled", True),
display_name=model_config.get("display_name", model_id),
description=model_config.get("description", ""),
default_variant=model_config.get("default_variant", "medium"),
variants=variants,
)
# Parse default generation parameters
self._default_params = config.get("defaults", {}).get("generation", {})
logger.info(f"Loaded {len(self._model_configs)} model families from config")
def register_adapter(
self, model_id: str, adapter_class: Type[BaseAudioModel]
) -> None:
"""Register a model adapter class.
Args:
model_id: Model family ID (e.g., 'musicgen')
adapter_class: Adapter class implementing BaseAudioModel
"""
self._adapter_classes[model_id] = adapter_class
logger.debug(f"Registered adapter for {model_id}: {adapter_class.__name__}")
def list_models(self) -> list[dict[str, Any]]:
"""List all available models with their configurations.
Returns:
List of model information dictionaries
"""
models = []
for model_id, config in self._model_configs.items():
for variant_name, variant in config.variants.items():
key = f"{model_id}/{variant_name}"
handle = self._handles.get(key)
can_load, reason = self.gpu_manager.can_load_model(variant.vram_mb)
models.append({
"model_id": model_id,
"variant": variant_name,
"display_name": config.display_name,
"description": variant.description or config.description,
"hf_id": variant.hf_id,
"vram_mb": variant.vram_mb,
"max_duration": variant.max_duration,
"channels": variant.channels,
"conditioning": variant.conditioning,
"is_default": variant_name == config.default_variant,
"is_loaded": handle is not None,
"can_load": can_load,
"load_reason": reason,
"has_adapter": model_id in self._adapter_classes,
})
return models
def get_model_config(
self, model_id: str, variant: Optional[str] = None
) -> tuple[ModelFamilyConfig, ModelVariantConfig]:
"""Get configuration for a model.
Args:
model_id: Model family ID
variant: Specific variant, or None for default
Returns:
Tuple of (family_config, variant_config)
Raises:
ValueError: If model or variant not found
"""
if model_id not in self._model_configs:
raise ValueError(f"Unknown model: {model_id}")
family = self._model_configs[model_id]
variant = variant or family.default_variant
if variant not in family.variants:
raise ValueError(f"Unknown variant {variant} for {model_id}")
return family, family.variants[variant]
def get_loaded_models(self) -> list[dict[str, Any]]:
"""Get information about currently loaded models.
Returns:
List of loaded model information
"""
with self._lock:
return [
{
"model_id": handle.model_id,
"variant": handle.variant,
"loaded_at": handle.loaded_at,
"last_accessed": handle.last_accessed,
"ref_count": handle.ref_count,
"idle_seconds": time.time() - handle.last_accessed,
}
for handle in self._handles.values()
]
@contextmanager
def get_model(
self, model_id: str, variant: Optional[str] = None
) -> Generator[BaseAudioModel, None, None]:
"""Get a model, loading it if necessary.
Context manager that handles reference counting to prevent
unloading during use.
Args:
model_id: Model family ID
variant: Specific variant, or None for default
Yields:
Loaded model instance
Raises:
ValueError: If model not found or cannot be loaded
RuntimeError: If VRAM insufficient
"""
family, variant_config = self.get_model_config(model_id, variant)
variant = variant or family.default_variant
key = f"{model_id}/{variant}"
with self._lock:
# Get or load model
if key not in self._handles:
self._load_model(model_id, variant)
handle = self._handles[key]
handle.ref_count += 1
handle.touch()
# Update LRU order
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
try:
yield handle.model
finally:
with self._lock:
handle.ref_count -= 1
def _load_model(self, model_id: str, variant: str) -> None:
"""Load a model into memory.
Must be called with self._lock held.
Args:
model_id: Model family ID
variant: Variant to load
Raises:
ValueError: If no adapter registered
RuntimeError: If VRAM insufficient
"""
key = f"{model_id}/{variant}"
family, variant_config = self.get_model_config(model_id, variant)
# Check for adapter
if model_id not in self._adapter_classes:
raise ValueError(f"No adapter registered for {model_id}")
# Check VRAM
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
if not can_load:
# Try to free memory by evicting models
self._evict_for_space(variant_config.vram_mb)
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
if not can_load:
raise RuntimeError(reason)
# Create and load model
logger.info(f"Loading model {key}...")
adapter_class = self._adapter_classes[model_id]
model = adapter_class(variant=variant)
model.load()
# Register handle
self._handles[key] = ModelHandle(
model=model,
model_id=model_id,
variant=variant,
loaded_at=time.time(),
last_accessed=time.time(),
)
self._access_order.append(key)
# Update GPU status
mem = self.gpu_manager.get_memory_info()
self.gpu_manager.update_status(mem["torch_allocated"], status="working")
logger.info(f"Model {key} loaded successfully")
def _evict_for_space(self, needed_mb: int) -> bool:
"""Evict models to free up VRAM.
Must be called with self._lock held.
Args:
needed_mb: VRAM needed
Returns:
True if enough space was freed
"""
freed = 0
budget = self.gpu_manager.get_available_budget()
deficit = needed_mb - budget.available_mb
if deficit <= 0:
return True
# Evict LRU models that have no active references
for key in list(self._access_order):
if deficit <= 0:
break
handle = self._handles.get(key)
if handle and handle.ref_count == 0:
_, variant_config = self.get_model_config(
handle.model_id, handle.variant
)
logger.info(f"Evicting {key} to free {variant_config.vram_mb}MB")
self._unload_model(key)
freed += variant_config.vram_mb
deficit -= variant_config.vram_mb
self.gpu_manager.force_cleanup()
return deficit <= 0
def _unload_model(self, key: str) -> None:
"""Unload a model from memory.
Must be called with self._lock held.
Args:
key: Model key (model_id/variant)
"""
if key not in self._handles:
return
handle = self._handles[key]
if handle.ref_count > 0:
logger.warning(f"Cannot unload {key}: {handle.ref_count} active references")
return
logger.info(f"Unloading model {key}")
handle.model.unload()
del self._handles[key]
if key in self._access_order:
self._access_order.remove(key)
self.gpu_manager.force_cleanup()
def unload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
"""Manually unload a model.
Args:
model_id: Model family ID
variant: Variant to unload, or None for all variants
Returns:
True if model was unloaded
"""
with self._lock:
if variant:
key = f"{model_id}/{variant}"
if key in self._handles:
self._unload_model(key)
return True
else:
# Unload all variants of this model
keys = [k for k in self._handles if k.startswith(f"{model_id}/")]
for key in keys:
self._unload_model(key)
return bool(keys)
return False
def preload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
"""Preload a model into memory.
Args:
model_id: Model family ID
variant: Variant to load
Returns:
True if model was loaded successfully
"""
family, _ = self.get_model_config(model_id, variant)
variant = variant or family.default_variant
key = f"{model_id}/{variant}"
with self._lock:
if key in self._handles:
return True # Already loaded
try:
self._load_model(model_id, variant)
return True
except Exception as e:
logger.error(f"Failed to preload {key}: {e}")
return False
def start_cleanup_thread(self) -> None:
"""Start background thread for idle model cleanup."""
if self._cleanup_thread is not None:
return
def cleanup_loop():
while not self._stop_cleanup.is_set():
self._cleanup_idle_models()
self._stop_cleanup.wait(60) # Check every minute
self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
self._cleanup_thread.start()
logger.info("Started model cleanup thread")
def stop_cleanup_thread(self) -> None:
"""Stop the background cleanup thread."""
if self._cleanup_thread is not None:
self._stop_cleanup.set()
self._cleanup_thread.join(timeout=5)
self._cleanup_thread = None
self._stop_cleanup.clear()
def _cleanup_idle_models(self) -> None:
"""Unload models that have been idle too long."""
with self._lock:
now = time.time()
for key, handle in list(self._handles.items()):
idle_time = now - handle.last_accessed
if idle_time > self.idle_timeout_seconds and handle.ref_count == 0:
logger.info(
f"Unloading idle model {key} (idle for {idle_time/60:.1f} min)"
)
self._unload_model(key)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters.
Returns:
Dictionary of default parameter values
"""
return self._default_params.copy()
def __del__(self) -> None:
"""Cleanup on destruction."""
self.stop_cleanup_thread()

297
src/core/oom_handler.py Normal file
View File

@@ -0,0 +1,297 @@
"""OOM (Out of Memory) handling and recovery strategies."""
import functools
import gc
import logging
import time
from typing import Any, Callable, Optional, ParamSpec, TypeVar
import torch
from src.core.gpu_manager import GPUMemoryManager
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
class OOMRecoveryError(Exception):
"""Raised when OOM recovery fails after all strategies exhausted."""
pass
class OOMHandler:
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
Recovery levels:
1. Clear PyTorch CUDA cache
2. Evict unused models from registry
3. Request ComfyUI to yield VRAM
"""
def __init__(
self,
gpu_manager: GPUMemoryManager,
model_registry: Optional[Any] = None, # Avoid circular import
max_retries: int = 3,
retry_delay: float = 0.5,
):
"""Initialize OOM handler.
Args:
gpu_manager: GPU memory manager instance
model_registry: Optional model registry for eviction
max_retries: Maximum recovery attempts
retry_delay: Delay between retries in seconds
"""
self.gpu_manager = gpu_manager
self.model_registry = model_registry
self.max_retries = max_retries
self.retry_delay = retry_delay
# Track OOM events for monitoring
self._oom_count = 0
self._last_oom_time: Optional[float] = None
@property
def oom_count(self) -> int:
"""Number of OOM events handled."""
return self._oom_count
def set_model_registry(self, registry: Any) -> None:
"""Set model registry (to avoid circular import at init time)."""
self.model_registry = registry
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
"""Decorator that wraps function with OOM recovery logic.
Usage:
@oom_handler.with_oom_recovery
def generate_audio(...):
...
Args:
func: Function to wrap
Returns:
Wrapped function with OOM recovery
"""
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
for attempt in range(self.max_retries + 1):
try:
if attempt > 0:
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
time.sleep(self.retry_delay)
return func(*args, **kwargs)
except torch.cuda.OutOfMemoryError as e:
last_exception = e
self._oom_count += 1
self._last_oom_time = time.time()
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
if attempt < self.max_retries:
self._execute_recovery_strategy(attempt)
else:
logger.error(
f"OOM recovery failed after {self.max_retries} attempts"
)
raise OOMRecoveryError(
f"OOM recovery failed after {self.max_retries} attempts"
) from last_exception
return wrapper
def _execute_recovery_strategy(self, level: int) -> None:
"""Execute recovery strategy based on severity level.
Args:
level: Recovery level (0-2)
"""
strategies = [
self._strategy_clear_cache,
self._strategy_evict_models,
self._strategy_request_comfyui_yield,
]
# Execute all strategies up to and including current level
for i in range(min(level + 1, len(strategies))):
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
strategies[i]()
def _strategy_clear_cache(self) -> None:
"""Level 1: Clear PyTorch CUDA cache.
This is the fastest and least disruptive recovery strategy.
Clears cached memory that PyTorch holds for future allocations.
"""
logger.info("Clearing CUDA cache...")
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Reset peak memory stats for monitoring
torch.cuda.reset_peak_memory_stats()
freed = self.gpu_manager.force_cleanup()
logger.info(f"Cache cleared, freed approximately {freed}MB")
def _strategy_evict_models(self) -> None:
"""Level 2: Evict non-essential models from registry.
Unloads all models that don't have active references,
freeing their VRAM for the current operation.
"""
if self.model_registry is None:
logger.warning("No model registry available for eviction")
self._strategy_clear_cache()
return
logger.info("Evicting unused models...")
# Get list of loaded models
loaded = self.model_registry.get_loaded_models()
evicted = []
for model_info in loaded:
# Only evict models with no active references
if model_info["ref_count"] == 0:
model_id = model_info["model_id"]
variant = model_info["variant"]
logger.info(f"Evicting {model_id}/{variant}")
self.model_registry.unload_model(model_id, variant)
evicted.append(f"{model_id}/{variant}")
# Clear cache after eviction
self._strategy_clear_cache()
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
def _strategy_request_comfyui_yield(self) -> None:
"""Level 3: Request ComfyUI to yield VRAM.
Uses the coordination protocol to ask ComfyUI to
temporarily release GPU memory.
"""
logger.info("Requesting ComfyUI to yield VRAM...")
# First, evict our own models
self._strategy_evict_models()
# Calculate how much VRAM we need
budget = self.gpu_manager.get_available_budget()
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
# Request priority from ComfyUI
success = self.gpu_manager.request_priority(needed, timeout=15.0)
if success:
logger.info("ComfyUI yielded VRAM successfully")
else:
logger.warning("ComfyUI did not yield VRAM within timeout")
# Final cache clear
self._strategy_clear_cache()
def recover_from_oom(self, level: int = 0) -> bool:
"""Manually trigger OOM recovery.
Args:
level: Recovery level to execute (0-2)
Returns:
True if recovery was successful (memory was freed)
"""
before = self.gpu_manager.get_memory_info()
self._execute_recovery_strategy(level)
after = self.gpu_manager.get_memory_info()
freed = before["used"] - after["used"]
logger.info(f"Manual recovery freed {freed}MB")
return freed > 0
def check_memory_for_operation(self, required_mb: int) -> bool:
"""Check if there's enough memory for an operation.
If not enough, attempts recovery strategies.
Args:
required_mb: Memory required in megabytes
Returns:
True if enough memory is available (possibly after recovery)
"""
budget = self.gpu_manager.get_available_budget()
if budget.available_mb >= required_mb:
return True
logger.info(
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
"Attempting recovery..."
)
# Try progressively more aggressive recovery
for level in range(3):
self._execute_recovery_strategy(level)
budget = self.gpu_manager.get_available_budget()
if budget.available_mb >= required_mb:
logger.info(f"Recovery successful at level {level + 1}")
return True
logger.error(
f"Could not free enough memory. Need {required_mb}MB, "
f"have {budget.available_mb}MB"
)
return False
def get_stats(self) -> dict[str, Any]:
"""Get OOM handling statistics.
Returns:
Dictionary with OOM stats
"""
return {
"oom_count": self._oom_count,
"last_oom_time": self._last_oom_time,
"max_retries": self.max_retries,
"has_registry": self.model_registry is not None,
}
# Module-level convenience function
def oom_safe(
gpu_manager: GPUMemoryManager,
model_registry: Optional[Any] = None,
max_retries: int = 3,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator factory for OOM-safe functions.
Usage:
@oom_safe(gpu_manager, model_registry)
def generate_audio(...):
...
Args:
gpu_manager: GPU memory manager
model_registry: Optional model registry for eviction
max_retries: Maximum recovery attempts
Returns:
Decorator function
"""
handler = OOMHandler(gpu_manager, model_registry, max_retries)
return handler.with_oom_recovery

84
src/main.py Normal file
View File

@@ -0,0 +1,84 @@
"""AudioCraft Studio - Main Application Entry Point."""
import asyncio
import logging
import sys
from pathlib import Path
# Add project root to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
from config.settings import get_settings
from src.core.gpu_manager import GPUMemoryManager
from src.core.model_registry import ModelRegistry
from src.storage.database import Database
logger = logging.getLogger(__name__)
async def init_app():
"""Initialize application components."""
settings = get_settings()
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Ensure directories exist
settings.ensure_directories()
# Initialize GPU manager
gpu_manager = GPUMemoryManager(
comfyui_reserve_gb=settings.comfyui_reserve_gb,
safety_buffer_gb=settings.safety_buffer_gb,
)
# Initialize model registry
registry = ModelRegistry(
config_path=settings.models_config,
gpu_manager=gpu_manager,
max_cached_models=settings.max_cached_models,
idle_timeout_minutes=settings.idle_unload_minutes,
)
# Initialize database
db = Database(settings.database_path)
await db.connect()
logger.info("AudioCraft Studio initialized")
logger.info(f"GPU Status: {gpu_manager.get_status()}")
logger.info(f"Available models: {len(registry.list_models())}")
return {
"settings": settings,
"gpu_manager": gpu_manager,
"registry": registry,
"database": db,
}
def main():
"""Main entry point."""
print("AudioCraft Studio - Starting...")
print("Phase 1 core infrastructure is complete.")
print("\nTo continue implementation:")
print(" - Phase 2: Model adapters (musicgen, audiogen, magnet, style, jasco)")
print(" - Phase 3: Services layer (generation, batch, project)")
print(" - Phase 4: Gradio UI")
print(" - Phase 5: REST API")
print(" - Phase 6: Deployment")
# Quick initialization test
async def test_init():
components = await init_app()
print(f"\nDatabase path: {components['settings'].database_path}")
print(f"GPU status: {components['gpu_manager'].get_status()}")
await components["database"].close()
asyncio.run(test_init())
if __name__ == "__main__":
main()

32
src/models/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""AudioCraft model adapters.
This module contains adapters that wrap AudioCraft's models with a
consistent interface for the application.
"""
from src.models.musicgen.adapter import MusicGenAdapter
from src.models.audiogen.adapter import AudioGenAdapter
from src.models.magnet.adapter import MAGNeTAdapter
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
from src.models.jasco.adapter import JASCOAdapter
__all__ = [
"MusicGenAdapter",
"AudioGenAdapter",
"MAGNeTAdapter",
"MusicGenStyleAdapter",
"JASCOAdapter",
]
def register_all_adapters(registry) -> None:
"""Register all model adapters with the registry.
Args:
registry: ModelRegistry instance to register adapters with
"""
registry.register_adapter("musicgen", MusicGenAdapter)
registry.register_adapter("audiogen", AudioGenAdapter)
registry.register_adapter("magnet", MAGNeTAdapter)
registry.register_adapter("musicgen-style", MusicGenStyleAdapter)
registry.register_adapter("jasco", JASCOAdapter)

View File

@@ -0,0 +1,5 @@
"""AudioGen model adapter."""
from src.models.audiogen.adapter import AudioGenAdapter
__all__ = ["AudioGenAdapter"]

View File

@@ -0,0 +1,203 @@
"""AudioGen model adapter for text-to-sound effects generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class AudioGenAdapter(BaseAudioModel):
"""Adapter for Facebook's AudioGen model.
Generates sound effects and environmental audio from text descriptions.
Optimized for non-musical audio like sound effects, ambiences, and foley.
"""
VARIANTS = {
"medium": {
"hf_id": "facebook/audiogen-medium",
"vram_mb": 5000,
"max_duration": 10,
"channels": 1,
},
}
def __init__(self, variant: str = "medium"):
"""Initialize AudioGen adapter.
Args:
variant: Model variant (currently only 'medium' available)
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown AudioGen variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "audiogen"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"AudioGen ({self._variant})"
@property
def description(self) -> str:
return "Text-to-sound effects generation"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 16000 # AudioGen default sample rate
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the AudioGen model."""
if self._model is not None:
logger.warning(f"AudioGen {self._variant} already loaded")
return
logger.info(f"Loading AudioGen {self._variant} from {self._config['hf_id']}...")
try:
from audiocraft.models import AudioGen
self._device = torch.device(device)
self._model = AudioGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"AudioGen {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load AudioGen {self._variant}: {e}")
raise RuntimeError(f"Failed to load AudioGen: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading AudioGen {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate sound effects from text prompts.
Args:
request: Generation parameters including prompts
Returns:
GenerationResult with audio tensor and metadata
"""
self.validate_request(request)
# Set random seed
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Generating {len(request.prompts)} sound effect(s) with AudioGen "
f"(duration={request.duration}s)"
)
# Generate audio
with torch.inference_mode():
audio = self._model.generate(request.prompts)
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters."""
return {
"duration": 5.0,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
}

View File

@@ -0,0 +1,5 @@
"""JASCO model adapter."""
from src.models.jasco.adapter import JASCOAdapter
__all__ = ["JASCOAdapter"]

348
src/models/jasco/adapter.py Normal file
View File

@@ -0,0 +1,348 @@
"""JASCO model adapter for chord and drum-conditioned music generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class JASCOAdapter(BaseAudioModel):
"""Adapter for Facebook's JASCO model.
JASCO (Joint Audio and Symbolic Conditioning) enables music generation
with control over chord progressions and drum patterns alongside text.
"""
VARIANTS = {
"chords-drums-400M": {
"hf_id": "facebook/jasco-chords-drums-400M",
"vram_mb": 2000,
"max_duration": 10,
"channels": 1,
},
"chords-drums-1B": {
"hf_id": "facebook/jasco-chords-drums-1B",
"vram_mb": 4000,
"max_duration": 10,
"channels": 1,
},
}
# Common chord types for validation
VALID_CHORD_TYPES = [
"maj", "min", "dim", "aug", "7", "maj7", "min7", "dim7",
"sus2", "sus4", "add9", "6", "min6", "9", "min9", "maj9",
]
def __init__(self, variant: str = "chords-drums-400M"):
"""Initialize JASCO adapter.
Args:
variant: Model variant to use
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown JASCO variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "jasco"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"JASCO ({self._variant})"
@property
def description(self) -> str:
return "Chord and drum-conditioned music generation"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 32000
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT, ConditioningType.CHORDS, ConditioningType.DRUMS]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the JASCO model."""
if self._model is not None:
logger.warning(f"JASCO {self._variant} already loaded")
return
logger.info(f"Loading JASCO {self._variant} from {self._config['hf_id']}...")
try:
from audiocraft.models import JASCO
self._device = torch.device(device)
self._model = JASCO.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"JASCO {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load JASCO {self._variant}: {e}")
raise RuntimeError(f"Failed to load JASCO: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading JASCO {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@staticmethod
def parse_chord_progression(
chords: list[dict[str, Any]], duration: float
) -> list[tuple[float, float, str]]:
"""Parse chord progression from user input format.
Args:
chords: List of chord dictionaries with keys:
- time: Start time in seconds
- chord: Chord name (e.g., "C", "Am", "G7")
duration: Total duration for calculating end times
Returns:
List of (start_time, end_time, chord_name) tuples
Example input:
[
{"time": 0.0, "chord": "C"},
{"time": 2.0, "chord": "Am"},
{"time": 4.0, "chord": "F"},
{"time": 6.0, "chord": "G"},
]
"""
if not chords:
return []
# Sort by time
sorted_chords = sorted(chords, key=lambda x: x["time"])
# Build (start, end, chord) tuples
result = []
for i, chord_info in enumerate(sorted_chords):
start = chord_info["time"]
# End time is either next chord's start or total duration
if i + 1 < len(sorted_chords):
end = sorted_chords[i + 1]["time"]
else:
end = duration
result.append((start, end, chord_info["chord"]))
return result
@staticmethod
def create_drum_pattern(
pattern: str, duration: float, bpm: float = 120.0
) -> list[tuple[float, str]]:
"""Create drum events from a pattern string.
Args:
pattern: Pattern string (e.g., "kick,snare,kick,snare")
or "4/4" for common time signature
duration: Total duration in seconds
bpm: Beats per minute
Returns:
List of (time, drum_type) tuples
"""
beat_duration = 60.0 / bpm
events = []
if pattern in ["4/4", "common"]:
# Standard 4/4 rock pattern
time = 0.0
beat = 0
while time < duration:
if beat % 4 == 0:
events.append((time, "kick"))
elif beat % 4 == 2:
events.append((time, "snare"))
if beat % 2 == 0:
events.append((time, "hihat"))
time += beat_duration / 2
beat += 1
else:
# Parse comma-separated pattern
drum_types = pattern.split(",")
time = 0.0
idx = 0
while time < duration:
drum = drum_types[idx % len(drum_types)].strip()
if drum:
events.append((time, drum))
time += beat_duration
idx += 1
return events
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate music with chord and drum conditioning.
Args:
request: Generation parameters with optional conditioning:
- chords: List of {"time": float, "chord": str} dicts
- drums: Drum pattern string or list of (time, drum_type)
- bpm: Beats per minute for drum pattern
Returns:
GenerationResult with audio tensor and metadata
"""
self.validate_request(request)
# Set random seed
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation parameters
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
# Process chord conditioning
chords_input = request.conditioning.get("chords")
chords_formatted = None
if chords_input:
if isinstance(chords_input, list) and len(chords_input) > 0:
if isinstance(chords_input[0], dict):
chords_formatted = self.parse_chord_progression(
chords_input, request.duration
)
else:
# Already in (start, end, chord) format
chords_formatted = chords_input
# Process drum conditioning
drums_input = request.conditioning.get("drums")
bpm = request.conditioning.get("bpm", 120.0)
drums_formatted = None
if drums_input:
if isinstance(drums_input, str):
drums_formatted = self.create_drum_pattern(
drums_input, request.duration, bpm
)
else:
drums_formatted = drums_input
logger.info(
f"Generating {len(request.prompts)} sample(s) with JASCO "
f"(duration={request.duration}s, chords={chords_formatted is not None}, "
f"drums={drums_formatted is not None})"
)
with torch.inference_mode():
# Build conditioning dict for JASCO
conditioning = {}
if chords_formatted:
conditioning["chords"] = chords_formatted
if drums_formatted:
conditioning["drums"] = drums_formatted
if conditioning:
audio = self._model.generate(
descriptions=request.prompts,
**conditioning,
)
else:
# Generate without symbolic conditioning
audio = self._model.generate(request.prompts)
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
"chords": chords_formatted,
"drums": drums_formatted,
"bpm": bpm,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters for JASCO."""
return {
"duration": 10.0,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
"bpm": 120.0,
}

View File

@@ -0,0 +1,5 @@
"""MAGNeT model adapter."""
from src.models.magnet.adapter import MAGNeTAdapter
__all__ = ["MAGNeTAdapter"]

View File

@@ -0,0 +1,253 @@
"""MAGNeT model adapter for fast non-autoregressive audio generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class MAGNeTAdapter(BaseAudioModel):
"""Adapter for Facebook's MAGNeT model.
MAGNeT (Masked Audio Generation using Non-autoregressive Transformers)
provides faster generation than autoregressive models like MusicGen.
Supports both music and sound effect generation.
"""
VARIANTS = {
"small-10secs": {
"hf_id": "facebook/magnet-small-10secs",
"vram_mb": 1500,
"max_duration": 10,
"channels": 1,
"audio_type": "music",
},
"medium-10secs": {
"hf_id": "facebook/magnet-medium-10secs",
"vram_mb": 5000,
"max_duration": 10,
"channels": 1,
"audio_type": "music",
},
"small-30secs": {
"hf_id": "facebook/magnet-small-30secs",
"vram_mb": 1800,
"max_duration": 30,
"channels": 1,
"audio_type": "music",
},
"medium-30secs": {
"hf_id": "facebook/magnet-medium-30secs",
"vram_mb": 6000,
"max_duration": 30,
"channels": 1,
"audio_type": "music",
},
"audio-small-10secs": {
"hf_id": "facebook/audio-magnet-small",
"vram_mb": 1500,
"max_duration": 10,
"channels": 1,
"audio_type": "sound",
},
"audio-medium-10secs": {
"hf_id": "facebook/audio-magnet-medium",
"vram_mb": 5000,
"max_duration": 10,
"channels": 1,
"audio_type": "sound",
},
}
def __init__(self, variant: str = "medium-10secs"):
"""Initialize MAGNeT adapter.
Args:
variant: Model variant to use
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown MAGNeT variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "magnet"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"MAGNeT ({self._variant})"
@property
def description(self) -> str:
audio_type = self._config.get("audio_type", "music")
return f"Fast non-autoregressive {audio_type} generation"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 32000
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the MAGNeT model."""
if self._model is not None:
logger.warning(f"MAGNeT {self._variant} already loaded")
return
logger.info(f"Loading MAGNeT {self._variant} from {self._config['hf_id']}...")
try:
from audiocraft.models import MAGNeT
self._device = torch.device(device)
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"MAGNeT {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load MAGNeT {self._variant}: {e}")
raise RuntimeError(f"Failed to load MAGNeT: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading MAGNeT {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate audio from text prompts using MAGNeT.
MAGNeT uses a non-autoregressive approach with iterative decoding,
which is significantly faster than autoregressive models.
Args:
request: Generation parameters including prompts
Returns:
GenerationResult with audio tensor and metadata
"""
self.validate_request(request)
# Set random seed
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation parameters
# MAGNeT has different parameters than MusicGen
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
# MAGNeT-specific parameters
decoding_steps=[
int(request.conditioning.get("decoding_steps_1", 20)),
int(request.conditioning.get("decoding_steps_2", 10)),
int(request.conditioning.get("decoding_steps_3", 10)),
int(request.conditioning.get("decoding_steps_4", 10)),
],
span_arrangement=request.conditioning.get("span_arrangement", "nonoverlap"),
)
logger.info(
f"Generating {len(request.prompts)} sample(s) with MAGNeT {self._variant} "
f"(duration={request.duration}s)"
)
# Generate audio
with torch.inference_mode():
audio = self._model.generate(request.prompts)
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters for MAGNeT."""
return {
"duration": 10.0,
"temperature": 3.0, # MAGNeT works better with higher temperature
"top_k": 0, # Use top_p instead for MAGNeT
"top_p": 0.9,
"cfg_coef": 3.0,
}

View File

@@ -0,0 +1,5 @@
"""MusicGen model adapter."""
from src.models.musicgen.adapter import MusicGenAdapter
__all__ = ["MusicGenAdapter"]

View File

@@ -0,0 +1,290 @@
"""MusicGen model adapter for text-to-music generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class MusicGenAdapter(BaseAudioModel):
"""Adapter for Facebook's MusicGen model.
Supports text-to-music generation with optional melody conditioning.
Available variants: small, medium, large, melody, and stereo versions.
"""
# Variant configurations
VARIANTS = {
"small": {
"hf_id": "facebook/musicgen-small",
"vram_mb": 1500,
"max_duration": 30,
"channels": 1,
"conditioning": [],
},
"medium": {
"hf_id": "facebook/musicgen-medium",
"vram_mb": 5000,
"max_duration": 30,
"channels": 1,
"conditioning": [],
},
"large": {
"hf_id": "facebook/musicgen-large",
"vram_mb": 10000,
"max_duration": 30,
"channels": 1,
"conditioning": [],
},
"melody": {
"hf_id": "facebook/musicgen-melody",
"vram_mb": 5000,
"max_duration": 30,
"channels": 1,
"conditioning": [ConditioningType.MELODY],
},
"stereo-small": {
"hf_id": "facebook/musicgen-stereo-small",
"vram_mb": 1800,
"max_duration": 30,
"channels": 2,
"conditioning": [],
},
"stereo-medium": {
"hf_id": "facebook/musicgen-stereo-medium",
"vram_mb": 6000,
"max_duration": 30,
"channels": 2,
"conditioning": [],
},
"stereo-large": {
"hf_id": "facebook/musicgen-stereo-large",
"vram_mb": 12000,
"max_duration": 30,
"channels": 2,
"conditioning": [],
},
"stereo-melody": {
"hf_id": "facebook/musicgen-stereo-melody",
"vram_mb": 6000,
"max_duration": 30,
"channels": 2,
"conditioning": [ConditioningType.MELODY],
},
}
def __init__(self, variant: str = "medium"):
"""Initialize MusicGen adapter.
Args:
variant: Model variant to use (small, medium, large, melody, etc.)
Raises:
ValueError: If variant is not recognized
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown MusicGen variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "musicgen"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"MusicGen ({self._variant})"
@property
def description(self) -> str:
if "melody" in self._variant:
return "Text-to-music with melody conditioning"
elif "stereo" in self._variant:
return "Stereo text-to-music generation"
return "Text-to-music generation"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 32000 # Default MusicGen sample rate
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT] + self._config["conditioning"]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the MusicGen model.
Args:
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
"""
if self._model is not None:
logger.warning(f"MusicGen {self._variant} already loaded")
return
logger.info(f"Loading MusicGen {self._variant} from {self._config['hf_id']}...")
try:
from audiocraft.models import MusicGen
self._device = torch.device(device)
self._model = MusicGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"MusicGen {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load MusicGen {self._variant}: {e}")
raise RuntimeError(f"Failed to load MusicGen: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading MusicGen {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate music from text prompts.
Args:
request: Generation parameters including prompts
Returns:
GenerationResult with audio tensor and metadata
Raises:
RuntimeError: If model not loaded
ValueError: If request is invalid
"""
self.validate_request(request)
# Set random seed for reproducibility
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation parameters
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Generating {len(request.prompts)} sample(s) with MusicGen {self._variant} "
f"(duration={request.duration}s, temp={request.temperature})"
)
# Generate audio
with torch.inference_mode():
melody_audio = request.conditioning.get("melody")
melody_sr = request.conditioning.get("melody_sr", self.sample_rate)
if melody_audio is not None and ConditioningType.MELODY in self.supports_conditioning:
# Melody-conditioned generation
if isinstance(melody_audio, str):
# Load from file path
import torchaudio
melody_tensor, melody_sr = torchaudio.load(melody_audio)
melody_tensor = melody_tensor.to(self._device)
else:
melody_tensor = torch.tensor(melody_audio).to(self._device)
audio = self._model.generate_with_chroma(
descriptions=request.prompts,
melody_wavs=melody_tensor.unsqueeze(0) if melody_tensor.dim() == 1 else melody_tensor,
melody_sample_rate=melody_sr,
)
else:
# Standard text-to-music generation
audio = self._model.generate(request.prompts)
# audio shape: [batch, channels, samples]
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s, shape={audio.shape}"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters."""
return {
"duration": 10.0,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
}

View File

@@ -0,0 +1,5 @@
"""MusicGen Style model adapter."""
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
__all__ = ["MusicGenStyleAdapter"]

View File

@@ -0,0 +1,277 @@
"""MusicGen Style model adapter for style-conditioned music generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
import torchaudio
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class MusicGenStyleAdapter(BaseAudioModel):
"""Adapter for Facebook's MusicGen Style model.
Generates music conditioned on both text and a style reference audio.
Extracts style features from the reference and applies them to new generations.
"""
VARIANTS = {
"medium": {
"hf_id": "facebook/musicgen-style",
"vram_mb": 5000,
"max_duration": 30,
"channels": 1,
},
}
def __init__(self, variant: str = "medium"):
"""Initialize MusicGen Style adapter.
Args:
variant: Model variant (currently only 'medium' available)
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown MusicGen Style variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "musicgen-style"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"MusicGen Style ({self._variant})"
@property
def description(self) -> str:
return "Style-conditioned music generation from reference audio"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 32000
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT, ConditioningType.STYLE]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the MusicGen Style model."""
if self._model is not None:
logger.warning(f"MusicGen Style {self._variant} already loaded")
return
logger.info(f"Loading MusicGen Style {self._variant}...")
try:
from audiocraft.models import MusicGen
self._device = torch.device(device)
self._model = MusicGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"MusicGen Style {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading MusicGen Style {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _load_style_audio(
self, style_input: Any, target_sr: int
) -> tuple[torch.Tensor, int]:
"""Load and prepare style reference audio.
Args:
style_input: File path, tensor, or numpy array
target_sr: Target sample rate
Returns:
Tuple of (audio_tensor, sample_rate)
"""
if isinstance(style_input, str):
# Load from file
audio, sr = torchaudio.load(style_input)
if sr != target_sr:
audio = torchaudio.functional.resample(audio, sr, target_sr)
return audio.to(self._device), target_sr
elif isinstance(style_input, torch.Tensor):
return style_input.to(self._device), target_sr
else:
# Assume numpy array
return torch.tensor(style_input).to(self._device), target_sr
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate music conditioned on text and style reference.
Args:
request: Generation parameters including prompts and style conditioning
Returns:
GenerationResult with audio tensor and metadata
Note:
Style conditioning requires 'style' in request.conditioning with either:
- File path to audio
- Audio tensor
- Numpy array
"""
self.validate_request(request)
# Set random seed
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Get style conditioning parameters
style_audio = request.conditioning.get("style")
eval_q = request.conditioning.get("eval_q", 3)
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
# Configure generation parameters
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
)
with torch.inference_mode():
if style_audio is not None:
# Load style reference
style_tensor, style_sr = self._load_style_audio(
style_audio, self.sample_rate
)
# Ensure proper shape [batch, channels, samples]
if style_tensor.dim() == 1:
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
elif style_tensor.dim() == 2:
style_tensor = style_tensor.unsqueeze(0)
# Set style conditioner parameters
if hasattr(self._model, 'set_style_conditioner_params'):
self._model.set_style_conditioner_params(
eval_q=eval_q,
excerpt_length=excerpt_length,
)
# Generate with style conditioning
# Expand style to match number of prompts if needed
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
audio = self._model.generate_with_chroma(
descriptions=request.prompts,
melody_wavs=style_tensor,
melody_sample_rate=style_sr,
)
else:
# Generate without style (falls back to standard MusicGen behavior)
logger.warning(
"No style reference provided, generating without style conditioning"
)
audio = self._model.generate(request.prompts)
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
"style_conditioned": style_audio is not None,
"eval_q": eval_q,
"excerpt_length": excerpt_length,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters for MusicGen Style."""
return {
"duration": 10.0,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
"eval_q": 3,
"excerpt_length": 3.0,
}

13
src/services/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
"""Services layer for AudioCraft Studio."""
from src.services.generation_service import GenerationService
from src.services.batch_processor import BatchProcessor, GenerationJob, JobStatus
from src.services.project_service import ProjectService
__all__ = [
"GenerationService",
"BatchProcessor",
"GenerationJob",
"JobStatus",
"ProjectService",
]

View File

@@ -0,0 +1,397 @@
"""Batch processor for queued audio generation jobs."""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
class JobStatus(str, Enum):
"""Status of a generation job."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class GenerationJob:
"""A queued generation job."""
id: str
model_id: str
variant: Optional[str]
prompts: list[str]
parameters: dict[str, Any]
conditioning: dict[str, Any]
project_id: Optional[str]
preset_used: Optional[str]
tags: list[str]
# Status tracking
status: JobStatus = JobStatus.PENDING
progress: float = 0.0
progress_message: str = ""
created_at: datetime = field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
# Results
result_id: Optional[str] = None # Generation ID if completed
audio_path: Optional[str] = None
error: Optional[str] = None
@classmethod
def create(
cls,
model_id: str,
variant: Optional[str],
prompts: list[str],
duration: float = 10.0,
temperature: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: float = 3.0,
seed: Optional[int] = None,
conditioning: Optional[dict[str, Any]] = None,
project_id: Optional[str] = None,
preset_used: Optional[str] = None,
tags: Optional[list[str]] = None,
) -> "GenerationJob":
"""Create a new generation job."""
return cls(
id=f"job_{uuid.uuid4().hex[:12]}",
model_id=model_id,
variant=variant,
prompts=prompts,
parameters={
"duration": duration,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"cfg_coef": cfg_coef,
"seed": seed,
},
conditioning=conditioning or {},
project_id=project_id,
preset_used=preset_used,
tags=tags or [],
)
def to_dict(self) -> dict[str, Any]:
"""Convert job to dictionary for API responses."""
return {
"id": self.id,
"model_id": self.model_id,
"variant": self.variant,
"prompts": self.prompts,
"parameters": self.parameters,
"status": self.status.value,
"progress": self.progress,
"progress_message": self.progress_message,
"created_at": self.created_at.isoformat(),
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"result_id": self.result_id,
"audio_path": self.audio_path,
"error": self.error,
}
class BatchProcessor:
"""Manages a queue of generation jobs.
Features:
- Async job queue with configurable concurrency
- Progress tracking and callbacks
- Job cancellation
- Priority support (future enhancement)
"""
def __init__(
self,
generation_service: Any, # Avoid circular import
max_queue_size: int = 100,
max_concurrent: int = 1, # GPU operations should be serialized
):
"""Initialize batch processor.
Args:
generation_service: GenerationService instance
max_queue_size: Maximum jobs in queue
max_concurrent: Maximum concurrent generations (usually 1 for GPU)
"""
self.generation_service = generation_service
self.max_queue_size = max_queue_size
self.max_concurrent = max_concurrent
# Job tracking
self._jobs: dict[str, GenerationJob] = {}
self._queue: asyncio.Queue[str] = asyncio.Queue(maxsize=max_queue_size)
# Processing control
self._workers: list[asyncio.Task] = []
self._running = False
self._lock = asyncio.Lock()
# Callbacks
self._on_job_complete: list[Callable[[GenerationJob], None]] = []
self._on_job_failed: list[Callable[[GenerationJob], None]] = []
self._on_progress: list[Callable[[GenerationJob], None]] = []
async def start(self) -> None:
"""Start the batch processor workers."""
if self._running:
return
self._running = True
# Start worker tasks
for i in range(self.max_concurrent):
worker = asyncio.create_task(self._worker_loop(i))
self._workers.append(worker)
logger.info(f"Batch processor started with {self.max_concurrent} worker(s)")
async def stop(self) -> None:
"""Stop the batch processor and wait for pending jobs."""
if not self._running:
return
self._running = False
# Cancel workers
for worker in self._workers:
worker.cancel()
# Wait for workers to finish
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("Batch processor stopped")
async def submit(self, job: GenerationJob) -> GenerationJob:
"""Submit a job to the queue.
Args:
job: Job to submit
Returns:
The submitted job with ID
Raises:
RuntimeError: If queue is full
"""
async with self._lock:
if len(self._jobs) >= self.max_queue_size:
raise RuntimeError(
f"Queue full (max {self.max_queue_size} jobs). "
"Please wait for jobs to complete."
)
self._jobs[job.id] = job
await self._queue.put(job.id)
logger.info(f"Job {job.id} submitted to queue (position: {self._queue.qsize()})")
return job
async def cancel(self, job_id: str) -> bool:
"""Cancel a pending job.
Args:
job_id: ID of job to cancel
Returns:
True if job was cancelled, False if not found or already processing
"""
async with self._lock:
job = self._jobs.get(job_id)
if job is None:
return False
if job.status != JobStatus.PENDING:
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
return False
job.status = JobStatus.CANCELLED
job.completed_at = datetime.utcnow()
logger.info(f"Job {job_id} cancelled")
return True
def get_job(self, job_id: str) -> Optional[GenerationJob]:
"""Get a job by ID."""
return self._jobs.get(job_id)
def get_queue_status(self) -> dict[str, Any]:
"""Get current queue status."""
jobs_by_status = {}
for job in self._jobs.values():
status = job.status.value
jobs_by_status[status] = jobs_by_status.get(status, 0) + 1
return {
"queue_size": self._queue.qsize(),
"total_jobs": len(self._jobs),
"jobs_by_status": jobs_by_status,
"running": self._running,
"max_queue_size": self.max_queue_size,
}
def list_jobs(
self,
status: Optional[JobStatus] = None,
limit: int = 50,
) -> list[GenerationJob]:
"""List jobs with optional status filter.
Args:
status: Filter by status
limit: Maximum jobs to return
Returns:
List of jobs ordered by creation time (newest first)
"""
jobs = list(self._jobs.values())
if status:
jobs = [j for j in jobs if j.status == status]
# Sort by created_at descending
jobs.sort(key=lambda j: j.created_at, reverse=True)
return jobs[:limit]
def cleanup_completed(self, max_age_hours: float = 24.0) -> int:
"""Remove old completed/failed jobs from memory.
Args:
max_age_hours: Remove jobs older than this
Returns:
Number of jobs removed
"""
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
removed = 0
for job_id, job in list(self._jobs.items()):
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
if job.completed_at and job.completed_at.timestamp() < cutoff:
del self._jobs[job_id]
removed += 1
if removed:
logger.info(f"Cleaned up {removed} old jobs")
return removed
async def _worker_loop(self, worker_id: int) -> None:
"""Worker loop that processes jobs from queue."""
logger.debug(f"Worker {worker_id} started")
while self._running:
try:
# Wait for job with timeout
try:
job_id = await asyncio.wait_for(
self._queue.get(), timeout=1.0
)
except asyncio.TimeoutError:
continue
job = self._jobs.get(job_id)
if job is None or job.status == JobStatus.CANCELLED:
continue
await self._process_job(job)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}")
logger.debug(f"Worker {worker_id} stopped")
async def _process_job(self, job: GenerationJob) -> None:
"""Process a single generation job."""
logger.info(f"Processing job {job.id}: {job.model_id}/{job.variant}")
job.status = JobStatus.PROCESSING
job.started_at = datetime.utcnow()
def progress_callback(progress: float, message: str) -> None:
job.progress = progress
job.progress_message = message
for callback in self._on_progress:
try:
callback(job)
except Exception as e:
logger.error(f"Progress callback error: {e}")
try:
result, generation = await self.generation_service.generate(
model_id=job.model_id,
variant=job.variant,
prompts=job.prompts,
duration=job.parameters.get("duration", 10.0),
temperature=job.parameters.get("temperature", 1.0),
top_k=job.parameters.get("top_k", 250),
top_p=job.parameters.get("top_p", 0.0),
cfg_coef=job.parameters.get("cfg_coef", 3.0),
seed=job.parameters.get("seed"),
conditioning=job.conditioning,
project_id=job.project_id,
preset_used=job.preset_used,
tags=job.tags,
progress_callback=progress_callback,
)
job.status = JobStatus.COMPLETED
job.result_id = generation.id
job.audio_path = generation.audio_path
job.completed_at = datetime.utcnow()
job.progress = 1.0
job.progress_message = "Complete"
logger.info(f"Job {job.id} completed: {generation.id}")
for callback in self._on_job_complete:
try:
callback(job)
except Exception as e:
logger.error(f"Completion callback error: {e}")
except Exception as e:
job.status = JobStatus.FAILED
job.error = str(e)
job.completed_at = datetime.utcnow()
logger.error(f"Job {job.id} failed: {e}")
for callback in self._on_job_failed:
try:
callback(job)
except Exception as e2:
logger.error(f"Failure callback error: {e2}")
# Callback registration
def on_job_complete(self, callback: Callable[[GenerationJob], None]) -> None:
"""Register callback for job completion."""
self._on_job_complete.append(callback)
def on_job_failed(self, callback: Callable[[GenerationJob], None]) -> None:
"""Register callback for job failure."""
self._on_job_failed.append(callback)
def on_progress(self, callback: Callable[[GenerationJob], None]) -> None:
"""Register callback for progress updates."""
self._on_progress.append(callback)

View File

@@ -0,0 +1,322 @@
"""Generation service for orchestrating audio generation."""
import logging
import time
from pathlib import Path
from typing import Any, Callable, Optional
import soundfile as sf
import torch
from src.core.base_model import GenerationRequest, GenerationResult
from src.core.gpu_manager import GPUMemoryManager
from src.core.model_registry import ModelRegistry
from src.core.oom_handler import OOMHandler
from src.storage.database import Database, Generation
logger = logging.getLogger(__name__)
class GenerationService:
"""Orchestrates audio generation across all models.
Handles:
- Model selection and loading
- Generation execution with OOM recovery
- Result saving and database recording
- Progress callbacks for UI updates
"""
def __init__(
self,
registry: ModelRegistry,
gpu_manager: GPUMemoryManager,
database: Database,
output_dir: Path,
):
"""Initialize generation service.
Args:
registry: Model registry for model access
gpu_manager: GPU memory manager
database: Database for storing generation records
output_dir: Directory for saving generated audio
"""
self.registry = registry
self.gpu_manager = gpu_manager
self.database = database
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
# OOM handler
self.oom_handler = OOMHandler(gpu_manager, registry)
# Statistics
self._generation_count = 0
self._total_duration_generated = 0.0
async def generate(
self,
model_id: str,
variant: Optional[str],
prompts: list[str],
duration: float = 10.0,
temperature: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: float = 3.0,
seed: Optional[int] = None,
conditioning: Optional[dict[str, Any]] = None,
project_id: Optional[str] = None,
preset_used: Optional[str] = None,
tags: Optional[list[str]] = None,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> tuple[GenerationResult, Generation]:
"""Generate audio and save to database.
Args:
model_id: Model family to use
variant: Model variant (None for default)
prompts: Text prompts for generation
duration: Target duration in seconds
temperature: Sampling temperature
top_k: Top-k sampling parameter
top_p: Nucleus sampling parameter
cfg_coef: Classifier-free guidance coefficient
seed: Random seed for reproducibility
conditioning: Optional conditioning data (melody, style, chords, etc.)
project_id: Optional project to associate with
preset_used: Name of preset used (for metadata)
tags: Optional tags for organization
progress_callback: Optional callback for progress updates
Returns:
Tuple of (GenerationResult, Generation database record)
Raises:
ValueError: If model not found or parameters invalid
RuntimeError: If generation fails
"""
start_time = time.time()
# Report progress
if progress_callback:
progress_callback(0.0, "Preparing generation...")
# Build generation request
request = GenerationRequest(
prompts=prompts,
duration=duration,
temperature=temperature,
top_k=top_k,
top_p=top_p,
cfg_coef=cfg_coef,
seed=seed,
conditioning=conditioning or {},
)
# Get model configuration
family_config, variant_config = self.registry.get_model_config(model_id, variant)
actual_variant = variant or family_config.default_variant
# Check VRAM availability
if progress_callback:
progress_callback(0.1, "Checking GPU memory...")
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
if not can_load:
# Try OOM recovery
if not self.oom_handler.check_memory_for_operation(variant_config.vram_mb):
raise RuntimeError(f"Insufficient GPU memory: {reason}")
# Generate with OOM recovery wrapper
if progress_callback:
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
@self.oom_handler.with_oom_recovery
def do_generation() -> GenerationResult:
with self.registry.get_model(model_id, actual_variant) as model:
if progress_callback:
progress_callback(0.4, "Generating audio...")
return model.generate(request)
result = do_generation()
if progress_callback:
progress_callback(0.8, "Saving audio...")
# Save audio file
audio_path = self._save_audio(result)
# Create database record
generation = Generation.create(
model=model_id,
variant=actual_variant,
prompt=prompts[0] if len(prompts) == 1 else "\n".join(prompts),
parameters={
"duration": duration,
"temperature": temperature,
"top_k": top_k,
"top_p": top_p,
"cfg_coef": cfg_coef,
},
project_id=project_id,
preset_used=preset_used,
conditioning=conditioning,
audio_path=str(audio_path),
duration_seconds=result.duration,
sample_rate=result.sample_rate,
tags=tags or [],
seed=result.seed,
)
# Save to database
await self.database.create_generation(generation)
# Update statistics
self._generation_count += 1
self._total_duration_generated += result.duration
elapsed = time.time() - start_time
logger.info(
f"Generation complete: {model_id}/{actual_variant}, "
f"duration={result.duration:.1f}s, elapsed={elapsed:.1f}s"
)
if progress_callback:
progress_callback(1.0, "Complete!")
return result, generation
def _save_audio(self, result: GenerationResult) -> Path:
"""Save generated audio to file.
Args:
result: Generation result with audio tensor
Returns:
Path to saved audio file
"""
# Generate unique filename
timestamp = int(time.time() * 1000)
filename = f"{result.model_id}_{result.variant}_{timestamp}.wav"
filepath = self.output_dir / filename
# Convert tensor to numpy and save
audio = result.audio.numpy()
# Handle batch dimension - save first sample if batched
if audio.ndim == 3:
audio = audio[0] # [channels, samples]
# Transpose to [samples, channels] for soundfile
if audio.ndim == 2:
audio = audio.T
sf.write(filepath, audio, result.sample_rate)
logger.debug(f"Saved audio to {filepath}")
return filepath
async def regenerate(
self,
generation_id: str,
new_seed: Optional[int] = None,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> tuple[GenerationResult, Generation]:
"""Regenerate audio using parameters from existing generation.
Args:
generation_id: ID of generation to regenerate
new_seed: Optional new seed (uses original if None)
progress_callback: Optional progress callback
Returns:
Tuple of (GenerationResult, new Generation record)
Raises:
ValueError: If generation not found
"""
# Load original generation
original = await self.database.get_generation(generation_id)
if original is None:
raise ValueError(f"Generation not found: {generation_id}")
# Parse prompts
prompts = original.prompt.split("\n") if "\n" in original.prompt else [original.prompt]
# Regenerate with same or new seed
return await self.generate(
model_id=original.model,
variant=original.variant,
prompts=prompts,
duration=original.parameters.get("duration", 10.0),
temperature=original.parameters.get("temperature", 1.0),
top_k=original.parameters.get("top_k", 250),
top_p=original.parameters.get("top_p", 0.0),
cfg_coef=original.parameters.get("cfg_coef", 3.0),
seed=new_seed if new_seed is not None else original.seed,
conditioning=original.conditioning,
project_id=original.project_id,
preset_used=original.preset_used,
tags=original.tags,
progress_callback=progress_callback,
)
def get_stats(self) -> dict[str, Any]:
"""Get generation statistics.
Returns:
Dictionary with generation stats
"""
return {
"generation_count": self._generation_count,
"total_duration_generated": self._total_duration_generated,
"oom_stats": self.oom_handler.get_stats(),
}
def estimate_generation_time(
self, model_id: str, variant: Optional[str], duration: float
) -> float:
"""Estimate generation time for given parameters.
Args:
model_id: Model family
variant: Model variant
duration: Target audio duration
Returns:
Estimated generation time in seconds
"""
# Rough estimates based on model type and RTX 4090
# These are approximations and vary based on many factors
estimates = {
"musicgen": {
"small": 0.8, # seconds per second of audio
"medium": 1.5,
"large": 3.0,
"melody": 1.8,
},
"audiogen": {
"medium": 1.5,
},
"magnet": {
"small-10secs": 0.3, # Non-autoregressive is faster
"medium-10secs": 0.5,
"small-30secs": 0.3,
"medium-30secs": 0.5,
},
"musicgen-style": {
"medium": 1.8,
},
"jasco": {
"chords-drums-400M": 1.0,
"chords-drums-1B": 1.5,
},
}
family_config, _ = self.registry.get_model_config(model_id, variant)
actual_variant = variant or family_config.default_variant
ratio = estimates.get(model_id, {}).get(actual_variant, 2.0)
return duration * ratio + 5.0 # Add 5s for model loading overhead

View File

@@ -0,0 +1,395 @@
"""Project service for managing projects and generations."""
import logging
import shutil
from pathlib import Path
from typing import Any, Optional
from src.storage.database import Database, Generation, Project, Preset
logger = logging.getLogger(__name__)
class ProjectService:
"""Service for managing projects, generations, and presets.
Provides a high-level API for project organization and
generation history management.
"""
def __init__(self, database: Database, output_dir: Path):
"""Initialize project service.
Args:
database: Database instance
output_dir: Directory where audio files are stored
"""
self.database = database
self.output_dir = Path(output_dir)
# Project Operations
async def create_project(
self, name: str, description: str = ""
) -> Project:
"""Create a new project.
Args:
name: Project name
description: Optional description
Returns:
Created project
"""
project = Project.create(name, description)
await self.database.create_project(project)
logger.info(f"Created project: {project.id} ({name})")
return project
async def get_project(self, project_id: str) -> Optional[Project]:
"""Get a project by ID."""
return await self.database.get_project(project_id)
async def list_projects(
self, limit: int = 100, offset: int = 0
) -> list[Project]:
"""List all projects."""
return await self.database.list_projects(limit, offset)
async def update_project(
self,
project_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
) -> Optional[Project]:
"""Update a project.
Args:
project_id: Project ID
name: New name (None to keep current)
description: New description (None to keep current)
Returns:
Updated project, or None if not found
"""
project = await self.database.get_project(project_id)
if project is None:
return None
if name is not None:
project.name = name
if description is not None:
project.description = description
await self.database.update_project(project)
logger.info(f"Updated project: {project_id}")
return project
async def delete_project(
self, project_id: str, delete_files: bool = False
) -> bool:
"""Delete a project.
Args:
project_id: Project ID
delete_files: If True, also delete associated audio files
Returns:
True if deleted
"""
if delete_files:
# Get all generations and delete their files
generations = await self.database.list_generations(project_id=project_id)
for gen in generations:
if gen.audio_path:
try:
Path(gen.audio_path).unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to delete {gen.audio_path}: {e}")
result = await self.database.delete_project(project_id)
if result:
logger.info(f"Deleted project: {project_id}")
return result
async def get_project_stats(self, project_id: str) -> dict[str, Any]:
"""Get statistics for a project.
Args:
project_id: Project ID
Returns:
Dictionary with project statistics
"""
generations = await self.database.list_generations(
project_id=project_id, limit=10000
)
total_duration = sum(g.duration_seconds or 0 for g in generations)
models_used = {}
for gen in generations:
key = f"{gen.model}/{gen.variant}"
models_used[key] = models_used.get(key, 0) + 1
return {
"generation_count": len(generations),
"total_duration_seconds": total_duration,
"models_used": models_used,
}
# Generation Operations
async def get_generation(self, generation_id: str) -> Optional[Generation]:
"""Get a generation by ID."""
return await self.database.get_generation(generation_id)
async def list_generations(
self,
project_id: Optional[str] = None,
model: Optional[str] = None,
search: Optional[str] = None,
limit: int = 100,
offset: int = 0,
) -> list[Generation]:
"""List generations with optional filters.
Args:
project_id: Filter by project
model: Filter by model family
search: Search in prompts, names, and tags
limit: Maximum results
offset: Pagination offset
Returns:
List of generations
"""
return await self.database.list_generations(
project_id=project_id,
model=model,
search=search,
limit=limit,
offset=offset,
)
async def update_generation(
self,
generation_id: str,
name: Optional[str] = None,
tags: Optional[list[str]] = None,
notes: Optional[str] = None,
project_id: Optional[str] = None,
) -> Optional[Generation]:
"""Update a generation's metadata.
Args:
generation_id: Generation ID
name: New name
tags: New tags (replaces existing)
notes: New notes
project_id: Move to different project
Returns:
Updated generation, or None if not found
"""
generation = await self.database.get_generation(generation_id)
if generation is None:
return None
if name is not None:
generation.name = name
if tags is not None:
generation.tags = tags
if notes is not None:
generation.notes = notes
if project_id is not None:
generation.project_id = project_id
await self.database.update_generation(generation)
logger.info(f"Updated generation: {generation_id}")
return generation
async def delete_generation(
self, generation_id: str, delete_file: bool = True
) -> bool:
"""Delete a generation.
Args:
generation_id: Generation ID
delete_file: If True, also delete audio file
Returns:
True if deleted
"""
if delete_file:
generation = await self.database.get_generation(generation_id)
if generation and generation.audio_path:
try:
Path(generation.audio_path).unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to delete audio file: {e}")
result = await self.database.delete_generation(generation_id)
if result:
logger.info(f"Deleted generation: {generation_id}")
return result
async def move_generations_to_project(
self, generation_ids: list[str], project_id: Optional[str]
) -> int:
"""Move multiple generations to a project.
Args:
generation_ids: List of generation IDs
project_id: Target project ID (None to unlink)
Returns:
Number of generations moved
"""
moved = 0
for gen_id in generation_ids:
result = await self.update_generation(gen_id, project_id=project_id)
if result:
moved += 1
logger.info(f"Moved {moved} generations to project {project_id}")
return moved
# Preset Operations
async def create_preset(
self,
model: str,
name: str,
parameters: dict[str, Any],
description: str = "",
) -> Preset:
"""Create a custom preset.
Args:
model: Model family this preset is for
name: Preset name
parameters: Generation parameters
description: Optional description
Returns:
Created preset
"""
preset = Preset.create(model, name, parameters, description)
await self.database.create_preset(preset)
logger.info(f"Created preset: {preset.id} ({name}) for {model}")
return preset
async def list_presets(
self, model: Optional[str] = None, include_builtin: bool = True
) -> list[Preset]:
"""List presets with optional model filter.
Args:
model: Filter by model family
include_builtin: Include built-in presets
Returns:
List of presets
"""
return await self.database.list_presets(model, include_builtin)
async def get_preset(self, preset_id: str) -> Optional[Preset]:
"""Get a preset by ID."""
return await self.database.get_preset(preset_id)
async def delete_preset(self, preset_id: str) -> bool:
"""Delete a custom preset.
Note: Built-in presets cannot be deleted.
Args:
preset_id: Preset ID
Returns:
True if deleted
"""
result = await self.database.delete_preset(preset_id)
if result:
logger.info(f"Deleted preset: {preset_id}")
return result
# Export Operations
async def export_project(
self, project_id: str, output_path: Path, include_metadata: bool = True
) -> Path:
"""Export a project as a ZIP archive.
Args:
project_id: Project ID
output_path: Output ZIP file path
include_metadata: Include JSON metadata file
Returns:
Path to created ZIP file
"""
import json
import tempfile
import zipfile
project = await self.database.get_project(project_id)
if project is None:
raise ValueError(f"Project not found: {project_id}")
generations = await self.database.list_generations(
project_id=project_id, limit=10000
)
with tempfile.TemporaryDirectory() as tmpdir:
tmppath = Path(tmpdir)
# Copy audio files
for gen in generations:
if gen.audio_path and Path(gen.audio_path).exists():
src = Path(gen.audio_path)
dst = tmppath / src.name
shutil.copy2(src, dst)
# Create metadata file
if include_metadata:
metadata = {
"project": {
"id": project.id,
"name": project.name,
"description": project.description,
"created_at": project.created_at.isoformat(),
},
"generations": [
{
"id": g.id,
"model": g.model,
"variant": g.variant,
"prompt": g.prompt,
"parameters": g.parameters,
"duration": g.duration_seconds,
"audio_file": Path(g.audio_path).name if g.audio_path else None,
"created_at": g.created_at.isoformat(),
"tags": g.tags,
"seed": g.seed,
}
for g in generations
],
}
metadata_path = tmppath / "metadata.json"
metadata_path.write_text(json.dumps(metadata, indent=2))
# Create ZIP
output_path = Path(output_path)
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf:
for file in tmppath.iterdir():
zf.write(file, file.name)
logger.info(f"Exported project {project_id} to {output_path}")
return output_path
# Statistics
async def get_stats(self) -> dict[str, Any]:
"""Get overall statistics."""
return await self.database.get_stats()

5
src/storage/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Storage module for AudioCraft Studio."""
from src.storage.database import Database, Generation, Project, Preset
__all__ = ["Database", "Generation", "Project", "Preset"]

550
src/storage/database.py Normal file
View File

@@ -0,0 +1,550 @@
"""SQLite database for projects, generations, and presets."""
import json
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Optional
import aiosqlite
logger = logging.getLogger(__name__)
@dataclass
class Project:
"""Project entity for organizing generations."""
id: str
name: str
created_at: datetime
updated_at: datetime
description: str = ""
@classmethod
def create(cls, name: str, description: str = "") -> "Project":
"""Create a new project with generated ID."""
now = datetime.utcnow()
return cls(
id=f"proj_{uuid.uuid4().hex[:12]}",
name=name,
created_at=now,
updated_at=now,
description=description,
)
@dataclass
class Generation:
"""Audio generation record."""
id: str
project_id: Optional[str]
model: str
variant: str
prompt: str
parameters: dict[str, Any]
created_at: datetime
audio_path: Optional[str] = None
duration_seconds: Optional[float] = None
sample_rate: Optional[int] = None
preset_used: Optional[str] = None
conditioning: dict[str, Any] = field(default_factory=dict)
name: Optional[str] = None
tags: list[str] = field(default_factory=list)
notes: Optional[str] = None
seed: Optional[int] = None
@classmethod
def create(
cls,
model: str,
variant: str,
prompt: str,
parameters: dict[str, Any],
project_id: Optional[str] = None,
**kwargs,
) -> "Generation":
"""Create a new generation record."""
return cls(
id=f"gen_{uuid.uuid4().hex[:12]}",
project_id=project_id,
model=model,
variant=variant,
prompt=prompt,
parameters=parameters,
created_at=datetime.utcnow(),
**kwargs,
)
@dataclass
class Preset:
"""Generation parameter preset."""
id: str
model: str
name: str
parameters: dict[str, Any]
created_at: datetime
description: str = ""
is_builtin: bool = False
@classmethod
def create(
cls,
model: str,
name: str,
parameters: dict[str, Any],
description: str = "",
) -> "Preset":
"""Create a new custom preset."""
return cls(
id=f"preset_{uuid.uuid4().hex[:12]}",
model=model,
name=name,
parameters=parameters,
created_at=datetime.utcnow(),
description=description,
is_builtin=False,
)
class Database:
"""Async SQLite database for AudioCraft Studio.
Handles storage of projects, generations, and presets.
"""
SCHEMA = """
CREATE TABLE IF NOT EXISTS projects (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT DEFAULT '',
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS generations (
id TEXT PRIMARY KEY,
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
model TEXT NOT NULL,
variant TEXT NOT NULL,
prompt TEXT NOT NULL,
parameters JSON NOT NULL,
preset_used TEXT,
conditioning JSON,
audio_path TEXT,
duration_seconds REAL,
sample_rate INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
name TEXT,
tags JSON,
notes TEXT,
seed INTEGER
);
CREATE TABLE IF NOT EXISTS presets (
id TEXT PRIMARY KEY,
model TEXT NOT NULL,
name TEXT NOT NULL,
description TEXT DEFAULT '',
parameters JSON NOT NULL,
is_builtin BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
"""
def __init__(self, db_path: Path):
"""Initialize database.
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._connection: Optional[aiosqlite.Connection] = None
async def connect(self) -> None:
"""Open database connection and initialize schema."""
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._connection = await aiosqlite.connect(self.db_path)
self._connection.row_factory = aiosqlite.Row
# Initialize schema
await self._connection.executescript(self.SCHEMA)
await self._connection.commit()
logger.info(f"Database connected: {self.db_path}")
async def close(self) -> None:
"""Close database connection."""
if self._connection:
await self._connection.close()
self._connection = None
@property
def conn(self) -> aiosqlite.Connection:
"""Get active connection."""
if not self._connection:
raise RuntimeError("Database not connected")
return self._connection
# Project Methods
async def create_project(self, project: Project) -> Project:
"""Create a new project."""
await self.conn.execute(
"""
INSERT INTO projects (id, name, description, created_at, updated_at)
VALUES (?, ?, ?, ?, ?)
""",
(
project.id,
project.name,
project.description,
project.created_at.isoformat(),
project.updated_at.isoformat(),
),
)
await self.conn.commit()
return project
async def get_project(self, project_id: str) -> Optional[Project]:
"""Get a project by ID."""
async with self.conn.execute(
"SELECT * FROM projects WHERE id = ?", (project_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return Project(
id=row["id"],
name=row["name"],
description=row["description"] or "",
created_at=datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]),
)
return None
async def list_projects(
self, limit: int = 100, offset: int = 0
) -> list[Project]:
"""List all projects, ordered by last update."""
async with self.conn.execute(
"""
SELECT * FROM projects
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset),
) as cursor:
rows = await cursor.fetchall()
return [
Project(
id=row["id"],
name=row["name"],
description=row["description"] or "",
created_at=datetime.fromisoformat(row["created_at"]),
updated_at=datetime.fromisoformat(row["updated_at"]),
)
for row in rows
]
async def update_project(self, project: Project) -> None:
"""Update a project."""
project.updated_at = datetime.utcnow()
await self.conn.execute(
"""
UPDATE projects SET name = ?, description = ?, updated_at = ?
WHERE id = ?
""",
(project.name, project.description, project.updated_at.isoformat(), project.id),
)
await self.conn.commit()
async def delete_project(self, project_id: str) -> bool:
"""Delete a project (generations are kept but unlinked)."""
result = await self.conn.execute(
"DELETE FROM projects WHERE id = ?", (project_id,)
)
await self.conn.commit()
return result.rowcount > 0
# Generation Methods
async def create_generation(self, generation: Generation) -> Generation:
"""Create a new generation record."""
await self.conn.execute(
"""
INSERT INTO generations (
id, project_id, model, variant, prompt, parameters,
preset_used, conditioning, audio_path, duration_seconds,
sample_rate, created_at, name, tags, notes, seed
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
generation.id,
generation.project_id,
generation.model,
generation.variant,
generation.prompt,
json.dumps(generation.parameters),
generation.preset_used,
json.dumps(generation.conditioning),
generation.audio_path,
generation.duration_seconds,
generation.sample_rate,
generation.created_at.isoformat(),
generation.name,
json.dumps(generation.tags),
generation.notes,
generation.seed,
),
)
await self.conn.commit()
# Update project's updated_at if linked
if generation.project_id:
await self.conn.execute(
"UPDATE projects SET updated_at = ? WHERE id = ?",
(datetime.utcnow().isoformat(), generation.project_id),
)
await self.conn.commit()
return generation
async def get_generation(self, generation_id: str) -> Optional[Generation]:
"""Get a generation by ID."""
async with self.conn.execute(
"SELECT * FROM generations WHERE id = ?", (generation_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_generation(row)
return None
async def list_generations(
self,
project_id: Optional[str] = None,
model: Optional[str] = None,
limit: int = 100,
offset: int = 0,
search: Optional[str] = None,
) -> list[Generation]:
"""List generations with optional filters."""
conditions = []
params = []
if project_id:
conditions.append("project_id = ?")
params.append(project_id)
if model:
conditions.append("model = ?")
params.append(model)
if search:
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
search_pattern = f"%{search}%"
params.extend([search_pattern, search_pattern, search_pattern])
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"""
SELECT * FROM generations
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT ? OFFSET ?
""",
(*params, limit, offset),
) as cursor:
rows = await cursor.fetchall()
return [self._row_to_generation(row) for row in rows]
async def update_generation(self, generation: Generation) -> None:
"""Update a generation record."""
await self.conn.execute(
"""
UPDATE generations SET
project_id = ?, name = ?, tags = ?, notes = ?,
audio_path = ?, duration_seconds = ?, sample_rate = ?
WHERE id = ?
""",
(
generation.project_id,
generation.name,
json.dumps(generation.tags),
generation.notes,
generation.audio_path,
generation.duration_seconds,
generation.sample_rate,
generation.id,
),
)
await self.conn.commit()
async def delete_generation(self, generation_id: str) -> bool:
"""Delete a generation record."""
result = await self.conn.execute(
"DELETE FROM generations WHERE id = ?", (generation_id,)
)
await self.conn.commit()
return result.rowcount > 0
async def count_generations(
self, project_id: Optional[str] = None, model: Optional[str] = None
) -> int:
"""Count generations with optional filters."""
conditions = []
params = []
if project_id:
conditions.append("project_id = ?")
params.append(project_id)
if model:
conditions.append("model = ?")
params.append(model)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
params,
) as cursor:
row = await cursor.fetchone()
return row[0] if row else 0
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
"""Convert database row to Generation object."""
return Generation(
id=row["id"],
project_id=row["project_id"],
model=row["model"],
variant=row["variant"],
prompt=row["prompt"],
parameters=json.loads(row["parameters"]),
preset_used=row["preset_used"],
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
audio_path=row["audio_path"],
duration_seconds=row["duration_seconds"],
sample_rate=row["sample_rate"],
created_at=datetime.fromisoformat(row["created_at"]),
name=row["name"],
tags=json.loads(row["tags"]) if row["tags"] else [],
notes=row["notes"],
seed=row["seed"],
)
# Preset Methods
async def create_preset(self, preset: Preset) -> Preset:
"""Create a new preset."""
await self.conn.execute(
"""
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
preset.id,
preset.model,
preset.name,
preset.description,
json.dumps(preset.parameters),
preset.is_builtin,
preset.created_at.isoformat(),
),
)
await self.conn.commit()
return preset
async def get_preset(self, preset_id: str) -> Optional[Preset]:
"""Get a preset by ID."""
async with self.conn.execute(
"SELECT * FROM presets WHERE id = ?", (preset_id,)
) as cursor:
row = await cursor.fetchone()
if row:
return self._row_to_preset(row)
return None
async def list_presets(
self, model: Optional[str] = None, include_builtin: bool = True
) -> list[Preset]:
"""List presets with optional model filter."""
conditions = []
params = []
if model:
conditions.append("model = ?")
params.append(model)
if not include_builtin:
conditions.append("is_builtin = FALSE")
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with self.conn.execute(
f"""
SELECT * FROM presets
WHERE {where_clause}
ORDER BY is_builtin DESC, name ASC
""",
params,
) as cursor:
rows = await cursor.fetchall()
return [self._row_to_preset(row) for row in rows]
async def delete_preset(self, preset_id: str) -> bool:
"""Delete a preset (only custom presets can be deleted)."""
result = await self.conn.execute(
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
(preset_id,),
)
await self.conn.commit()
return result.rowcount > 0
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
"""Convert database row to Preset object."""
return Preset(
id=row["id"],
model=row["model"],
name=row["name"],
description=row["description"] or "",
parameters=json.loads(row["parameters"]),
is_builtin=bool(row["is_builtin"]),
created_at=datetime.fromisoformat(row["created_at"]),
)
# Utility Methods
async def get_stats(self) -> dict[str, Any]:
"""Get database statistics."""
stats = {}
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
row = await cursor.fetchone()
stats["projects"] = row[0] if row else 0
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
row = await cursor.fetchone()
stats["generations"] = row[0] if row else 0
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
row = await cursor.fetchone()
stats["presets"] = row[0] if row else 0
async with self.conn.execute(
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
) as cursor:
rows = await cursor.fetchall()
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
return stats

5
src/ui/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Gradio UI for AudioCraft Studio."""
from src.ui.app import create_app
__all__ = ["create_app"]

355
src/ui/app.py Normal file
View File

@@ -0,0 +1,355 @@
"""Main Gradio application for AudioCraft Studio."""
import asyncio
import gradio as gr
from typing import Any, Optional
from pathlib import Path
from src.ui.theme import create_audiocraft_theme, get_custom_css
from src.ui.state import UIState, DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.vram_monitor import create_vram_monitor
from src.ui.tabs import (
create_dashboard_tab,
create_musicgen_tab,
create_audiogen_tab,
create_magnet_tab,
create_style_tab,
create_jasco_tab,
)
from src.ui.pages import create_projects_page, create_settings_page
from config.settings import get_settings
class AudioCraftApp:
"""Main AudioCraft Studio Gradio application."""
def __init__(
self,
generation_service: Any = None,
batch_processor: Any = None,
project_service: Any = None,
gpu_manager: Any = None,
model_registry: Any = None,
):
"""Initialize the application.
Args:
generation_service: Service for handling generations
batch_processor: Service for batch/queue processing
project_service: Service for project management
gpu_manager: GPU memory manager
model_registry: Model registry for loading/unloading
"""
self.settings = get_settings()
self.generation_service = generation_service
self.batch_processor = batch_processor
self.project_service = project_service
self.gpu_manager = gpu_manager
self.model_registry = model_registry
self.ui_state = UIState()
self.app: Optional[gr.Blocks] = None
def _get_queue_status(self) -> dict[str, Any]:
"""Get current queue status."""
if self.batch_processor:
return {
"queue_size": len(self.batch_processor.queue),
"active_jobs": self.batch_processor.active_count,
"completed_today": self.batch_processor.completed_count,
}
return {"queue_size": 0, "active_jobs": 0, "completed_today": 0}
def _get_recent_generations(self, limit: int = 5) -> list[dict[str, Any]]:
"""Get recent generations."""
if self.project_service:
try:
return asyncio.run(self.project_service.get_recent_generations(limit))
except Exception:
pass
return []
def _get_gpu_status(self) -> dict[str, Any]:
"""Get GPU memory status."""
if self.gpu_manager:
return {
"used_gb": self.gpu_manager.get_used_memory() / 1024**3,
"total_gb": self.gpu_manager.total_memory / 1024**3,
"utilization_percent": self.gpu_manager.get_utilization(),
"available_gb": self.gpu_manager.get_available_memory() / 1024**3,
}
return {"used_gb": 0, "total_gb": 24, "utilization_percent": 0, "available_gb": 24}
async def _generate(self, **kwargs) -> tuple[Any, Any]:
"""Generate audio using the generation service."""
if self.generation_service:
return await self.generation_service.generate(**kwargs)
raise RuntimeError("Generation service not configured")
def _add_to_queue(self, **kwargs) -> Any:
"""Add generation job to queue."""
if self.batch_processor:
return self.batch_processor.add_job(**kwargs)
raise RuntimeError("Batch processor not configured")
def _get_projects(self) -> list[dict]:
"""Get all projects."""
if self.project_service:
try:
return asyncio.run(self.project_service.list_projects())
except Exception:
pass
return []
def _get_generations(self, project_id: str, limit: int, offset: int) -> list[dict]:
"""Get generations for a project."""
if self.project_service:
try:
return asyncio.run(
self.project_service.list_generations(project_id, limit, offset)
)
except Exception:
pass
return []
def _delete_generation(self, generation_id: str) -> bool:
"""Delete a generation."""
if self.project_service:
try:
asyncio.run(self.project_service.delete_generation(generation_id))
return True
except Exception:
pass
return False
def _export_project(self, project_id: str) -> str:
"""Export project as ZIP."""
if self.project_service:
return asyncio.run(self.project_service.export_project_zip(project_id))
raise RuntimeError("Project service not configured")
def _create_project(self, name: str, description: str) -> dict:
"""Create a new project."""
if self.project_service:
return asyncio.run(self.project_service.create_project(name, description))
raise RuntimeError("Project service not configured")
def _get_app_settings(self) -> dict:
"""Get application settings."""
return {
"output_dir": str(self.settings.output_dir),
"default_format": self.settings.default_format,
"sample_rate": self.settings.sample_rate,
"normalize_audio": self.settings.normalize_audio,
"theme_mode": "Dark",
"show_advanced": False,
"auto_play": True,
"comfyui_reserve_gb": self.settings.comfyui_reserve_gb,
"idle_timeout_minutes": self.settings.idle_unload_minutes,
"max_loaded_models": self.settings.max_loaded_models,
"musicgen_variant": "medium",
"musicgen_duration": 10,
"audiogen_duration": 5,
"magnet_variant": "medium",
"magnet_decoding_steps": 20,
"api_enabled": self.settings.api_enabled,
"api_port": self.settings.api_port,
"rate_limit": self.settings.api_rate_limit,
"max_batch_size": self.settings.max_batch_size,
"max_queue_size": self.settings.max_queue_size,
"max_workers": 1,
"priority_queue": False,
}
def _update_app_settings(self, settings: dict) -> bool:
"""Update application settings."""
# In a real implementation, this would persist settings
# For now, just return success
return True
def _clear_cache(self) -> bool:
"""Clear model cache."""
if self.model_registry:
try:
self.model_registry.clear_cache()
return True
except Exception:
pass
return False
def _unload_all_models(self) -> bool:
"""Unload all models from memory."""
if self.model_registry:
try:
asyncio.run(self.model_registry.unload_all())
return True
except Exception:
pass
return False
def build(self) -> gr.Blocks:
"""Build the Gradio application."""
theme = create_audiocraft_theme()
css = get_custom_css()
with gr.Blocks(
theme=theme,
css=css,
title="AudioCraft Studio",
analytics_enabled=False,
) as app:
# Header with VRAM monitor
with gr.Row():
with gr.Column(scale=4):
gr.Markdown("# AudioCraft Studio")
with gr.Column(scale=1):
vram_monitor = create_vram_monitor(
get_status_fn=self._get_gpu_status,
update_interval=5,
)
# Main tabs
with gr.Tabs() as main_tabs:
# Dashboard
with gr.TabItem("Dashboard", id="dashboard"):
dashboard = create_dashboard_tab(
get_queue_status=self._get_queue_status,
get_recent_generations=self._get_recent_generations,
get_gpu_status=self._get_gpu_status,
)
# Model tabs
with gr.TabItem("MusicGen", id="musicgen"):
musicgen = create_musicgen_tab(
generate_fn=self._generate,
add_to_queue_fn=self._add_to_queue,
)
with gr.TabItem("AudioGen", id="audiogen"):
audiogen = create_audiogen_tab(
generate_fn=self._generate,
add_to_queue_fn=self._add_to_queue,
)
with gr.TabItem("MAGNeT", id="magnet"):
magnet = create_magnet_tab(
generate_fn=self._generate,
add_to_queue_fn=self._add_to_queue,
)
with gr.TabItem("Style", id="style"):
style = create_style_tab(
generate_fn=self._generate,
add_to_queue_fn=self._add_to_queue,
)
with gr.TabItem("JASCO", id="jasco"):
jasco = create_jasco_tab(
generate_fn=self._generate,
add_to_queue_fn=self._add_to_queue,
)
# Projects
with gr.TabItem("Projects", id="projects"):
projects = create_projects_page(
get_projects=self._get_projects,
get_generations=self._get_generations,
delete_generation=self._delete_generation,
export_project=self._export_project,
create_project=self._create_project,
)
# Settings
with gr.TabItem("Settings", id="settings"):
settings = create_settings_page(
get_settings=self._get_app_settings,
update_settings=self._update_app_settings,
get_gpu_info=self._get_gpu_status,
clear_cache=self._clear_cache,
unload_all_models=self._unload_all_models,
)
# Footer
gr.Markdown("---")
gr.Markdown(
"AudioCraft Studio | "
"[Documentation](https://github.com/facebookresearch/audiocraft) | "
"Powered by Meta AudioCraft"
)
# Store component references
self.components = {
"vram_monitor": vram_monitor,
"dashboard": dashboard,
"musicgen": musicgen,
"audiogen": audiogen,
"magnet": magnet,
"style": style,
"jasco": jasco,
"projects": projects,
"settings": settings,
}
self.app = app
return app
def launch(
self,
server_name: Optional[str] = None,
server_port: Optional[int] = None,
share: bool = False,
**kwargs,
) -> None:
"""Launch the Gradio application.
Args:
server_name: Server hostname
server_port: Server port
share: Whether to create a public share link
**kwargs: Additional arguments for gr.Blocks.launch()
"""
if self.app is None:
self.build()
self.app.launch(
server_name=server_name or self.settings.host,
server_port=server_port or self.settings.gradio_port,
share=share,
show_error=True,
**kwargs,
)
def create_app(
generation_service: Any = None,
batch_processor: Any = None,
project_service: Any = None,
gpu_manager: Any = None,
model_registry: Any = None,
) -> AudioCraftApp:
"""Create and return the AudioCraft application.
Args:
generation_service: Service for handling generations
batch_processor: Service for batch/queue processing
project_service: Service for project management
gpu_manager: GPU memory manager
model_registry: Model registry for loading/unloading
Returns:
Configured AudioCraftApp instance
"""
return AudioCraftApp(
generation_service=generation_service,
batch_processor=batch_processor,
project_service=project_service,
gpu_manager=gpu_manager,
model_registry=model_registry,
)
# Standalone launch for development/testing
if __name__ == "__main__":
app = create_app()
app.launch()

View File

@@ -0,0 +1,13 @@
"""Reusable UI components for AudioCraft Studio."""
from src.ui.components.vram_monitor import create_vram_monitor
from src.ui.components.audio_player import create_audio_player
from src.ui.components.preset_selector import create_preset_selector
from src.ui.components.generation_params import create_generation_params
__all__ = [
"create_vram_monitor",
"create_audio_player",
"create_preset_selector",
"create_generation_params",
]

View File

@@ -0,0 +1,178 @@
"""Audio player component with waveform visualization."""
import gradio as gr
from pathlib import Path
from typing import Any, Optional, Callable
def create_audio_player(
label: str = "Generated Audio",
show_waveform: bool = True,
show_download: bool = True,
show_info: bool = True,
) -> dict[str, Any]:
"""Create audio player component with optional waveform.
Args:
label: Label for the audio component
show_waveform: Show waveform image
show_download: Show download buttons
show_info: Show audio info (duration, sample rate)
Returns:
Dictionary with component references
"""
with gr.Group():
# Audio player
audio_output = gr.Audio(
label=label,
type="filepath",
interactive=False,
show_download_button=show_download,
)
# Waveform visualization
if show_waveform:
waveform_image = gr.Image(
label="Waveform",
type="filepath",
interactive=False,
height=100,
visible=False,
)
else:
waveform_image = None
# Audio info
if show_info:
with gr.Row():
duration_text = gr.Textbox(
label="Duration",
value="",
interactive=False,
scale=1,
)
sample_rate_text = gr.Textbox(
label="Sample Rate",
value="",
interactive=False,
scale=1,
)
seed_text = gr.Textbox(
label="Seed",
value="",
interactive=False,
scale=1,
)
else:
duration_text = None
sample_rate_text = None
seed_text = None
# Download buttons
if show_download:
with gr.Row():
download_wav = gr.Button("Download WAV", size="sm")
download_mp3 = gr.Button("Download MP3", size="sm")
download_flac = gr.Button("Download FLAC", size="sm")
else:
download_wav = download_mp3 = download_flac = None
return {
"audio": audio_output,
"waveform": waveform_image,
"duration": duration_text,
"sample_rate": sample_rate_text,
"seed": seed_text,
"download_wav": download_wav,
"download_mp3": download_mp3,
"download_flac": download_flac,
}
def update_audio_player(
audio_path: Optional[str],
duration: Optional[float] = None,
sample_rate: Optional[int] = None,
seed: Optional[int] = None,
waveform_path: Optional[str] = None,
) -> tuple:
"""Update audio player with new audio.
Args:
audio_path: Path to audio file
duration: Audio duration in seconds
sample_rate: Audio sample rate
seed: Generation seed
waveform_path: Path to waveform image
Returns:
Tuple of update values for components
"""
duration_str = f"{duration:.2f}s" if duration else ""
sample_rate_str = f"{sample_rate} Hz" if sample_rate else ""
seed_str = str(seed) if seed is not None else ""
waveform_update = gr.update(value=waveform_path, visible=waveform_path is not None)
return (
audio_path,
waveform_update,
duration_str,
sample_rate_str,
seed_str,
)
def create_generation_output() -> dict[str, Any]:
"""Create generation output section with audio player and metadata.
Returns:
Dictionary with component references
"""
with gr.Group():
gr.Markdown("### Output")
# Status/progress
with gr.Row():
status_text = gr.Markdown("Ready to generate")
progress_bar = gr.Slider(
minimum=0,
maximum=100,
value=0,
label="Progress",
interactive=False,
visible=False,
)
# Audio player
player = create_audio_player(
label="Generated Audio",
show_waveform=True,
show_download=True,
show_info=True,
)
# Generation metadata
with gr.Accordion("Generation Details", open=False):
generation_info = gr.JSON(
label="Parameters",
value={},
)
# Actions
with gr.Row():
save_btn = gr.Button("Save to Project", variant="secondary")
regenerate_btn = gr.Button("Regenerate", variant="secondary")
add_queue_btn = gr.Button("Add to Queue", variant="secondary")
return {
"status": status_text,
"progress": progress_bar,
"player": player,
"info": generation_info,
"save_btn": save_btn,
"regenerate_btn": regenerate_btn,
"add_queue_btn": add_queue_btn,
}

View File

@@ -0,0 +1,199 @@
"""Generation parameters component."""
import gradio as gr
from typing import Any, Optional
def create_generation_params(
model_id: str,
show_advanced: bool = False,
max_duration: float = 30.0,
) -> dict[str, Any]:
"""Create generation parameters panel.
Args:
model_id: Model family for customizing available options
show_advanced: Whether to show advanced parameters by default
max_duration: Maximum allowed duration
Returns:
Dictionary with component references
"""
# Model-specific defaults
defaults = {
"musicgen": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
"audiogen": {"duration": 5, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
"magnet": {"duration": 10, "temperature": 3.0, "top_k": 0, "top_p": 0.9, "cfg_coef": 3.0},
"musicgen-style": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
"jasco": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
}
d = defaults.get(model_id, defaults["musicgen"])
with gr.Group():
# Basic parameters (always visible)
duration_slider = gr.Slider(
minimum=1,
maximum=max_duration,
value=d["duration"],
step=1,
label="Duration (seconds)",
info="Length of audio to generate",
)
# Advanced parameters (expandable)
with gr.Accordion("Advanced Parameters", open=show_advanced):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=d["temperature"],
step=0.05,
label="Temperature",
info="Higher = more random, lower = more deterministic",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=d["cfg_coef"],
step=0.5,
label="CFG Coefficient",
info="Classifier-free guidance strength",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=d["top_k"],
step=10,
label="Top-K",
info="Token selection limit (0 = disabled)",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=d["top_p"],
step=0.05,
label="Top-P (Nucleus)",
info="Cumulative probability threshold (0 = disabled)",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed",
info="Random seed for reproducibility (leave empty for random)",
precision=0,
)
use_random_seed = gr.Checkbox(
value=True,
label="Random Seed",
)
# Reset button
reset_btn = gr.Button("Reset to Defaults", size="sm", variant="secondary")
def reset_params():
"""Reset all parameters to defaults."""
return (
d["duration"],
d["temperature"],
d["cfg_coef"],
d["top_k"],
d["top_p"],
None,
True,
)
reset_btn.click(
fn=reset_params,
outputs=[
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
use_random_seed,
],
)
# Link random seed checkbox to seed input
def toggle_seed(use_random: bool, current_seed: Optional[int]):
if use_random:
return gr.update(value=None, interactive=False)
return gr.update(interactive=True)
use_random_seed.change(
fn=toggle_seed,
inputs=[use_random_seed, seed_input],
outputs=[seed_input],
)
return {
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"use_random_seed": use_random_seed,
"reset_btn": reset_btn,
}
def create_model_variant_selector(
model_id: str,
variants: list[dict[str, Any]],
default_variant: str = "medium",
) -> dict[str, Any]:
"""Create model variant selector.
Args:
model_id: Model family ID
variants: List of variant configurations
default_variant: Default variant to select
Returns:
Dictionary with component references
"""
# Build choices with descriptions
choices = []
for v in variants:
name = v.get("name", v.get("id", "unknown"))
vram = v.get("vram_mb", 0)
desc = v.get("description", "")
label = f"{name} ({vram/1024:.1f}GB)"
choices.append((label, name))
with gr.Group():
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=choices,
value=default_variant,
interactive=True,
)
variant_info = gr.Markdown(
value="",
visible=True,
)
def update_info(variant_name: str):
for v in variants:
if v.get("name", v.get("id")) == variant_name:
return v.get("description", "")
return ""
variant_dropdown.change(
fn=update_info,
inputs=[variant_dropdown],
outputs=[variant_info],
)
return {
"dropdown": variant_dropdown,
"info": variant_info,
"variants": variants,
}

View File

@@ -0,0 +1,103 @@
"""Preset selector component."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS
def create_preset_selector(
model_id: str,
on_preset_select: Optional[Callable[[dict], None]] = None,
) -> dict[str, Any]:
"""Create preset selector component for a model.
Args:
model_id: Model family ID
on_preset_select: Callback when preset is selected
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get(model_id, [])
# Create preset choices
choices = [(p["name"], p["id"]) for p in presets]
choices.append(("Custom", "custom"))
def get_preset_by_id(preset_id: str) -> Optional[dict]:
"""Get preset data by ID."""
for p in presets:
if p["id"] == preset_id:
return p
return None
def on_change(preset_id: str):
"""Handle preset selection change."""
if preset_id == "custom":
return gr.update(visible=True), {}
preset = get_preset_by_id(preset_id)
if preset:
return gr.update(visible=False), preset.get("parameters", {})
return gr.update(visible=True), {}
with gr.Group():
preset_dropdown = gr.Dropdown(
label="Preset",
choices=choices,
value=presets[0]["id"] if presets else "custom",
interactive=True,
)
preset_description = gr.Markdown(
value=presets[0]["description"] if presets else "",
visible=True,
)
return {
"dropdown": preset_dropdown,
"description": preset_description,
"presets": presets,
"get_preset": get_preset_by_id,
"on_change": on_change,
}
def create_preset_chips(
model_id: str,
on_select: Callable[[str], None],
) -> dict[str, Any]:
"""Create preset selector as clickable chips/buttons.
Args:
model_id: Model family ID
on_select: Callback when preset is clicked
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get(model_id, [])
with gr.Row():
buttons = []
for preset in presets:
btn = gr.Button(
preset["name"],
size="sm",
variant="secondary",
)
buttons.append((btn, preset))
custom_btn = gr.Button(
"Custom",
size="sm",
variant="secondary",
)
return {
"buttons": buttons,
"custom_btn": custom_btn,
"presets": presets,
}

View File

@@ -0,0 +1,151 @@
"""VRAM monitor component for GPU memory tracking."""
import gradio as gr
from typing import Any, Callable, Optional
def create_vram_monitor(
get_gpu_status: Callable[[], dict[str, Any]],
get_loaded_models: Callable[[], list[dict[str, Any]]],
unload_model: Callable[[str, str], bool],
load_model: Callable[[str, str], bool],
) -> dict[str, Any]:
"""Create VRAM monitor component.
Args:
get_gpu_status: Function to get GPU status dict
get_loaded_models: Function to get list of loaded models
unload_model: Function to unload a model (model_id, variant)
load_model: Function to load a model (model_id, variant)
Returns:
Dictionary with component references
"""
def refresh_status():
"""Refresh GPU status display."""
status = get_gpu_status()
loaded = get_loaded_models()
# Format VRAM bar
used_gb = status.get("used_gb", 0)
total_gb = status.get("total_gb", 24)
util_pct = status.get("utilization_percent", 0)
vram_text = f"{used_gb:.1f} / {total_gb:.1f} GB ({util_pct:.0f}%)"
# Format loaded models list
if loaded:
models_text = "\n".join([
f"{m['model_id']}/{m['variant']} "
f"(idle: {m['idle_seconds']:.0f}s)"
for m in loaded
])
else:
models_text = "No models loaded"
# Determine status color
if util_pct > 90:
status_color = "🔴"
elif util_pct > 75:
status_color = "🟡"
else:
status_color = "🟢"
status_text = f"{status_color} GPU: {status.get('device', 'N/A')}"
return vram_text, util_pct, models_text, status_text
def handle_unload(model_selection: str):
"""Handle model unload."""
if not model_selection or "/" not in model_selection:
return "Select a model to unload", *refresh_status()
parts = model_selection.split("/")
model_id, variant = parts[0], parts[1]
success = unload_model(model_id, variant)
if success:
msg = f"Unloaded {model_id}/{variant}"
else:
msg = f"Failed to unload {model_id}/{variant}"
return msg, *refresh_status()
with gr.Group():
gr.Markdown("### GPU Memory")
status_text = gr.Markdown("🟢 GPU: Checking...")
with gr.Row():
vram_display = gr.Textbox(
label="VRAM Usage",
value="Loading...",
interactive=False,
scale=3,
)
refresh_btn = gr.Button("🔄", scale=1, min_width=50)
vram_slider = gr.Slider(
minimum=0,
maximum=100,
value=0,
label="",
interactive=False,
visible=True,
)
gr.Markdown("### Loaded Models")
models_display = gr.Textbox(
label="",
value="No models loaded",
interactive=False,
lines=4,
max_lines=6,
)
with gr.Row():
model_selector = gr.Dropdown(
label="Select Model",
choices=[],
interactive=True,
scale=3,
)
unload_btn = gr.Button("Unload", variant="secondary", scale=1)
unload_status = gr.Markdown("")
# Event handlers
def update_model_choices():
loaded = get_loaded_models()
choices = [f"{m['model_id']}/{m['variant']}" for m in loaded]
return gr.update(choices=choices, value=None)
refresh_btn.click(
fn=refresh_status,
outputs=[vram_display, vram_slider, models_display, status_text],
).then(
fn=update_model_choices,
outputs=[model_selector],
)
unload_btn.click(
fn=handle_unload,
inputs=[model_selector],
outputs=[unload_status, vram_display, vram_slider, models_display, status_text],
).then(
fn=update_model_choices,
outputs=[model_selector],
)
return {
"vram_display": vram_display,
"vram_slider": vram_slider,
"models_display": models_display,
"status_text": status_text,
"model_selector": model_selector,
"refresh_btn": refresh_btn,
"unload_btn": unload_btn,
"refresh_fn": refresh_status,
}

9
src/ui/pages/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
"""UI pages for AudioCraft Studio."""
from src.ui.pages.projects_page import create_projects_page
from src.ui.pages.settings_page import create_settings_page
__all__ = [
"create_projects_page",
"create_settings_page",
]

View File

@@ -0,0 +1,374 @@
"""Projects page for managing generations and history."""
import gradio as gr
from typing import Any, Callable, Optional
from datetime import datetime
def create_projects_page(
get_projects: Callable[[], list[dict]],
get_generations: Callable[[str, int, int], list[dict]],
delete_generation: Callable[[str], bool],
export_project: Callable[[str], str],
create_project: Callable[[str, str], dict],
) -> dict[str, Any]:
"""Create projects management page.
Args:
get_projects: Function to get all projects
get_generations: Function to get generations (project_id, limit, offset)
delete_generation: Function to delete a generation
export_project: Function to export project as ZIP
create_project: Function to create new project
Returns:
Dictionary with component references
"""
with gr.Column():
gr.Markdown("# Projects")
gr.Markdown("Browse and manage your generations")
with gr.Row():
# Left sidebar - project list
with gr.Column(scale=1):
gr.Markdown("### Projects")
with gr.Row():
new_project_name = gr.Textbox(
placeholder="New project name...",
show_label=False,
scale=3,
)
new_project_btn = gr.Button("+", size="sm", scale=1)
project_list = gr.Dataframe(
headers=["ID", "Name", "Count"],
datatype=["str", "str", "number"],
col_count=(3, "fixed"),
interactive=False,
height=400,
)
refresh_projects_btn = gr.Button("Refresh Projects", size="sm")
# Main content - generations
with gr.Column(scale=3):
# Selected project info
selected_project_id = gr.State(value=None)
selected_project_name = gr.Markdown("### Select a project")
# Filters
with gr.Row():
model_filter = gr.Dropdown(
label="Model",
choices=[
("All", "all"),
("MusicGen", "musicgen"),
("AudioGen", "audiogen"),
("MAGNeT", "magnet"),
("Style", "musicgen-style"),
("JASCO", "jasco"),
],
value="all",
scale=1,
)
sort_by = gr.Dropdown(
label="Sort By",
choices=[
("Newest First", "newest"),
("Oldest First", "oldest"),
("Duration (Long)", "duration_desc"),
("Duration (Short)", "duration_asc"),
],
value="newest",
scale=1,
)
search_input = gr.Textbox(
label="Search Prompts",
placeholder="Search...",
scale=2,
)
# Generations grid
generations_gallery = gr.Gallery(
label="Generations",
columns=3,
rows=3,
height=400,
object_fit="contain",
show_label=False,
)
# Pagination
with gr.Row():
prev_page_btn = gr.Button("← Previous", size="sm")
page_info = gr.Markdown("Page 1 of 1")
next_page_btn = gr.Button("Next →", size="sm")
current_page = gr.State(value=1)
total_pages = gr.State(value=1)
# Selected generation details
gr.Markdown("---")
gr.Markdown("### Generation Details")
with gr.Row():
with gr.Column(scale=2):
selected_audio = gr.Audio(
label="Audio",
interactive=False,
)
with gr.Column(scale=2):
selected_prompt = gr.Textbox(
label="Prompt",
interactive=False,
lines=2,
)
with gr.Row():
selected_model = gr.Textbox(
label="Model",
interactive=False,
)
selected_duration = gr.Textbox(
label="Duration",
interactive=False,
)
with gr.Row():
selected_seed = gr.Textbox(
label="Seed",
interactive=False,
)
selected_date = gr.Textbox(
label="Created",
interactive=False,
)
# Action buttons
with gr.Row():
regenerate_btn = gr.Button("Regenerate", variant="secondary")
download_btn = gr.Button("Download", variant="secondary")
delete_btn = gr.Button("Delete", variant="stop")
export_project_btn = gr.Button("Export Project", variant="secondary")
# Event handlers
def load_projects():
"""Load all projects into the list."""
projects = get_projects()
data = []
for p in projects:
data.append([
p.get("id", ""),
p.get("name", "Untitled"),
p.get("generation_count", 0),
])
return data
def on_project_select(evt: gr.SelectData, df):
"""Handle project selection from dataframe."""
if evt.index is None:
return None, "### Select a project"
row = evt.index[0]
if row < len(df):
project_id = df[row][0]
project_name = df[row][1]
return project_id, f"### {project_name}"
return None, "### Select a project"
def load_generations(project_id, page, model, sort, search):
"""Load generations for selected project."""
if not project_id:
return [], "Page 0 of 0", 1, 1
limit = 9 # 3x3 grid
offset = (page - 1) * limit
gens = get_generations(project_id, limit + 1, offset)
# Check if there are more pages
has_more = len(gens) > limit
gens = gens[:limit]
# Filter by model if needed
if model != "all":
gens = [g for g in gens if g.get("model") == model]
# Filter by search
if search:
search_lower = search.lower()
gens = [g for g in gens if search_lower in g.get("prompt", "").lower()]
# Sort
if sort == "oldest":
gens = sorted(gens, key=lambda x: x.get("created_at", ""))
elif sort == "duration_desc":
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0), reverse=True)
elif sort == "duration_asc":
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0))
# Default is newest first (already sorted from DB)
# Build gallery items (using waveform images if available)
gallery_items = []
for g in gens:
waveform = g.get("waveform_path")
if waveform:
gallery_items.append((waveform, g.get("prompt", "")[:50]))
else:
# Placeholder
gallery_items.append((None, g.get("prompt", "")[:50]))
# Calculate total pages (estimate)
total = offset + len(gens) + (1 if has_more else 0)
total_p = max(1, (total + limit - 1) // limit)
return gallery_items, f"Page {page} of {total_p}", page, total_p
def on_generation_select(evt: gr.SelectData, project_id):
"""Handle generation selection from gallery."""
if evt.index is None or not project_id:
return None, "", "", "", "", ""
# Get generations again to find the selected one
gens = get_generations(project_id, 100, 0)
if evt.index < len(gens):
gen = gens[evt.index]
return (
gen.get("audio_path"),
gen.get("prompt", ""),
gen.get("model", ""),
f"{gen.get('duration_seconds', 0):.1f}s",
str(gen.get("seed", "")),
gen.get("created_at", "")[:19] if gen.get("created_at") else "",
)
return None, "", "", "", "", ""
def do_create_project(name):
"""Create a new project."""
if not name.strip():
return gr.update(), "Please enter a project name"
project = create_project(name.strip(), "")
projects_data = load_projects()
return projects_data, f"Created project: {name}"
def do_delete_generation(project_id, audio_path):
"""Delete selected generation."""
if not audio_path:
return "No generation selected"
# Find generation by audio path
gens = get_generations(project_id, 100, 0)
for g in gens:
if g.get("audio_path") == audio_path:
if delete_generation(g.get("id")):
return "Generation deleted"
else:
return "Failed to delete"
return "Generation not found"
def do_export_project(project_id):
"""Export project as ZIP."""
if not project_id:
return "No project selected"
try:
zip_path = export_project(project_id)
return f"Exported to: {zip_path}"
except Exception as e:
return f"Export failed: {str(e)}"
# Wire up events
refresh_projects_btn.click(
fn=load_projects,
outputs=[project_list],
)
project_list.select(
fn=on_project_select,
inputs=[project_list],
outputs=[selected_project_id, selected_project_name],
).then(
fn=load_generations,
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
outputs=[generations_gallery, page_info, current_page, total_pages],
)
# Filter changes reload generations
for component in [model_filter, sort_by, search_input]:
component.change(
fn=load_generations,
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
outputs=[generations_gallery, page_info, current_page, total_pages],
)
# Pagination
def go_prev(page, total):
return max(1, page - 1)
def go_next(page, total):
return min(total, page + 1)
prev_page_btn.click(
fn=go_prev,
inputs=[current_page, total_pages],
outputs=[current_page],
).then(
fn=load_generations,
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
outputs=[generations_gallery, page_info, current_page, total_pages],
)
next_page_btn.click(
fn=go_next,
inputs=[current_page, total_pages],
outputs=[current_page],
).then(
fn=load_generations,
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
outputs=[generations_gallery, page_info, current_page, total_pages],
)
# Generation selection
generations_gallery.select(
fn=on_generation_select,
inputs=[selected_project_id],
outputs=[selected_audio, selected_prompt, selected_model, selected_duration, selected_seed, selected_date],
)
# Actions
new_project_btn.click(
fn=do_create_project,
inputs=[new_project_name],
outputs=[project_list, selected_project_name],
)
delete_btn.click(
fn=do_delete_generation,
inputs=[selected_project_id, selected_audio],
outputs=[selected_project_name],
).then(
fn=load_generations,
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
outputs=[generations_gallery, page_info, current_page, total_pages],
)
export_project_btn.click(
fn=do_export_project,
inputs=[selected_project_id],
outputs=[selected_project_name],
)
return {
"project_list": project_list,
"generations_gallery": generations_gallery,
"selected_audio": selected_audio,
"selected_project_id": selected_project_id,
"refresh_fn": load_projects,
}

View File

@@ -0,0 +1,397 @@
"""Settings page for application configuration."""
import gradio as gr
from typing import Any, Callable, Optional
from pathlib import Path
def create_settings_page(
get_settings: Callable[[], dict],
update_settings: Callable[[dict], bool],
get_gpu_info: Callable[[], dict],
clear_cache: Callable[[], bool],
unload_all_models: Callable[[], bool],
) -> dict[str, Any]:
"""Create settings management page.
Args:
get_settings: Function to get current settings
update_settings: Function to update settings
get_gpu_info: Function to get GPU information
clear_cache: Function to clear model cache
unload_all_models: Function to unload all models
Returns:
Dictionary with component references
"""
with gr.Column():
gr.Markdown("# Settings")
gr.Markdown("Configure AudioCraft Studio")
with gr.Tabs():
# General Settings Tab
with gr.TabItem("General"):
with gr.Group():
gr.Markdown("### Output Settings")
output_dir = gr.Textbox(
label="Output Directory",
placeholder="/path/to/output",
info="Where generated audio files are saved",
)
with gr.Row():
default_format = gr.Dropdown(
label="Default Audio Format",
choices=[("WAV", "wav"), ("MP3", "mp3"), ("FLAC", "flac"), ("OGG", "ogg")],
value="wav",
)
sample_rate = gr.Dropdown(
label="Sample Rate",
choices=[
("32000 Hz (AudioCraft default)", 32000),
("44100 Hz (CD quality)", 44100),
("48000 Hz (Video standard)", 48000),
],
value=32000,
)
normalize_audio = gr.Checkbox(
label="Normalize audio output",
value=True,
info="Normalize audio levels to prevent clipping",
)
with gr.Group():
gr.Markdown("### Interface Settings")
theme_mode = gr.Radio(
label="Theme",
choices=["Dark", "Light", "System"],
value="Dark",
)
show_advanced = gr.Checkbox(
label="Show advanced parameters by default",
value=False,
)
auto_play = gr.Checkbox(
label="Auto-play generated audio",
value=True,
)
# GPU & Memory Tab
with gr.TabItem("GPU & Memory"):
with gr.Group():
gr.Markdown("### GPU Information")
gpu_info_display = gr.JSON(
label="GPU Status",
value={},
)
refresh_gpu_btn = gr.Button("Refresh GPU Info", size="sm")
with gr.Group():
gr.Markdown("### Memory Management")
comfyui_reserve = gr.Slider(
minimum=0,
maximum=16,
value=10,
step=0.5,
label="ComfyUI VRAM Reserve (GB)",
info="VRAM to reserve for ComfyUI when running alongside",
)
idle_timeout = gr.Slider(
minimum=1,
maximum=60,
value=15,
step=1,
label="Idle Model Timeout (minutes)",
info="Unload models after this period of inactivity",
)
max_loaded = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Maximum Loaded Models",
info="Maximum number of models to keep in memory",
)
with gr.Group():
gr.Markdown("### Cache Management")
with gr.Row():
clear_cache_btn = gr.Button("Clear Model Cache", variant="secondary")
unload_models_btn = gr.Button("Unload All Models", variant="stop")
cache_status = gr.Markdown("Cache status: Ready")
# Model Defaults Tab
with gr.TabItem("Model Defaults"):
with gr.Group():
gr.Markdown("### MusicGen Defaults")
with gr.Row():
musicgen_variant = gr.Dropdown(
label="Default Variant",
choices=[
("Small", "small"),
("Medium", "medium"),
("Large", "large"),
("Melody", "melody"),
],
value="medium",
)
musicgen_duration = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Default Duration (s)",
)
with gr.Group():
gr.Markdown("### AudioGen Defaults")
audiogen_duration = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Default Duration (s)",
)
with gr.Group():
gr.Markdown("### MAGNeT Defaults")
with gr.Row():
magnet_variant = gr.Dropdown(
label="Default Variant",
choices=[
("Small Music", "small"),
("Medium Music", "medium"),
("Small Audio", "audio-small"),
("Medium Audio", "audio-medium"),
],
value="medium",
)
magnet_decoding_steps = gr.Slider(
minimum=10,
maximum=100,
value=20,
step=5,
label="Decoding Steps",
)
# API Settings Tab
with gr.TabItem("API"):
with gr.Group():
gr.Markdown("### REST API Configuration")
api_enabled = gr.Checkbox(
label="Enable REST API",
value=True,
info="Enable FastAPI endpoints for programmatic access",
)
api_port = gr.Number(
value=8000,
label="API Port",
precision=0,
)
with gr.Row():
api_key_display = gr.Textbox(
label="API Key",
value="••••••••",
interactive=False,
)
regenerate_key_btn = gr.Button("Regenerate", size="sm")
with gr.Group():
gr.Markdown("### Rate Limiting")
rate_limit = gr.Slider(
minimum=1,
maximum=100,
value=10,
step=1,
label="Requests per minute",
)
max_batch_size = gr.Slider(
minimum=1,
maximum=10,
value=4,
step=1,
label="Maximum batch size",
)
# Queue Settings Tab
with gr.TabItem("Queue"):
with gr.Group():
gr.Markdown("### Batch Processing")
max_queue_size = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Maximum Queue Size",
)
max_workers = gr.Slider(
minimum=1,
maximum=4,
value=1,
step=1,
label="Concurrent Workers",
info="Number of parallel generation workers",
)
priority_queue = gr.Checkbox(
label="Enable priority queue",
value=False,
info="Allow high-priority jobs to skip the queue",
)
# Save button
gr.Markdown("---")
with gr.Row():
save_btn = gr.Button("Save Settings", variant="primary", scale=2)
reset_btn = gr.Button("Reset to Defaults", variant="secondary", scale=1)
settings_status = gr.Markdown("")
# Event handlers
def load_settings():
"""Load current settings into form."""
settings = get_settings()
return (
settings.get("output_dir", ""),
settings.get("default_format", "wav"),
settings.get("sample_rate", 32000),
settings.get("normalize_audio", True),
settings.get("theme_mode", "Dark"),
settings.get("show_advanced", False),
settings.get("auto_play", True),
settings.get("comfyui_reserve_gb", 10),
settings.get("idle_timeout_minutes", 15),
settings.get("max_loaded_models", 2),
settings.get("musicgen_variant", "medium"),
settings.get("musicgen_duration", 10),
settings.get("audiogen_duration", 5),
settings.get("magnet_variant", "medium"),
settings.get("magnet_decoding_steps", 20),
settings.get("api_enabled", True),
settings.get("api_port", 8000),
settings.get("rate_limit", 10),
settings.get("max_batch_size", 4),
settings.get("max_queue_size", 100),
settings.get("max_workers", 1),
settings.get("priority_queue", False),
)
def save_settings(
out_dir, fmt, sr, norm, theme, adv, play,
comfyui_res, idle_to, max_load,
mg_var, mg_dur, ag_dur, mn_var, mn_steps,
api_en, api_p, rate, batch, queue_sz, workers, priority
):
"""Save settings from form."""
settings = {
"output_dir": out_dir,
"default_format": fmt,
"sample_rate": sr,
"normalize_audio": norm,
"theme_mode": theme,
"show_advanced": adv,
"auto_play": play,
"comfyui_reserve_gb": comfyui_res,
"idle_timeout_minutes": idle_to,
"max_loaded_models": max_load,
"musicgen_variant": mg_var,
"musicgen_duration": mg_dur,
"audiogen_duration": ag_dur,
"magnet_variant": mn_var,
"magnet_decoding_steps": mn_steps,
"api_enabled": api_en,
"api_port": int(api_p),
"rate_limit": rate,
"max_batch_size": batch,
"max_queue_size": queue_sz,
"max_workers": workers,
"priority_queue": priority,
}
if update_settings(settings):
return "✅ Settings saved successfully"
else:
return "❌ Failed to save settings"
def do_refresh_gpu():
"""Refresh GPU info display."""
return get_gpu_info()
def do_clear_cache():
"""Clear model cache."""
if clear_cache():
return "✅ Cache cleared"
return "❌ Failed to clear cache"
def do_unload_models():
"""Unload all models."""
if unload_all_models():
return "✅ All models unloaded"
return "❌ Failed to unload models"
# Wire up events
refresh_gpu_btn.click(
fn=do_refresh_gpu,
outputs=[gpu_info_display],
)
clear_cache_btn.click(
fn=do_clear_cache,
outputs=[cache_status],
)
unload_models_btn.click(
fn=do_unload_models,
outputs=[cache_status],
)
save_btn.click(
fn=save_settings,
inputs=[
output_dir, default_format, sample_rate, normalize_audio,
theme_mode, show_advanced, auto_play,
comfyui_reserve, idle_timeout, max_loaded,
musicgen_variant, musicgen_duration, audiogen_duration,
magnet_variant, magnet_decoding_steps,
api_enabled, api_port, rate_limit, max_batch_size,
max_queue_size, max_workers, priority_queue,
],
outputs=[settings_status],
)
return {
"output_dir": output_dir,
"default_format": default_format,
"sample_rate": sample_rate,
"comfyui_reserve": comfyui_reserve,
"idle_timeout": idle_timeout,
"api_enabled": api_enabled,
"save_btn": save_btn,
"settings_status": settings_status,
"load_fn": load_settings,
}

294
src/ui/state.py Normal file
View File

@@ -0,0 +1,294 @@
"""State management for Gradio UI."""
from dataclasses import dataclass, field
from typing import Any, Optional
@dataclass
class UIState:
"""Global UI state container."""
# Current view
current_tab: str = "dashboard"
# Generation state
is_generating: bool = False
current_job_id: Optional[str] = None
# Selected items
selected_project_id: Optional[str] = None
selected_generation_id: Optional[str] = None
selected_preset_id: Optional[str] = None
# Model state
selected_model: str = "musicgen"
selected_variant: str = "medium"
# Generation parameters (current values)
prompt: str = ""
duration: float = 10.0
temperature: float = 1.0
top_k: int = 250
top_p: float = 0.0
cfg_coef: float = 3.0
seed: Optional[int] = None
# Conditioning
melody_audio: Optional[str] = None
style_audio: Optional[str] = None
chords: list[dict[str, Any]] = field(default_factory=list)
drums_pattern: str = ""
bpm: float = 120.0
# UI preferences
show_advanced: bool = False
auto_play: bool = True
def reset_generation_params(self) -> None:
"""Reset generation parameters to defaults."""
self.prompt = ""
self.duration = 10.0
self.temperature = 1.0
self.top_k = 250
self.top_p = 0.0
self.cfg_coef = 3.0
self.seed = None
self.melody_audio = None
self.style_audio = None
self.chords = []
self.drums_pattern = ""
def apply_preset(self, preset: dict[str, Any]) -> None:
"""Apply preset parameters."""
params = preset.get("parameters", {})
self.duration = params.get("duration", self.duration)
self.temperature = params.get("temperature", self.temperature)
self.top_k = params.get("top_k", self.top_k)
self.top_p = params.get("top_p", self.top_p)
self.cfg_coef = params.get("cfg_coef", self.cfg_coef)
def to_generation_params(self) -> dict[str, Any]:
"""Convert current state to generation parameters."""
return {
"duration": self.duration,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"cfg_coef": self.cfg_coef,
"seed": self.seed,
}
# Default presets for each model
DEFAULT_PRESETS = {
"musicgen": [
{
"id": "cinematic",
"name": "Cinematic",
"description": "Epic orchestral soundscapes",
"parameters": {
"duration": 30,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
},
},
{
"id": "electronic",
"name": "Electronic",
"description": "Synthesizers and beats",
"parameters": {
"duration": 15,
"temperature": 1.1,
"top_k": 200,
"cfg_coef": 3.5,
},
},
{
"id": "ambient",
"name": "Ambient",
"description": "Atmospheric and calm",
"parameters": {
"duration": 30,
"temperature": 0.9,
"top_k": 300,
"cfg_coef": 2.5,
},
},
{
"id": "rock",
"name": "Rock",
"description": "Guitar-driven energy",
"parameters": {
"duration": 20,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
},
},
{
"id": "jazz",
"name": "Jazz",
"description": "Smooth and improvisational",
"parameters": {
"duration": 20,
"temperature": 1.2,
"top_k": 200,
"cfg_coef": 2.5,
},
},
],
"audiogen": [
{
"id": "nature",
"name": "Nature",
"description": "Natural environments",
"parameters": {
"duration": 10,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
},
},
{
"id": "urban",
"name": "Urban",
"description": "City sounds",
"parameters": {
"duration": 10,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
},
},
{
"id": "mechanical",
"name": "Mechanical",
"description": "Machines and tools",
"parameters": {
"duration": 5,
"temperature": 0.9,
"top_k": 200,
"cfg_coef": 3.5,
},
},
{
"id": "weather",
"name": "Weather",
"description": "Rain, thunder, wind",
"parameters": {
"duration": 10,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
},
},
],
"magnet": [
{
"id": "fast",
"name": "Fast",
"description": "Quick generation",
"parameters": {
"duration": 10,
"temperature": 3.0,
"top_p": 0.9,
"cfg_coef": 3.0,
},
},
{
"id": "quality",
"name": "Quality",
"description": "Higher quality output",
"parameters": {
"duration": 10,
"temperature": 2.5,
"top_p": 0.85,
"cfg_coef": 4.0,
},
},
],
"musicgen-style": [
{
"id": "style_transfer",
"name": "Style Transfer",
"description": "Copy style from reference",
"parameters": {
"duration": 15,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
"eval_q": 3,
"excerpt_length": 3.0,
},
},
],
"jasco": [
{
"id": "pop",
"name": "Pop",
"description": "Pop chord progressions",
"parameters": {
"duration": 10,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
"bpm": 120,
},
},
{
"id": "blues",
"name": "Blues",
"description": "12-bar blues",
"parameters": {
"duration": 10,
"temperature": 1.0,
"top_k": 250,
"cfg_coef": 3.0,
"bpm": 100,
},
},
],
}
# Prompt suggestions for each model
PROMPT_SUGGESTIONS = {
"musicgen": [
"Epic orchestral music with dramatic strings and powerful brass",
"Upbeat electronic dance music with synthesizers and heavy bass",
"Calm acoustic guitar melody with soft piano accompaniment",
"Energetic rock song with electric guitars and driving drums",
"Smooth jazz with saxophone solo and walking bass",
"Ambient soundscape with ethereal pads and gentle textures",
"Cinematic trailer music building to an epic climax",
"Lo-fi hip hop beats with vinyl crackle and mellow keys",
],
"audiogen": [
"Thunder and heavy rain with occasional lightning strikes",
"Busy city street with traffic, horns, and distant sirens",
"Forest ambience with birds singing and wind in trees",
"Ocean waves crashing on a rocky shore",
"Crackling fireplace with wood popping",
"Coffee shop atmosphere with murmuring voices and clinking cups",
"Construction site with hammering and machinery",
"Spaceship engine humming with occasional beeps",
],
"magnet": [
"Energetic pop music with catchy melody",
"Dark electronic music with deep bass",
"Cheerful ukulele tune with whistling",
"Dramatic piano piece with building intensity",
],
"musicgen-style": [
"Generate music in the style of the uploaded reference",
"Create a variation with similar instrumentation",
"Compose a piece matching the mood of the reference",
],
"jasco": [
"Upbeat pop song with the specified chord progression",
"Mellow jazz piece following the chord changes",
"Rock anthem with powerful drum pattern",
"Electronic track with syncopated rhythms",
],
}

17
src/ui/tabs/__init__.py Normal file
View File

@@ -0,0 +1,17 @@
"""Model tabs for AudioCraft Studio."""
from src.ui.tabs.dashboard_tab import create_dashboard_tab
from src.ui.tabs.musicgen_tab import create_musicgen_tab
from src.ui.tabs.audiogen_tab import create_audiogen_tab
from src.ui.tabs.magnet_tab import create_magnet_tab
from src.ui.tabs.style_tab import create_style_tab
from src.ui.tabs.jasco_tab import create_jasco_tab
__all__ = [
"create_dashboard_tab",
"create_musicgen_tab",
"create_audiogen_tab",
"create_magnet_tab",
"create_style_tab",
"create_jasco_tab",
]

283
src/ui/tabs/audiogen_tab.py Normal file
View File

@@ -0,0 +1,283 @@
"""AudioGen tab for text-to-sound generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
AUDIOGEN_VARIANTS = [
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, balanced quality/speed"},
]
def create_audiogen_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create AudioGen generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("audiogen", [])
suggestions = PROMPT_SUGGESTIONS.get("audiogen", [])
with gr.Column():
gr.Markdown("## 🔊 AudioGen")
gr.Markdown("Generate sound effects and environmental audio from text")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant (AudioGen only has medium)
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in AUDIOGEN_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the sound you want to generate...",
lines=3,
max_lines=5,
)
# Prompt suggestions
with gr.Accordion("Prompt Suggestions", open=False):
suggestion_btns = []
for i, suggestion in enumerate(suggestions[:6]):
btn = gr.Button(suggestion[:50] + "...", size="sm", variant="secondary")
suggestion_btns.append((btn, suggestion))
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
label="Duration (seconds)",
info="AudioGen works best with shorter clips",
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🔊 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 5),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Prompt suggestions
for btn, suggestion in suggestion_btns:
btn.click(
fn=lambda s=suggestion: s,
outputs=[prompt_input],
)
# Generate
async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed
):
if not prompt:
return (
gr.update(value="Please enter a prompt"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
result, generation = await generate_fn(
model_id="audiogen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed):
if not prompt:
return "Please enter a prompt"
job = add_to_queue_fn(
model_id="audiogen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}

View File

@@ -0,0 +1,166 @@
"""Dashboard tab - home page with model overview and quick actions."""
import gradio as gr
from typing import Any, Callable, Optional
MODEL_INFO = {
"musicgen": {
"name": "MusicGen",
"icon": "🎵",
"description": "Text-to-music generation with optional melody conditioning",
"capabilities": ["Text prompts", "Melody conditioning", "Stereo output"],
},
"audiogen": {
"name": "AudioGen",
"icon": "🔊",
"description": "Text-to-sound effects and environmental audio",
"capabilities": ["Sound effects", "Ambiences", "Foley"],
},
"magnet": {
"name": "MAGNeT",
"icon": "",
"description": "Fast non-autoregressive music generation",
"capabilities": ["Fast generation", "Music", "Sound effects"],
},
"musicgen-style": {
"name": "MusicGen Style",
"icon": "🎨",
"description": "Style-conditioned music from reference audio",
"capabilities": ["Style transfer", "Reference audio", "Text prompts"],
},
"jasco": {
"name": "JASCO",
"icon": "🎹",
"description": "Chord and drum-conditioned music generation",
"capabilities": ["Chord control", "Drum patterns", "Symbolic conditioning"],
},
}
def create_dashboard_tab(
get_queue_status: Callable[[], dict[str, Any]],
get_recent_generations: Callable[[int], list[dict[str, Any]]],
get_gpu_status: Callable[[], dict[str, Any]],
) -> dict[str, Any]:
"""Create dashboard tab with model overview and status.
Args:
get_queue_status: Function to get generation queue status
get_recent_generations: Function to get recent generations
get_gpu_status: Function to get GPU status
Returns:
Dictionary with component references
"""
def refresh_dashboard():
"""Refresh all dashboard data."""
queue = get_queue_status()
recent = get_recent_generations(5)
gpu = get_gpu_status()
# Format queue status
queue_size = queue.get("queue_size", 0)
queue_text = f"**Queue:** {queue_size} job(s) pending"
# Format recent generations
if recent:
recent_items = []
for gen in recent[:5]:
model = gen.get("model", "unknown")
prompt = gen.get("prompt", "")[:50]
duration = gen.get("duration_seconds", 0)
recent_items.append(f"• **{model}** ({duration:.0f}s): {prompt}...")
recent_text = "\n".join(recent_items)
else:
recent_text = "No recent generations"
# Format GPU status
used_gb = gpu.get("used_gb", 0)
total_gb = gpu.get("total_gb", 24)
util = gpu.get("utilization_percent", 0)
gpu_text = f"**GPU:** {used_gb:.1f}/{total_gb:.1f} GB ({util:.0f}%)"
return queue_text, recent_text, gpu_text
with gr.Column():
# Header
gr.Markdown("# AudioCraft Studio")
gr.Markdown("AI-powered music and sound generation")
# Status bar
with gr.Row():
queue_status = gr.Markdown("**Queue:** Loading...")
gpu_status = gr.Markdown("**GPU:** Loading...")
refresh_btn = gr.Button("🔄 Refresh", size="sm")
gr.Markdown("---")
# Model cards
gr.Markdown("## Models")
with gr.Row():
# First row of cards
for model_id in ["musicgen", "audiogen", "magnet"]:
info = MODEL_INFO[model_id]
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"### {info['icon']} {info['name']}")
gr.Markdown(info["description"])
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
gr.Button(
f"Open {info['name']}",
variant="primary",
size="sm",
elem_id=f"btn_{model_id}",
)
with gr.Row():
# Second row of cards
for model_id in ["musicgen-style", "jasco"]:
info = MODEL_INFO[model_id]
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(f"### {info['icon']} {info['name']}")
gr.Markdown(info["description"])
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
gr.Button(
f"Open {info['name']}",
variant="primary",
size="sm",
elem_id=f"btn_{model_id}",
)
# Empty column for balance
with gr.Column(scale=1):
pass
gr.Markdown("---")
# Recent generations and queue
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Recent Generations")
recent_list = gr.Markdown("Loading...")
with gr.Column(scale=1):
gr.Markdown("## Quick Actions")
with gr.Group():
gr.Button("📁 Browse Projects", variant="secondary")
gr.Button("⚙️ Settings", variant="secondary")
gr.Button("📖 API Documentation", variant="secondary")
# Refresh handler
refresh_btn.click(
fn=refresh_dashboard,
outputs=[queue_status, recent_list, gpu_status],
)
return {
"queue_status": queue_status,
"gpu_status": gpu_status,
"recent_list": recent_list,
"refresh_btn": refresh_btn,
"refresh_fn": refresh_dashboard,
}

364
src/ui/tabs/jasco_tab.py Normal file
View File

@@ -0,0 +1,364 @@
"""JASCO tab for chord and drum-conditioned generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
JASCO_VARIANTS = [
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation"},
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning"},
]
# Common chord progressions
CHORD_PRESETS = [
{"name": "Pop I-V-vi-IV", "chords": "C G Am F"},
{"name": "Jazz ii-V-I", "chords": "Dm7 G7 Cmaj7"},
{"name": "Blues I-IV-V", "chords": "A7 D7 E7"},
{"name": "Rock I-bVII-IV", "chords": "E D A"},
{"name": "Minor i-VI-III-VII", "chords": "Am F C G"},
]
def create_jasco_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create JASCO generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("jasco", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
with gr.Column():
gr.Markdown("## 🎹 JASCO")
gr.Markdown("Generate music conditioned on chords and drum patterns")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in JASCO_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="chords-drums",
)
# Prompt input
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Describe the music style, mood, instruments...",
lines=2,
max_lines=4,
)
# Chord conditioning
gr.Markdown("### Chord Progression")
chord_input = gr.Textbox(
label="Chords",
placeholder="C G Am F or Cmaj7 Dm7 G7 Cmaj7",
lines=1,
info="Space-separated chord symbols",
)
# Chord presets
with gr.Accordion("Chord Presets", open=False):
chord_preset_btns = []
with gr.Row():
for cp in CHORD_PRESETS[:3]:
btn = gr.Button(cp["name"], size="sm", variant="secondary")
chord_preset_btns.append((btn, cp["chords"]))
with gr.Row():
for cp in CHORD_PRESETS[3:]:
btn = gr.Button(cp["name"], size="sm", variant="secondary")
chord_preset_btns.append((btn, cp["chords"]))
# Drum conditioning (for chords-drums variant)
with gr.Group(visible=True) as drum_section:
gr.Markdown("### Drum Pattern")
drum_input = gr.Audio(
label="Drum Reference",
type="filepath",
sources=["upload"],
)
gr.Markdown("*Upload a drum loop to condition the rhythm*")
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
bpm_slider = gr.Slider(
minimum=60,
maximum=180,
value=120,
step=1,
label="BPM",
info="Tempo for chord timing",
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🎹 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("bpm", 120),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, bpm_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Variant change - show/hide drum section
def on_variant_change(variant: str):
show_drums = "drums" in variant.lower()
return gr.update(visible=show_drums)
variant_dropdown.change(
fn=on_variant_change,
inputs=[variant_dropdown],
outputs=[drum_section],
)
# Chord presets
for btn, chords in chord_preset_btns:
btn.click(
fn=lambda c=chords: c,
outputs=[chord_input],
)
# Generate
async def do_generate(
prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed
):
if not chords:
return (
gr.update(value="Please enter a chord progression"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
conditioning = {
"chords": chords,
"bpm": bpm,
}
if drums and "drums" in variant.lower():
conditioning["drums"] = drums
result, generation = await generate_fn(
model_id="jasco",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
chord_input,
drum_input,
duration_slider,
bpm_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed):
if not chords:
return "Please enter a chord progression"
conditioning = {
"chords": chords,
"bpm": bpm,
}
if drums and "drums" in variant.lower():
conditioning["drums"] = drums
job = add_to_queue_fn(
model_id="jasco",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
chord_input,
drum_input,
duration_slider,
bpm_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"chords": chord_input,
"drums": drum_input,
"duration": duration_slider,
"bpm": bpm_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}

316
src/ui/tabs/magnet_tab.py Normal file
View File

@@ -0,0 +1,316 @@
"""MAGNeT tab for fast non-autoregressive generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
MAGNET_VARIANTS = [
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params"},
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params"},
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects"},
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects"},
]
def create_magnet_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create MAGNeT generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("magnet", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) # Reuse music suggestions
with gr.Column():
gr.Markdown("## ⚡ MAGNeT")
gr.Markdown("Fast non-autoregressive music and sound generation")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MAGNET_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the music or sound you want to generate...",
lines=3,
max_lines=5,
)
# Prompt suggestions
with gr.Accordion("Prompt Suggestions", open=False):
suggestion_btns = []
for i, suggestion in enumerate(suggestions[:4]):
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
suggestion_btns.append((btn, suggestion))
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
with gr.Accordion("Advanced Parameters", open=False):
gr.Markdown("*MAGNeT uses different sampling compared to MusicGen*")
with gr.Row():
temperature_slider = gr.Slider(
minimum=1.0,
maximum=5.0,
value=3.0,
step=0.1,
label="Temperature",
info="Higher values recommended (3.0 default)",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=0,
step=10,
label="Top-K",
info="0 recommended for MAGNeT",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P",
info="0.9 recommended for MAGNeT",
)
with gr.Row():
decoding_steps_slider = gr.Slider(
minimum=10,
maximum=100,
value=20,
step=5,
label="Decoding Steps",
info="More steps = better quality, slower",
)
span_arrangement = gr.Dropdown(
label="Span Arrangement",
choices=[("No Overlap", "nonoverlap"), ("Overlap", "stride1")],
value="nonoverlap",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("⚡ Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("temperature", 3.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 0),
params.get("top_p", 0.9),
params.get("decoding_steps", 20),
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, decoding_steps_slider],
)
# Prompt suggestions
for btn, suggestion in suggestion_btns:
btn.click(
fn=lambda s=suggestion: s,
outputs=[prompt_input],
)
# Generate
async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed
):
if not prompt:
return (
gr.update(value="Please enter a prompt"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
result, generation = await generate_fn(
model_id="magnet",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
decoding_steps=int(decoding_steps),
span_arrangement=span_arr,
seed=int(seed) if seed else None,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
decoding_steps_slider,
span_arrangement,
seed_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed):
if not prompt:
return "Please enter a prompt"
job = add_to_queue_fn(
model_id="magnet",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
decoding_steps=int(decoding_steps),
span_arrangement=span_arr,
seed=int(seed) if seed else None,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
decoding_steps_slider,
span_arrangement,
seed_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"decoding_steps": decoding_steps_slider,
"span_arrangement": span_arrangement,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}

325
src/ui/tabs/musicgen_tab.py Normal file
View File

@@ -0,0 +1,325 @@
"""MusicGen tab for text-to-music generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
MUSICGEN_VARIANTS = [
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"},
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"},
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"},
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"},
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"},
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"},
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"},
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"},
]
def create_musicgen_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create MusicGen generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("musicgen", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
with gr.Column():
gr.Markdown("## 🎵 MusicGen")
gr.Markdown("Generate music from text descriptions")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the music you want to generate...",
lines=3,
max_lines=5,
)
# Prompt suggestions
with gr.Accordion("Prompt Suggestions", open=False):
suggestion_btns = []
for i, suggestion in enumerate(suggestions[:4]):
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
suggestion_btns.append((btn, suggestion))
# Melody conditioning (for melody variants)
with gr.Group(visible=False) as melody_section:
gr.Markdown("### Melody Conditioning")
melody_input = gr.Audio(
label="Reference Melody",
type="filepath",
sources=["upload", "microphone"],
)
gr.Markdown("*Upload audio to condition generation on its melody*")
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
# Custom preset - don't change values
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Variant change - show/hide melody section
def on_variant_change(variant: str):
show_melody = "melody" in variant.lower()
return gr.update(visible=show_melody)
variant_dropdown.change(
fn=on_variant_change,
inputs=[variant_dropdown],
outputs=[melody_section],
)
# Prompt suggestions
for btn, suggestion in suggestion_btns:
btn.click(
fn=lambda s=suggestion: s,
outputs=[prompt_input],
)
# Generate
async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
):
if not prompt:
return (
gr.update(value="Please enter a prompt"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
# Update status
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
conditioning = {}
if melody:
conditioning["melody"] = melody
result, generation = await generate_fn(
model_id="musicgen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
melody_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
if not prompt:
return "Please enter a prompt"
conditioning = {}
if melody:
conditioning["melody"] = melody
job = add_to_queue_fn(
model_id="musicgen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
melody_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"melody": melody_input,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}

292
src/ui/tabs/style_tab.py Normal file
View File

@@ -0,0 +1,292 @@
"""MusicGen Style tab for style-conditioned generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
STYLE_VARIANTS = [
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
]
def create_style_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create MusicGen Style generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("musicgen-style", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
with gr.Column():
gr.Markdown("## 🎨 MusicGen Style")
gr.Markdown("Generate music conditioned on the style of reference audio")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Describe additional characteristics for the music...",
lines=3,
max_lines=5,
info="Optional: combine with style conditioning",
)
# Style conditioning (required)
gr.Markdown("### Style Conditioning")
gr.Markdown("*Upload reference audio to extract musical style*")
style_input = gr.Audio(
label="Style Reference",
type="filepath",
sources=["upload", "microphone"],
)
style_info = gr.Markdown(
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
)
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Generate
async def do_generate(
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
):
if not style_audio:
return (
gr.update(value="Please upload a style reference audio"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
conditioning = {"style": style_audio}
result, generation = await generate_fn(
model_id="musicgen-style",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
style_input,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
if not style_audio:
return "Please upload a style reference audio"
conditioning = {"style": style_audio}
job = add_to_queue_fn(
model_id="musicgen-style",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
style_input,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"style": style_input,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}

303
src/ui/theme.py Normal file
View File

@@ -0,0 +1,303 @@
"""Custom Gradio theme for AudioCraft Studio."""
import gradio as gr
def create_theme() -> gr.themes.Base:
"""Create custom theme for AudioCraft Studio.
Returns:
Gradio theme instance
"""
return gr.themes.Soft(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.slate,
neutral_hue=gr.themes.colors.gray,
font=[
gr.themes.GoogleFont("Inter"),
"ui-sans-serif",
"system-ui",
"sans-serif",
],
font_mono=[
gr.themes.GoogleFont("JetBrains Mono"),
"ui-monospace",
"monospace",
],
).set(
# Colors
body_background_fill="#0f172a",
body_background_fill_dark="#0f172a",
background_fill_primary="#1e293b",
background_fill_primary_dark="#1e293b",
background_fill_secondary="#334155",
background_fill_secondary_dark="#334155",
border_color_primary="#475569",
border_color_primary_dark="#475569",
# Text
body_text_color="#e2e8f0",
body_text_color_dark="#e2e8f0",
body_text_color_subdued="#94a3b8",
body_text_color_subdued_dark="#94a3b8",
# Buttons
button_primary_background_fill="#3b82f6",
button_primary_background_fill_dark="#3b82f6",
button_primary_background_fill_hover="#2563eb",
button_primary_background_fill_hover_dark="#2563eb",
button_primary_text_color="#ffffff",
button_primary_text_color_dark="#ffffff",
button_secondary_background_fill="#475569",
button_secondary_background_fill_dark="#475569",
button_secondary_background_fill_hover="#64748b",
button_secondary_background_fill_hover_dark="#64748b",
# Inputs
input_background_fill="#1e293b",
input_background_fill_dark="#1e293b",
input_border_color="#475569",
input_border_color_dark="#475569",
input_border_color_focus="#3b82f6",
input_border_color_focus_dark="#3b82f6",
# Blocks
block_background_fill="#1e293b",
block_background_fill_dark="#1e293b",
block_border_color="#334155",
block_border_color_dark="#334155",
block_label_background_fill="#334155",
block_label_background_fill_dark="#334155",
block_label_text_color="#e2e8f0",
block_label_text_color_dark="#e2e8f0",
block_title_text_color="#f1f5f9",
block_title_text_color_dark="#f1f5f9",
# Tabs
tab_nav_background_fill="#1e293b",
# Sliders
slider_color="#3b82f6",
slider_color_dark="#3b82f6",
# Shadows
shadow_spread="4px",
block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3)",
# Spacing
layout_gap="16px",
block_padding="16px",
panel_border_width="1px",
# Radius
radius_sm="6px",
radius_md="8px",
radius_lg="12px",
)
# CSS overrides for additional customization
CUSTOM_CSS = """
/* Global styles */
.gradio-container {
max-width: 100% !important;
}
/* Header styling */
.header-title {
font-size: 1.5rem;
font-weight: 700;
color: #f1f5f9;
}
/* Sidebar styling */
.sidebar {
background: #1e293b;
border-right: 1px solid #334155;
padding: 1rem;
}
.sidebar-nav-btn {
width: 100%;
justify-content: flex-start;
margin-bottom: 0.5rem;
}
/* Model cards */
.model-card {
background: #334155;
border-radius: 12px;
padding: 1rem;
transition: transform 0.2s, box-shadow 0.2s;
}
.model-card:hover {
transform: translateY(-2px);
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3);
}
/* Audio player */
.audio-player {
background: #1e293b;
border-radius: 8px;
padding: 1rem;
}
/* Progress bar */
.progress-bar {
background: #334155;
border-radius: 4px;
overflow: hidden;
}
.progress-fill {
background: linear-gradient(90deg, #3b82f6, #8b5cf6);
height: 100%;
transition: width 0.3s ease;
}
/* VRAM monitor */
.vram-bar {
background: #334155;
border-radius: 4px;
height: 24px;
position: relative;
overflow: hidden;
}
.vram-fill {
position: absolute;
left: 0;
top: 0;
height: 100%;
background: linear-gradient(90deg, #22c55e, #eab308, #ef4444);
transition: width 0.5s ease;
}
.vram-text {
position: absolute;
width: 100%;
text-align: center;
line-height: 24px;
font-size: 0.875rem;
font-weight: 500;
color: white;
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.5);
}
/* Queue badge */
.queue-badge {
background: #3b82f6;
color: white;
padding: 0.25rem 0.75rem;
border-radius: 9999px;
font-size: 0.875rem;
font-weight: 500;
}
/* Generation card */
.generation-card {
background: #334155;
border-radius: 8px;
padding: 1rem;
margin-bottom: 0.5rem;
}
/* Preset chips */
.preset-chip {
display: inline-block;
background: #475569;
color: #e2e8f0;
padding: 0.25rem 0.75rem;
border-radius: 9999px;
font-size: 0.875rem;
margin: 0.25rem;
cursor: pointer;
transition: background 0.2s;
}
.preset-chip:hover {
background: #3b82f6;
}
.preset-chip.active {
background: #3b82f6;
}
/* Tag input */
.tag {
display: inline-flex;
align-items: center;
background: #475569;
color: #e2e8f0;
padding: 0.25rem 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
margin: 0.125rem;
}
/* Accordion tweaks */
.accordion-header {
font-weight: 600;
color: #f1f5f9;
}
/* Status indicators */
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
display: inline-block;
margin-right: 0.5rem;
}
.status-dot.loaded {
background: #22c55e;
}
.status-dot.unloaded {
background: #64748b;
}
.status-dot.loading {
background: #eab308;
animation: pulse 1s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.5; }
}
/* Tooltip */
.tooltip {
position: relative;
}
.tooltip:hover::after {
content: attr(data-tooltip);
position: absolute;
bottom: 100%;
left: 50%;
transform: translateX(-50%);
background: #1e293b;
color: #e2e8f0;
padding: 0.5rem;
border-radius: 4px;
font-size: 0.75rem;
white-space: nowrap;
z-index: 100;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.sidebar {
display: none;
}
.mobile-nav {
display: flex !important;
}
}
"""