Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
42
.env.example
Normal file
42
.env.example
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# AudioCraft Studio Configuration
|
||||||
|
# Copy this file to .env and customize as needed
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
AUDIOCRAFT_HOST=0.0.0.0
|
||||||
|
AUDIOCRAFT_GRADIO_PORT=7860
|
||||||
|
AUDIOCRAFT_API_PORT=8000
|
||||||
|
|
||||||
|
# Paths (relative to project root)
|
||||||
|
AUDIOCRAFT_DATA_DIR=./data
|
||||||
|
AUDIOCRAFT_OUTPUT_DIR=./outputs
|
||||||
|
AUDIOCRAFT_CACHE_DIR=./cache
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
# Reserve this much VRAM for ComfyUI (GB)
|
||||||
|
AUDIOCRAFT_COMFYUI_RESERVE_GB=10
|
||||||
|
# Safety buffer to prevent OOM (GB)
|
||||||
|
AUDIOCRAFT_SAFETY_BUFFER_GB=1
|
||||||
|
# Unload idle models after this many minutes
|
||||||
|
AUDIOCRAFT_IDLE_UNLOAD_MINUTES=15
|
||||||
|
# Maximum number of models to keep loaded
|
||||||
|
AUDIOCRAFT_MAX_CACHED_MODELS=2
|
||||||
|
|
||||||
|
# API Authentication
|
||||||
|
# Generate a secure random key for production
|
||||||
|
AUDIOCRAFT_API_KEY=your-secret-api-key-here
|
||||||
|
|
||||||
|
# Generation Defaults
|
||||||
|
AUDIOCRAFT_DEFAULT_DURATION=10.0
|
||||||
|
AUDIOCRAFT_MAX_DURATION=300.0
|
||||||
|
AUDIOCRAFT_DEFAULT_BATCH_SIZE=1
|
||||||
|
AUDIOCRAFT_MAX_BATCH_SIZE=8
|
||||||
|
AUDIOCRAFT_MAX_QUEUE_SIZE=100
|
||||||
|
|
||||||
|
# Database
|
||||||
|
AUDIOCRAFT_DATABASE_URL=sqlite+aiosqlite:///./data/audiocraft.db
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
AUDIOCRAFT_LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# PyTorch Optimization (recommended)
|
||||||
|
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
76
.gitignore
vendored
Normal file
76
.gitignore
vendored
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
data/
|
||||||
|
outputs/
|
||||||
|
cache/
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Model weights (downloaded from HuggingFace)
|
||||||
|
*.bin
|
||||||
|
*.safetensors
|
||||||
|
*.pt
|
||||||
|
*.pth
|
||||||
|
|
||||||
|
# Audio files (generated)
|
||||||
|
*.wav
|
||||||
|
*.mp3
|
||||||
|
*.flac
|
||||||
|
*.ogg
|
||||||
|
|
||||||
|
# Temp files
|
||||||
|
/tmp/
|
||||||
|
*.tmp
|
||||||
83
Dockerfile
Normal file
83
Dockerfile
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# AudioCraft Studio Dockerfile for RunPod
|
||||||
|
# Optimized for NVIDIA RTX 4090 (24GB VRAM)
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.1-cudnn8-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PIP_NO_CACHE_DIR=1
|
||||||
|
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||||
|
|
||||||
|
# CUDA settings
|
||||||
|
ENV CUDA_HOME=/usr/local/cuda
|
||||||
|
ENV PATH="${CUDA_HOME}/bin:${PATH}"
|
||||||
|
ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}"
|
||||||
|
|
||||||
|
# AudioCraft settings
|
||||||
|
ENV AUDIOCRAFT_OUTPUT_DIR=/workspace/outputs
|
||||||
|
ENV AUDIOCRAFT_DATA_DIR=/workspace/data
|
||||||
|
ENV AUDIOCRAFT_MODEL_CACHE=/workspace/models
|
||||||
|
ENV AUDIOCRAFT_HOST=0.0.0.0
|
||||||
|
ENV AUDIOCRAFT_GRADIO_PORT=7860
|
||||||
|
ENV AUDIOCRAFT_API_PORT=8000
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
git \
|
||||||
|
curl \
|
||||||
|
wget \
|
||||||
|
ffmpeg \
|
||||||
|
libsndfile1 \
|
||||||
|
libsox-dev \
|
||||||
|
sox \
|
||||||
|
build-essential \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-venv \
|
||||||
|
python3.10-dev \
|
||||||
|
python3-pip \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set Python 3.10 as default
|
||||||
|
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
|
||||||
|
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||||
|
|
||||||
|
# Upgrade pip
|
||||||
|
RUN pip install --upgrade pip setuptools wheel
|
||||||
|
|
||||||
|
# Create workspace directory
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# Create necessary directories
|
||||||
|
RUN mkdir -p /workspace/outputs /workspace/data /workspace/models /workspace/app
|
||||||
|
|
||||||
|
# Copy requirements first for caching
|
||||||
|
COPY requirements.txt /workspace/app/
|
||||||
|
WORKDIR /workspace/app
|
||||||
|
|
||||||
|
# Install PyTorch with CUDA support
|
||||||
|
RUN pip install torch==2.1.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
|
||||||
|
# Install other requirements
|
||||||
|
RUN pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Install AudioCraft from source for latest features
|
||||||
|
RUN pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY . /workspace/app/
|
||||||
|
|
||||||
|
# Create non-root user for security (optional, RunPod often uses root)
|
||||||
|
# RUN useradd -m -u 1000 audiocraft && chown -R audiocraft:audiocraft /workspace
|
||||||
|
# USER audiocraft
|
||||||
|
|
||||||
|
# Expose ports
|
||||||
|
EXPOSE 7860 8000
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:7860/ || exit 1
|
||||||
|
|
||||||
|
# Default command
|
||||||
|
CMD ["python", "main.py"]
|
||||||
197
README.md
Normal file
197
README.md
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
# AudioCraft Studio
|
||||||
|
|
||||||
|
A comprehensive web interface for Meta's AudioCraft AI audio generation models, optimized for RunPod deployment with NVIDIA RTX 4090 GPUs.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Models Supported
|
||||||
|
- **MusicGen** - Text-to-music generation with melody conditioning
|
||||||
|
- **AudioGen** - Text-to-sound effects and environmental audio
|
||||||
|
- **MAGNeT** - Fast non-autoregressive music generation
|
||||||
|
- **MusicGen Style** - Style-conditioned music from reference audio
|
||||||
|
- **JASCO** - Chord and drum-conditioned music generation
|
||||||
|
|
||||||
|
### Core Capabilities
|
||||||
|
- **Gradio Web UI** - Intuitive interface with real-time generation
|
||||||
|
- **REST API** - Full-featured API with OpenAPI documentation
|
||||||
|
- **Batch Processing** - Queue system for multiple generations
|
||||||
|
- **Project Management** - Organize and browse generation history
|
||||||
|
- **VRAM Management** - Smart model loading/unloading, ComfyUI coexistence
|
||||||
|
- **Waveform Visualization** - Visual audio feedback
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Local Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone repository
|
||||||
|
git clone https://github.com/your-username/audiocraft-ui.git
|
||||||
|
cd audiocraft-ui
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # Linux/Mac
|
||||||
|
# or: venv\Scripts\activate # Windows
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Run application
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Access the UI at `http://localhost:7860`
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build and run
|
||||||
|
docker-compose up --build
|
||||||
|
|
||||||
|
# Or build manually
|
||||||
|
docker build -t audiocraft-studio .
|
||||||
|
docker run --gpus all -p 7860:7860 -p 8000:8000 audiocraft-studio
|
||||||
|
```
|
||||||
|
|
||||||
|
### RunPod Deployment
|
||||||
|
|
||||||
|
1. Build and push Docker image:
|
||||||
|
```bash
|
||||||
|
docker build -t your-dockerhub/audiocraft-studio:latest .
|
||||||
|
docker push your-dockerhub/audiocraft-studio:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create RunPod template using `runpod.yaml` as reference
|
||||||
|
|
||||||
|
3. Deploy with RTX 4090 or equivalent GPU
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configuration via environment variables:
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|----------|---------|-------------|
|
||||||
|
| `AUDIOCRAFT_HOST` | `0.0.0.0` | Server bind address |
|
||||||
|
| `AUDIOCRAFT_GRADIO_PORT` | `7860` | Gradio UI port |
|
||||||
|
| `AUDIOCRAFT_API_PORT` | `8000` | REST API port |
|
||||||
|
| `AUDIOCRAFT_OUTPUT_DIR` | `./outputs` | Generated audio output |
|
||||||
|
| `AUDIOCRAFT_DATA_DIR` | `./data` | Database and config |
|
||||||
|
| `AUDIOCRAFT_COMFYUI_RESERVE_GB` | `10` | VRAM reserved for ComfyUI |
|
||||||
|
| `AUDIOCRAFT_MAX_LOADED_MODELS` | `2` | Max models in memory |
|
||||||
|
| `AUDIOCRAFT_IDLE_UNLOAD_MINUTES` | `15` | Auto-unload idle models |
|
||||||
|
|
||||||
|
See `.env.example` for full configuration options.
|
||||||
|
|
||||||
|
## API Usage
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Get API key from Settings page or generate via CLI
|
||||||
|
curl -X POST http://localhost:8000/api/v1/system/api-key/regenerate \
|
||||||
|
-H "X-API-Key: YOUR_CURRENT_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Generate Audio
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Synchronous generation
|
||||||
|
curl -X POST http://localhost:8000/api/v1/generate \
|
||||||
|
-H "X-API-Key: YOUR_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "musicgen",
|
||||||
|
"variant": "medium",
|
||||||
|
"prompts": ["upbeat electronic dance music with synth leads"],
|
||||||
|
"duration": 10
|
||||||
|
}'
|
||||||
|
|
||||||
|
# Async (queue) generation
|
||||||
|
curl -X POST http://localhost:8000/api/v1/generate/async \
|
||||||
|
-H "X-API-Key: YOUR_API_KEY" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"request": {
|
||||||
|
"model": "musicgen",
|
||||||
|
"prompts": ["ambient soundscape"],
|
||||||
|
"duration": 30
|
||||||
|
},
|
||||||
|
"priority": 5
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Job Status
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/api/v1/generate/jobs/{job_id} \
|
||||||
|
-H "X-API-Key: YOUR_API_KEY"
|
||||||
|
```
|
||||||
|
|
||||||
|
Full API documentation available at `http://localhost:8000/api/docs`
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
audiocraft-ui/
|
||||||
|
├── config/
|
||||||
|
│ ├── settings.py # Pydantic settings
|
||||||
|
│ └── models.yaml # Model registry
|
||||||
|
├── src/
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── base_model.py # Abstract model interface
|
||||||
|
│ │ ├── gpu_manager.py # VRAM management
|
||||||
|
│ │ ├── model_registry.py # Model loading/caching
|
||||||
|
│ │ └── oom_handler.py # OOM recovery
|
||||||
|
│ ├── models/
|
||||||
|
│ │ ├── musicgen/ # MusicGen adapter
|
||||||
|
│ │ ├── audiogen/ # AudioGen adapter
|
||||||
|
│ │ ├── magnet/ # MAGNeT adapter
|
||||||
|
│ │ ├── musicgen_style/ # Style adapter
|
||||||
|
│ │ └── jasco/ # JASCO adapter
|
||||||
|
│ ├── services/
|
||||||
|
│ │ ├── generation_service.py
|
||||||
|
│ │ ├── batch_processor.py
|
||||||
|
│ │ └── project_service.py
|
||||||
|
│ ├── storage/
|
||||||
|
│ │ └── database.py # SQLite storage
|
||||||
|
│ ├── api/
|
||||||
|
│ │ ├── app.py # FastAPI app
|
||||||
|
│ │ └── routes/ # API endpoints
|
||||||
|
│ └── ui/
|
||||||
|
│ ├── app.py # Gradio app
|
||||||
|
│ ├── components/ # Reusable UI components
|
||||||
|
│ ├── tabs/ # Model generation tabs
|
||||||
|
│ └── pages/ # Projects, Settings
|
||||||
|
├── main.py # Entry point
|
||||||
|
├── Dockerfile
|
||||||
|
└── docker-compose.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
## ComfyUI Coexistence
|
||||||
|
|
||||||
|
AudioCraft Studio is designed to run alongside ComfyUI on the same GPU:
|
||||||
|
|
||||||
|
1. Set `AUDIOCRAFT_COMFYUI_RESERVE_GB` to reserve VRAM for ComfyUI
|
||||||
|
2. Models are automatically unloaded when idle
|
||||||
|
3. Coordination file at `/tmp/audiocraft_comfyui_coord.json` prevents conflicts
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install dev dependencies
|
||||||
|
pip install -r requirements-dev.txt
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Format code
|
||||||
|
black src/ config/
|
||||||
|
ruff check src/ config/
|
||||||
|
|
||||||
|
# Type checking
|
||||||
|
mypy src/
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This project uses Meta's AudioCraft library. See [AudioCraft License](https://github.com/facebookresearch/audiocraft/blob/main/LICENSE).
|
||||||
5
config/__init__.py
Normal file
5
config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Configuration module for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from config.settings import Settings, get_settings
|
||||||
|
|
||||||
|
__all__ = ["Settings", "get_settings"]
|
||||||
151
config/models.yaml
Normal file
151
config/models.yaml
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# AudioCraft Model Registry Configuration
|
||||||
|
# This file defines all available models and their configurations
|
||||||
|
|
||||||
|
models:
|
||||||
|
musicgen:
|
||||||
|
enabled: true
|
||||||
|
display_name: "MusicGen"
|
||||||
|
description: "Text-to-music generation with optional melody conditioning"
|
||||||
|
default_variant: medium
|
||||||
|
variants:
|
||||||
|
small:
|
||||||
|
hf_id: facebook/musicgen-small
|
||||||
|
vram_mb: 1500
|
||||||
|
max_duration: 30
|
||||||
|
description: "Fast, lightweight model (300M params)"
|
||||||
|
medium:
|
||||||
|
hf_id: facebook/musicgen-medium
|
||||||
|
vram_mb: 5000
|
||||||
|
max_duration: 30
|
||||||
|
description: "Balanced quality and speed (1.5B params)"
|
||||||
|
large:
|
||||||
|
hf_id: facebook/musicgen-large
|
||||||
|
vram_mb: 10000
|
||||||
|
max_duration: 30
|
||||||
|
description: "Highest quality, slower (3.3B params)"
|
||||||
|
melody:
|
||||||
|
hf_id: facebook/musicgen-melody
|
||||||
|
vram_mb: 5000
|
||||||
|
max_duration: 30
|
||||||
|
conditioning:
|
||||||
|
- melody
|
||||||
|
description: "Melody-conditioned generation (1.5B params)"
|
||||||
|
stereo-small:
|
||||||
|
hf_id: facebook/musicgen-stereo-small
|
||||||
|
vram_mb: 1800
|
||||||
|
max_duration: 30
|
||||||
|
channels: 2
|
||||||
|
description: "Stereo output, fast (300M params)"
|
||||||
|
stereo-medium:
|
||||||
|
hf_id: facebook/musicgen-stereo-medium
|
||||||
|
vram_mb: 6000
|
||||||
|
max_duration: 30
|
||||||
|
channels: 2
|
||||||
|
description: "Stereo output, balanced (1.5B params)"
|
||||||
|
stereo-large:
|
||||||
|
hf_id: facebook/musicgen-stereo-large
|
||||||
|
vram_mb: 12000
|
||||||
|
max_duration: 30
|
||||||
|
channels: 2
|
||||||
|
description: "Stereo output, highest quality (3.3B params)"
|
||||||
|
stereo-melody:
|
||||||
|
hf_id: facebook/musicgen-stereo-melody
|
||||||
|
vram_mb: 6000
|
||||||
|
max_duration: 30
|
||||||
|
channels: 2
|
||||||
|
conditioning:
|
||||||
|
- melody
|
||||||
|
description: "Stereo melody-conditioned (1.5B params)"
|
||||||
|
|
||||||
|
audiogen:
|
||||||
|
enabled: true
|
||||||
|
display_name: "AudioGen"
|
||||||
|
description: "Text-to-sound effects generation"
|
||||||
|
default_variant: medium
|
||||||
|
variants:
|
||||||
|
medium:
|
||||||
|
hf_id: facebook/audiogen-medium
|
||||||
|
vram_mb: 5000
|
||||||
|
max_duration: 10
|
||||||
|
description: "Sound effects generator (1.5B params)"
|
||||||
|
|
||||||
|
magnet:
|
||||||
|
enabled: true
|
||||||
|
display_name: "MAGNeT"
|
||||||
|
description: "Fast non-autoregressive music generation"
|
||||||
|
default_variant: medium-10secs
|
||||||
|
variants:
|
||||||
|
small-10secs:
|
||||||
|
hf_id: facebook/magnet-small-10secs
|
||||||
|
vram_mb: 1500
|
||||||
|
max_duration: 10
|
||||||
|
description: "Fast 10-second clips (300M params)"
|
||||||
|
medium-10secs:
|
||||||
|
hf_id: facebook/magnet-medium-10secs
|
||||||
|
vram_mb: 5000
|
||||||
|
max_duration: 10
|
||||||
|
description: "Quality 10-second clips (1.5B params)"
|
||||||
|
small-30secs:
|
||||||
|
hf_id: facebook/magnet-small-30secs
|
||||||
|
vram_mb: 1800
|
||||||
|
max_duration: 30
|
||||||
|
description: "Fast 30-second clips (300M params)"
|
||||||
|
medium-30secs:
|
||||||
|
hf_id: facebook/magnet-medium-30secs
|
||||||
|
vram_mb: 6000
|
||||||
|
max_duration: 30
|
||||||
|
description: "Quality 30-second clips (1.5B params)"
|
||||||
|
|
||||||
|
musicgen-style:
|
||||||
|
enabled: true
|
||||||
|
display_name: "MusicGen Style"
|
||||||
|
description: "Style-conditioned music generation from reference audio"
|
||||||
|
default_variant: medium
|
||||||
|
variants:
|
||||||
|
medium:
|
||||||
|
hf_id: facebook/musicgen-style
|
||||||
|
vram_mb: 5000
|
||||||
|
max_duration: 30
|
||||||
|
conditioning:
|
||||||
|
- style
|
||||||
|
description: "Style transfer from reference audio (1.5B params)"
|
||||||
|
|
||||||
|
jasco:
|
||||||
|
enabled: true
|
||||||
|
display_name: "JASCO"
|
||||||
|
description: "Chord and drum-conditioned music generation"
|
||||||
|
default_variant: chords-drums-400M
|
||||||
|
variants:
|
||||||
|
chords-drums-400M:
|
||||||
|
hf_id: facebook/jasco-chords-drums-400M
|
||||||
|
vram_mb: 2000
|
||||||
|
max_duration: 10
|
||||||
|
conditioning:
|
||||||
|
- chords
|
||||||
|
- drums
|
||||||
|
description: "Chord/drum control, fast (400M params)"
|
||||||
|
chords-drums-1B:
|
||||||
|
hf_id: facebook/jasco-chords-drums-1B
|
||||||
|
vram_mb: 4000
|
||||||
|
max_duration: 10
|
||||||
|
conditioning:
|
||||||
|
- chords
|
||||||
|
- drums
|
||||||
|
description: "Chord/drum control, higher quality (1B params)"
|
||||||
|
|
||||||
|
# Default generation parameters
|
||||||
|
defaults:
|
||||||
|
generation:
|
||||||
|
duration: 10
|
||||||
|
temperature: 1.0
|
||||||
|
top_k: 250
|
||||||
|
top_p: 0.0
|
||||||
|
cfg_coef: 3.0
|
||||||
|
|
||||||
|
# VRAM thresholds for warnings
|
||||||
|
vram:
|
||||||
|
warning_threshold: 0.85 # 85% utilization warning
|
||||||
|
critical_threshold: 0.95 # 95% utilization critical
|
||||||
|
|
||||||
|
# Presets are loaded from data/presets/*.yaml
|
||||||
|
presets_dir: "./data/presets"
|
||||||
94
config/settings.py
Normal file
94
config/settings.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Application settings with environment variable support."""
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application configuration with environment variable support.
|
||||||
|
|
||||||
|
All settings can be overridden via environment variables prefixed with AUDIOCRAFT_.
|
||||||
|
Example: AUDIOCRAFT_API_PORT=8080
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_prefix="AUDIOCRAFT_",
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Server Configuration
|
||||||
|
host: str = Field(default="0.0.0.0", description="Server bind host")
|
||||||
|
gradio_port: int = Field(default=7860, description="Gradio UI port")
|
||||||
|
api_port: int = Field(default=8000, description="FastAPI port")
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
data_dir: Path = Field(default=Path("./data"), description="Data directory")
|
||||||
|
output_dir: Path = Field(default=Path("./outputs"), description="Generated audio output")
|
||||||
|
cache_dir: Path = Field(default=Path("./cache"), description="Model cache directory")
|
||||||
|
models_config: Path = Field(
|
||||||
|
default=Path("./config/models.yaml"), description="Model registry config"
|
||||||
|
)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
comfyui_reserve_gb: float = Field(
|
||||||
|
default=10.0, description="VRAM reserved for ComfyUI (GB)"
|
||||||
|
)
|
||||||
|
safety_buffer_gb: float = Field(
|
||||||
|
default=1.0, description="Safety buffer to prevent OOM (GB)"
|
||||||
|
)
|
||||||
|
idle_unload_minutes: int = Field(
|
||||||
|
default=15, description="Unload models after idle time (minutes)"
|
||||||
|
)
|
||||||
|
max_cached_models: int = Field(
|
||||||
|
default=2, description="Maximum number of models to keep loaded"
|
||||||
|
)
|
||||||
|
|
||||||
|
# API Authentication
|
||||||
|
api_key: Optional[str] = Field(default=None, description="API key for authentication")
|
||||||
|
cors_origins: list[str] = Field(
|
||||||
|
default=["*"], description="Allowed CORS origins"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generation Defaults
|
||||||
|
default_duration: float = Field(default=10.0, description="Default generation duration")
|
||||||
|
max_duration: float = Field(default=300.0, description="Maximum generation duration")
|
||||||
|
default_batch_size: int = Field(default=1, description="Default batch size")
|
||||||
|
max_batch_size: int = Field(default=8, description="Maximum batch size")
|
||||||
|
max_queue_size: int = Field(default=100, description="Maximum generation queue size")
|
||||||
|
|
||||||
|
# Database
|
||||||
|
database_url: str = Field(
|
||||||
|
default="sqlite+aiosqlite:///./data/audiocraft.db",
|
||||||
|
description="Database connection URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_level: str = Field(default="INFO", description="Logging level")
|
||||||
|
|
||||||
|
def ensure_directories(self) -> None:
|
||||||
|
"""Create required directories if they don't exist."""
|
||||||
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(self.data_dir / "presets").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def database_path(self) -> Path:
|
||||||
|
"""Extract database file path from URL."""
|
||||||
|
if self.database_url.startswith("sqlite"):
|
||||||
|
# Handle both sqlite:/// and sqlite+aiosqlite:///
|
||||||
|
path = self.database_url.split("///")[-1]
|
||||||
|
return Path(path)
|
||||||
|
raise ValueError("Only SQLite databases are supported")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get cached settings instance."""
|
||||||
|
return Settings()
|
||||||
64
docker-compose.yml
Normal file
64
docker-compose.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Docker Compose for local development and testing
|
||||||
|
# For RunPod deployment, use the Dockerfile directly
|
||||||
|
|
||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
audiocraft:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: audiocraft-studio
|
||||||
|
ports:
|
||||||
|
- "7860:7860" # Gradio UI
|
||||||
|
- "8000:8000" # REST API
|
||||||
|
volumes:
|
||||||
|
# Persistent storage
|
||||||
|
- audiocraft-outputs:/workspace/outputs
|
||||||
|
- audiocraft-data:/workspace/data
|
||||||
|
- audiocraft-models:/workspace/models
|
||||||
|
# Development: mount source code
|
||||||
|
- ./src:/workspace/app/src:ro
|
||||||
|
- ./config:/workspace/app/config:ro
|
||||||
|
environment:
|
||||||
|
- AUDIOCRAFT_HOST=0.0.0.0
|
||||||
|
- AUDIOCRAFT_GRADIO_PORT=7860
|
||||||
|
- AUDIOCRAFT_API_PORT=8000
|
||||||
|
- AUDIOCRAFT_DEBUG=false
|
||||||
|
- AUDIOCRAFT_COMFYUI_RESERVE_GB=0 # No ComfyUI in this compose
|
||||||
|
- NVIDIA_VISIBLE_DEVICES=all
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:7860/"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 60s
|
||||||
|
|
||||||
|
# Optional: Run alongside ComfyUI
|
||||||
|
# comfyui:
|
||||||
|
# image: your-comfyui-image
|
||||||
|
# container_name: comfyui
|
||||||
|
# ports:
|
||||||
|
# - "8188:8188"
|
||||||
|
# volumes:
|
||||||
|
# - comfyui-data:/workspace
|
||||||
|
# deploy:
|
||||||
|
# resources:
|
||||||
|
# reservations:
|
||||||
|
# devices:
|
||||||
|
# - driver: nvidia
|
||||||
|
# count: 1
|
||||||
|
# capabilities: [gpu]
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
audiocraft-outputs:
|
||||||
|
audiocraft-data:
|
||||||
|
audiocraft-models:
|
||||||
147
main.py
Normal file
147
main.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Main entry point for AudioCraft Studio."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
|
from src.core.model_registry import ModelRegistry
|
||||||
|
from src.services.generation_service import GenerationService
|
||||||
|
from src.services.batch_processor import BatchProcessor
|
||||||
|
from src.services.project_service import ProjectService
|
||||||
|
from src.storage.database import Database
|
||||||
|
from src.ui.app import create_app
|
||||||
|
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler("audiocraft.log"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_services():
|
||||||
|
"""Initialize all application services."""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
logger.info("Initializing database...")
|
||||||
|
db = Database(settings.database_path)
|
||||||
|
await db.initialize()
|
||||||
|
|
||||||
|
# Initialize GPU manager
|
||||||
|
logger.info("Initializing GPU manager...")
|
||||||
|
gpu_manager = GPUMemoryManager(
|
||||||
|
device_id=0,
|
||||||
|
comfyui_reserve_bytes=int(settings.comfyui_reserve_gb * 1024**3),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model registry
|
||||||
|
logger.info("Initializing model registry...")
|
||||||
|
model_registry = ModelRegistry(
|
||||||
|
gpu_manager=gpu_manager,
|
||||||
|
max_loaded=settings.max_loaded_models,
|
||||||
|
idle_timeout_seconds=settings.idle_unload_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
logger.info("Initializing services...")
|
||||||
|
generation_service = GenerationService(
|
||||||
|
model_registry=model_registry,
|
||||||
|
gpu_manager=gpu_manager,
|
||||||
|
output_dir=settings.output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_processor = BatchProcessor(
|
||||||
|
generation_service=generation_service,
|
||||||
|
max_queue_size=settings.max_queue_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
project_service = ProjectService(
|
||||||
|
db=db,
|
||||||
|
output_dir=settings.output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"db": db,
|
||||||
|
"gpu_manager": gpu_manager,
|
||||||
|
"model_registry": model_registry,
|
||||||
|
"generation_service": generation_service,
|
||||||
|
"batch_processor": batch_processor,
|
||||||
|
"project_service": project_service,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("AudioCraft Studio")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"Host: {settings.host}")
|
||||||
|
logger.info(f"Gradio Port: {settings.gradio_port}")
|
||||||
|
logger.info(f"API Port: {settings.api_port}")
|
||||||
|
logger.info(f"Output Dir: {settings.output_dir}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Initialize services
|
||||||
|
logger.info("Initializing services...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
services = asyncio.run(initialize_services())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize services: {e}")
|
||||||
|
logger.warning("Starting in demo mode without backend services")
|
||||||
|
services = {}
|
||||||
|
|
||||||
|
# Create and launch app
|
||||||
|
logger.info("Creating Gradio application...")
|
||||||
|
app = create_app(
|
||||||
|
generation_service=services.get("generation_service"),
|
||||||
|
batch_processor=services.get("batch_processor"),
|
||||||
|
project_service=services.get("project_service"),
|
||||||
|
gpu_manager=services.get("gpu_manager"),
|
||||||
|
model_registry=services.get("model_registry"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start batch processor if available
|
||||||
|
batch_processor = services.get("batch_processor")
|
||||||
|
if batch_processor:
|
||||||
|
logger.info("Starting batch processor...")
|
||||||
|
asyncio.run(batch_processor.start())
|
||||||
|
|
||||||
|
# Launch the app
|
||||||
|
logger.info("Launching application...")
|
||||||
|
try:
|
||||||
|
app.launch(
|
||||||
|
server_name=settings.host,
|
||||||
|
server_port=settings.gradio_port,
|
||||||
|
share=False,
|
||||||
|
show_api=settings.api_enabled,
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
if batch_processor:
|
||||||
|
asyncio.run(batch_processor.stop())
|
||||||
|
if "db" in services:
|
||||||
|
asyncio.run(services["db"].close())
|
||||||
|
|
||||||
|
logger.info("Shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
89
pyproject.toml
Normal file
89
pyproject.toml
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
[project]
|
||||||
|
name = "audiocraft-ui"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Sophisticated AI audio web application based on Facebook's AudioCraft"
|
||||||
|
readme = "README.md"
|
||||||
|
license = { text = "MIT" }
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
authors = [{ name = "AudioCraft UI Team" }]
|
||||||
|
keywords = ["audio", "music", "generation", "ai", "audiocraft", "gradio"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Topic :: Multimedia :: Sound/Audio",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
# Core ML
|
||||||
|
"torch>=2.1.0",
|
||||||
|
"torchaudio>=2.1.0",
|
||||||
|
"audiocraft>=1.3.0",
|
||||||
|
"xformers>=0.0.22",
|
||||||
|
|
||||||
|
# UI
|
||||||
|
"gradio>=4.0.0",
|
||||||
|
|
||||||
|
# API
|
||||||
|
"fastapi>=0.104.0",
|
||||||
|
"uvicorn[standard]>=0.24.0",
|
||||||
|
"python-multipart>=0.0.6",
|
||||||
|
|
||||||
|
# GPU Monitoring
|
||||||
|
"pynvml>=11.5.0",
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
"aiosqlite>=0.19.0",
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
"pydantic>=2.5.0",
|
||||||
|
"pydantic-settings>=2.1.0",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
|
||||||
|
# Audio Processing
|
||||||
|
"numpy>=1.24.0",
|
||||||
|
"scipy>=1.11.0",
|
||||||
|
"librosa>=0.10.0",
|
||||||
|
"soundfile>=0.12.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.4.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
"pytest-cov>=4.1.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"mypy>=1.6.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
audiocraft-ui = "src.main:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I", "N", "W", "UP"]
|
||||||
|
ignore = ["E501"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
testpaths = ["tests"]
|
||||||
30
requirements.txt
Normal file
30
requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Core ML
|
||||||
|
torch>=2.1.0
|
||||||
|
torchaudio>=2.1.0
|
||||||
|
audiocraft>=1.3.0
|
||||||
|
xformers>=0.0.22
|
||||||
|
|
||||||
|
# UI
|
||||||
|
gradio>=4.0.0
|
||||||
|
|
||||||
|
# API
|
||||||
|
fastapi>=0.104.0
|
||||||
|
uvicorn[standard]>=0.24.0
|
||||||
|
python-multipart>=0.0.6
|
||||||
|
|
||||||
|
# GPU Monitoring
|
||||||
|
pynvml>=11.5.0
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
aiosqlite>=0.19.0
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
pydantic>=2.5.0
|
||||||
|
pydantic-settings>=2.1.0
|
||||||
|
pyyaml>=6.0
|
||||||
|
|
||||||
|
# Audio Processing
|
||||||
|
numpy>=1.24.0
|
||||||
|
scipy>=1.11.0
|
||||||
|
librosa>=0.10.0
|
||||||
|
soundfile>=0.12.0
|
||||||
77
runpod.yaml
Normal file
77
runpod.yaml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# RunPod Template Configuration
|
||||||
|
# Use this as reference when creating a RunPod template
|
||||||
|
|
||||||
|
name: AudioCraft Studio
|
||||||
|
description: AI-powered music and sound generation using Meta's AudioCraft
|
||||||
|
|
||||||
|
# Container settings
|
||||||
|
container:
|
||||||
|
image: your-dockerhub-username/audiocraft-studio:latest
|
||||||
|
|
||||||
|
# Or build from GitHub
|
||||||
|
# dockerfile: Dockerfile
|
||||||
|
# context: https://github.com/your-username/audiocraft-ui.git
|
||||||
|
|
||||||
|
# GPU requirements
|
||||||
|
gpu:
|
||||||
|
type: RTX 4090 # Recommended: RTX 4090, RTX 3090, A100
|
||||||
|
count: 1
|
||||||
|
minVram: 24 # GB
|
||||||
|
|
||||||
|
# Resource limits
|
||||||
|
resources:
|
||||||
|
cpu: 8
|
||||||
|
memory: 32 # GB
|
||||||
|
disk: 100 # GB (for model cache and outputs)
|
||||||
|
|
||||||
|
# Port mappings
|
||||||
|
ports:
|
||||||
|
- name: Gradio UI
|
||||||
|
internal: 7860
|
||||||
|
external: 7860
|
||||||
|
protocol: http
|
||||||
|
- name: REST API
|
||||||
|
internal: 8000
|
||||||
|
external: 8000
|
||||||
|
protocol: http
|
||||||
|
|
||||||
|
# Volume mounts
|
||||||
|
volumes:
|
||||||
|
- name: outputs
|
||||||
|
mountPath: /workspace/outputs
|
||||||
|
size: 50 # GB
|
||||||
|
- name: models
|
||||||
|
mountPath: /workspace/models
|
||||||
|
size: 30 # GB (model cache)
|
||||||
|
- name: data
|
||||||
|
mountPath: /workspace/data
|
||||||
|
size: 10 # GB
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
env:
|
||||||
|
- name: AUDIOCRAFT_HOST
|
||||||
|
value: "0.0.0.0"
|
||||||
|
- name: AUDIOCRAFT_GRADIO_PORT
|
||||||
|
value: "7860"
|
||||||
|
- name: AUDIOCRAFT_API_PORT
|
||||||
|
value: "8000"
|
||||||
|
- name: AUDIOCRAFT_COMFYUI_RESERVE_GB
|
||||||
|
value: "10" # Reserve VRAM for ComfyUI if running alongside
|
||||||
|
- name: AUDIOCRAFT_MAX_LOADED_MODELS
|
||||||
|
value: "2"
|
||||||
|
- name: AUDIOCRAFT_IDLE_UNLOAD_MINUTES
|
||||||
|
value: "15"
|
||||||
|
- name: HF_HOME
|
||||||
|
value: "/workspace/models/huggingface"
|
||||||
|
|
||||||
|
# Startup command
|
||||||
|
command: ["python", "main.py"]
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
healthCheck:
|
||||||
|
path: /
|
||||||
|
port: 7860
|
||||||
|
initialDelaySeconds: 120
|
||||||
|
periodSeconds: 30
|
||||||
|
timeoutSeconds: 10
|
||||||
|
failureThreshold: 3
|
||||||
116
scripts/download_models.py
Executable file
116
scripts/download_models.py
Executable file
@@ -0,0 +1,116 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Pre-download AudioCraft models for faster startup."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def download_musicgen_models(variants: list[str] = None):
|
||||||
|
"""Download MusicGen models."""
|
||||||
|
from audiocraft.models import MusicGen
|
||||||
|
|
||||||
|
variants = variants or ["small", "medium", "large", "melody"]
|
||||||
|
|
||||||
|
for variant in variants:
|
||||||
|
print(f"Downloading MusicGen {variant}...")
|
||||||
|
try:
|
||||||
|
model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}")
|
||||||
|
del model
|
||||||
|
print(f" ✓ MusicGen {variant} downloaded")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ Failed to download MusicGen {variant}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_audiogen_models():
|
||||||
|
"""Download AudioGen models."""
|
||||||
|
from audiocraft.models import AudioGen
|
||||||
|
|
||||||
|
print("Downloading AudioGen medium...")
|
||||||
|
try:
|
||||||
|
model = AudioGen.get_pretrained("facebook/audiogen-medium")
|
||||||
|
del model
|
||||||
|
print(" ✓ AudioGen medium downloaded")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ Failed to download AudioGen: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def download_magnet_models(variants: list[str] = None):
|
||||||
|
"""Download MAGNeT models."""
|
||||||
|
from audiocraft.models import MAGNeT
|
||||||
|
|
||||||
|
variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"]
|
||||||
|
|
||||||
|
for variant in variants:
|
||||||
|
print(f"Downloading MAGNeT {variant}...")
|
||||||
|
try:
|
||||||
|
model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}")
|
||||||
|
del model
|
||||||
|
print(f" ✓ MAGNeT {variant} downloaded")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ Failed to download MAGNeT {variant}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Pre-download AudioCraft models")
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
choices=["musicgen", "audiogen", "magnet", "all"],
|
||||||
|
default=["all"],
|
||||||
|
help="Models to download",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--musicgen-variants",
|
||||||
|
nargs="+",
|
||||||
|
default=["small", "medium"],
|
||||||
|
help="MusicGen variants to download",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--magnet-variants",
|
||||||
|
nargs="+",
|
||||||
|
default=["small", "medium"],
|
||||||
|
help="MAGNeT variants to download",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cache-dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model cache directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Set cache directory
|
||||||
|
if args.cache_dir:
|
||||||
|
os.environ["HF_HOME"] = args.cache_dir
|
||||||
|
os.environ["TORCH_HOME"] = args.cache_dir
|
||||||
|
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
models = args.models
|
||||||
|
if "all" in models:
|
||||||
|
models = ["musicgen", "audiogen", "magnet"]
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("AudioCraft Model Downloader")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
|
||||||
|
print(f"Models to download: {models}")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
if "musicgen" in models:
|
||||||
|
download_musicgen_models(args.musicgen_variants)
|
||||||
|
|
||||||
|
if "audiogen" in models:
|
||||||
|
download_audiogen_models()
|
||||||
|
|
||||||
|
if "magnet" in models:
|
||||||
|
download_magnet_models(args.magnet_variants)
|
||||||
|
|
||||||
|
print("=" * 50)
|
||||||
|
print("Download complete!")
|
||||||
|
print("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
55
scripts/start.sh
Executable file
55
scripts/start.sh
Executable file
@@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Startup script for AudioCraft Studio
|
||||||
|
# Used in Docker container and RunPod
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo " AudioCraft Studio"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
mkdir -p "${AUDIOCRAFT_OUTPUT_DIR:-/workspace/outputs}"
|
||||||
|
mkdir -p "${AUDIOCRAFT_DATA_DIR:-/workspace/data}"
|
||||||
|
mkdir -p "${AUDIOCRAFT_MODEL_CACHE:-/workspace/models}"
|
||||||
|
|
||||||
|
# Check GPU availability
|
||||||
|
echo "Checking GPU..."
|
||||||
|
if command -v nvidia-smi &> /dev/null; then
|
||||||
|
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
||||||
|
else
|
||||||
|
echo "Warning: nvidia-smi not found"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check Python and dependencies
|
||||||
|
echo "Python version:"
|
||||||
|
python --version
|
||||||
|
|
||||||
|
echo "PyTorch version:"
|
||||||
|
python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
|
||||||
|
|
||||||
|
# Check AudioCraft installation
|
||||||
|
echo "AudioCraft version:"
|
||||||
|
python -c "import audiocraft; print(audiocraft.__version__)" 2>/dev/null || echo "AudioCraft installed from source"
|
||||||
|
|
||||||
|
# Generate API key if not exists
|
||||||
|
if [ ! -f "${AUDIOCRAFT_DATA_DIR:-/workspace/data}/.api_key" ]; then
|
||||||
|
echo "Generating API key..."
|
||||||
|
python -c "
|
||||||
|
from src.api.auth import get_key_manager
|
||||||
|
km = get_key_manager()
|
||||||
|
if not km.has_key():
|
||||||
|
key = km.generate_new_key()
|
||||||
|
print(f'Generated API key: {key}')
|
||||||
|
print('Store this key securely - it will not be shown again!')
|
||||||
|
"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Start the application
|
||||||
|
echo "Starting AudioCraft Studio..."
|
||||||
|
echo "Gradio UI: http://0.0.0.0:${AUDIOCRAFT_GRADIO_PORT:-7860}"
|
||||||
|
echo "REST API: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}"
|
||||||
|
echo "API Docs: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}/api/docs"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
exec python main.py "$@"
|
||||||
3
src/__init__.py
Normal file
3
src/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""AudioCraft Studio - AI Audio Generation Web Application."""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
5
src/api/__init__.py
Normal file
5
src/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""REST API for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.api.app import create_api_app
|
||||||
|
|
||||||
|
__all__ = ["create_api_app"]
|
||||||
150
src/api/app.py
Normal file
150
src/api/app.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""FastAPI application for AudioCraft Studio REST API."""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import time
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
from src.api.routes import (
|
||||||
|
generation_router,
|
||||||
|
projects_router,
|
||||||
|
models_router,
|
||||||
|
system_router,
|
||||||
|
)
|
||||||
|
from src.api.routes.generation import set_services as set_generation_services
|
||||||
|
from src.api.routes.projects import set_services as set_project_services
|
||||||
|
from src.api.routes.models import set_services as set_model_services
|
||||||
|
from src.api.routes.system import set_services as set_system_services
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""Application lifespan handler."""
|
||||||
|
# Startup
|
||||||
|
yield
|
||||||
|
# Shutdown
|
||||||
|
|
||||||
|
|
||||||
|
def create_api_app(
|
||||||
|
generation_service: Any = None,
|
||||||
|
batch_processor: Any = None,
|
||||||
|
project_service: Any = None,
|
||||||
|
gpu_manager: Any = None,
|
||||||
|
model_registry: Any = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
"""Create and configure the FastAPI application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_service: Service for handling generations
|
||||||
|
batch_processor: Service for batch/queue processing
|
||||||
|
project_service: Service for project management
|
||||||
|
gpu_manager: GPU memory manager
|
||||||
|
model_registry: Model registry for loading/unloading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured FastAPI application
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AudioCraft Studio API",
|
||||||
|
description="REST API for AI-powered music and sound generation",
|
||||||
|
version="1.0.0",
|
||||||
|
docs_url="/api/docs" if settings.api_enabled else None,
|
||||||
|
redoc_url="/api/redoc" if settings.api_enabled else None,
|
||||||
|
openapi_url="/api/openapi.json" if settings.api_enabled else None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
# CORS middleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request timing middleware
|
||||||
|
@app.middleware("http")
|
||||||
|
async def add_process_time_header(request: Request, call_next):
|
||||||
|
start_time = time.time()
|
||||||
|
response = await call_next(request)
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
response.headers["X-Process-Time"] = str(process_time)
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Global exception handler
|
||||||
|
@app.exception_handler(Exception)
|
||||||
|
async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=500,
|
||||||
|
content={
|
||||||
|
"error": "Internal server error",
|
||||||
|
"detail": str(exc) if settings.debug else "An unexpected error occurred",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject service dependencies
|
||||||
|
set_generation_services(generation_service, batch_processor)
|
||||||
|
set_project_services(project_service)
|
||||||
|
set_model_services(model_registry)
|
||||||
|
set_system_services(gpu_manager, batch_processor, model_registry)
|
||||||
|
|
||||||
|
# Register routers
|
||||||
|
app.include_router(generation_router, prefix="/api/v1")
|
||||||
|
app.include_router(projects_router, prefix="/api/v1")
|
||||||
|
app.include_router(models_router, prefix="/api/v1")
|
||||||
|
app.include_router(system_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
# Root endpoint
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {
|
||||||
|
"name": "AudioCraft Studio API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"docs": "/api/docs",
|
||||||
|
}
|
||||||
|
|
||||||
|
# API info endpoint
|
||||||
|
@app.get("/api/v1")
|
||||||
|
async def api_info():
|
||||||
|
return {
|
||||||
|
"version": "1.0.0",
|
||||||
|
"endpoints": {
|
||||||
|
"generation": "/api/v1/generate",
|
||||||
|
"projects": "/api/v1/projects",
|
||||||
|
"models": "/api/v1/models",
|
||||||
|
"system": "/api/v1/system",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def run_api_server(
|
||||||
|
app: FastAPI,
|
||||||
|
host: Optional[str] = None,
|
||||||
|
port: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Run the API server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: FastAPI application
|
||||||
|
host: Server hostname
|
||||||
|
port: Server port
|
||||||
|
"""
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host=host or settings.host,
|
||||||
|
port=port or settings.api_port,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
133
src/api/auth.py
Normal file
133
src/api/auth.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""API authentication middleware."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Security, status
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
# API key header
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_api_key() -> str:
|
||||||
|
"""Generate a new API key."""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_api_key(key: str) -> str:
|
||||||
|
"""Hash an API key for storage."""
|
||||||
|
return hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_api_key(key: str, hashed: str) -> bool:
|
||||||
|
"""Verify an API key against its hash."""
|
||||||
|
return secrets.compare_digest(hash_api_key(key), hashed)
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyManager:
|
||||||
|
"""Manage API keys for authentication."""
|
||||||
|
|
||||||
|
def __init__(self, key_file: Optional[Path] = None):
|
||||||
|
"""Initialize the key manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_file: Path to store API key hash
|
||||||
|
"""
|
||||||
|
self.settings = get_settings()
|
||||||
|
self.key_file = key_file or Path(self.settings.data_dir) / ".api_key"
|
||||||
|
self._key_hash: Optional[str] = None
|
||||||
|
self._load_key()
|
||||||
|
|
||||||
|
def _load_key(self) -> None:
|
||||||
|
"""Load API key hash from file."""
|
||||||
|
if self.key_file.exists():
|
||||||
|
self._key_hash = self.key_file.read_text().strip()
|
||||||
|
|
||||||
|
def _save_key(self, key_hash: str) -> None:
|
||||||
|
"""Save API key hash to file."""
|
||||||
|
self.key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.key_file.write_text(key_hash)
|
||||||
|
self._key_hash = key_hash
|
||||||
|
|
||||||
|
def generate_new_key(self) -> str:
|
||||||
|
"""Generate and store a new API key.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The new API key (only shown once)
|
||||||
|
"""
|
||||||
|
key = generate_api_key()
|
||||||
|
self._save_key(hash_api_key(key))
|
||||||
|
return key
|
||||||
|
|
||||||
|
def verify(self, key: str) -> bool:
|
||||||
|
"""Verify an API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: API key to verify
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid, False otherwise
|
||||||
|
"""
|
||||||
|
if not self._key_hash:
|
||||||
|
return False
|
||||||
|
return verify_api_key(key, self._key_hash)
|
||||||
|
|
||||||
|
def has_key(self) -> bool:
|
||||||
|
"""Check if an API key has been generated."""
|
||||||
|
return self._key_hash is not None
|
||||||
|
|
||||||
|
|
||||||
|
# Global key manager instance
|
||||||
|
_key_manager: Optional[APIKeyManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_manager() -> APIKeyManager:
|
||||||
|
"""Get the global key manager instance."""
|
||||||
|
global _key_manager
|
||||||
|
if _key_manager is None:
|
||||||
|
_key_manager = APIKeyManager()
|
||||||
|
return _key_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_key_dependency(
|
||||||
|
api_key: Optional[str] = Security(api_key_header),
|
||||||
|
) -> str:
|
||||||
|
"""FastAPI dependency to verify API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key from header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The verified API key
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If key is missing or invalid
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Skip auth if disabled
|
||||||
|
if not settings.api_key_required:
|
||||||
|
return "anonymous"
|
||||||
|
|
||||||
|
if api_key is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="API key required",
|
||||||
|
headers={"WWW-Authenticate": "ApiKey"},
|
||||||
|
)
|
||||||
|
|
||||||
|
key_manager = get_key_manager()
|
||||||
|
|
||||||
|
if not key_manager.verify(api_key):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Invalid API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
return api_key
|
||||||
166
src/api/models.py
Normal file
166
src/api/models.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Pydantic models for API requests and responses."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFamily(str, Enum):
|
||||||
|
"""Available model families."""
|
||||||
|
MUSICGEN = "musicgen"
|
||||||
|
AUDIOGEN = "audiogen"
|
||||||
|
MAGNET = "magnet"
|
||||||
|
MUSICGEN_STYLE = "musicgen-style"
|
||||||
|
JASCO = "jasco"
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
"""Generation job status."""
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
# Generation requests
|
||||||
|
|
||||||
|
class GenerationRequest(BaseModel):
|
||||||
|
"""Request to generate audio."""
|
||||||
|
model: ModelFamily = Field(..., description="Model family to use")
|
||||||
|
variant: str = Field("medium", description="Model variant")
|
||||||
|
prompts: list[str] = Field(..., min_length=1, max_length=10, description="Text prompts")
|
||||||
|
duration: float = Field(10.0, ge=1, le=30, description="Duration in seconds")
|
||||||
|
temperature: float = Field(1.0, ge=0, le=2, description="Sampling temperature")
|
||||||
|
top_k: int = Field(250, ge=0, le=500, description="Top-K sampling")
|
||||||
|
top_p: float = Field(0.0, ge=0, le=1, description="Top-P (nucleus) sampling")
|
||||||
|
cfg_coef: float = Field(3.0, ge=1, le=10, description="CFG coefficient")
|
||||||
|
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||||
|
conditioning: Optional[dict[str, Any]] = Field(None, description="Model-specific conditioning")
|
||||||
|
project_id: Optional[str] = Field(None, description="Project to save to")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchGenerationRequest(BaseModel):
|
||||||
|
"""Request to add generation to queue."""
|
||||||
|
request: GenerationRequest
|
||||||
|
priority: int = Field(0, ge=0, le=10, description="Job priority (higher = sooner)")
|
||||||
|
|
||||||
|
|
||||||
|
# Generation responses
|
||||||
|
|
||||||
|
class GenerationResult(BaseModel):
|
||||||
|
"""Result of a completed generation."""
|
||||||
|
id: str = Field(..., description="Generation ID")
|
||||||
|
audio_url: str = Field(..., description="URL to download audio")
|
||||||
|
waveform_url: Optional[str] = Field(None, description="URL to waveform image")
|
||||||
|
duration: float = Field(..., description="Actual duration in seconds")
|
||||||
|
seed: int = Field(..., description="Seed used for generation")
|
||||||
|
model: str = Field(..., description="Model used")
|
||||||
|
variant: str = Field(..., description="Variant used")
|
||||||
|
prompt: str = Field(..., description="Prompt used")
|
||||||
|
created_at: datetime = Field(..., description="Creation timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class JobResponse(BaseModel):
|
||||||
|
"""Response for a queued job."""
|
||||||
|
job_id: str = Field(..., description="Job ID for tracking")
|
||||||
|
status: JobStatus = Field(..., description="Current status")
|
||||||
|
position: Optional[int] = Field(None, description="Queue position if pending")
|
||||||
|
progress: Optional[float] = Field(None, description="Progress 0-1 if running")
|
||||||
|
result: Optional[GenerationResult] = Field(None, description="Result if completed")
|
||||||
|
error: Optional[str] = Field(None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
# Project models
|
||||||
|
|
||||||
|
class ProjectCreate(BaseModel):
|
||||||
|
"""Request to create a project."""
|
||||||
|
name: str = Field(..., min_length=1, max_length=100)
|
||||||
|
description: Optional[str] = Field(None, max_length=500)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectResponse(BaseModel):
|
||||||
|
"""Project information."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: Optional[str]
|
||||||
|
generation_count: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationResponse(BaseModel):
|
||||||
|
"""Generation record from database."""
|
||||||
|
id: str
|
||||||
|
project_id: str
|
||||||
|
model: str
|
||||||
|
variant: str
|
||||||
|
prompt: str
|
||||||
|
duration_seconds: float
|
||||||
|
seed: int
|
||||||
|
audio_path: str
|
||||||
|
waveform_path: Optional[str]
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
# Model info
|
||||||
|
|
||||||
|
class ModelVariantInfo(BaseModel):
|
||||||
|
"""Information about a model variant."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
vram_mb: int
|
||||||
|
description: str
|
||||||
|
capabilities: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInfo(BaseModel):
|
||||||
|
"""Information about a model family."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
variants: list[ModelVariantInfo]
|
||||||
|
loaded: bool
|
||||||
|
current_variant: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
# System info
|
||||||
|
|
||||||
|
class GPUStatus(BaseModel):
|
||||||
|
"""GPU memory status."""
|
||||||
|
device_name: str
|
||||||
|
total_gb: float
|
||||||
|
used_gb: float
|
||||||
|
available_gb: float
|
||||||
|
utilization_percent: float
|
||||||
|
temperature_c: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
class QueueStatus(BaseModel):
|
||||||
|
"""Generation queue status."""
|
||||||
|
queue_size: int
|
||||||
|
active_jobs: int
|
||||||
|
completed_today: int
|
||||||
|
failed_today: int
|
||||||
|
|
||||||
|
|
||||||
|
class SystemStatus(BaseModel):
|
||||||
|
"""Overall system status."""
|
||||||
|
gpu: GPUStatus
|
||||||
|
queue: QueueStatus
|
||||||
|
loaded_models: list[str]
|
||||||
|
uptime_seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
|
||||||
|
class PaginatedResponse(BaseModel):
|
||||||
|
"""Paginated list response."""
|
||||||
|
items: list[Any]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
page_size: int
|
||||||
|
pages: int
|
||||||
13
src/api/routes/__init__.py
Normal file
13
src/api/routes/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""API route modules."""
|
||||||
|
|
||||||
|
from src.api.routes.generation import router as generation_router
|
||||||
|
from src.api.routes.projects import router as projects_router
|
||||||
|
from src.api.routes.models import router as models_router
|
||||||
|
from src.api.routes.system import router as system_router
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"generation_router",
|
||||||
|
"projects_router",
|
||||||
|
"models_router",
|
||||||
|
"system_router",
|
||||||
|
]
|
||||||
234
src/api/routes/generation.py
Normal file
234
src/api/routes/generation.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""Generation API endpoints."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
from src.api.auth import verify_api_key_dependency
|
||||||
|
from src.api.models import (
|
||||||
|
GenerationRequest,
|
||||||
|
BatchGenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
JobResponse,
|
||||||
|
JobStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/generate", tags=["generation"])
|
||||||
|
|
||||||
|
|
||||||
|
# Service dependencies (injected at app startup)
|
||||||
|
_generation_service = None
|
||||||
|
_batch_processor = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_services(generation_service: Any, batch_processor: Any) -> None:
|
||||||
|
"""Set service dependencies."""
|
||||||
|
global _generation_service, _batch_processor
|
||||||
|
_generation_service = generation_service
|
||||||
|
_batch_processor = batch_processor
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/",
|
||||||
|
response_model=GenerationResult,
|
||||||
|
summary="Generate audio synchronously",
|
||||||
|
description="Generate audio and wait for completion. For long generations, consider using the async endpoint.",
|
||||||
|
)
|
||||||
|
async def generate_sync(
|
||||||
|
request: GenerationRequest,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> GenerationResult:
|
||||||
|
"""Generate audio synchronously."""
|
||||||
|
if _generation_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Generation service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result, generation = await _generation_service.generate(
|
||||||
|
model_id=request.model.value,
|
||||||
|
variant=request.variant,
|
||||||
|
prompts=request.prompts,
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
seed=request.seed,
|
||||||
|
conditioning=request.conditioning,
|
||||||
|
project_id=request.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
id=generation.id,
|
||||||
|
audio_url=f"/api/v1/audio/{generation.id}",
|
||||||
|
waveform_url=f"/api/v1/audio/{generation.id}/waveform" if generation.waveform_path else None,
|
||||||
|
duration=result.duration,
|
||||||
|
seed=result.seed,
|
||||||
|
model=request.model.value,
|
||||||
|
variant=request.variant,
|
||||||
|
prompt=request.prompts[0],
|
||||||
|
created_at=generation.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/async",
|
||||||
|
response_model=JobResponse,
|
||||||
|
summary="Queue generation job",
|
||||||
|
description="Add a generation to the queue for async processing.",
|
||||||
|
)
|
||||||
|
async def generate_async(
|
||||||
|
request: BatchGenerationRequest,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> JobResponse:
|
||||||
|
"""Add generation to queue."""
|
||||||
|
if _batch_processor is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Batch processor not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
job = _batch_processor.add_job(
|
||||||
|
model_id=request.request.model.value,
|
||||||
|
variant=request.request.variant,
|
||||||
|
prompts=request.request.prompts,
|
||||||
|
duration=request.request.duration,
|
||||||
|
temperature=request.request.temperature,
|
||||||
|
top_k=request.request.top_k,
|
||||||
|
top_p=request.request.top_p,
|
||||||
|
cfg_coef=request.request.cfg_coef,
|
||||||
|
seed=request.request.seed,
|
||||||
|
conditioning=request.request.conditioning,
|
||||||
|
project_id=request.request.project_id,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JobResponse(
|
||||||
|
job_id=job.id,
|
||||||
|
status=JobStatus.PENDING,
|
||||||
|
position=_batch_processor.get_position(job.id),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/jobs/{job_id}",
|
||||||
|
response_model=JobResponse,
|
||||||
|
summary="Get job status",
|
||||||
|
description="Check the status of a queued generation job.",
|
||||||
|
)
|
||||||
|
async def get_job_status(
|
||||||
|
job_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> JobResponse:
|
||||||
|
"""Get status of a queued job."""
|
||||||
|
if _batch_processor is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Batch processor not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
job = _batch_processor.get_job(job_id)
|
||||||
|
if job is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Job {job_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = JobResponse(
|
||||||
|
job_id=job.id,
|
||||||
|
status=JobStatus(job.status.value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if job.status.value == "pending":
|
||||||
|
response.position = _batch_processor.get_position(job_id)
|
||||||
|
elif job.status.value == "running":
|
||||||
|
response.progress = job.progress
|
||||||
|
elif job.status.value == "completed" and job.result:
|
||||||
|
response.result = GenerationResult(
|
||||||
|
id=job.result.id,
|
||||||
|
audio_url=f"/api/v1/audio/{job.result.id}",
|
||||||
|
waveform_url=f"/api/v1/audio/{job.result.id}/waveform",
|
||||||
|
duration=job.result.duration,
|
||||||
|
seed=job.result.seed,
|
||||||
|
model=job.model_id,
|
||||||
|
variant=job.variant,
|
||||||
|
prompt=job.prompts[0],
|
||||||
|
created_at=job.completed_at,
|
||||||
|
)
|
||||||
|
elif job.status.value == "failed":
|
||||||
|
response.error = job.error
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/jobs/{job_id}",
|
||||||
|
summary="Cancel job",
|
||||||
|
description="Cancel a pending or running job.",
|
||||||
|
)
|
||||||
|
async def cancel_job(
|
||||||
|
job_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Cancel a queued job."""
|
||||||
|
if _batch_processor is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Batch processor not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
success = _batch_processor.cancel_job(job_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Job {job_id} not found or cannot be cancelled",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": f"Job {job_id} cancelled"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/jobs",
|
||||||
|
response_model=list[JobResponse],
|
||||||
|
summary="List jobs",
|
||||||
|
description="List all jobs in the queue.",
|
||||||
|
)
|
||||||
|
async def list_jobs(
|
||||||
|
status_filter: str = None,
|
||||||
|
limit: int = 50,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> list[JobResponse]:
|
||||||
|
"""List queued jobs."""
|
||||||
|
if _batch_processor is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Batch processor not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
jobs = _batch_processor.list_jobs(status_filter=status_filter, limit=limit)
|
||||||
|
|
||||||
|
return [
|
||||||
|
JobResponse(
|
||||||
|
job_id=job.id,
|
||||||
|
status=JobStatus(job.status.value),
|
||||||
|
position=_batch_processor.get_position(job.id) if job.status.value == "pending" else None,
|
||||||
|
progress=job.progress if job.status.value == "running" else None,
|
||||||
|
)
|
||||||
|
for job in jobs
|
||||||
|
]
|
||||||
228
src/api/routes/models.py
Normal file
228
src/api/routes/models.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Models API endpoints."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from src.api.auth import verify_api_key_dependency
|
||||||
|
from src.api.models import ModelInfo, ModelVariantInfo
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/models", tags=["models"])
|
||||||
|
|
||||||
|
|
||||||
|
# Service dependency (injected at app startup)
|
||||||
|
_model_registry = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_services(model_registry: Any) -> None:
|
||||||
|
"""Set service dependencies."""
|
||||||
|
global _model_registry
|
||||||
|
_model_registry = model_registry
|
||||||
|
|
||||||
|
|
||||||
|
# Static model information
|
||||||
|
MODEL_CATALOG = {
|
||||||
|
"musicgen": {
|
||||||
|
"id": "musicgen",
|
||||||
|
"name": "MusicGen",
|
||||||
|
"description": "Text-to-music generation with optional melody conditioning",
|
||||||
|
"variants": [
|
||||||
|
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params", "capabilities": ["text"]},
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params", "capabilities": ["text"]},
|
||||||
|
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params", "capabilities": ["text"]},
|
||||||
|
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning", "capabilities": ["text", "melody"]},
|
||||||
|
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params", "capabilities": ["text", "stereo"]},
|
||||||
|
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params", "capabilities": ["text", "stereo"]},
|
||||||
|
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params", "capabilities": ["text", "stereo"]},
|
||||||
|
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody", "capabilities": ["text", "melody", "stereo"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"audiogen": {
|
||||||
|
"id": "audiogen",
|
||||||
|
"name": "AudioGen",
|
||||||
|
"description": "Text-to-sound effects and environmental audio",
|
||||||
|
"variants": [
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params", "capabilities": ["text", "sfx"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"magnet": {
|
||||||
|
"id": "magnet",
|
||||||
|
"name": "MAGNeT",
|
||||||
|
"description": "Fast non-autoregressive music generation",
|
||||||
|
"variants": [
|
||||||
|
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params", "capabilities": ["text", "music"]},
|
||||||
|
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params", "capabilities": ["text", "music"]},
|
||||||
|
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects", "capabilities": ["text", "sfx"]},
|
||||||
|
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects", "capabilities": ["text", "sfx"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"musicgen-style": {
|
||||||
|
"id": "musicgen-style",
|
||||||
|
"name": "MusicGen Style",
|
||||||
|
"description": "Style-conditioned music from reference audio",
|
||||||
|
"variants": [
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning", "capabilities": ["text", "style"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"jasco": {
|
||||||
|
"id": "jasco",
|
||||||
|
"name": "JASCO",
|
||||||
|
"description": "Chord and drum-conditioned music generation",
|
||||||
|
"variants": [
|
||||||
|
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation", "capabilities": ["text", "chords"]},
|
||||||
|
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning", "capabilities": ["text", "chords", "drums"]},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/",
|
||||||
|
response_model=list[ModelInfo],
|
||||||
|
summary="List models",
|
||||||
|
description="Get information about all available models.",
|
||||||
|
)
|
||||||
|
async def list_models(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> list[ModelInfo]:
|
||||||
|
"""List all available models."""
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_id, info in MODEL_CATALOG.items():
|
||||||
|
loaded = False
|
||||||
|
current_variant = None
|
||||||
|
|
||||||
|
if _model_registry:
|
||||||
|
loaded = _model_registry.is_loaded(model_id)
|
||||||
|
if loaded:
|
||||||
|
current_variant = _model_registry.get_current_variant(model_id)
|
||||||
|
|
||||||
|
models.append(
|
||||||
|
ModelInfo(
|
||||||
|
id=info["id"],
|
||||||
|
name=info["name"],
|
||||||
|
description=info["description"],
|
||||||
|
variants=[ModelVariantInfo(**v) for v in info["variants"]],
|
||||||
|
loaded=loaded,
|
||||||
|
current_variant=current_variant,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{model_id}",
|
||||||
|
response_model=ModelInfo,
|
||||||
|
summary="Get model info",
|
||||||
|
description="Get detailed information about a specific model.",
|
||||||
|
)
|
||||||
|
async def get_model(
|
||||||
|
model_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> ModelInfo:
|
||||||
|
"""Get model information by ID."""
|
||||||
|
if model_id not in MODEL_CATALOG:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Model {model_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
info = MODEL_CATALOG[model_id]
|
||||||
|
loaded = False
|
||||||
|
current_variant = None
|
||||||
|
|
||||||
|
if _model_registry:
|
||||||
|
loaded = _model_registry.is_loaded(model_id)
|
||||||
|
if loaded:
|
||||||
|
current_variant = _model_registry.get_current_variant(model_id)
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=info["id"],
|
||||||
|
name=info["name"],
|
||||||
|
description=info["description"],
|
||||||
|
variants=[ModelVariantInfo(**v) for v in info["variants"]],
|
||||||
|
loaded=loaded,
|
||||||
|
current_variant=current_variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{model_id}/load",
|
||||||
|
summary="Load model",
|
||||||
|
description="Load a model into GPU memory.",
|
||||||
|
)
|
||||||
|
async def load_model(
|
||||||
|
model_id: str,
|
||||||
|
variant: str = "medium",
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Load a model into memory."""
|
||||||
|
if model_id not in MODEL_CATALOG:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Model {model_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _model_registry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Model registry not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _model_registry.load_model(model_id, variant)
|
||||||
|
return {"message": f"Model {model_id} ({variant}) loaded successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{model_id}/unload",
|
||||||
|
summary="Unload model",
|
||||||
|
description="Unload a model from GPU memory.",
|
||||||
|
)
|
||||||
|
async def unload_model(
|
||||||
|
model_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Unload a model from memory."""
|
||||||
|
if model_id not in MODEL_CATALOG:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Model {model_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _model_registry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Model registry not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _model_registry.unload_model(model_id)
|
||||||
|
return {"message": f"Model {model_id} unloaded successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/loaded",
|
||||||
|
response_model=list[str],
|
||||||
|
summary="List loaded models",
|
||||||
|
description="Get list of currently loaded models.",
|
||||||
|
)
|
||||||
|
async def list_loaded_models(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> list[str]:
|
||||||
|
"""List currently loaded models."""
|
||||||
|
if _model_registry is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return _model_registry.get_loaded_models()
|
||||||
250
src/api/routes/projects.py
Normal file
250
src/api/routes/projects.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Projects API endpoints."""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
from src.api.auth import verify_api_key_dependency
|
||||||
|
from src.api.models import (
|
||||||
|
ProjectCreate,
|
||||||
|
ProjectResponse,
|
||||||
|
GenerationResponse,
|
||||||
|
PaginatedResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/projects", tags=["projects"])
|
||||||
|
|
||||||
|
|
||||||
|
# Service dependency (injected at app startup)
|
||||||
|
_project_service = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_services(project_service: Any) -> None:
|
||||||
|
"""Set service dependencies."""
|
||||||
|
global _project_service
|
||||||
|
_project_service = project_service
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
status_code=status.HTTP_201_CREATED,
|
||||||
|
summary="Create project",
|
||||||
|
description="Create a new project for organizing generations.",
|
||||||
|
)
|
||||||
|
async def create_project(
|
||||||
|
request: ProjectCreate,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> ProjectResponse:
|
||||||
|
"""Create a new project."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
project = await _project_service.create_project(
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
)
|
||||||
|
return ProjectResponse(**project)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/",
|
||||||
|
response_model=list[ProjectResponse],
|
||||||
|
summary="List projects",
|
||||||
|
description="Get all projects.",
|
||||||
|
)
|
||||||
|
async def list_projects(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> list[ProjectResponse]:
|
||||||
|
"""List all projects."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
projects = await _project_service.list_projects()
|
||||||
|
return [ProjectResponse(**p) for p in projects]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{project_id}",
|
||||||
|
response_model=ProjectResponse,
|
||||||
|
summary="Get project",
|
||||||
|
description="Get project details by ID.",
|
||||||
|
)
|
||||||
|
async def get_project(
|
||||||
|
project_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> ProjectResponse:
|
||||||
|
"""Get a project by ID."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
project = await _project_service.get_project(project_id)
|
||||||
|
if project is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Project {project_id} not found",
|
||||||
|
)
|
||||||
|
return ProjectResponse(**project)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{project_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete project",
|
||||||
|
description="Delete a project and all its generations.",
|
||||||
|
)
|
||||||
|
async def delete_project(
|
||||||
|
project_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a project."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _project_service.delete_project(project_id)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{project_id}/generations",
|
||||||
|
response_model=PaginatedResponse,
|
||||||
|
summary="List generations",
|
||||||
|
description="Get generations for a project with pagination.",
|
||||||
|
)
|
||||||
|
async def list_generations(
|
||||||
|
project_id: str,
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
page_size: int = Query(20, ge=1, le=100),
|
||||||
|
model: Optional[str] = Query(None, description="Filter by model"),
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> PaginatedResponse:
|
||||||
|
"""List generations for a project."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
generations = await _project_service.list_generations(
|
||||||
|
project_id=project_id,
|
||||||
|
limit=page_size + 1, # +1 to check if more pages
|
||||||
|
offset=offset,
|
||||||
|
model_filter=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
has_more = len(generations) > page_size
|
||||||
|
generations = generations[:page_size]
|
||||||
|
|
||||||
|
# Estimate total (could be improved with actual count query)
|
||||||
|
total = offset + len(generations) + (1 if has_more else 0)
|
||||||
|
pages = (total + page_size - 1) // page_size
|
||||||
|
|
||||||
|
return PaginatedResponse(
|
||||||
|
items=[GenerationResponse(**g) for g in generations],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
pages=pages,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{project_id}/export",
|
||||||
|
summary="Export project",
|
||||||
|
description="Export project as ZIP file with all audio and metadata.",
|
||||||
|
)
|
||||||
|
async def export_project(
|
||||||
|
project_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Export project as ZIP."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
zip_path = await _project_service.export_project_zip(project_id)
|
||||||
|
return FileResponse(
|
||||||
|
path=zip_path,
|
||||||
|
filename=f"project_{project_id}.zip",
|
||||||
|
media_type="application/zip",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/{project_id}/generations/{generation_id}",
|
||||||
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
|
summary="Delete generation",
|
||||||
|
description="Delete a specific generation.",
|
||||||
|
)
|
||||||
|
async def delete_generation(
|
||||||
|
project_id: str,
|
||||||
|
generation_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> None:
|
||||||
|
"""Delete a generation."""
|
||||||
|
if _project_service is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Project service not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _project_service.delete_generation(generation_id)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
263
src/api/routes/system.py
Normal file
263
src/api/routes/system.py
Normal file
@@ -0,0 +1,263 @@
|
|||||||
|
"""System API endpoints."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from pathlib import Path
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
from src.api.auth import verify_api_key_dependency, get_key_manager
|
||||||
|
from src.api.models import GPUStatus, QueueStatus, SystemStatus
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/system", tags=["system"])
|
||||||
|
|
||||||
|
|
||||||
|
# Service dependencies (injected at app startup)
|
||||||
|
_gpu_manager = None
|
||||||
|
_batch_processor = None
|
||||||
|
_model_registry = None
|
||||||
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def set_services(
|
||||||
|
gpu_manager: Any,
|
||||||
|
batch_processor: Any,
|
||||||
|
model_registry: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Set service dependencies."""
|
||||||
|
global _gpu_manager, _batch_processor, _model_registry
|
||||||
|
_gpu_manager = gpu_manager
|
||||||
|
_batch_processor = batch_processor
|
||||||
|
_model_registry = model_registry
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/status",
|
||||||
|
response_model=SystemStatus,
|
||||||
|
summary="System status",
|
||||||
|
description="Get overall system status including GPU, queue, and loaded models.",
|
||||||
|
)
|
||||||
|
async def get_status(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> SystemStatus:
|
||||||
|
"""Get system status."""
|
||||||
|
# GPU status
|
||||||
|
if _gpu_manager:
|
||||||
|
gpu = GPUStatus(
|
||||||
|
device_name=_gpu_manager.device_name,
|
||||||
|
total_gb=_gpu_manager.total_memory / 1024**3,
|
||||||
|
used_gb=_gpu_manager.get_used_memory() / 1024**3,
|
||||||
|
available_gb=_gpu_manager.get_available_memory() / 1024**3,
|
||||||
|
utilization_percent=_gpu_manager.get_utilization(),
|
||||||
|
temperature_c=_gpu_manager.get_temperature(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
gpu = GPUStatus(
|
||||||
|
device_name="Unknown",
|
||||||
|
total_gb=0,
|
||||||
|
used_gb=0,
|
||||||
|
available_gb=0,
|
||||||
|
utilization_percent=0,
|
||||||
|
temperature_c=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Queue status
|
||||||
|
if _batch_processor:
|
||||||
|
queue = QueueStatus(
|
||||||
|
queue_size=len(_batch_processor.queue),
|
||||||
|
active_jobs=_batch_processor.active_count,
|
||||||
|
completed_today=_batch_processor.completed_count,
|
||||||
|
failed_today=_batch_processor.failed_count,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
queue = QueueStatus(
|
||||||
|
queue_size=0,
|
||||||
|
active_jobs=0,
|
||||||
|
completed_today=0,
|
||||||
|
failed_today=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loaded models
|
||||||
|
loaded_models = []
|
||||||
|
if _model_registry:
|
||||||
|
loaded_models = _model_registry.get_loaded_models()
|
||||||
|
|
||||||
|
return SystemStatus(
|
||||||
|
gpu=gpu,
|
||||||
|
queue=queue,
|
||||||
|
loaded_models=loaded_models,
|
||||||
|
uptime_seconds=time.time() - _start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/gpu",
|
||||||
|
response_model=GPUStatus,
|
||||||
|
summary="GPU status",
|
||||||
|
description="Get detailed GPU memory and utilization status.",
|
||||||
|
)
|
||||||
|
async def get_gpu_status(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> GPUStatus:
|
||||||
|
"""Get GPU status."""
|
||||||
|
if _gpu_manager is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="GPU manager not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
return GPUStatus(
|
||||||
|
device_name=_gpu_manager.device_name,
|
||||||
|
total_gb=_gpu_manager.total_memory / 1024**3,
|
||||||
|
used_gb=_gpu_manager.get_used_memory() / 1024**3,
|
||||||
|
available_gb=_gpu_manager.get_available_memory() / 1024**3,
|
||||||
|
utilization_percent=_gpu_manager.get_utilization(),
|
||||||
|
temperature_c=_gpu_manager.get_temperature(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/clear-cache",
|
||||||
|
summary="Clear cache",
|
||||||
|
description="Clear model cache and free GPU memory.",
|
||||||
|
)
|
||||||
|
async def clear_cache(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Clear model cache."""
|
||||||
|
if _model_registry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Model registry not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_model_registry.clear_cache()
|
||||||
|
return {"message": "Cache cleared successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/unload-all",
|
||||||
|
summary="Unload all models",
|
||||||
|
description="Unload all models from GPU memory.",
|
||||||
|
)
|
||||||
|
async def unload_all_models(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Unload all models."""
|
||||||
|
if _model_registry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Model registry not available",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _model_registry.unload_all()
|
||||||
|
return {"message": "All models unloaded successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/health",
|
||||||
|
summary="Health check",
|
||||||
|
description="Simple health check endpoint.",
|
||||||
|
)
|
||||||
|
async def health_check() -> dict:
|
||||||
|
"""Health check endpoint (no auth required)."""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"uptime_seconds": time.time() - _start_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/api-key/regenerate",
|
||||||
|
summary="Regenerate API key",
|
||||||
|
description="Generate a new API key. The old key will be invalidated.",
|
||||||
|
)
|
||||||
|
async def regenerate_api_key(
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> dict:
|
||||||
|
"""Regenerate API key."""
|
||||||
|
key_manager = get_key_manager()
|
||||||
|
new_key = key_manager.generate_new_key()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"api_key": new_key,
|
||||||
|
"message": "New API key generated. Store it securely - it won't be shown again.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/audio/{generation_id}",
|
||||||
|
summary="Download audio",
|
||||||
|
description="Download generated audio file.",
|
||||||
|
)
|
||||||
|
async def download_audio(
|
||||||
|
generation_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Download audio file for a generation."""
|
||||||
|
# This would look up the actual file path from the database
|
||||||
|
# For now, construct expected path
|
||||||
|
from config.settings import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Find the audio file
|
||||||
|
audio_dir = Path(settings.output_dir)
|
||||||
|
possible_paths = [
|
||||||
|
audio_dir / f"{generation_id}.wav",
|
||||||
|
audio_dir / f"{generation_id}.mp3",
|
||||||
|
audio_dir / f"{generation_id}.flac",
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in possible_paths:
|
||||||
|
if path.exists():
|
||||||
|
return FileResponse(
|
||||||
|
path=path,
|
||||||
|
filename=path.name,
|
||||||
|
media_type="audio/wav" if path.suffix == ".wav" else f"audio/{path.suffix[1:]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Audio file for generation {generation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/audio/{generation_id}/waveform",
|
||||||
|
summary="Download waveform",
|
||||||
|
description="Download waveform visualization image.",
|
||||||
|
)
|
||||||
|
async def download_waveform(
|
||||||
|
generation_id: str,
|
||||||
|
api_key: str = Depends(verify_api_key_dependency),
|
||||||
|
) -> FileResponse:
|
||||||
|
"""Download waveform image for a generation."""
|
||||||
|
from config.settings import get_settings
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
waveform_path = Path(settings.output_dir) / f"{generation_id}_waveform.png"
|
||||||
|
|
||||||
|
if not waveform_path.exists():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Waveform for generation {generation_id} not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=waveform_path,
|
||||||
|
filename=waveform_path.name,
|
||||||
|
media_type="image/png",
|
||||||
|
)
|
||||||
24
src/core/__init__.py
Normal file
24
src/core/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Core infrastructure for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
ConditioningType,
|
||||||
|
)
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager, VRAMBudget
|
||||||
|
from src.core.model_registry import ModelRegistry
|
||||||
|
from src.core.oom_handler import OOMHandler, OOMRecoveryError, oom_safe
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseAudioModel",
|
||||||
|
"GenerationRequest",
|
||||||
|
"GenerationResult",
|
||||||
|
"ConditioningType",
|
||||||
|
"GPUMemoryManager",
|
||||||
|
"VRAMBudget",
|
||||||
|
"ModelRegistry",
|
||||||
|
"OOMHandler",
|
||||||
|
"OOMRecoveryError",
|
||||||
|
"oom_safe",
|
||||||
|
]
|
||||||
535
src/core/audio_utils.py
Normal file
535
src/core/audio_utils.py
Normal file
@@ -0,0 +1,535 @@
|
|||||||
|
"""Audio utilities for processing, visualization, and export."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_audio(
|
||||||
|
audio: Union[torch.Tensor, np.ndarray],
|
||||||
|
target_db: float = -14.0,
|
||||||
|
peak_normalize: bool = False,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Normalize audio to target loudness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio tensor or array [channels, samples] or [samples]
|
||||||
|
target_db: Target loudness in dB (LUFS-like)
|
||||||
|
peak_normalize: If True, normalize to peak instead of RMS
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized audio as numpy array
|
||||||
|
"""
|
||||||
|
if isinstance(audio, torch.Tensor):
|
||||||
|
audio = audio.numpy()
|
||||||
|
|
||||||
|
# Ensure float32
|
||||||
|
audio = audio.astype(np.float32)
|
||||||
|
|
||||||
|
# Handle batch dimension
|
||||||
|
if audio.ndim == 3:
|
||||||
|
audio = audio[0] # Take first sample if batched
|
||||||
|
|
||||||
|
if peak_normalize:
|
||||||
|
# Peak normalization
|
||||||
|
peak = np.abs(audio).max()
|
||||||
|
if peak > 0:
|
||||||
|
target_linear = 10 ** (target_db / 20)
|
||||||
|
audio = audio * (target_linear / peak)
|
||||||
|
else:
|
||||||
|
# RMS normalization (approximating LUFS)
|
||||||
|
rms = np.sqrt(np.mean(audio ** 2))
|
||||||
|
if rms > 0:
|
||||||
|
target_rms = 10 ** (target_db / 20)
|
||||||
|
audio = audio * (target_rms / rms)
|
||||||
|
|
||||||
|
# Clip to prevent clipping
|
||||||
|
audio = np.clip(audio, -1.0, 1.0)
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sample_rate(
|
||||||
|
audio: np.ndarray,
|
||||||
|
orig_sr: int,
|
||||||
|
target_sr: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Convert audio sample rate.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio array [channels, samples] or [samples]
|
||||||
|
orig_sr: Original sample rate
|
||||||
|
target_sr: Target sample rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resampled audio
|
||||||
|
"""
|
||||||
|
if orig_sr == target_sr:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
# Handle multi-channel
|
||||||
|
if audio.ndim == 2:
|
||||||
|
resampled = np.array([
|
||||||
|
librosa.resample(ch, orig_sr=orig_sr, target_sr=target_sr)
|
||||||
|
for ch in audio
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
resampled = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||||
|
|
||||||
|
return resampled
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("librosa not available, using scipy for resampling")
|
||||||
|
from scipy import signal
|
||||||
|
|
||||||
|
ratio = target_sr / orig_sr
|
||||||
|
new_length = int(audio.shape[-1] * ratio)
|
||||||
|
|
||||||
|
if audio.ndim == 2:
|
||||||
|
resampled = np.array([
|
||||||
|
signal.resample(ch, new_length) for ch in audio
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
resampled = signal.resample(audio, new_length)
|
||||||
|
|
||||||
|
return resampled
|
||||||
|
|
||||||
|
|
||||||
|
def generate_waveform(
|
||||||
|
audio: Union[torch.Tensor, np.ndarray],
|
||||||
|
sample_rate: int,
|
||||||
|
width: int = 800,
|
||||||
|
height: int = 200,
|
||||||
|
color: str = "#3b82f6",
|
||||||
|
background: str = "#1f2937",
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate waveform image as PNG bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data [channels, samples] or [samples]
|
||||||
|
sample_rate: Sample rate in Hz
|
||||||
|
width: Image width in pixels
|
||||||
|
height: Image height in pixels
|
||||||
|
color: Waveform color (hex)
|
||||||
|
background: Background color (hex)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PNG image as bytes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("matplotlib not available for waveform generation")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
if isinstance(audio, torch.Tensor):
|
||||||
|
audio = audio.numpy()
|
||||||
|
|
||||||
|
# Handle dimensions
|
||||||
|
if audio.ndim == 3:
|
||||||
|
audio = audio[0]
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio = audio.mean(axis=0) # Mix to mono for visualization
|
||||||
|
|
||||||
|
# Downsample for visualization
|
||||||
|
samples_per_pixel = max(1, len(audio) // width)
|
||||||
|
num_chunks = len(audio) // samples_per_pixel
|
||||||
|
|
||||||
|
if num_chunks > 0:
|
||||||
|
audio_chunks = audio[:num_chunks * samples_per_pixel].reshape(
|
||||||
|
num_chunks, samples_per_pixel
|
||||||
|
)
|
||||||
|
# Get min/max for each chunk
|
||||||
|
mins = audio_chunks.min(axis=1)
|
||||||
|
maxs = audio_chunks.max(axis=1)
|
||||||
|
else:
|
||||||
|
mins = maxs = audio
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
|
||||||
|
fig.patch.set_facecolor(background)
|
||||||
|
ax.set_facecolor(background)
|
||||||
|
|
||||||
|
# Plot waveform
|
||||||
|
x = np.arange(len(mins))
|
||||||
|
ax.fill_between(x, mins, maxs, color=color, alpha=0.7)
|
||||||
|
ax.axhline(y=0, color=color, alpha=0.3, linewidth=0.5)
|
||||||
|
|
||||||
|
# Style
|
||||||
|
ax.set_xlim(0, len(mins))
|
||||||
|
ax.set_ylim(-1, 1)
|
||||||
|
ax.axis('off')
|
||||||
|
plt.tight_layout(pad=0)
|
||||||
|
|
||||||
|
# Save to bytes
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', facecolor=background, edgecolor='none')
|
||||||
|
plt.close(fig)
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
return buf.read()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_spectrogram(
|
||||||
|
audio: Union[torch.Tensor, np.ndarray],
|
||||||
|
sample_rate: int,
|
||||||
|
width: int = 800,
|
||||||
|
height: int = 200,
|
||||||
|
colormap: str = "magma",
|
||||||
|
) -> bytes:
|
||||||
|
"""Generate spectrogram image as PNG bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data
|
||||||
|
sample_rate: Sample rate in Hz
|
||||||
|
width: Image width
|
||||||
|
height: Image height
|
||||||
|
colormap: Matplotlib colormap name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PNG image as bytes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import librosa
|
||||||
|
import librosa.display
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("matplotlib/librosa not available for spectrogram")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
if isinstance(audio, torch.Tensor):
|
||||||
|
audio = audio.numpy()
|
||||||
|
|
||||||
|
# Handle dimensions
|
||||||
|
if audio.ndim == 3:
|
||||||
|
audio = audio[0]
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio = audio.mean(axis=0)
|
||||||
|
|
||||||
|
# Compute mel spectrogram
|
||||||
|
S = librosa.feature.melspectrogram(
|
||||||
|
y=audio,
|
||||||
|
sr=sample_rate,
|
||||||
|
n_mels=128,
|
||||||
|
fmax=sample_rate // 2,
|
||||||
|
)
|
||||||
|
S_db = librosa.power_to_db(S, ref=np.max)
|
||||||
|
|
||||||
|
# Create figure
|
||||||
|
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
|
||||||
|
|
||||||
|
librosa.display.specshow(
|
||||||
|
S_db,
|
||||||
|
sr=sample_rate,
|
||||||
|
x_axis='time',
|
||||||
|
y_axis='mel',
|
||||||
|
cmap=colormap,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.axis('off')
|
||||||
|
plt.tight_layout(pad=0)
|
||||||
|
|
||||||
|
# Save to bytes
|
||||||
|
buf = io.BytesIO()
|
||||||
|
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
||||||
|
plt.close(fig)
|
||||||
|
buf.seek(0)
|
||||||
|
|
||||||
|
return buf.read()
|
||||||
|
|
||||||
|
|
||||||
|
def save_audio(
|
||||||
|
audio: Union[torch.Tensor, np.ndarray],
|
||||||
|
sample_rate: int,
|
||||||
|
path: Path,
|
||||||
|
format: str = "wav",
|
||||||
|
normalize: bool = True,
|
||||||
|
target_db: float = -14.0,
|
||||||
|
) -> Path:
|
||||||
|
"""Save audio to file with optional normalization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio data
|
||||||
|
sample_rate: Sample rate
|
||||||
|
path: Output path (extension will be added if needed)
|
||||||
|
format: Output format (wav, mp3, flac, ogg)
|
||||||
|
normalize: Whether to normalize audio
|
||||||
|
target_db: Normalization target
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved file
|
||||||
|
"""
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
if isinstance(audio, torch.Tensor):
|
||||||
|
audio = audio.numpy()
|
||||||
|
|
||||||
|
# Handle batch dimension
|
||||||
|
if audio.ndim == 3:
|
||||||
|
audio = audio[0]
|
||||||
|
|
||||||
|
# Normalize if requested
|
||||||
|
if normalize:
|
||||||
|
audio = normalize_audio(audio, target_db=target_db)
|
||||||
|
|
||||||
|
# Transpose for soundfile [samples, channels]
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio = audio.T
|
||||||
|
|
||||||
|
# Ensure correct extension
|
||||||
|
path = Path(path)
|
||||||
|
if not path.suffix:
|
||||||
|
path = path.with_suffix(f".{format}")
|
||||||
|
|
||||||
|
# Save based on format
|
||||||
|
if format in ("wav", "flac"):
|
||||||
|
sf.write(path, audio, sample_rate)
|
||||||
|
elif format == "mp3":
|
||||||
|
# Use scipy.io.wavfile then convert with pydub if available
|
||||||
|
try:
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
# Save as WAV first
|
||||||
|
wav_path = path.with_suffix(".wav")
|
||||||
|
sf.write(wav_path, audio, sample_rate)
|
||||||
|
|
||||||
|
# Convert to MP3
|
||||||
|
sound = AudioSegment.from_wav(wav_path)
|
||||||
|
sound.export(path, format="mp3", bitrate="320k")
|
||||||
|
|
||||||
|
# Remove temp WAV
|
||||||
|
wav_path.unlink()
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("pydub not available, saving as WAV instead")
|
||||||
|
path = path.with_suffix(".wav")
|
||||||
|
sf.write(path, audio, sample_rate)
|
||||||
|
elif format == "ogg":
|
||||||
|
sf.write(path, audio, sample_rate, format="ogg", subtype="vorbis")
|
||||||
|
else:
|
||||||
|
# Default to WAV
|
||||||
|
path = path.with_suffix(".wav")
|
||||||
|
sf.write(path, audio, sample_rate)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(
|
||||||
|
path: Path,
|
||||||
|
target_sr: Optional[int] = None,
|
||||||
|
mono: bool = False,
|
||||||
|
) -> Tuple[np.ndarray, int]:
|
||||||
|
"""Load audio from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to audio file
|
||||||
|
target_sr: Target sample rate (None to keep original)
|
||||||
|
mono: Convert to mono
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (audio_array, sample_rate)
|
||||||
|
"""
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
audio, sr = sf.read(path)
|
||||||
|
|
||||||
|
# Convert to [channels, samples] format
|
||||||
|
if audio.ndim == 1:
|
||||||
|
audio = audio[np.newaxis, :]
|
||||||
|
else:
|
||||||
|
audio = audio.T
|
||||||
|
|
||||||
|
# Convert to mono
|
||||||
|
if mono and audio.shape[0] > 1:
|
||||||
|
audio = audio.mean(axis=0, keepdims=True)
|
||||||
|
|
||||||
|
# Resample if needed
|
||||||
|
if target_sr and target_sr != sr:
|
||||||
|
audio = convert_sample_rate(audio, sr, target_sr)
|
||||||
|
sr = target_sr
|
||||||
|
|
||||||
|
return audio, sr
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(path: Path) -> dict:
|
||||||
|
"""Get audio file information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to audio file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with audio info
|
||||||
|
"""
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
info = sf.info(path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"path": str(path),
|
||||||
|
"duration": info.duration,
|
||||||
|
"sample_rate": info.samplerate,
|
||||||
|
"channels": info.channels,
|
||||||
|
"format": info.format,
|
||||||
|
"subtype": info.subtype,
|
||||||
|
"frames": info.frames,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def trim_silence(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
threshold_db: float = -40.0,
|
||||||
|
min_silence_ms: int = 100,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Trim silence from start and end of audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio array
|
||||||
|
sample_rate: Sample rate
|
||||||
|
threshold_db: Silence threshold in dB
|
||||||
|
min_silence_ms: Minimum silence duration to trim
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Trimmed audio
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
if audio.ndim == 2:
|
||||||
|
# Process mono for trimming
|
||||||
|
mono = audio.mean(axis=0)
|
||||||
|
else:
|
||||||
|
mono = audio
|
||||||
|
|
||||||
|
# Get non-silent intervals
|
||||||
|
intervals = librosa.effects.split(
|
||||||
|
mono,
|
||||||
|
top_db=abs(threshold_db),
|
||||||
|
frame_length=int(sample_rate * min_silence_ms / 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(intervals) == 0:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
start = intervals[0][0]
|
||||||
|
end = intervals[-1][1]
|
||||||
|
|
||||||
|
if audio.ndim == 2:
|
||||||
|
return audio[:, start:end]
|
||||||
|
return audio[start:end]
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("librosa not available for silence trimming")
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def apply_fade(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
fade_in_ms: float = 0,
|
||||||
|
fade_out_ms: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Apply fade in/out to audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio: Audio array [channels, samples] or [samples]
|
||||||
|
sample_rate: Sample rate
|
||||||
|
fade_in_ms: Fade in duration in milliseconds
|
||||||
|
fade_out_ms: Fade out duration in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Audio with fades applied
|
||||||
|
"""
|
||||||
|
audio = audio.copy()
|
||||||
|
|
||||||
|
if fade_in_ms > 0:
|
||||||
|
fade_in_samples = int(sample_rate * fade_in_ms / 1000)
|
||||||
|
fade_in_samples = min(fade_in_samples, audio.shape[-1])
|
||||||
|
fade_in_curve = np.linspace(0, 1, fade_in_samples)
|
||||||
|
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio[:, :fade_in_samples] *= fade_in_curve
|
||||||
|
else:
|
||||||
|
audio[:fade_in_samples] *= fade_in_curve
|
||||||
|
|
||||||
|
if fade_out_ms > 0:
|
||||||
|
fade_out_samples = int(sample_rate * fade_out_ms / 1000)
|
||||||
|
fade_out_samples = min(fade_out_samples, audio.shape[-1])
|
||||||
|
fade_out_curve = np.linspace(1, 0, fade_out_samples)
|
||||||
|
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio[:, -fade_out_samples:] *= fade_out_curve
|
||||||
|
else:
|
||||||
|
audio[-fade_out_samples:] *= fade_out_curve
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def concatenate_audio(
|
||||||
|
audio_list: list[np.ndarray],
|
||||||
|
sample_rate: int,
|
||||||
|
crossfade_ms: float = 0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Concatenate multiple audio segments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_list: List of audio arrays
|
||||||
|
sample_rate: Sample rate (must be same for all)
|
||||||
|
crossfade_ms: Crossfade duration between segments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated audio
|
||||||
|
"""
|
||||||
|
if not audio_list:
|
||||||
|
return np.array([])
|
||||||
|
|
||||||
|
if len(audio_list) == 1:
|
||||||
|
return audio_list[0]
|
||||||
|
|
||||||
|
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
|
||||||
|
|
||||||
|
result = audio_list[0]
|
||||||
|
|
||||||
|
for audio in audio_list[1:]:
|
||||||
|
if crossfade_samples > 0 and crossfade_samples < min(
|
||||||
|
result.shape[-1], audio.shape[-1]
|
||||||
|
):
|
||||||
|
# Apply crossfade
|
||||||
|
fade_out = np.linspace(1, 0, crossfade_samples)
|
||||||
|
fade_in = np.linspace(0, 1, crossfade_samples)
|
||||||
|
|
||||||
|
if result.ndim == 2:
|
||||||
|
# Overlap region
|
||||||
|
result[:, -crossfade_samples:] *= fade_out
|
||||||
|
overlap = result[:, -crossfade_samples:] + audio[:, :crossfade_samples] * fade_in
|
||||||
|
result = np.concatenate([
|
||||||
|
result[:, :-crossfade_samples],
|
||||||
|
overlap,
|
||||||
|
audio[:, crossfade_samples:]
|
||||||
|
], axis=1)
|
||||||
|
else:
|
||||||
|
result[-crossfade_samples:] *= fade_out
|
||||||
|
overlap = result[-crossfade_samples:] + audio[:crossfade_samples] * fade_in
|
||||||
|
result = np.concatenate([
|
||||||
|
result[:-crossfade_samples],
|
||||||
|
overlap,
|
||||||
|
audio[crossfade_samples:]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Simple concatenation
|
||||||
|
result = np.concatenate([result, audio], axis=-1)
|
||||||
|
|
||||||
|
return result
|
||||||
247
src/core/base_model.py
Normal file
247
src/core/base_model.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
"""Abstract base classes for AudioCraft model adapters."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningType(str, Enum):
|
||||||
|
"""Types of conditioning supported by models."""
|
||||||
|
|
||||||
|
TEXT = "text"
|
||||||
|
MELODY = "melody"
|
||||||
|
STYLE = "style"
|
||||||
|
CHORDS = "chords"
|
||||||
|
DRUMS = "drums"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationRequest:
|
||||||
|
"""Request parameters for audio generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompts: List of text prompts for generation
|
||||||
|
duration: Target duration in seconds
|
||||||
|
temperature: Sampling temperature (higher = more random)
|
||||||
|
top_k: Top-k sampling parameter
|
||||||
|
top_p: Nucleus sampling parameter (0 = disabled)
|
||||||
|
cfg_coef: Classifier-free guidance coefficient
|
||||||
|
batch_size: Number of samples to generate per prompt
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
conditioning: Optional conditioning data
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompts: list[str]
|
||||||
|
duration: float = 10.0
|
||||||
|
temperature: float = 1.0
|
||||||
|
top_k: int = 250
|
||||||
|
top_p: float = 0.0
|
||||||
|
cfg_coef: float = 3.0
|
||||||
|
batch_size: int = 1
|
||||||
|
seed: Optional[int] = None
|
||||||
|
conditioning: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validate request parameters."""
|
||||||
|
if not self.prompts:
|
||||||
|
raise ValueError("At least one prompt is required")
|
||||||
|
if self.duration <= 0:
|
||||||
|
raise ValueError("Duration must be positive")
|
||||||
|
if self.temperature < 0:
|
||||||
|
raise ValueError("Temperature must be non-negative")
|
||||||
|
if self.top_k < 0:
|
||||||
|
raise ValueError("top_k must be non-negative")
|
||||||
|
if not 0 <= self.top_p <= 1:
|
||||||
|
raise ValueError("top_p must be between 0 and 1")
|
||||||
|
if self.cfg_coef < 1:
|
||||||
|
raise ValueError("cfg_coef must be >= 1")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationResult:
|
||||||
|
"""Result of audio generation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
audio: Generated audio tensor (shape: [batch, channels, samples])
|
||||||
|
sample_rate: Audio sample rate in Hz
|
||||||
|
duration: Actual duration in seconds
|
||||||
|
model_id: ID of the model used
|
||||||
|
variant: Model variant used
|
||||||
|
parameters: Generation parameters used
|
||||||
|
seed: Actual seed used (for reproducibility)
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio: torch.Tensor
|
||||||
|
sample_rate: int
|
||||||
|
duration: float
|
||||||
|
model_id: str
|
||||||
|
variant: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
seed: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
"""Number of audio samples generated."""
|
||||||
|
return self.audio.shape[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels(self) -> int:
|
||||||
|
"""Number of audio channels."""
|
||||||
|
return self.audio.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_frames(self) -> int:
|
||||||
|
"""Number of audio frames."""
|
||||||
|
return self.audio.shape[2]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAudioModel(ABC):
|
||||||
|
"""Abstract base class for AudioCraft model adapters.
|
||||||
|
|
||||||
|
All model adapters must implement this interface to integrate with
|
||||||
|
the model registry and generation service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model_id(self) -> str:
|
||||||
|
"""Unique identifier for this model family (e.g., 'musicgen')."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def variant(self) -> str:
|
||||||
|
"""Current model variant (e.g., 'medium', 'large')."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def display_name(self) -> str:
|
||||||
|
"""Human-readable name for UI display."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def description(self) -> str:
|
||||||
|
"""Brief description of the model's capabilities."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
"""Estimated VRAM usage when loaded (in megabytes)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
"""Maximum supported generation duration in seconds."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
"""Output audio sample rate in Hz."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
"""List of conditioning types supported by this model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Whether the model is currently loaded in memory."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
"""Device the model is loaded on, or None if not loaded."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the model into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If loading fails
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory.
|
||||||
|
|
||||||
|
Should be idempotent - safe to call even if not loaded.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate audio based on the request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters and prompts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult containing audio and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model is not loaded
|
||||||
|
ValueError: If request parameters are invalid for this model
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters for this model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of parameter names to default values
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def validate_request(self, request: GenerationRequest) -> None:
|
||||||
|
"""Validate a generation request for this model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Request to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If request is invalid for this model
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
|
||||||
|
|
||||||
|
if request.duration > self.max_duration:
|
||||||
|
raise ValueError(
|
||||||
|
f"Duration {request.duration}s exceeds maximum {self.max_duration}s "
|
||||||
|
f"for {self.model_id}/{self.variant}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check conditioning requirements
|
||||||
|
for cond_type, cond_data in request.conditioning.items():
|
||||||
|
if cond_data is not None:
|
||||||
|
try:
|
||||||
|
cond_enum = ConditioningType(cond_type)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f"Unknown conditioning type: {cond_type}")
|
||||||
|
|
||||||
|
if cond_enum not in self.supports_conditioning:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {self.model_id}/{self.variant} does not support "
|
||||||
|
f"{cond_type} conditioning"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
"""String representation."""
|
||||||
|
loaded = "loaded" if self.is_loaded else "not loaded"
|
||||||
|
return f"<{self.__class__.__name__} {self.model_id}/{self.variant} ({loaded})>"
|
||||||
433
src/core/gpu_manager.py
Normal file
433
src/core/gpu_manager.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""GPU memory management for AudioCraft models."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VRAMBudget:
|
||||||
|
"""VRAM budget allocation information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
total_mb: Total VRAM in megabytes
|
||||||
|
used_mb: Currently used VRAM
|
||||||
|
free_mb: Free VRAM
|
||||||
|
reserved_comfyui_mb: VRAM reserved for ComfyUI
|
||||||
|
safety_buffer_mb: Safety buffer to prevent OOM
|
||||||
|
available_mb: VRAM available for AudioCraft models
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_mb: int
|
||||||
|
used_mb: int
|
||||||
|
free_mb: int
|
||||||
|
reserved_comfyui_mb: int
|
||||||
|
safety_buffer_mb: int
|
||||||
|
available_mb: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def utilization(self) -> float:
|
||||||
|
"""Current VRAM utilization as a fraction (0-1)."""
|
||||||
|
return self.used_mb / self.total_mb if self.total_mb > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPUState:
|
||||||
|
"""State information for inter-service coordination."""
|
||||||
|
|
||||||
|
timestamp: float
|
||||||
|
service: str # "audiocraft" or "comfyui"
|
||||||
|
vram_used_mb: int
|
||||||
|
vram_requested_mb: int
|
||||||
|
status: str # "idle", "working", "requesting_priority", "yielded"
|
||||||
|
|
||||||
|
|
||||||
|
class GPUMemoryManager:
|
||||||
|
"""Manages GPU memory allocation and coordination with ComfyUI.
|
||||||
|
|
||||||
|
Uses pynvml for accurate system-wide VRAM tracking and file-based
|
||||||
|
IPC for coordination with ComfyUI running on the same system.
|
||||||
|
"""
|
||||||
|
|
||||||
|
COORDINATION_FILE = Path("/tmp/audiocraft_comfyui_coord.json")
|
||||||
|
LOCK_FILE = Path("/tmp/audiocraft_comfyui_coord.lock")
|
||||||
|
STALE_THRESHOLD = 30.0 # seconds
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device_id: int = 0,
|
||||||
|
comfyui_reserve_gb: float = 10.0,
|
||||||
|
safety_buffer_gb: float = 1.0,
|
||||||
|
):
|
||||||
|
"""Initialize GPU memory manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_id: CUDA device index
|
||||||
|
comfyui_reserve_gb: VRAM to reserve for ComfyUI (gigabytes)
|
||||||
|
safety_buffer_gb: Safety buffer to prevent OOM (gigabytes)
|
||||||
|
"""
|
||||||
|
self.device_id = device_id
|
||||||
|
self.device = torch.device(f"cuda:{device_id}")
|
||||||
|
self.comfyui_reserve_mb = int(comfyui_reserve_gb * 1024)
|
||||||
|
self.safety_buffer_mb = int(safety_buffer_gb * 1024)
|
||||||
|
|
||||||
|
# Initialize NVML for direct GPU monitoring
|
||||||
|
self._nvml_initialized = False
|
||||||
|
self._nvml_handle = None
|
||||||
|
self._init_nvml()
|
||||||
|
|
||||||
|
# Threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
|
||||||
|
# Callbacks for memory events
|
||||||
|
self._low_memory_callbacks: list[Callable[[VRAMBudget], None]] = []
|
||||||
|
self._oom_callbacks: list[Callable[[], None]] = []
|
||||||
|
|
||||||
|
# Initialize coordination file
|
||||||
|
self._ensure_coordination_file()
|
||||||
|
|
||||||
|
def _init_nvml(self) -> None:
|
||||||
|
"""Initialize NVML for GPU monitoring."""
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
|
||||||
|
self._nvml_initialized = True
|
||||||
|
logger.info("NVML initialized successfully")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("pynvml not available, falling back to torch.cuda")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to initialize NVML: {e}, falling back to torch.cuda")
|
||||||
|
|
||||||
|
def _ensure_coordination_file(self) -> None:
|
||||||
|
"""Create coordination file if it doesn't exist."""
|
||||||
|
if not self.COORDINATION_FILE.exists():
|
||||||
|
initial_state = {
|
||||||
|
"audiocraft": None,
|
||||||
|
"comfyui": None,
|
||||||
|
"priority": None,
|
||||||
|
"last_update": time.time(),
|
||||||
|
}
|
||||||
|
self._write_coordination_state(initial_state)
|
||||||
|
|
||||||
|
def get_memory_info(self) -> dict[str, int]:
|
||||||
|
"""Get current GPU memory status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with memory values in megabytes:
|
||||||
|
- total: Total VRAM
|
||||||
|
- used: Used VRAM (system-wide)
|
||||||
|
- free: Free VRAM
|
||||||
|
- torch_allocated: PyTorch allocated memory
|
||||||
|
- torch_reserved: PyTorch reserved memory
|
||||||
|
- torch_cached: PyTorch cached memory
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._nvml_initialized:
|
||||||
|
return self._get_memory_info_nvml()
|
||||||
|
return self._get_memory_info_torch()
|
||||||
|
|
||||||
|
def _get_memory_info_nvml(self) -> dict[str, int]:
|
||||||
|
"""Get memory info using NVML (more accurate)."""
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle)
|
||||||
|
torch_allocated = torch.cuda.memory_allocated(self.device)
|
||||||
|
torch_reserved = torch.cuda.memory_reserved(self.device)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": info.total // (1024 * 1024),
|
||||||
|
"used": info.used // (1024 * 1024),
|
||||||
|
"free": info.free // (1024 * 1024),
|
||||||
|
"torch_allocated": torch_allocated // (1024 * 1024),
|
||||||
|
"torch_reserved": torch_reserved // (1024 * 1024),
|
||||||
|
"torch_cached": (torch_reserved - torch_allocated) // (1024 * 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_memory_info_torch(self) -> dict[str, int]:
|
||||||
|
"""Get memory info using torch.cuda (fallback)."""
|
||||||
|
props = torch.cuda.get_device_properties(self.device)
|
||||||
|
allocated = torch.cuda.memory_allocated(self.device)
|
||||||
|
reserved = torch.cuda.memory_reserved(self.device)
|
||||||
|
|
||||||
|
# Note: This is less accurate for system-wide usage
|
||||||
|
return {
|
||||||
|
"total": props.total_memory // (1024 * 1024),
|
||||||
|
"used": reserved // (1024 * 1024),
|
||||||
|
"free": (props.total_memory - reserved) // (1024 * 1024),
|
||||||
|
"torch_allocated": allocated // (1024 * 1024),
|
||||||
|
"torch_reserved": reserved // (1024 * 1024),
|
||||||
|
"torch_cached": (reserved - allocated) // (1024 * 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_available_budget(self) -> VRAMBudget:
|
||||||
|
"""Calculate available VRAM budget considering ComfyUI.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VRAMBudget with current allocation information
|
||||||
|
"""
|
||||||
|
mem = self.get_memory_info()
|
||||||
|
|
||||||
|
# Check ComfyUI's actual usage via coordination file
|
||||||
|
comfyui_state = self.get_comfyui_status()
|
||||||
|
if comfyui_state and comfyui_state.status != "yielded":
|
||||||
|
# Use actual ComfyUI usage + buffer, or reserve, whichever is higher
|
||||||
|
effective_comfyui_reserve = max(
|
||||||
|
self.comfyui_reserve_mb,
|
||||||
|
comfyui_state.vram_used_mb + 2048, # 2GB headroom
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_comfyui_reserve = self.comfyui_reserve_mb
|
||||||
|
|
||||||
|
available = max(
|
||||||
|
0,
|
||||||
|
mem["total"]
|
||||||
|
- mem["used"]
|
||||||
|
+ mem["torch_allocated"] # Our own usage doesn't count against us
|
||||||
|
- effective_comfyui_reserve
|
||||||
|
- self.safety_buffer_mb,
|
||||||
|
)
|
||||||
|
|
||||||
|
return VRAMBudget(
|
||||||
|
total_mb=mem["total"],
|
||||||
|
used_mb=mem["used"],
|
||||||
|
free_mb=mem["free"],
|
||||||
|
reserved_comfyui_mb=effective_comfyui_reserve,
|
||||||
|
safety_buffer_mb=self.safety_buffer_mb,
|
||||||
|
available_mb=available,
|
||||||
|
)
|
||||||
|
|
||||||
|
def can_load_model(self, vram_required_mb: int) -> tuple[bool, str]:
|
||||||
|
"""Check if a model can fit in available VRAM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vram_required_mb: VRAM needed by the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (can_load, reason_message)
|
||||||
|
"""
|
||||||
|
budget = self.get_available_budget()
|
||||||
|
|
||||||
|
if vram_required_mb <= budget.available_mb:
|
||||||
|
return True, "Sufficient VRAM available"
|
||||||
|
|
||||||
|
deficit = vram_required_mb - budget.available_mb
|
||||||
|
return False, (
|
||||||
|
f"Insufficient VRAM: need {vram_required_mb}MB, "
|
||||||
|
f"available {budget.available_mb}MB (deficit: {deficit}MB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def force_cleanup(self) -> int:
|
||||||
|
"""Force GPU memory cleanup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Freed memory in megabytes (approximate)
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
before = self.get_memory_info()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize(self.device)
|
||||||
|
|
||||||
|
after = self.get_memory_info()
|
||||||
|
freed = before["torch_reserved"] - after["torch_reserved"]
|
||||||
|
|
||||||
|
if freed > 0:
|
||||||
|
logger.info(f"Freed {freed}MB of GPU memory")
|
||||||
|
|
||||||
|
return freed
|
||||||
|
|
||||||
|
def get_status(self) -> dict[str, Any]:
|
||||||
|
"""Get detailed GPU status for UI display.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with status information
|
||||||
|
"""
|
||||||
|
mem = self.get_memory_info()
|
||||||
|
budget = self.get_available_budget()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"device": str(self.device),
|
||||||
|
"total_gb": round(mem["total"] / 1024, 2),
|
||||||
|
"used_gb": round(mem["used"] / 1024, 2),
|
||||||
|
"free_gb": round(mem["free"] / 1024, 2),
|
||||||
|
"utilization_percent": round(budget.utilization * 100, 1),
|
||||||
|
"available_for_models_gb": round(budget.available_mb / 1024, 2),
|
||||||
|
"comfyui_reserve_gb": round(budget.reserved_comfyui_mb / 1024, 2),
|
||||||
|
"torch_allocated_gb": round(mem["torch_allocated"] / 1024, 2),
|
||||||
|
"torch_cached_gb": round(mem["torch_cached"] / 1024, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ComfyUI Coordination Methods
|
||||||
|
|
||||||
|
def _read_coordination_state(self) -> dict[str, Any]:
|
||||||
|
"""Read coordination state from file."""
|
||||||
|
try:
|
||||||
|
if self.COORDINATION_FILE.exists():
|
||||||
|
return json.loads(self.COORDINATION_FILE.read_text())
|
||||||
|
except (json.JSONDecodeError, IOError) as e:
|
||||||
|
logger.warning(f"Failed to read coordination file: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _write_coordination_state(self, state: dict[str, Any]) -> None:
|
||||||
|
"""Write coordination state to file with locking."""
|
||||||
|
import fcntl
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.LOCK_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.LOCK_FILE, "w") as lock:
|
||||||
|
fcntl.flock(lock, fcntl.LOCK_EX)
|
||||||
|
try:
|
||||||
|
self.COORDINATION_FILE.write_text(json.dumps(state, indent=2))
|
||||||
|
finally:
|
||||||
|
fcntl.flock(lock, fcntl.LOCK_UN)
|
||||||
|
except IOError as e:
|
||||||
|
logger.warning(f"Failed to write coordination file: {e}")
|
||||||
|
|
||||||
|
def update_status(
|
||||||
|
self,
|
||||||
|
vram_used_mb: int,
|
||||||
|
vram_requested_mb: int = 0,
|
||||||
|
status: str = "idle",
|
||||||
|
) -> None:
|
||||||
|
"""Update AudioCraft's status in coordination file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vram_used_mb: Current VRAM usage
|
||||||
|
vram_requested_mb: VRAM needed for pending operation
|
||||||
|
status: Current status ("idle", "working", "requesting_priority")
|
||||||
|
"""
|
||||||
|
state = self._read_coordination_state()
|
||||||
|
state["audiocraft"] = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"service": "audiocraft",
|
||||||
|
"vram_used_mb": vram_used_mb,
|
||||||
|
"vram_requested_mb": vram_requested_mb,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
state["last_update"] = time.time()
|
||||||
|
self._write_coordination_state(state)
|
||||||
|
|
||||||
|
def get_comfyui_status(self) -> Optional[GPUState]:
|
||||||
|
"""Get ComfyUI's current status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPUState if ComfyUI is active and status is fresh, None otherwise
|
||||||
|
"""
|
||||||
|
state = self._read_coordination_state()
|
||||||
|
comfyui_data = state.get("comfyui")
|
||||||
|
|
||||||
|
if not comfyui_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if stale
|
||||||
|
if time.time() - comfyui_data.get("timestamp", 0) > self.STALE_THRESHOLD:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return GPUState(
|
||||||
|
timestamp=comfyui_data["timestamp"],
|
||||||
|
service="comfyui",
|
||||||
|
vram_used_mb=comfyui_data.get("vram_used_mb", 0),
|
||||||
|
vram_requested_mb=comfyui_data.get("vram_requested_mb", 0),
|
||||||
|
status=comfyui_data.get("status", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def request_priority(self, vram_needed_mb: int, timeout: float = 30.0) -> bool:
|
||||||
|
"""Request VRAM priority from ComfyUI.
|
||||||
|
|
||||||
|
Signals ComfyUI to release VRAM if possible.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vram_needed_mb: Amount of VRAM needed
|
||||||
|
timeout: Seconds to wait for ComfyUI to yield
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if ComfyUI acknowledged and yielded, False otherwise
|
||||||
|
"""
|
||||||
|
state = self._read_coordination_state()
|
||||||
|
state["priority"] = {
|
||||||
|
"requester": "audiocraft",
|
||||||
|
"vram_needed_mb": vram_needed_mb,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
}
|
||||||
|
self._write_coordination_state(state)
|
||||||
|
|
||||||
|
logger.info(f"Requesting {vram_needed_mb}MB VRAM from ComfyUI...")
|
||||||
|
|
||||||
|
# Wait for ComfyUI to respond
|
||||||
|
start = time.time()
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
comfyui = self.get_comfyui_status()
|
||||||
|
if comfyui and comfyui.status == "yielded":
|
||||||
|
logger.info("ComfyUI yielded VRAM")
|
||||||
|
return True
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
logger.warning("ComfyUI did not yield VRAM within timeout")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_comfyui_busy(self) -> bool:
|
||||||
|
"""Check if ComfyUI is actively processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if ComfyUI is working, False otherwise
|
||||||
|
"""
|
||||||
|
status = self.get_comfyui_status()
|
||||||
|
return status is not None and status.status == "working"
|
||||||
|
|
||||||
|
# Callback Registration
|
||||||
|
|
||||||
|
def on_low_memory(self, callback: Callable[[VRAMBudget], None]) -> None:
|
||||||
|
"""Register callback for low memory warnings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call with budget info when memory is low
|
||||||
|
"""
|
||||||
|
self._low_memory_callbacks.append(callback)
|
||||||
|
|
||||||
|
def on_oom(self, callback: Callable[[], None]) -> None:
|
||||||
|
"""Register callback for OOM events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function to call when OOM occurs
|
||||||
|
"""
|
||||||
|
self._oom_callbacks.append(callback)
|
||||||
|
|
||||||
|
def check_memory_pressure(self, warning_threshold: float = 0.85) -> None:
|
||||||
|
"""Check memory pressure and trigger callbacks if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
warning_threshold: Utilization threshold for warnings (0-1)
|
||||||
|
"""
|
||||||
|
budget = self.get_available_budget()
|
||||||
|
|
||||||
|
if budget.utilization >= warning_threshold:
|
||||||
|
logger.warning(
|
||||||
|
f"High GPU memory pressure: {budget.utilization*100:.1f}% utilized"
|
||||||
|
)
|
||||||
|
for callback in self._low_memory_callbacks:
|
||||||
|
try:
|
||||||
|
callback(budget)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Low memory callback failed: {e}")
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""Cleanup NVML on destruction."""
|
||||||
|
if self._nvml_initialized:
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
pynvml.nvmlShutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
487
src/core/model_registry.py
Normal file
487
src/core/model_registry.py
Normal file
@@ -0,0 +1,487 @@
|
|||||||
|
"""Model registry for discovering and managing AudioCraft model adapters."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator, Optional, Type
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from src.core.base_model import BaseAudioModel, ConditioningType
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelVariantConfig:
|
||||||
|
"""Configuration for a model variant."""
|
||||||
|
|
||||||
|
hf_id: str
|
||||||
|
vram_mb: int
|
||||||
|
max_duration: float = 30.0
|
||||||
|
channels: int = 1
|
||||||
|
conditioning: list[str] = field(default_factory=list)
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelFamilyConfig:
|
||||||
|
"""Configuration for a model family."""
|
||||||
|
|
||||||
|
enabled: bool
|
||||||
|
display_name: str
|
||||||
|
description: str
|
||||||
|
default_variant: str
|
||||||
|
variants: dict[str, ModelVariantConfig]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelHandle:
|
||||||
|
"""Handle for a loaded model with reference counting."""
|
||||||
|
|
||||||
|
model: BaseAudioModel
|
||||||
|
model_id: str
|
||||||
|
variant: str
|
||||||
|
loaded_at: float
|
||||||
|
last_accessed: float
|
||||||
|
ref_count: int = 0
|
||||||
|
|
||||||
|
def touch(self) -> None:
|
||||||
|
"""Update last accessed time."""
|
||||||
|
self.last_accessed = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistry:
|
||||||
|
"""Central registry for discovering and managing model adapters.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Loading model configurations from YAML
|
||||||
|
- Lazy loading models on demand
|
||||||
|
- LRU eviction when VRAM is constrained
|
||||||
|
- Reference counting to prevent unloading during use
|
||||||
|
- Automatic idle timeout for unused models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_path: Path,
|
||||||
|
gpu_manager: GPUMemoryManager,
|
||||||
|
max_cached_models: int = 2,
|
||||||
|
idle_timeout_minutes: int = 15,
|
||||||
|
):
|
||||||
|
"""Initialize the model registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to models.yaml configuration
|
||||||
|
gpu_manager: GPU memory manager instance
|
||||||
|
max_cached_models: Maximum models to keep loaded
|
||||||
|
idle_timeout_minutes: Unload models after this idle time
|
||||||
|
"""
|
||||||
|
self.config_path = config_path
|
||||||
|
self.gpu_manager = gpu_manager
|
||||||
|
self.max_cached_models = max_cached_models
|
||||||
|
self.idle_timeout_seconds = idle_timeout_minutes * 60
|
||||||
|
|
||||||
|
# Model configurations
|
||||||
|
self._model_configs: dict[str, ModelFamilyConfig] = {}
|
||||||
|
self._default_params: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Loaded model handles
|
||||||
|
self._handles: dict[str, ModelHandle] = {} # Key: "model_id/variant"
|
||||||
|
self._access_order: list[str] = [] # LRU tracking
|
||||||
|
|
||||||
|
# Registered adapter classes
|
||||||
|
self._adapter_classes: dict[str, Type[BaseAudioModel]] = {}
|
||||||
|
|
||||||
|
# Threading
|
||||||
|
self._lock = threading.RLock()
|
||||||
|
self._cleanup_thread: Optional[threading.Thread] = None
|
||||||
|
self._stop_cleanup = threading.Event()
|
||||||
|
|
||||||
|
# Load configuration
|
||||||
|
self._load_config()
|
||||||
|
|
||||||
|
def _load_config(self) -> None:
|
||||||
|
"""Load model configurations from YAML file."""
|
||||||
|
if not self.config_path.exists():
|
||||||
|
logger.warning(f"Model config not found: {self.config_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(self.config_path) as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Parse model families
|
||||||
|
for model_id, model_config in config.get("models", {}).items():
|
||||||
|
if not model_config.get("enabled", True):
|
||||||
|
continue
|
||||||
|
|
||||||
|
variants = {}
|
||||||
|
for variant_name, variant_config in model_config.get("variants", {}).items():
|
||||||
|
variants[variant_name] = ModelVariantConfig(
|
||||||
|
hf_id=variant_config["hf_id"],
|
||||||
|
vram_mb=variant_config["vram_mb"],
|
||||||
|
max_duration=variant_config.get("max_duration", 30.0),
|
||||||
|
channels=variant_config.get("channels", 1),
|
||||||
|
conditioning=variant_config.get("conditioning", []),
|
||||||
|
description=variant_config.get("description", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._model_configs[model_id] = ModelFamilyConfig(
|
||||||
|
enabled=model_config.get("enabled", True),
|
||||||
|
display_name=model_config.get("display_name", model_id),
|
||||||
|
description=model_config.get("description", ""),
|
||||||
|
default_variant=model_config.get("default_variant", "medium"),
|
||||||
|
variants=variants,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse default generation parameters
|
||||||
|
self._default_params = config.get("defaults", {}).get("generation", {})
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(self._model_configs)} model families from config")
|
||||||
|
|
||||||
|
def register_adapter(
|
||||||
|
self, model_id: str, adapter_class: Type[BaseAudioModel]
|
||||||
|
) -> None:
|
||||||
|
"""Register a model adapter class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID (e.g., 'musicgen')
|
||||||
|
adapter_class: Adapter class implementing BaseAudioModel
|
||||||
|
"""
|
||||||
|
self._adapter_classes[model_id] = adapter_class
|
||||||
|
logger.debug(f"Registered adapter for {model_id}: {adapter_class.__name__}")
|
||||||
|
|
||||||
|
def list_models(self) -> list[dict[str, Any]]:
|
||||||
|
"""List all available models with their configurations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of model information dictionaries
|
||||||
|
"""
|
||||||
|
models = []
|
||||||
|
|
||||||
|
for model_id, config in self._model_configs.items():
|
||||||
|
for variant_name, variant in config.variants.items():
|
||||||
|
key = f"{model_id}/{variant_name}"
|
||||||
|
handle = self._handles.get(key)
|
||||||
|
|
||||||
|
can_load, reason = self.gpu_manager.can_load_model(variant.vram_mb)
|
||||||
|
|
||||||
|
models.append({
|
||||||
|
"model_id": model_id,
|
||||||
|
"variant": variant_name,
|
||||||
|
"display_name": config.display_name,
|
||||||
|
"description": variant.description or config.description,
|
||||||
|
"hf_id": variant.hf_id,
|
||||||
|
"vram_mb": variant.vram_mb,
|
||||||
|
"max_duration": variant.max_duration,
|
||||||
|
"channels": variant.channels,
|
||||||
|
"conditioning": variant.conditioning,
|
||||||
|
"is_default": variant_name == config.default_variant,
|
||||||
|
"is_loaded": handle is not None,
|
||||||
|
"can_load": can_load,
|
||||||
|
"load_reason": reason,
|
||||||
|
"has_adapter": model_id in self._adapter_classes,
|
||||||
|
})
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
def get_model_config(
|
||||||
|
self, model_id: str, variant: Optional[str] = None
|
||||||
|
) -> tuple[ModelFamilyConfig, ModelVariantConfig]:
|
||||||
|
"""Get configuration for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variant: Specific variant, or None for default
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (family_config, variant_config)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model or variant not found
|
||||||
|
"""
|
||||||
|
if model_id not in self._model_configs:
|
||||||
|
raise ValueError(f"Unknown model: {model_id}")
|
||||||
|
|
||||||
|
family = self._model_configs[model_id]
|
||||||
|
variant = variant or family.default_variant
|
||||||
|
|
||||||
|
if variant not in family.variants:
|
||||||
|
raise ValueError(f"Unknown variant {variant} for {model_id}")
|
||||||
|
|
||||||
|
return family, family.variants[variant]
|
||||||
|
|
||||||
|
def get_loaded_models(self) -> list[dict[str, Any]]:
|
||||||
|
"""Get information about currently loaded models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of loaded model information
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"model_id": handle.model_id,
|
||||||
|
"variant": handle.variant,
|
||||||
|
"loaded_at": handle.loaded_at,
|
||||||
|
"last_accessed": handle.last_accessed,
|
||||||
|
"ref_count": handle.ref_count,
|
||||||
|
"idle_seconds": time.time() - handle.last_accessed,
|
||||||
|
}
|
||||||
|
for handle in self._handles.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_model(
|
||||||
|
self, model_id: str, variant: Optional[str] = None
|
||||||
|
) -> Generator[BaseAudioModel, None, None]:
|
||||||
|
"""Get a model, loading it if necessary.
|
||||||
|
|
||||||
|
Context manager that handles reference counting to prevent
|
||||||
|
unloading during use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variant: Specific variant, or None for default
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Loaded model instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model not found or cannot be loaded
|
||||||
|
RuntimeError: If VRAM insufficient
|
||||||
|
"""
|
||||||
|
family, variant_config = self.get_model_config(model_id, variant)
|
||||||
|
variant = variant or family.default_variant
|
||||||
|
key = f"{model_id}/{variant}"
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Get or load model
|
||||||
|
if key not in self._handles:
|
||||||
|
self._load_model(model_id, variant)
|
||||||
|
|
||||||
|
handle = self._handles[key]
|
||||||
|
handle.ref_count += 1
|
||||||
|
handle.touch()
|
||||||
|
|
||||||
|
# Update LRU order
|
||||||
|
if key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
self._access_order.append(key)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield handle.model
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
handle.ref_count -= 1
|
||||||
|
|
||||||
|
def _load_model(self, model_id: str, variant: str) -> None:
|
||||||
|
"""Load a model into memory.
|
||||||
|
|
||||||
|
Must be called with self._lock held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variant: Variant to load
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no adapter registered
|
||||||
|
RuntimeError: If VRAM insufficient
|
||||||
|
"""
|
||||||
|
key = f"{model_id}/{variant}"
|
||||||
|
family, variant_config = self.get_model_config(model_id, variant)
|
||||||
|
|
||||||
|
# Check for adapter
|
||||||
|
if model_id not in self._adapter_classes:
|
||||||
|
raise ValueError(f"No adapter registered for {model_id}")
|
||||||
|
|
||||||
|
# Check VRAM
|
||||||
|
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||||
|
if not can_load:
|
||||||
|
# Try to free memory by evicting models
|
||||||
|
self._evict_for_space(variant_config.vram_mb)
|
||||||
|
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||||
|
if not can_load:
|
||||||
|
raise RuntimeError(reason)
|
||||||
|
|
||||||
|
# Create and load model
|
||||||
|
logger.info(f"Loading model {key}...")
|
||||||
|
adapter_class = self._adapter_classes[model_id]
|
||||||
|
model = adapter_class(variant=variant)
|
||||||
|
model.load()
|
||||||
|
|
||||||
|
# Register handle
|
||||||
|
self._handles[key] = ModelHandle(
|
||||||
|
model=model,
|
||||||
|
model_id=model_id,
|
||||||
|
variant=variant,
|
||||||
|
loaded_at=time.time(),
|
||||||
|
last_accessed=time.time(),
|
||||||
|
)
|
||||||
|
self._access_order.append(key)
|
||||||
|
|
||||||
|
# Update GPU status
|
||||||
|
mem = self.gpu_manager.get_memory_info()
|
||||||
|
self.gpu_manager.update_status(mem["torch_allocated"], status="working")
|
||||||
|
|
||||||
|
logger.info(f"Model {key} loaded successfully")
|
||||||
|
|
||||||
|
def _evict_for_space(self, needed_mb: int) -> bool:
|
||||||
|
"""Evict models to free up VRAM.
|
||||||
|
|
||||||
|
Must be called with self._lock held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
needed_mb: VRAM needed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if enough space was freed
|
||||||
|
"""
|
||||||
|
freed = 0
|
||||||
|
budget = self.gpu_manager.get_available_budget()
|
||||||
|
deficit = needed_mb - budget.available_mb
|
||||||
|
|
||||||
|
if deficit <= 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Evict LRU models that have no active references
|
||||||
|
for key in list(self._access_order):
|
||||||
|
if deficit <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
handle = self._handles.get(key)
|
||||||
|
if handle and handle.ref_count == 0:
|
||||||
|
_, variant_config = self.get_model_config(
|
||||||
|
handle.model_id, handle.variant
|
||||||
|
)
|
||||||
|
logger.info(f"Evicting {key} to free {variant_config.vram_mb}MB")
|
||||||
|
self._unload_model(key)
|
||||||
|
freed += variant_config.vram_mb
|
||||||
|
deficit -= variant_config.vram_mb
|
||||||
|
|
||||||
|
self.gpu_manager.force_cleanup()
|
||||||
|
return deficit <= 0
|
||||||
|
|
||||||
|
def _unload_model(self, key: str) -> None:
|
||||||
|
"""Unload a model from memory.
|
||||||
|
|
||||||
|
Must be called with self._lock held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Model key (model_id/variant)
|
||||||
|
"""
|
||||||
|
if key not in self._handles:
|
||||||
|
return
|
||||||
|
|
||||||
|
handle = self._handles[key]
|
||||||
|
if handle.ref_count > 0:
|
||||||
|
logger.warning(f"Cannot unload {key}: {handle.ref_count} active references")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading model {key}")
|
||||||
|
handle.model.unload()
|
||||||
|
del self._handles[key]
|
||||||
|
|
||||||
|
if key in self._access_order:
|
||||||
|
self._access_order.remove(key)
|
||||||
|
|
||||||
|
self.gpu_manager.force_cleanup()
|
||||||
|
|
||||||
|
def unload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
|
||||||
|
"""Manually unload a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variant: Variant to unload, or None for all variants
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model was unloaded
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if variant:
|
||||||
|
key = f"{model_id}/{variant}"
|
||||||
|
if key in self._handles:
|
||||||
|
self._unload_model(key)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Unload all variants of this model
|
||||||
|
keys = [k for k in self._handles if k.startswith(f"{model_id}/")]
|
||||||
|
for key in keys:
|
||||||
|
self._unload_model(key)
|
||||||
|
return bool(keys)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def preload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
|
||||||
|
"""Preload a model into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variant: Variant to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model was loaded successfully
|
||||||
|
"""
|
||||||
|
family, _ = self.get_model_config(model_id, variant)
|
||||||
|
variant = variant or family.default_variant
|
||||||
|
key = f"{model_id}/{variant}"
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if key in self._handles:
|
||||||
|
return True # Already loaded
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._load_model(model_id, variant)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to preload {key}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def start_cleanup_thread(self) -> None:
|
||||||
|
"""Start background thread for idle model cleanup."""
|
||||||
|
if self._cleanup_thread is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def cleanup_loop():
|
||||||
|
while not self._stop_cleanup.is_set():
|
||||||
|
self._cleanup_idle_models()
|
||||||
|
self._stop_cleanup.wait(60) # Check every minute
|
||||||
|
|
||||||
|
self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
||||||
|
self._cleanup_thread.start()
|
||||||
|
logger.info("Started model cleanup thread")
|
||||||
|
|
||||||
|
def stop_cleanup_thread(self) -> None:
|
||||||
|
"""Stop the background cleanup thread."""
|
||||||
|
if self._cleanup_thread is not None:
|
||||||
|
self._stop_cleanup.set()
|
||||||
|
self._cleanup_thread.join(timeout=5)
|
||||||
|
self._cleanup_thread = None
|
||||||
|
self._stop_cleanup.clear()
|
||||||
|
|
||||||
|
def _cleanup_idle_models(self) -> None:
|
||||||
|
"""Unload models that have been idle too long."""
|
||||||
|
with self._lock:
|
||||||
|
now = time.time()
|
||||||
|
for key, handle in list(self._handles.items()):
|
||||||
|
idle_time = now - handle.last_accessed
|
||||||
|
if idle_time > self.idle_timeout_seconds and handle.ref_count == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Unloading idle model {key} (idle for {idle_time/60:.1f} min)"
|
||||||
|
)
|
||||||
|
self._unload_model(key)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of default parameter values
|
||||||
|
"""
|
||||||
|
return self._default_params.copy()
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""Cleanup on destruction."""
|
||||||
|
self.stop_cleanup_thread()
|
||||||
297
src/core/oom_handler.py
Normal file
297
src/core/oom_handler.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""OOM (Out of Memory) handling and recovery strategies."""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class OOMRecoveryError(Exception):
|
||||||
|
"""Raised when OOM recovery fails after all strategies exhausted."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OOMHandler:
|
||||||
|
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
|
||||||
|
|
||||||
|
Recovery levels:
|
||||||
|
1. Clear PyTorch CUDA cache
|
||||||
|
2. Evict unused models from registry
|
||||||
|
3. Request ComfyUI to yield VRAM
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
gpu_manager: GPUMemoryManager,
|
||||||
|
model_registry: Optional[Any] = None, # Avoid circular import
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 0.5,
|
||||||
|
):
|
||||||
|
"""Initialize OOM handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_manager: GPU memory manager instance
|
||||||
|
model_registry: Optional model registry for eviction
|
||||||
|
max_retries: Maximum recovery attempts
|
||||||
|
retry_delay: Delay between retries in seconds
|
||||||
|
"""
|
||||||
|
self.gpu_manager = gpu_manager
|
||||||
|
self.model_registry = model_registry
|
||||||
|
self.max_retries = max_retries
|
||||||
|
self.retry_delay = retry_delay
|
||||||
|
|
||||||
|
# Track OOM events for monitoring
|
||||||
|
self._oom_count = 0
|
||||||
|
self._last_oom_time: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def oom_count(self) -> int:
|
||||||
|
"""Number of OOM events handled."""
|
||||||
|
return self._oom_count
|
||||||
|
|
||||||
|
def set_model_registry(self, registry: Any) -> None:
|
||||||
|
"""Set model registry (to avoid circular import at init time)."""
|
||||||
|
self.model_registry = registry
|
||||||
|
|
||||||
|
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""Decorator that wraps function with OOM recovery logic.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@oom_handler.with_oom_recovery
|
||||||
|
def generate_audio(...):
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to wrap
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function with OOM recovery
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
last_exception = None
|
||||||
|
|
||||||
|
for attempt in range(self.max_retries + 1):
|
||||||
|
try:
|
||||||
|
if attempt > 0:
|
||||||
|
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
|
||||||
|
time.sleep(self.retry_delay)
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
except torch.cuda.OutOfMemoryError as e:
|
||||||
|
last_exception = e
|
||||||
|
self._oom_count += 1
|
||||||
|
self._last_oom_time = time.time()
|
||||||
|
|
||||||
|
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
|
||||||
|
|
||||||
|
if attempt < self.max_retries:
|
||||||
|
self._execute_recovery_strategy(attempt)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"OOM recovery failed after {self.max_retries} attempts"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise OOMRecoveryError(
|
||||||
|
f"OOM recovery failed after {self.max_retries} attempts"
|
||||||
|
) from last_exception
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def _execute_recovery_strategy(self, level: int) -> None:
|
||||||
|
"""Execute recovery strategy based on severity level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Recovery level (0-2)
|
||||||
|
"""
|
||||||
|
strategies = [
|
||||||
|
self._strategy_clear_cache,
|
||||||
|
self._strategy_evict_models,
|
||||||
|
self._strategy_request_comfyui_yield,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute all strategies up to and including current level
|
||||||
|
for i in range(min(level + 1, len(strategies))):
|
||||||
|
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
|
||||||
|
strategies[i]()
|
||||||
|
|
||||||
|
def _strategy_clear_cache(self) -> None:
|
||||||
|
"""Level 1: Clear PyTorch CUDA cache.
|
||||||
|
|
||||||
|
This is the fastest and least disruptive recovery strategy.
|
||||||
|
Clears cached memory that PyTorch holds for future allocations.
|
||||||
|
"""
|
||||||
|
logger.info("Clearing CUDA cache...")
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Reset peak memory stats for monitoring
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
freed = self.gpu_manager.force_cleanup()
|
||||||
|
logger.info(f"Cache cleared, freed approximately {freed}MB")
|
||||||
|
|
||||||
|
def _strategy_evict_models(self) -> None:
|
||||||
|
"""Level 2: Evict non-essential models from registry.
|
||||||
|
|
||||||
|
Unloads all models that don't have active references,
|
||||||
|
freeing their VRAM for the current operation.
|
||||||
|
"""
|
||||||
|
if self.model_registry is None:
|
||||||
|
logger.warning("No model registry available for eviction")
|
||||||
|
self._strategy_clear_cache()
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Evicting unused models...")
|
||||||
|
|
||||||
|
# Get list of loaded models
|
||||||
|
loaded = self.model_registry.get_loaded_models()
|
||||||
|
evicted = []
|
||||||
|
|
||||||
|
for model_info in loaded:
|
||||||
|
# Only evict models with no active references
|
||||||
|
if model_info["ref_count"] == 0:
|
||||||
|
model_id = model_info["model_id"]
|
||||||
|
variant = model_info["variant"]
|
||||||
|
logger.info(f"Evicting {model_id}/{variant}")
|
||||||
|
self.model_registry.unload_model(model_id, variant)
|
||||||
|
evicted.append(f"{model_id}/{variant}")
|
||||||
|
|
||||||
|
# Clear cache after eviction
|
||||||
|
self._strategy_clear_cache()
|
||||||
|
|
||||||
|
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
|
||||||
|
|
||||||
|
def _strategy_request_comfyui_yield(self) -> None:
|
||||||
|
"""Level 3: Request ComfyUI to yield VRAM.
|
||||||
|
|
||||||
|
Uses the coordination protocol to ask ComfyUI to
|
||||||
|
temporarily release GPU memory.
|
||||||
|
"""
|
||||||
|
logger.info("Requesting ComfyUI to yield VRAM...")
|
||||||
|
|
||||||
|
# First, evict our own models
|
||||||
|
self._strategy_evict_models()
|
||||||
|
|
||||||
|
# Calculate how much VRAM we need
|
||||||
|
budget = self.gpu_manager.get_available_budget()
|
||||||
|
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
|
||||||
|
|
||||||
|
# Request priority from ComfyUI
|
||||||
|
success = self.gpu_manager.request_priority(needed, timeout=15.0)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("ComfyUI yielded VRAM successfully")
|
||||||
|
else:
|
||||||
|
logger.warning("ComfyUI did not yield VRAM within timeout")
|
||||||
|
|
||||||
|
# Final cache clear
|
||||||
|
self._strategy_clear_cache()
|
||||||
|
|
||||||
|
def recover_from_oom(self, level: int = 0) -> bool:
|
||||||
|
"""Manually trigger OOM recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Recovery level to execute (0-2)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if recovery was successful (memory was freed)
|
||||||
|
"""
|
||||||
|
before = self.gpu_manager.get_memory_info()
|
||||||
|
|
||||||
|
self._execute_recovery_strategy(level)
|
||||||
|
|
||||||
|
after = self.gpu_manager.get_memory_info()
|
||||||
|
freed = before["used"] - after["used"]
|
||||||
|
|
||||||
|
logger.info(f"Manual recovery freed {freed}MB")
|
||||||
|
return freed > 0
|
||||||
|
|
||||||
|
def check_memory_for_operation(self, required_mb: int) -> bool:
|
||||||
|
"""Check if there's enough memory for an operation.
|
||||||
|
|
||||||
|
If not enough, attempts recovery strategies.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_mb: Memory required in megabytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if enough memory is available (possibly after recovery)
|
||||||
|
"""
|
||||||
|
budget = self.gpu_manager.get_available_budget()
|
||||||
|
|
||||||
|
if budget.available_mb >= required_mb:
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
|
||||||
|
"Attempting recovery..."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try progressively more aggressive recovery
|
||||||
|
for level in range(3):
|
||||||
|
self._execute_recovery_strategy(level)
|
||||||
|
budget = self.gpu_manager.get_available_budget()
|
||||||
|
|
||||||
|
if budget.available_mb >= required_mb:
|
||||||
|
logger.info(f"Recovery successful at level {level + 1}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
f"Could not free enough memory. Need {required_mb}MB, "
|
||||||
|
f"have {budget.available_mb}MB"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get OOM handling statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with OOM stats
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"oom_count": self._oom_count,
|
||||||
|
"last_oom_time": self._last_oom_time,
|
||||||
|
"max_retries": self.max_retries,
|
||||||
|
"has_registry": self.model_registry is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level convenience function
|
||||||
|
def oom_safe(
|
||||||
|
gpu_manager: GPUMemoryManager,
|
||||||
|
model_registry: Optional[Any] = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||||
|
"""Decorator factory for OOM-safe functions.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@oom_safe(gpu_manager, model_registry)
|
||||||
|
def generate_audio(...):
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_manager: GPU memory manager
|
||||||
|
model_registry: Optional model registry for eviction
|
||||||
|
max_retries: Maximum recovery attempts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function
|
||||||
|
"""
|
||||||
|
handler = OOMHandler(gpu_manager, model_registry, max_retries)
|
||||||
|
return handler.with_oom_recovery
|
||||||
84
src/main.py
Normal file
84
src/main.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""AudioCraft Studio - Main Application Entry Point."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path for imports
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
|
from src.core.model_registry import ModelRegistry
|
||||||
|
from src.storage.database import Database
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_app():
|
||||||
|
"""Initialize application components."""
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, settings.log_level),
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure directories exist
|
||||||
|
settings.ensure_directories()
|
||||||
|
|
||||||
|
# Initialize GPU manager
|
||||||
|
gpu_manager = GPUMemoryManager(
|
||||||
|
comfyui_reserve_gb=settings.comfyui_reserve_gb,
|
||||||
|
safety_buffer_gb=settings.safety_buffer_gb,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model registry
|
||||||
|
registry = ModelRegistry(
|
||||||
|
config_path=settings.models_config,
|
||||||
|
gpu_manager=gpu_manager,
|
||||||
|
max_cached_models=settings.max_cached_models,
|
||||||
|
idle_timeout_minutes=settings.idle_unload_minutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
db = Database(settings.database_path)
|
||||||
|
await db.connect()
|
||||||
|
|
||||||
|
logger.info("AudioCraft Studio initialized")
|
||||||
|
logger.info(f"GPU Status: {gpu_manager.get_status()}")
|
||||||
|
logger.info(f"Available models: {len(registry.list_models())}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"settings": settings,
|
||||||
|
"gpu_manager": gpu_manager,
|
||||||
|
"registry": registry,
|
||||||
|
"database": db,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
print("AudioCraft Studio - Starting...")
|
||||||
|
print("Phase 1 core infrastructure is complete.")
|
||||||
|
print("\nTo continue implementation:")
|
||||||
|
print(" - Phase 2: Model adapters (musicgen, audiogen, magnet, style, jasco)")
|
||||||
|
print(" - Phase 3: Services layer (generation, batch, project)")
|
||||||
|
print(" - Phase 4: Gradio UI")
|
||||||
|
print(" - Phase 5: REST API")
|
||||||
|
print(" - Phase 6: Deployment")
|
||||||
|
|
||||||
|
# Quick initialization test
|
||||||
|
async def test_init():
|
||||||
|
components = await init_app()
|
||||||
|
print(f"\nDatabase path: {components['settings'].database_path}")
|
||||||
|
print(f"GPU status: {components['gpu_manager'].get_status()}")
|
||||||
|
await components["database"].close()
|
||||||
|
|
||||||
|
asyncio.run(test_init())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
32
src/models/__init__.py
Normal file
32
src/models/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""AudioCraft model adapters.
|
||||||
|
|
||||||
|
This module contains adapters that wrap AudioCraft's models with a
|
||||||
|
consistent interface for the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.models.musicgen.adapter import MusicGenAdapter
|
||||||
|
from src.models.audiogen.adapter import AudioGenAdapter
|
||||||
|
from src.models.magnet.adapter import MAGNeTAdapter
|
||||||
|
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
|
||||||
|
from src.models.jasco.adapter import JASCOAdapter
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MusicGenAdapter",
|
||||||
|
"AudioGenAdapter",
|
||||||
|
"MAGNeTAdapter",
|
||||||
|
"MusicGenStyleAdapter",
|
||||||
|
"JASCOAdapter",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def register_all_adapters(registry) -> None:
|
||||||
|
"""Register all model adapters with the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: ModelRegistry instance to register adapters with
|
||||||
|
"""
|
||||||
|
registry.register_adapter("musicgen", MusicGenAdapter)
|
||||||
|
registry.register_adapter("audiogen", AudioGenAdapter)
|
||||||
|
registry.register_adapter("magnet", MAGNeTAdapter)
|
||||||
|
registry.register_adapter("musicgen-style", MusicGenStyleAdapter)
|
||||||
|
registry.register_adapter("jasco", JASCOAdapter)
|
||||||
5
src/models/audiogen/__init__.py
Normal file
5
src/models/audiogen/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""AudioGen model adapter."""
|
||||||
|
|
||||||
|
from src.models.audiogen.adapter import AudioGenAdapter
|
||||||
|
|
||||||
|
__all__ = ["AudioGenAdapter"]
|
||||||
203
src/models/audiogen/adapter.py
Normal file
203
src/models/audiogen/adapter.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
"""AudioGen model adapter for text-to-sound effects generation."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
ConditioningType,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioGenAdapter(BaseAudioModel):
|
||||||
|
"""Adapter for Facebook's AudioGen model.
|
||||||
|
|
||||||
|
Generates sound effects and environmental audio from text descriptions.
|
||||||
|
Optimized for non-musical audio like sound effects, ambiences, and foley.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VARIANTS = {
|
||||||
|
"medium": {
|
||||||
|
"hf_id": "facebook/audiogen-medium",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, variant: str = "medium"):
|
||||||
|
"""Initialize AudioGen adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: Model variant (currently only 'medium' available)
|
||||||
|
"""
|
||||||
|
if variant not in self.VARIANTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown AudioGen variant: {variant}. "
|
||||||
|
f"Available: {list(self.VARIANTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._variant = variant
|
||||||
|
self._config = self.VARIANTS[variant]
|
||||||
|
self._model = None
|
||||||
|
self._device: Optional[torch.device] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return "audiogen"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variant(self) -> str:
|
||||||
|
return self._variant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return f"AudioGen ({self._variant})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text-to-sound effects generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
return self._config["vram_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
return self._config["max_duration"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
if self._model is not None:
|
||||||
|
return self._model.sample_rate
|
||||||
|
return 16000 # AudioGen default sample rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
return [ConditioningType.TEXT]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the AudioGen model."""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.warning(f"AudioGen {self._variant} already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading AudioGen {self._variant} from {self._config['hf_id']}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import AudioGen
|
||||||
|
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._model = AudioGen.get_pretrained(self._config["hf_id"])
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"AudioGen {self._variant} loaded successfully "
|
||||||
|
f"(sample_rate={self._model.sample_rate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
logger.error(f"Failed to load AudioGen {self._variant}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load AudioGen: {e}") from e
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory."""
|
||||||
|
if self._model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading AudioGen {self._variant}...")
|
||||||
|
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate sound effects from text prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters including prompts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with audio tensor and metadata
|
||||||
|
"""
|
||||||
|
self.validate_request(request)
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Configure generation
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generating {len(request.prompts)} sound effect(s) with AudioGen "
|
||||||
|
f"(duration={request.duration}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio = self._model.generate(request.prompts)
|
||||||
|
|
||||||
|
actual_duration = audio.shape[-1] / self.sample_rate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {audio.shape[0]} sample(s), "
|
||||||
|
f"duration={actual_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
audio=audio.cpu(),
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
duration=actual_duration,
|
||||||
|
model_id=self.model_id,
|
||||||
|
variant=self._variant,
|
||||||
|
parameters={
|
||||||
|
"duration": request.duration,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"cfg_coef": request.cfg_coef,
|
||||||
|
"prompts": request.prompts,
|
||||||
|
},
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters."""
|
||||||
|
return {
|
||||||
|
"duration": 5.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"top_p": 0.0,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
}
|
||||||
5
src/models/jasco/__init__.py
Normal file
5
src/models/jasco/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""JASCO model adapter."""
|
||||||
|
|
||||||
|
from src.models.jasco.adapter import JASCOAdapter
|
||||||
|
|
||||||
|
__all__ = ["JASCOAdapter"]
|
||||||
348
src/models/jasco/adapter.py
Normal file
348
src/models/jasco/adapter.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""JASCO model adapter for chord and drum-conditioned music generation."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
ConditioningType,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JASCOAdapter(BaseAudioModel):
|
||||||
|
"""Adapter for Facebook's JASCO model.
|
||||||
|
|
||||||
|
JASCO (Joint Audio and Symbolic Conditioning) enables music generation
|
||||||
|
with control over chord progressions and drum patterns alongside text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VARIANTS = {
|
||||||
|
"chords-drums-400M": {
|
||||||
|
"hf_id": "facebook/jasco-chords-drums-400M",
|
||||||
|
"vram_mb": 2000,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
},
|
||||||
|
"chords-drums-1B": {
|
||||||
|
"hf_id": "facebook/jasco-chords-drums-1B",
|
||||||
|
"vram_mb": 4000,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Common chord types for validation
|
||||||
|
VALID_CHORD_TYPES = [
|
||||||
|
"maj", "min", "dim", "aug", "7", "maj7", "min7", "dim7",
|
||||||
|
"sus2", "sus4", "add9", "6", "min6", "9", "min9", "maj9",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, variant: str = "chords-drums-400M"):
|
||||||
|
"""Initialize JASCO adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: Model variant to use
|
||||||
|
"""
|
||||||
|
if variant not in self.VARIANTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown JASCO variant: {variant}. "
|
||||||
|
f"Available: {list(self.VARIANTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._variant = variant
|
||||||
|
self._config = self.VARIANTS[variant]
|
||||||
|
self._model = None
|
||||||
|
self._device: Optional[torch.device] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return "jasco"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variant(self) -> str:
|
||||||
|
return self._variant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return f"JASCO ({self._variant})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Chord and drum-conditioned music generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
return self._config["vram_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
return self._config["max_duration"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
if self._model is not None:
|
||||||
|
return self._model.sample_rate
|
||||||
|
return 32000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
return [ConditioningType.TEXT, ConditioningType.CHORDS, ConditioningType.DRUMS]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the JASCO model."""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.warning(f"JASCO {self._variant} already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading JASCO {self._variant} from {self._config['hf_id']}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import JASCO
|
||||||
|
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._model = JASCO.get_pretrained(self._config["hf_id"])
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"JASCO {self._variant} loaded successfully "
|
||||||
|
f"(sample_rate={self._model.sample_rate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
logger.error(f"Failed to load JASCO {self._variant}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load JASCO: {e}") from e
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory."""
|
||||||
|
if self._model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading JASCO {self._variant}...")
|
||||||
|
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_chord_progression(
|
||||||
|
chords: list[dict[str, Any]], duration: float
|
||||||
|
) -> list[tuple[float, float, str]]:
|
||||||
|
"""Parse chord progression from user input format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chords: List of chord dictionaries with keys:
|
||||||
|
- time: Start time in seconds
|
||||||
|
- chord: Chord name (e.g., "C", "Am", "G7")
|
||||||
|
duration: Total duration for calculating end times
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (start_time, end_time, chord_name) tuples
|
||||||
|
|
||||||
|
Example input:
|
||||||
|
[
|
||||||
|
{"time": 0.0, "chord": "C"},
|
||||||
|
{"time": 2.0, "chord": "Am"},
|
||||||
|
{"time": 4.0, "chord": "F"},
|
||||||
|
{"time": 6.0, "chord": "G"},
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
if not chords:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by time
|
||||||
|
sorted_chords = sorted(chords, key=lambda x: x["time"])
|
||||||
|
|
||||||
|
# Build (start, end, chord) tuples
|
||||||
|
result = []
|
||||||
|
for i, chord_info in enumerate(sorted_chords):
|
||||||
|
start = chord_info["time"]
|
||||||
|
# End time is either next chord's start or total duration
|
||||||
|
if i + 1 < len(sorted_chords):
|
||||||
|
end = sorted_chords[i + 1]["time"]
|
||||||
|
else:
|
||||||
|
end = duration
|
||||||
|
result.append((start, end, chord_info["chord"]))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_drum_pattern(
|
||||||
|
pattern: str, duration: float, bpm: float = 120.0
|
||||||
|
) -> list[tuple[float, str]]:
|
||||||
|
"""Create drum events from a pattern string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Pattern string (e.g., "kick,snare,kick,snare")
|
||||||
|
or "4/4" for common time signature
|
||||||
|
duration: Total duration in seconds
|
||||||
|
bpm: Beats per minute
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (time, drum_type) tuples
|
||||||
|
"""
|
||||||
|
beat_duration = 60.0 / bpm
|
||||||
|
events = []
|
||||||
|
|
||||||
|
if pattern in ["4/4", "common"]:
|
||||||
|
# Standard 4/4 rock pattern
|
||||||
|
time = 0.0
|
||||||
|
beat = 0
|
||||||
|
while time < duration:
|
||||||
|
if beat % 4 == 0:
|
||||||
|
events.append((time, "kick"))
|
||||||
|
elif beat % 4 == 2:
|
||||||
|
events.append((time, "snare"))
|
||||||
|
if beat % 2 == 0:
|
||||||
|
events.append((time, "hihat"))
|
||||||
|
time += beat_duration / 2
|
||||||
|
beat += 1
|
||||||
|
else:
|
||||||
|
# Parse comma-separated pattern
|
||||||
|
drum_types = pattern.split(",")
|
||||||
|
time = 0.0
|
||||||
|
idx = 0
|
||||||
|
while time < duration:
|
||||||
|
drum = drum_types[idx % len(drum_types)].strip()
|
||||||
|
if drum:
|
||||||
|
events.append((time, drum))
|
||||||
|
time += beat_duration
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate music with chord and drum conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters with optional conditioning:
|
||||||
|
- chords: List of {"time": float, "chord": str} dicts
|
||||||
|
- drums: Drum pattern string or list of (time, drum_type)
|
||||||
|
- bpm: Beats per minute for drum pattern
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with audio tensor and metadata
|
||||||
|
"""
|
||||||
|
self.validate_request(request)
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Configure generation parameters
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process chord conditioning
|
||||||
|
chords_input = request.conditioning.get("chords")
|
||||||
|
chords_formatted = None
|
||||||
|
if chords_input:
|
||||||
|
if isinstance(chords_input, list) and len(chords_input) > 0:
|
||||||
|
if isinstance(chords_input[0], dict):
|
||||||
|
chords_formatted = self.parse_chord_progression(
|
||||||
|
chords_input, request.duration
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Already in (start, end, chord) format
|
||||||
|
chords_formatted = chords_input
|
||||||
|
|
||||||
|
# Process drum conditioning
|
||||||
|
drums_input = request.conditioning.get("drums")
|
||||||
|
bpm = request.conditioning.get("bpm", 120.0)
|
||||||
|
drums_formatted = None
|
||||||
|
if drums_input:
|
||||||
|
if isinstance(drums_input, str):
|
||||||
|
drums_formatted = self.create_drum_pattern(
|
||||||
|
drums_input, request.duration, bpm
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
drums_formatted = drums_input
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generating {len(request.prompts)} sample(s) with JASCO "
|
||||||
|
f"(duration={request.duration}s, chords={chords_formatted is not None}, "
|
||||||
|
f"drums={drums_formatted is not None})"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
# Build conditioning dict for JASCO
|
||||||
|
conditioning = {}
|
||||||
|
if chords_formatted:
|
||||||
|
conditioning["chords"] = chords_formatted
|
||||||
|
if drums_formatted:
|
||||||
|
conditioning["drums"] = drums_formatted
|
||||||
|
|
||||||
|
if conditioning:
|
||||||
|
audio = self._model.generate(
|
||||||
|
descriptions=request.prompts,
|
||||||
|
**conditioning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Generate without symbolic conditioning
|
||||||
|
audio = self._model.generate(request.prompts)
|
||||||
|
|
||||||
|
actual_duration = audio.shape[-1] / self.sample_rate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {audio.shape[0]} sample(s), "
|
||||||
|
f"duration={actual_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
audio=audio.cpu(),
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
duration=actual_duration,
|
||||||
|
model_id=self.model_id,
|
||||||
|
variant=self._variant,
|
||||||
|
parameters={
|
||||||
|
"duration": request.duration,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"cfg_coef": request.cfg_coef,
|
||||||
|
"prompts": request.prompts,
|
||||||
|
"chords": chords_formatted,
|
||||||
|
"drums": drums_formatted,
|
||||||
|
"bpm": bpm,
|
||||||
|
},
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters for JASCO."""
|
||||||
|
return {
|
||||||
|
"duration": 10.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"top_p": 0.0,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
"bpm": 120.0,
|
||||||
|
}
|
||||||
5
src/models/magnet/__init__.py
Normal file
5
src/models/magnet/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""MAGNeT model adapter."""
|
||||||
|
|
||||||
|
from src.models.magnet.adapter import MAGNeTAdapter
|
||||||
|
|
||||||
|
__all__ = ["MAGNeTAdapter"]
|
||||||
253
src/models/magnet/adapter.py
Normal file
253
src/models/magnet/adapter.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""MAGNeT model adapter for fast non-autoregressive audio generation."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
ConditioningType,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MAGNeTAdapter(BaseAudioModel):
|
||||||
|
"""Adapter for Facebook's MAGNeT model.
|
||||||
|
|
||||||
|
MAGNeT (Masked Audio Generation using Non-autoregressive Transformers)
|
||||||
|
provides faster generation than autoregressive models like MusicGen.
|
||||||
|
Supports both music and sound effect generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VARIANTS = {
|
||||||
|
"small-10secs": {
|
||||||
|
"hf_id": "facebook/magnet-small-10secs",
|
||||||
|
"vram_mb": 1500,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "music",
|
||||||
|
},
|
||||||
|
"medium-10secs": {
|
||||||
|
"hf_id": "facebook/magnet-medium-10secs",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "music",
|
||||||
|
},
|
||||||
|
"small-30secs": {
|
||||||
|
"hf_id": "facebook/magnet-small-30secs",
|
||||||
|
"vram_mb": 1800,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "music",
|
||||||
|
},
|
||||||
|
"medium-30secs": {
|
||||||
|
"hf_id": "facebook/magnet-medium-30secs",
|
||||||
|
"vram_mb": 6000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "music",
|
||||||
|
},
|
||||||
|
"audio-small-10secs": {
|
||||||
|
"hf_id": "facebook/audio-magnet-small",
|
||||||
|
"vram_mb": 1500,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "sound",
|
||||||
|
},
|
||||||
|
"audio-medium-10secs": {
|
||||||
|
"hf_id": "facebook/audio-magnet-medium",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 10,
|
||||||
|
"channels": 1,
|
||||||
|
"audio_type": "sound",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, variant: str = "medium-10secs"):
|
||||||
|
"""Initialize MAGNeT adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: Model variant to use
|
||||||
|
"""
|
||||||
|
if variant not in self.VARIANTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown MAGNeT variant: {variant}. "
|
||||||
|
f"Available: {list(self.VARIANTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._variant = variant
|
||||||
|
self._config = self.VARIANTS[variant]
|
||||||
|
self._model = None
|
||||||
|
self._device: Optional[torch.device] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return "magnet"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variant(self) -> str:
|
||||||
|
return self._variant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return f"MAGNeT ({self._variant})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
audio_type = self._config.get("audio_type", "music")
|
||||||
|
return f"Fast non-autoregressive {audio_type} generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
return self._config["vram_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
return self._config["max_duration"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
if self._model is not None:
|
||||||
|
return self._model.sample_rate
|
||||||
|
return 32000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
return [ConditioningType.TEXT]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the MAGNeT model."""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.warning(f"MAGNeT {self._variant} already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading MAGNeT {self._variant} from {self._config['hf_id']}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import MAGNeT
|
||||||
|
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MAGNeT {self._variant} loaded successfully "
|
||||||
|
f"(sample_rate={self._model.sample_rate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
logger.error(f"Failed to load MAGNeT {self._variant}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load MAGNeT: {e}") from e
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory."""
|
||||||
|
if self._model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading MAGNeT {self._variant}...")
|
||||||
|
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate audio from text prompts using MAGNeT.
|
||||||
|
|
||||||
|
MAGNeT uses a non-autoregressive approach with iterative decoding,
|
||||||
|
which is significantly faster than autoregressive models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters including prompts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with audio tensor and metadata
|
||||||
|
"""
|
||||||
|
self.validate_request(request)
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Configure generation parameters
|
||||||
|
# MAGNeT has different parameters than MusicGen
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
# MAGNeT-specific parameters
|
||||||
|
decoding_steps=[
|
||||||
|
int(request.conditioning.get("decoding_steps_1", 20)),
|
||||||
|
int(request.conditioning.get("decoding_steps_2", 10)),
|
||||||
|
int(request.conditioning.get("decoding_steps_3", 10)),
|
||||||
|
int(request.conditioning.get("decoding_steps_4", 10)),
|
||||||
|
],
|
||||||
|
span_arrangement=request.conditioning.get("span_arrangement", "nonoverlap"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generating {len(request.prompts)} sample(s) with MAGNeT {self._variant} "
|
||||||
|
f"(duration={request.duration}s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
with torch.inference_mode():
|
||||||
|
audio = self._model.generate(request.prompts)
|
||||||
|
|
||||||
|
actual_duration = audio.shape[-1] / self.sample_rate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {audio.shape[0]} sample(s), "
|
||||||
|
f"duration={actual_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
audio=audio.cpu(),
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
duration=actual_duration,
|
||||||
|
model_id=self.model_id,
|
||||||
|
variant=self._variant,
|
||||||
|
parameters={
|
||||||
|
"duration": request.duration,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"cfg_coef": request.cfg_coef,
|
||||||
|
"prompts": request.prompts,
|
||||||
|
},
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters for MAGNeT."""
|
||||||
|
return {
|
||||||
|
"duration": 10.0,
|
||||||
|
"temperature": 3.0, # MAGNeT works better with higher temperature
|
||||||
|
"top_k": 0, # Use top_p instead for MAGNeT
|
||||||
|
"top_p": 0.9,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
}
|
||||||
5
src/models/musicgen/__init__.py
Normal file
5
src/models/musicgen/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""MusicGen model adapter."""
|
||||||
|
|
||||||
|
from src.models.musicgen.adapter import MusicGenAdapter
|
||||||
|
|
||||||
|
__all__ = ["MusicGenAdapter"]
|
||||||
290
src/models/musicgen/adapter.py
Normal file
290
src/models/musicgen/adapter.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""MusicGen model adapter for text-to-music generation."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
ConditioningType,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MusicGenAdapter(BaseAudioModel):
|
||||||
|
"""Adapter for Facebook's MusicGen model.
|
||||||
|
|
||||||
|
Supports text-to-music generation with optional melody conditioning.
|
||||||
|
Available variants: small, medium, large, melody, and stereo versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Variant configurations
|
||||||
|
VARIANTS = {
|
||||||
|
"small": {
|
||||||
|
"hf_id": "facebook/musicgen-small",
|
||||||
|
"vram_mb": 1500,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"medium": {
|
||||||
|
"hf_id": "facebook/musicgen-medium",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"large": {
|
||||||
|
"hf_id": "facebook/musicgen-large",
|
||||||
|
"vram_mb": 10000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"melody": {
|
||||||
|
"hf_id": "facebook/musicgen-melody",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
"conditioning": [ConditioningType.MELODY],
|
||||||
|
},
|
||||||
|
"stereo-small": {
|
||||||
|
"hf_id": "facebook/musicgen-stereo-small",
|
||||||
|
"vram_mb": 1800,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 2,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"stereo-medium": {
|
||||||
|
"hf_id": "facebook/musicgen-stereo-medium",
|
||||||
|
"vram_mb": 6000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 2,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"stereo-large": {
|
||||||
|
"hf_id": "facebook/musicgen-stereo-large",
|
||||||
|
"vram_mb": 12000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 2,
|
||||||
|
"conditioning": [],
|
||||||
|
},
|
||||||
|
"stereo-melody": {
|
||||||
|
"hf_id": "facebook/musicgen-stereo-melody",
|
||||||
|
"vram_mb": 6000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 2,
|
||||||
|
"conditioning": [ConditioningType.MELODY],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, variant: str = "medium"):
|
||||||
|
"""Initialize MusicGen adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: Model variant to use (small, medium, large, melody, etc.)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If variant is not recognized
|
||||||
|
"""
|
||||||
|
if variant not in self.VARIANTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown MusicGen variant: {variant}. "
|
||||||
|
f"Available: {list(self.VARIANTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._variant = variant
|
||||||
|
self._config = self.VARIANTS[variant]
|
||||||
|
self._model = None
|
||||||
|
self._device: Optional[torch.device] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return "musicgen"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variant(self) -> str:
|
||||||
|
return self._variant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return f"MusicGen ({self._variant})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
if "melody" in self._variant:
|
||||||
|
return "Text-to-music with melody conditioning"
|
||||||
|
elif "stereo" in self._variant:
|
||||||
|
return "Stereo text-to-music generation"
|
||||||
|
return "Text-to-music generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
return self._config["vram_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
return self._config["max_duration"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
if self._model is not None:
|
||||||
|
return self._model.sample_rate
|
||||||
|
return 32000 # Default MusicGen sample rate
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
return [ConditioningType.TEXT] + self._config["conditioning"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the MusicGen model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
|
||||||
|
"""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.warning(f"MusicGen {self._variant} already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading MusicGen {self._variant} from {self._config['hf_id']}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import MusicGen
|
||||||
|
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MusicGen {self._variant} loaded successfully "
|
||||||
|
f"(sample_rate={self._model.sample_rate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
logger.error(f"Failed to load MusicGen {self._variant}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load MusicGen: {e}") from e
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory."""
|
||||||
|
if self._model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading MusicGen {self._variant}...")
|
||||||
|
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate music from text prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters including prompts
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with audio tensor and metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model not loaded
|
||||||
|
ValueError: If request is invalid
|
||||||
|
"""
|
||||||
|
self.validate_request(request)
|
||||||
|
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Configure generation parameters
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generating {len(request.prompts)} sample(s) with MusicGen {self._variant} "
|
||||||
|
f"(duration={request.duration}s, temp={request.temperature})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
with torch.inference_mode():
|
||||||
|
melody_audio = request.conditioning.get("melody")
|
||||||
|
melody_sr = request.conditioning.get("melody_sr", self.sample_rate)
|
||||||
|
|
||||||
|
if melody_audio is not None and ConditioningType.MELODY in self.supports_conditioning:
|
||||||
|
# Melody-conditioned generation
|
||||||
|
if isinstance(melody_audio, str):
|
||||||
|
# Load from file path
|
||||||
|
import torchaudio
|
||||||
|
melody_tensor, melody_sr = torchaudio.load(melody_audio)
|
||||||
|
melody_tensor = melody_tensor.to(self._device)
|
||||||
|
else:
|
||||||
|
melody_tensor = torch.tensor(melody_audio).to(self._device)
|
||||||
|
|
||||||
|
audio = self._model.generate_with_chroma(
|
||||||
|
descriptions=request.prompts,
|
||||||
|
melody_wavs=melody_tensor.unsqueeze(0) if melody_tensor.dim() == 1 else melody_tensor,
|
||||||
|
melody_sample_rate=melody_sr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard text-to-music generation
|
||||||
|
audio = self._model.generate(request.prompts)
|
||||||
|
|
||||||
|
# audio shape: [batch, channels, samples]
|
||||||
|
actual_duration = audio.shape[-1] / self.sample_rate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {audio.shape[0]} sample(s), "
|
||||||
|
f"duration={actual_duration:.2f}s, shape={audio.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
audio=audio.cpu(),
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
duration=actual_duration,
|
||||||
|
model_id=self.model_id,
|
||||||
|
variant=self._variant,
|
||||||
|
parameters={
|
||||||
|
"duration": request.duration,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"cfg_coef": request.cfg_coef,
|
||||||
|
"prompts": request.prompts,
|
||||||
|
},
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters."""
|
||||||
|
return {
|
||||||
|
"duration": 10.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"top_p": 0.0,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
}
|
||||||
5
src/models/musicgen_style/__init__.py
Normal file
5
src/models/musicgen_style/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""MusicGen Style model adapter."""
|
||||||
|
|
||||||
|
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
|
||||||
|
|
||||||
|
__all__ = ["MusicGenStyleAdapter"]
|
||||||
277
src/models/musicgen_style/adapter.py
Normal file
277
src/models/musicgen_style/adapter.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
"""MusicGen Style model adapter for style-conditioned music generation."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
from src.core.base_model import (
|
||||||
|
BaseAudioModel,
|
||||||
|
ConditioningType,
|
||||||
|
GenerationRequest,
|
||||||
|
GenerationResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MusicGenStyleAdapter(BaseAudioModel):
|
||||||
|
"""Adapter for Facebook's MusicGen Style model.
|
||||||
|
|
||||||
|
Generates music conditioned on both text and a style reference audio.
|
||||||
|
Extracts style features from the reference and applies them to new generations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VARIANTS = {
|
||||||
|
"medium": {
|
||||||
|
"hf_id": "facebook/musicgen-style",
|
||||||
|
"vram_mb": 5000,
|
||||||
|
"max_duration": 30,
|
||||||
|
"channels": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, variant: str = "medium"):
|
||||||
|
"""Initialize MusicGen Style adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
variant: Model variant (currently only 'medium' available)
|
||||||
|
"""
|
||||||
|
if variant not in self.VARIANTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown MusicGen Style variant: {variant}. "
|
||||||
|
f"Available: {list(self.VARIANTS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._variant = variant
|
||||||
|
self._config = self.VARIANTS[variant]
|
||||||
|
self._model = None
|
||||||
|
self._device: Optional[torch.device] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return "musicgen-style"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variant(self) -> str:
|
||||||
|
return self._variant
|
||||||
|
|
||||||
|
@property
|
||||||
|
def display_name(self) -> str:
|
||||||
|
return f"MusicGen Style ({self._variant})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Style-conditioned music generation from reference audio"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vram_estimate_mb(self) -> int:
|
||||||
|
return self._config["vram_mb"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_duration(self) -> float:
|
||||||
|
return self._config["max_duration"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sample_rate(self) -> int:
|
||||||
|
if self._model is not None:
|
||||||
|
return self._model.sample_rate
|
||||||
|
return 32000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_conditioning(self) -> list[ConditioningType]:
|
||||||
|
return [ConditioningType.TEXT, ConditioningType.STYLE]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Optional[torch.device]:
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def load(self, device: str = "cuda") -> None:
|
||||||
|
"""Load the MusicGen Style model."""
|
||||||
|
if self._model is not None:
|
||||||
|
logger.warning(f"MusicGen Style {self._variant} already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Loading MusicGen Style {self._variant}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from audiocraft.models import MusicGen
|
||||||
|
|
||||||
|
self._device = torch.device(device)
|
||||||
|
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||||
|
self._model.to(self._device)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"MusicGen Style {self._variant} loaded successfully "
|
||||||
|
f"(sample_rate={self._model.sample_rate})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
|
||||||
|
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload the model and free memory."""
|
||||||
|
if self._model is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Unloading MusicGen Style {self._variant}...")
|
||||||
|
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._device = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _load_style_audio(
|
||||||
|
self, style_input: Any, target_sr: int
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""Load and prepare style reference audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
style_input: File path, tensor, or numpy array
|
||||||
|
target_sr: Target sample rate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (audio_tensor, sample_rate)
|
||||||
|
"""
|
||||||
|
if isinstance(style_input, str):
|
||||||
|
# Load from file
|
||||||
|
audio, sr = torchaudio.load(style_input)
|
||||||
|
if sr != target_sr:
|
||||||
|
audio = torchaudio.functional.resample(audio, sr, target_sr)
|
||||||
|
return audio.to(self._device), target_sr
|
||||||
|
elif isinstance(style_input, torch.Tensor):
|
||||||
|
return style_input.to(self._device), target_sr
|
||||||
|
else:
|
||||||
|
# Assume numpy array
|
||||||
|
return torch.tensor(style_input).to(self._device), target_sr
|
||||||
|
|
||||||
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||||
|
"""Generate music conditioned on text and style reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Generation parameters including prompts and style conditioning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GenerationResult with audio tensor and metadata
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Style conditioning requires 'style' in request.conditioning with either:
|
||||||
|
- File path to audio
|
||||||
|
- Audio tensor
|
||||||
|
- Numpy array
|
||||||
|
"""
|
||||||
|
self.validate_request(request)
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
# Get style conditioning parameters
|
||||||
|
style_audio = request.conditioning.get("style")
|
||||||
|
eval_q = request.conditioning.get("eval_q", 3)
|
||||||
|
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
|
||||||
|
|
||||||
|
# Configure generation parameters
|
||||||
|
self._model.set_generation_params(
|
||||||
|
duration=request.duration,
|
||||||
|
temperature=request.temperature,
|
||||||
|
top_k=request.top_k,
|
||||||
|
top_p=request.top_p,
|
||||||
|
cfg_coef=request.cfg_coef,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
|
||||||
|
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
if style_audio is not None:
|
||||||
|
# Load style reference
|
||||||
|
style_tensor, style_sr = self._load_style_audio(
|
||||||
|
style_audio, self.sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure proper shape [batch, channels, samples]
|
||||||
|
if style_tensor.dim() == 1:
|
||||||
|
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
|
||||||
|
elif style_tensor.dim() == 2:
|
||||||
|
style_tensor = style_tensor.unsqueeze(0)
|
||||||
|
|
||||||
|
# Set style conditioner parameters
|
||||||
|
if hasattr(self._model, 'set_style_conditioner_params'):
|
||||||
|
self._model.set_style_conditioner_params(
|
||||||
|
eval_q=eval_q,
|
||||||
|
excerpt_length=excerpt_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate with style conditioning
|
||||||
|
# Expand style to match number of prompts if needed
|
||||||
|
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
|
||||||
|
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
|
||||||
|
|
||||||
|
audio = self._model.generate_with_chroma(
|
||||||
|
descriptions=request.prompts,
|
||||||
|
melody_wavs=style_tensor,
|
||||||
|
melody_sample_rate=style_sr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Generate without style (falls back to standard MusicGen behavior)
|
||||||
|
logger.warning(
|
||||||
|
"No style reference provided, generating without style conditioning"
|
||||||
|
)
|
||||||
|
audio = self._model.generate(request.prompts)
|
||||||
|
|
||||||
|
actual_duration = audio.shape[-1] / self.sample_rate
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Generated {audio.shape[0]} sample(s), "
|
||||||
|
f"duration={actual_duration:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return GenerationResult(
|
||||||
|
audio=audio.cpu(),
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
duration=actual_duration,
|
||||||
|
model_id=self.model_id,
|
||||||
|
variant=self._variant,
|
||||||
|
parameters={
|
||||||
|
"duration": request.duration,
|
||||||
|
"temperature": request.temperature,
|
||||||
|
"top_k": request.top_k,
|
||||||
|
"top_p": request.top_p,
|
||||||
|
"cfg_coef": request.cfg_coef,
|
||||||
|
"prompts": request.prompts,
|
||||||
|
"style_conditioned": style_audio is not None,
|
||||||
|
"eval_q": eval_q,
|
||||||
|
"excerpt_length": excerpt_length,
|
||||||
|
},
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_default_params(self) -> dict[str, Any]:
|
||||||
|
"""Get default generation parameters for MusicGen Style."""
|
||||||
|
return {
|
||||||
|
"duration": 10.0,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"top_p": 0.0,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
"eval_q": 3,
|
||||||
|
"excerpt_length": 3.0,
|
||||||
|
}
|
||||||
13
src/services/__init__.py
Normal file
13
src/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Services layer for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.services.generation_service import GenerationService
|
||||||
|
from src.services.batch_processor import BatchProcessor, GenerationJob, JobStatus
|
||||||
|
from src.services.project_service import ProjectService
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GenerationService",
|
||||||
|
"BatchProcessor",
|
||||||
|
"GenerationJob",
|
||||||
|
"JobStatus",
|
||||||
|
"ProjectService",
|
||||||
|
]
|
||||||
397
src/services/batch_processor.py
Normal file
397
src/services/batch_processor.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""Batch processor for queued audio generation jobs."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
"""Status of a generation job."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
PROCESSING = "processing"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationJob:
|
||||||
|
"""A queued generation job."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
model_id: str
|
||||||
|
variant: Optional[str]
|
||||||
|
prompts: list[str]
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
conditioning: dict[str, Any]
|
||||||
|
project_id: Optional[str]
|
||||||
|
preset_used: Optional[str]
|
||||||
|
tags: list[str]
|
||||||
|
|
||||||
|
# Status tracking
|
||||||
|
status: JobStatus = JobStatus.PENDING
|
||||||
|
progress: float = 0.0
|
||||||
|
progress_message: str = ""
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
started_at: Optional[datetime] = None
|
||||||
|
completed_at: Optional[datetime] = None
|
||||||
|
|
||||||
|
# Results
|
||||||
|
result_id: Optional[str] = None # Generation ID if completed
|
||||||
|
audio_path: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model_id: str,
|
||||||
|
variant: Optional[str],
|
||||||
|
prompts: list[str],
|
||||||
|
duration: float = 10.0,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 250,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
cfg_coef: float = 3.0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
conditioning: Optional[dict[str, Any]] = None,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
preset_used: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
) -> "GenerationJob":
|
||||||
|
"""Create a new generation job."""
|
||||||
|
return cls(
|
||||||
|
id=f"job_{uuid.uuid4().hex[:12]}",
|
||||||
|
model_id=model_id,
|
||||||
|
variant=variant,
|
||||||
|
prompts=prompts,
|
||||||
|
parameters={
|
||||||
|
"duration": duration,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"cfg_coef": cfg_coef,
|
||||||
|
"seed": seed,
|
||||||
|
},
|
||||||
|
conditioning=conditioning or {},
|
||||||
|
project_id=project_id,
|
||||||
|
preset_used=preset_used,
|
||||||
|
tags=tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert job to dictionary for API responses."""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"variant": self.variant,
|
||||||
|
"prompts": self.prompts,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"status": self.status.value,
|
||||||
|
"progress": self.progress,
|
||||||
|
"progress_message": self.progress_message,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||||
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||||
|
"result_id": self.result_id,
|
||||||
|
"audio_path": self.audio_path,
|
||||||
|
"error": self.error,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BatchProcessor:
|
||||||
|
"""Manages a queue of generation jobs.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Async job queue with configurable concurrency
|
||||||
|
- Progress tracking and callbacks
|
||||||
|
- Job cancellation
|
||||||
|
- Priority support (future enhancement)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
generation_service: Any, # Avoid circular import
|
||||||
|
max_queue_size: int = 100,
|
||||||
|
max_concurrent: int = 1, # GPU operations should be serialized
|
||||||
|
):
|
||||||
|
"""Initialize batch processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_service: GenerationService instance
|
||||||
|
max_queue_size: Maximum jobs in queue
|
||||||
|
max_concurrent: Maximum concurrent generations (usually 1 for GPU)
|
||||||
|
"""
|
||||||
|
self.generation_service = generation_service
|
||||||
|
self.max_queue_size = max_queue_size
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
|
||||||
|
# Job tracking
|
||||||
|
self._jobs: dict[str, GenerationJob] = {}
|
||||||
|
self._queue: asyncio.Queue[str] = asyncio.Queue(maxsize=max_queue_size)
|
||||||
|
|
||||||
|
# Processing control
|
||||||
|
self._workers: list[asyncio.Task] = []
|
||||||
|
self._running = False
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self._on_job_complete: list[Callable[[GenerationJob], None]] = []
|
||||||
|
self._on_job_failed: list[Callable[[GenerationJob], None]] = []
|
||||||
|
self._on_progress: list[Callable[[GenerationJob], None]] = []
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""Start the batch processor workers."""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
# Start worker tasks
|
||||||
|
for i in range(self.max_concurrent):
|
||||||
|
worker = asyncio.create_task(self._worker_loop(i))
|
||||||
|
self._workers.append(worker)
|
||||||
|
|
||||||
|
logger.info(f"Batch processor started with {self.max_concurrent} worker(s)")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""Stop the batch processor and wait for pending jobs."""
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
# Cancel workers
|
||||||
|
for worker in self._workers:
|
||||||
|
worker.cancel()
|
||||||
|
|
||||||
|
# Wait for workers to finish
|
||||||
|
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||||
|
self._workers.clear()
|
||||||
|
|
||||||
|
logger.info("Batch processor stopped")
|
||||||
|
|
||||||
|
async def submit(self, job: GenerationJob) -> GenerationJob:
|
||||||
|
"""Submit a job to the queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job: Job to submit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The submitted job with ID
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If queue is full
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
if len(self._jobs) >= self.max_queue_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Queue full (max {self.max_queue_size} jobs). "
|
||||||
|
"Please wait for jobs to complete."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._jobs[job.id] = job
|
||||||
|
await self._queue.put(job.id)
|
||||||
|
|
||||||
|
logger.info(f"Job {job.id} submitted to queue (position: {self._queue.qsize()})")
|
||||||
|
return job
|
||||||
|
|
||||||
|
async def cancel(self, job_id: str) -> bool:
|
||||||
|
"""Cancel a pending job.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: ID of job to cancel
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if job was cancelled, False if not found or already processing
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
job = self._jobs.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if job.status != JobStatus.PENDING:
|
||||||
|
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
job.status = JobStatus.CANCELLED
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
logger.info(f"Job {job_id} cancelled")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_job(self, job_id: str) -> Optional[GenerationJob]:
|
||||||
|
"""Get a job by ID."""
|
||||||
|
return self._jobs.get(job_id)
|
||||||
|
|
||||||
|
def get_queue_status(self) -> dict[str, Any]:
|
||||||
|
"""Get current queue status."""
|
||||||
|
jobs_by_status = {}
|
||||||
|
for job in self._jobs.values():
|
||||||
|
status = job.status.value
|
||||||
|
jobs_by_status[status] = jobs_by_status.get(status, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"queue_size": self._queue.qsize(),
|
||||||
|
"total_jobs": len(self._jobs),
|
||||||
|
"jobs_by_status": jobs_by_status,
|
||||||
|
"running": self._running,
|
||||||
|
"max_queue_size": self.max_queue_size,
|
||||||
|
}
|
||||||
|
|
||||||
|
def list_jobs(
|
||||||
|
self,
|
||||||
|
status: Optional[JobStatus] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> list[GenerationJob]:
|
||||||
|
"""List jobs with optional status filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Filter by status
|
||||||
|
limit: Maximum jobs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of jobs ordered by creation time (newest first)
|
||||||
|
"""
|
||||||
|
jobs = list(self._jobs.values())
|
||||||
|
|
||||||
|
if status:
|
||||||
|
jobs = [j for j in jobs if j.status == status]
|
||||||
|
|
||||||
|
# Sort by created_at descending
|
||||||
|
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||||
|
|
||||||
|
return jobs[:limit]
|
||||||
|
|
||||||
|
def cleanup_completed(self, max_age_hours: float = 24.0) -> int:
|
||||||
|
"""Remove old completed/failed jobs from memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_hours: Remove jobs older than this
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of jobs removed
|
||||||
|
"""
|
||||||
|
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
|
||||||
|
removed = 0
|
||||||
|
|
||||||
|
for job_id, job in list(self._jobs.items()):
|
||||||
|
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||||
|
if job.completed_at and job.completed_at.timestamp() < cutoff:
|
||||||
|
del self._jobs[job_id]
|
||||||
|
removed += 1
|
||||||
|
|
||||||
|
if removed:
|
||||||
|
logger.info(f"Cleaned up {removed} old jobs")
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def _worker_loop(self, worker_id: int) -> None:
|
||||||
|
"""Worker loop that processes jobs from queue."""
|
||||||
|
logger.debug(f"Worker {worker_id} started")
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
# Wait for job with timeout
|
||||||
|
try:
|
||||||
|
job_id = await asyncio.wait_for(
|
||||||
|
self._queue.get(), timeout=1.0
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
job = self._jobs.get(job_id)
|
||||||
|
if job is None or job.status == JobStatus.CANCELLED:
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._process_job(job)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Worker {worker_id} error: {e}")
|
||||||
|
|
||||||
|
logger.debug(f"Worker {worker_id} stopped")
|
||||||
|
|
||||||
|
async def _process_job(self, job: GenerationJob) -> None:
|
||||||
|
"""Process a single generation job."""
|
||||||
|
logger.info(f"Processing job {job.id}: {job.model_id}/{job.variant}")
|
||||||
|
|
||||||
|
job.status = JobStatus.PROCESSING
|
||||||
|
job.started_at = datetime.utcnow()
|
||||||
|
|
||||||
|
def progress_callback(progress: float, message: str) -> None:
|
||||||
|
job.progress = progress
|
||||||
|
job.progress_message = message
|
||||||
|
for callback in self._on_progress:
|
||||||
|
try:
|
||||||
|
callback(job)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Progress callback error: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result, generation = await self.generation_service.generate(
|
||||||
|
model_id=job.model_id,
|
||||||
|
variant=job.variant,
|
||||||
|
prompts=job.prompts,
|
||||||
|
duration=job.parameters.get("duration", 10.0),
|
||||||
|
temperature=job.parameters.get("temperature", 1.0),
|
||||||
|
top_k=job.parameters.get("top_k", 250),
|
||||||
|
top_p=job.parameters.get("top_p", 0.0),
|
||||||
|
cfg_coef=job.parameters.get("cfg_coef", 3.0),
|
||||||
|
seed=job.parameters.get("seed"),
|
||||||
|
conditioning=job.conditioning,
|
||||||
|
project_id=job.project_id,
|
||||||
|
preset_used=job.preset_used,
|
||||||
|
tags=job.tags,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
job.status = JobStatus.COMPLETED
|
||||||
|
job.result_id = generation.id
|
||||||
|
job.audio_path = generation.audio_path
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
job.progress = 1.0
|
||||||
|
job.progress_message = "Complete"
|
||||||
|
|
||||||
|
logger.info(f"Job {job.id} completed: {generation.id}")
|
||||||
|
|
||||||
|
for callback in self._on_job_complete:
|
||||||
|
try:
|
||||||
|
callback(job)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Completion callback error: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
job.status = JobStatus.FAILED
|
||||||
|
job.error = str(e)
|
||||||
|
job.completed_at = datetime.utcnow()
|
||||||
|
|
||||||
|
logger.error(f"Job {job.id} failed: {e}")
|
||||||
|
|
||||||
|
for callback in self._on_job_failed:
|
||||||
|
try:
|
||||||
|
callback(job)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.error(f"Failure callback error: {e2}")
|
||||||
|
|
||||||
|
# Callback registration
|
||||||
|
|
||||||
|
def on_job_complete(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||||
|
"""Register callback for job completion."""
|
||||||
|
self._on_job_complete.append(callback)
|
||||||
|
|
||||||
|
def on_job_failed(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||||
|
"""Register callback for job failure."""
|
||||||
|
self._on_job_failed.append(callback)
|
||||||
|
|
||||||
|
def on_progress(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||||
|
"""Register callback for progress updates."""
|
||||||
|
self._on_progress.append(callback)
|
||||||
322
src/services/generation_service.py
Normal file
322
src/services/generation_service.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""Generation service for orchestrating audio generation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from src.core.base_model import GenerationRequest, GenerationResult
|
||||||
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
|
from src.core.model_registry import ModelRegistry
|
||||||
|
from src.core.oom_handler import OOMHandler
|
||||||
|
from src.storage.database import Database, Generation
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GenerationService:
|
||||||
|
"""Orchestrates audio generation across all models.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- Model selection and loading
|
||||||
|
- Generation execution with OOM recovery
|
||||||
|
- Result saving and database recording
|
||||||
|
- Progress callbacks for UI updates
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
registry: ModelRegistry,
|
||||||
|
gpu_manager: GPUMemoryManager,
|
||||||
|
database: Database,
|
||||||
|
output_dir: Path,
|
||||||
|
):
|
||||||
|
"""Initialize generation service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: Model registry for model access
|
||||||
|
gpu_manager: GPU memory manager
|
||||||
|
database: Database for storing generation records
|
||||||
|
output_dir: Directory for saving generated audio
|
||||||
|
"""
|
||||||
|
self.registry = registry
|
||||||
|
self.gpu_manager = gpu_manager
|
||||||
|
self.database = database
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# OOM handler
|
||||||
|
self.oom_handler = OOMHandler(gpu_manager, registry)
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self._generation_count = 0
|
||||||
|
self._total_duration_generated = 0.0
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
variant: Optional[str],
|
||||||
|
prompts: list[str],
|
||||||
|
duration: float = 10.0,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_k: int = 250,
|
||||||
|
top_p: float = 0.0,
|
||||||
|
cfg_coef: float = 3.0,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
conditioning: Optional[dict[str, Any]] = None,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
preset_used: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
) -> tuple[GenerationResult, Generation]:
|
||||||
|
"""Generate audio and save to database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family to use
|
||||||
|
variant: Model variant (None for default)
|
||||||
|
prompts: Text prompts for generation
|
||||||
|
duration: Target duration in seconds
|
||||||
|
temperature: Sampling temperature
|
||||||
|
top_k: Top-k sampling parameter
|
||||||
|
top_p: Nucleus sampling parameter
|
||||||
|
cfg_coef: Classifier-free guidance coefficient
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
conditioning: Optional conditioning data (melody, style, chords, etc.)
|
||||||
|
project_id: Optional project to associate with
|
||||||
|
preset_used: Name of preset used (for metadata)
|
||||||
|
tags: Optional tags for organization
|
||||||
|
progress_callback: Optional callback for progress updates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (GenerationResult, Generation database record)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If model not found or parameters invalid
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Report progress
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.0, "Preparing generation...")
|
||||||
|
|
||||||
|
# Build generation request
|
||||||
|
request = GenerationRequest(
|
||||||
|
prompts=prompts,
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=seed,
|
||||||
|
conditioning=conditioning or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get model configuration
|
||||||
|
family_config, variant_config = self.registry.get_model_config(model_id, variant)
|
||||||
|
actual_variant = variant or family_config.default_variant
|
||||||
|
|
||||||
|
# Check VRAM availability
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.1, "Checking GPU memory...")
|
||||||
|
|
||||||
|
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||||
|
if not can_load:
|
||||||
|
# Try OOM recovery
|
||||||
|
if not self.oom_handler.check_memory_for_operation(variant_config.vram_mb):
|
||||||
|
raise RuntimeError(f"Insufficient GPU memory: {reason}")
|
||||||
|
|
||||||
|
# Generate with OOM recovery wrapper
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
|
||||||
|
|
||||||
|
@self.oom_handler.with_oom_recovery
|
||||||
|
def do_generation() -> GenerationResult:
|
||||||
|
with self.registry.get_model(model_id, actual_variant) as model:
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.4, "Generating audio...")
|
||||||
|
return model.generate(request)
|
||||||
|
|
||||||
|
result = do_generation()
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(0.8, "Saving audio...")
|
||||||
|
|
||||||
|
# Save audio file
|
||||||
|
audio_path = self._save_audio(result)
|
||||||
|
|
||||||
|
# Create database record
|
||||||
|
generation = Generation.create(
|
||||||
|
model=model_id,
|
||||||
|
variant=actual_variant,
|
||||||
|
prompt=prompts[0] if len(prompts) == 1 else "\n".join(prompts),
|
||||||
|
parameters={
|
||||||
|
"duration": duration,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_k": top_k,
|
||||||
|
"top_p": top_p,
|
||||||
|
"cfg_coef": cfg_coef,
|
||||||
|
},
|
||||||
|
project_id=project_id,
|
||||||
|
preset_used=preset_used,
|
||||||
|
conditioning=conditioning,
|
||||||
|
audio_path=str(audio_path),
|
||||||
|
duration_seconds=result.duration,
|
||||||
|
sample_rate=result.sample_rate,
|
||||||
|
tags=tags or [],
|
||||||
|
seed=result.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to database
|
||||||
|
await self.database.create_generation(generation)
|
||||||
|
|
||||||
|
# Update statistics
|
||||||
|
self._generation_count += 1
|
||||||
|
self._total_duration_generated += result.duration
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"Generation complete: {model_id}/{actual_variant}, "
|
||||||
|
f"duration={result.duration:.1f}s, elapsed={elapsed:.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1.0, "Complete!")
|
||||||
|
|
||||||
|
return result, generation
|
||||||
|
|
||||||
|
def _save_audio(self, result: GenerationResult) -> Path:
|
||||||
|
"""Save generated audio to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Generation result with audio tensor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to saved audio file
|
||||||
|
"""
|
||||||
|
# Generate unique filename
|
||||||
|
timestamp = int(time.time() * 1000)
|
||||||
|
filename = f"{result.model_id}_{result.variant}_{timestamp}.wav"
|
||||||
|
filepath = self.output_dir / filename
|
||||||
|
|
||||||
|
# Convert tensor to numpy and save
|
||||||
|
audio = result.audio.numpy()
|
||||||
|
|
||||||
|
# Handle batch dimension - save first sample if batched
|
||||||
|
if audio.ndim == 3:
|
||||||
|
audio = audio[0] # [channels, samples]
|
||||||
|
|
||||||
|
# Transpose to [samples, channels] for soundfile
|
||||||
|
if audio.ndim == 2:
|
||||||
|
audio = audio.T
|
||||||
|
|
||||||
|
sf.write(filepath, audio, result.sample_rate)
|
||||||
|
|
||||||
|
logger.debug(f"Saved audio to {filepath}")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
async def regenerate(
|
||||||
|
self,
|
||||||
|
generation_id: str,
|
||||||
|
new_seed: Optional[int] = None,
|
||||||
|
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||||
|
) -> tuple[GenerationResult, Generation]:
|
||||||
|
"""Regenerate audio using parameters from existing generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_id: ID of generation to regenerate
|
||||||
|
new_seed: Optional new seed (uses original if None)
|
||||||
|
progress_callback: Optional progress callback
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (GenerationResult, new Generation record)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If generation not found
|
||||||
|
"""
|
||||||
|
# Load original generation
|
||||||
|
original = await self.database.get_generation(generation_id)
|
||||||
|
if original is None:
|
||||||
|
raise ValueError(f"Generation not found: {generation_id}")
|
||||||
|
|
||||||
|
# Parse prompts
|
||||||
|
prompts = original.prompt.split("\n") if "\n" in original.prompt else [original.prompt]
|
||||||
|
|
||||||
|
# Regenerate with same or new seed
|
||||||
|
return await self.generate(
|
||||||
|
model_id=original.model,
|
||||||
|
variant=original.variant,
|
||||||
|
prompts=prompts,
|
||||||
|
duration=original.parameters.get("duration", 10.0),
|
||||||
|
temperature=original.parameters.get("temperature", 1.0),
|
||||||
|
top_k=original.parameters.get("top_k", 250),
|
||||||
|
top_p=original.parameters.get("top_p", 0.0),
|
||||||
|
cfg_coef=original.parameters.get("cfg_coef", 3.0),
|
||||||
|
seed=new_seed if new_seed is not None else original.seed,
|
||||||
|
conditioning=original.conditioning,
|
||||||
|
project_id=original.project_id,
|
||||||
|
preset_used=original.preset_used,
|
||||||
|
tags=original.tags,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get generation statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with generation stats
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"generation_count": self._generation_count,
|
||||||
|
"total_duration_generated": self._total_duration_generated,
|
||||||
|
"oom_stats": self.oom_handler.get_stats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def estimate_generation_time(
|
||||||
|
self, model_id: str, variant: Optional[str], duration: float
|
||||||
|
) -> float:
|
||||||
|
"""Estimate generation time for given parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family
|
||||||
|
variant: Model variant
|
||||||
|
duration: Target audio duration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated generation time in seconds
|
||||||
|
"""
|
||||||
|
# Rough estimates based on model type and RTX 4090
|
||||||
|
# These are approximations and vary based on many factors
|
||||||
|
estimates = {
|
||||||
|
"musicgen": {
|
||||||
|
"small": 0.8, # seconds per second of audio
|
||||||
|
"medium": 1.5,
|
||||||
|
"large": 3.0,
|
||||||
|
"melody": 1.8,
|
||||||
|
},
|
||||||
|
"audiogen": {
|
||||||
|
"medium": 1.5,
|
||||||
|
},
|
||||||
|
"magnet": {
|
||||||
|
"small-10secs": 0.3, # Non-autoregressive is faster
|
||||||
|
"medium-10secs": 0.5,
|
||||||
|
"small-30secs": 0.3,
|
||||||
|
"medium-30secs": 0.5,
|
||||||
|
},
|
||||||
|
"musicgen-style": {
|
||||||
|
"medium": 1.8,
|
||||||
|
},
|
||||||
|
"jasco": {
|
||||||
|
"chords-drums-400M": 1.0,
|
||||||
|
"chords-drums-1B": 1.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
family_config, _ = self.registry.get_model_config(model_id, variant)
|
||||||
|
actual_variant = variant or family_config.default_variant
|
||||||
|
|
||||||
|
ratio = estimates.get(model_id, {}).get(actual_variant, 2.0)
|
||||||
|
return duration * ratio + 5.0 # Add 5s for model loading overhead
|
||||||
395
src/services/project_service.py
Normal file
395
src/services/project_service.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""Project service for managing projects and generations."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from src.storage.database import Database, Generation, Project, Preset
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectService:
|
||||||
|
"""Service for managing projects, generations, and presets.
|
||||||
|
|
||||||
|
Provides a high-level API for project organization and
|
||||||
|
generation history management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, database: Database, output_dir: Path):
|
||||||
|
"""Initialize project service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database: Database instance
|
||||||
|
output_dir: Directory where audio files are stored
|
||||||
|
"""
|
||||||
|
self.database = database
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
# Project Operations
|
||||||
|
|
||||||
|
async def create_project(
|
||||||
|
self, name: str, description: str = ""
|
||||||
|
) -> Project:
|
||||||
|
"""Create a new project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Project name
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created project
|
||||||
|
"""
|
||||||
|
project = Project.create(name, description)
|
||||||
|
await self.database.create_project(project)
|
||||||
|
logger.info(f"Created project: {project.id} ({name})")
|
||||||
|
return project
|
||||||
|
|
||||||
|
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||||
|
"""Get a project by ID."""
|
||||||
|
return await self.database.get_project(project_id)
|
||||||
|
|
||||||
|
async def list_projects(
|
||||||
|
self, limit: int = 100, offset: int = 0
|
||||||
|
) -> list[Project]:
|
||||||
|
"""List all projects."""
|
||||||
|
return await self.database.list_projects(limit, offset)
|
||||||
|
|
||||||
|
async def update_project(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
) -> Optional[Project]:
|
||||||
|
"""Update a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
name: New name (None to keep current)
|
||||||
|
description: New description (None to keep current)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated project, or None if not found
|
||||||
|
"""
|
||||||
|
project = await self.database.get_project(project_id)
|
||||||
|
if project is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
project.name = name
|
||||||
|
if description is not None:
|
||||||
|
project.description = description
|
||||||
|
|
||||||
|
await self.database.update_project(project)
|
||||||
|
logger.info(f"Updated project: {project_id}")
|
||||||
|
return project
|
||||||
|
|
||||||
|
async def delete_project(
|
||||||
|
self, project_id: str, delete_files: bool = False
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
delete_files: If True, also delete associated audio files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
if delete_files:
|
||||||
|
# Get all generations and delete their files
|
||||||
|
generations = await self.database.list_generations(project_id=project_id)
|
||||||
|
for gen in generations:
|
||||||
|
if gen.audio_path:
|
||||||
|
try:
|
||||||
|
Path(gen.audio_path).unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete {gen.audio_path}: {e}")
|
||||||
|
|
||||||
|
result = await self.database.delete_project(project_id)
|
||||||
|
if result:
|
||||||
|
logger.info(f"Deleted project: {project_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_project_stats(self, project_id: str) -> dict[str, Any]:
|
||||||
|
"""Get statistics for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with project statistics
|
||||||
|
"""
|
||||||
|
generations = await self.database.list_generations(
|
||||||
|
project_id=project_id, limit=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
total_duration = sum(g.duration_seconds or 0 for g in generations)
|
||||||
|
models_used = {}
|
||||||
|
for gen in generations:
|
||||||
|
key = f"{gen.model}/{gen.variant}"
|
||||||
|
models_used[key] = models_used.get(key, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"generation_count": len(generations),
|
||||||
|
"total_duration_seconds": total_duration,
|
||||||
|
"models_used": models_used,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Generation Operations
|
||||||
|
|
||||||
|
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||||
|
"""Get a generation by ID."""
|
||||||
|
return await self.database.get_generation(generation_id)
|
||||||
|
|
||||||
|
async def list_generations(
|
||||||
|
self,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Generation]:
|
||||||
|
"""List generations with optional filters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Filter by project
|
||||||
|
model: Filter by model family
|
||||||
|
search: Search in prompts, names, and tags
|
||||||
|
limit: Maximum results
|
||||||
|
offset: Pagination offset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of generations
|
||||||
|
"""
|
||||||
|
return await self.database.list_generations(
|
||||||
|
project_id=project_id,
|
||||||
|
model=model,
|
||||||
|
search=search,
|
||||||
|
limit=limit,
|
||||||
|
offset=offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_generation(
|
||||||
|
self,
|
||||||
|
generation_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
tags: Optional[list[str]] = None,
|
||||||
|
notes: Optional[str] = None,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
) -> Optional[Generation]:
|
||||||
|
"""Update a generation's metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_id: Generation ID
|
||||||
|
name: New name
|
||||||
|
tags: New tags (replaces existing)
|
||||||
|
notes: New notes
|
||||||
|
project_id: Move to different project
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated generation, or None if not found
|
||||||
|
"""
|
||||||
|
generation = await self.database.get_generation(generation_id)
|
||||||
|
if generation is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
generation.name = name
|
||||||
|
if tags is not None:
|
||||||
|
generation.tags = tags
|
||||||
|
if notes is not None:
|
||||||
|
generation.notes = notes
|
||||||
|
if project_id is not None:
|
||||||
|
generation.project_id = project_id
|
||||||
|
|
||||||
|
await self.database.update_generation(generation)
|
||||||
|
logger.info(f"Updated generation: {generation_id}")
|
||||||
|
return generation
|
||||||
|
|
||||||
|
async def delete_generation(
|
||||||
|
self, generation_id: str, delete_file: bool = True
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_id: Generation ID
|
||||||
|
delete_file: If True, also delete audio file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
if delete_file:
|
||||||
|
generation = await self.database.get_generation(generation_id)
|
||||||
|
if generation and generation.audio_path:
|
||||||
|
try:
|
||||||
|
Path(generation.audio_path).unlink(missing_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete audio file: {e}")
|
||||||
|
|
||||||
|
result = await self.database.delete_generation(generation_id)
|
||||||
|
if result:
|
||||||
|
logger.info(f"Deleted generation: {generation_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def move_generations_to_project(
|
||||||
|
self, generation_ids: list[str], project_id: Optional[str]
|
||||||
|
) -> int:
|
||||||
|
"""Move multiple generations to a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_ids: List of generation IDs
|
||||||
|
project_id: Target project ID (None to unlink)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of generations moved
|
||||||
|
"""
|
||||||
|
moved = 0
|
||||||
|
for gen_id in generation_ids:
|
||||||
|
result = await self.update_generation(gen_id, project_id=project_id)
|
||||||
|
if result:
|
||||||
|
moved += 1
|
||||||
|
|
||||||
|
logger.info(f"Moved {moved} generations to project {project_id}")
|
||||||
|
return moved
|
||||||
|
|
||||||
|
# Preset Operations
|
||||||
|
|
||||||
|
async def create_preset(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
description: str = "",
|
||||||
|
) -> Preset:
|
||||||
|
"""Create a custom preset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model family this preset is for
|
||||||
|
name: Preset name
|
||||||
|
parameters: Generation parameters
|
||||||
|
description: Optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created preset
|
||||||
|
"""
|
||||||
|
preset = Preset.create(model, name, parameters, description)
|
||||||
|
await self.database.create_preset(preset)
|
||||||
|
logger.info(f"Created preset: {preset.id} ({name}) for {model}")
|
||||||
|
return preset
|
||||||
|
|
||||||
|
async def list_presets(
|
||||||
|
self, model: Optional[str] = None, include_builtin: bool = True
|
||||||
|
) -> list[Preset]:
|
||||||
|
"""List presets with optional model filter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Filter by model family
|
||||||
|
include_builtin: Include built-in presets
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of presets
|
||||||
|
"""
|
||||||
|
return await self.database.list_presets(model, include_builtin)
|
||||||
|
|
||||||
|
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
||||||
|
"""Get a preset by ID."""
|
||||||
|
return await self.database.get_preset(preset_id)
|
||||||
|
|
||||||
|
async def delete_preset(self, preset_id: str) -> bool:
|
||||||
|
"""Delete a custom preset.
|
||||||
|
|
||||||
|
Note: Built-in presets cannot be deleted.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preset_id: Preset ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
result = await self.database.delete_preset(preset_id)
|
||||||
|
if result:
|
||||||
|
logger.info(f"Deleted preset: {preset_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Export Operations
|
||||||
|
|
||||||
|
async def export_project(
|
||||||
|
self, project_id: str, output_path: Path, include_metadata: bool = True
|
||||||
|
) -> Path:
|
||||||
|
"""Export a project as a ZIP archive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
output_path: Output ZIP file path
|
||||||
|
include_metadata: Include JSON metadata file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to created ZIP file
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
project = await self.database.get_project(project_id)
|
||||||
|
if project is None:
|
||||||
|
raise ValueError(f"Project not found: {project_id}")
|
||||||
|
|
||||||
|
generations = await self.database.list_generations(
|
||||||
|
project_id=project_id, limit=10000
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
tmppath = Path(tmpdir)
|
||||||
|
|
||||||
|
# Copy audio files
|
||||||
|
for gen in generations:
|
||||||
|
if gen.audio_path and Path(gen.audio_path).exists():
|
||||||
|
src = Path(gen.audio_path)
|
||||||
|
dst = tmppath / src.name
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
|
||||||
|
# Create metadata file
|
||||||
|
if include_metadata:
|
||||||
|
metadata = {
|
||||||
|
"project": {
|
||||||
|
"id": project.id,
|
||||||
|
"name": project.name,
|
||||||
|
"description": project.description,
|
||||||
|
"created_at": project.created_at.isoformat(),
|
||||||
|
},
|
||||||
|
"generations": [
|
||||||
|
{
|
||||||
|
"id": g.id,
|
||||||
|
"model": g.model,
|
||||||
|
"variant": g.variant,
|
||||||
|
"prompt": g.prompt,
|
||||||
|
"parameters": g.parameters,
|
||||||
|
"duration": g.duration_seconds,
|
||||||
|
"audio_file": Path(g.audio_path).name if g.audio_path else None,
|
||||||
|
"created_at": g.created_at.isoformat(),
|
||||||
|
"tags": g.tags,
|
||||||
|
"seed": g.seed,
|
||||||
|
}
|
||||||
|
for g in generations
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata_path = tmppath / "metadata.json"
|
||||||
|
metadata_path.write_text(json.dumps(metadata, indent=2))
|
||||||
|
|
||||||
|
# Create ZIP
|
||||||
|
output_path = Path(output_path)
|
||||||
|
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||||
|
for file in tmppath.iterdir():
|
||||||
|
zf.write(file, file.name)
|
||||||
|
|
||||||
|
logger.info(f"Exported project {project_id} to {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get overall statistics."""
|
||||||
|
return await self.database.get_stats()
|
||||||
5
src/storage/__init__.py
Normal file
5
src/storage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Storage module for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.storage.database import Database, Generation, Project, Preset
|
||||||
|
|
||||||
|
__all__ = ["Database", "Generation", "Project", "Preset"]
|
||||||
550
src/storage/database.py
Normal file
550
src/storage/database.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
"""SQLite database for projects, generations, and presets."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Project:
|
||||||
|
"""Project entity for organizing generations."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
description: str = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, description: str = "") -> "Project":
|
||||||
|
"""Create a new project with generated ID."""
|
||||||
|
now = datetime.utcnow()
|
||||||
|
return cls(
|
||||||
|
id=f"proj_{uuid.uuid4().hex[:12]}",
|
||||||
|
name=name,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Generation:
|
||||||
|
"""Audio generation record."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
project_id: Optional[str]
|
||||||
|
model: str
|
||||||
|
variant: str
|
||||||
|
prompt: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
created_at: datetime
|
||||||
|
audio_path: Optional[str] = None
|
||||||
|
duration_seconds: Optional[float] = None
|
||||||
|
sample_rate: Optional[int] = None
|
||||||
|
preset_used: Optional[str] = None
|
||||||
|
conditioning: dict[str, Any] = field(default_factory=dict)
|
||||||
|
name: Optional[str] = None
|
||||||
|
tags: list[str] = field(default_factory=list)
|
||||||
|
notes: Optional[str] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
variant: str,
|
||||||
|
prompt: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> "Generation":
|
||||||
|
"""Create a new generation record."""
|
||||||
|
return cls(
|
||||||
|
id=f"gen_{uuid.uuid4().hex[:12]}",
|
||||||
|
project_id=project_id,
|
||||||
|
model=model,
|
||||||
|
variant=variant,
|
||||||
|
prompt=prompt,
|
||||||
|
parameters=parameters,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Preset:
|
||||||
|
"""Generation parameter preset."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
model: str
|
||||||
|
name: str
|
||||||
|
parameters: dict[str, Any]
|
||||||
|
created_at: datetime
|
||||||
|
description: str = ""
|
||||||
|
is_builtin: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
name: str,
|
||||||
|
parameters: dict[str, Any],
|
||||||
|
description: str = "",
|
||||||
|
) -> "Preset":
|
||||||
|
"""Create a new custom preset."""
|
||||||
|
return cls(
|
||||||
|
id=f"preset_{uuid.uuid4().hex[:12]}",
|
||||||
|
model=model,
|
||||||
|
name=name,
|
||||||
|
parameters=parameters,
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
description=description,
|
||||||
|
is_builtin=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""Async SQLite database for AudioCraft Studio.
|
||||||
|
|
||||||
|
Handles storage of projects, generations, and presets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS projects (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT DEFAULT '',
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS generations (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
variant TEXT NOT NULL,
|
||||||
|
prompt TEXT NOT NULL,
|
||||||
|
parameters JSON NOT NULL,
|
||||||
|
preset_used TEXT,
|
||||||
|
conditioning JSON,
|
||||||
|
audio_path TEXT,
|
||||||
|
duration_seconds REAL,
|
||||||
|
sample_rate INTEGER,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
name TEXT,
|
||||||
|
tags JSON,
|
||||||
|
notes TEXT,
|
||||||
|
seed INTEGER
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS presets (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT DEFAULT '',
|
||||||
|
parameters JSON NOT NULL,
|
||||||
|
is_builtin BOOLEAN DEFAULT FALSE,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
"""Initialize database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database file
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self._connection: Optional[aiosqlite.Connection] = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Open database connection and initialize schema."""
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._connection = await aiosqlite.connect(self.db_path)
|
||||||
|
self._connection.row_factory = aiosqlite.Row
|
||||||
|
|
||||||
|
# Initialize schema
|
||||||
|
await self._connection.executescript(self.SCHEMA)
|
||||||
|
await self._connection.commit()
|
||||||
|
|
||||||
|
logger.info(f"Database connected: {self.db_path}")
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
if self._connection:
|
||||||
|
await self._connection.close()
|
||||||
|
self._connection = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conn(self) -> aiosqlite.Connection:
|
||||||
|
"""Get active connection."""
|
||||||
|
if not self._connection:
|
||||||
|
raise RuntimeError("Database not connected")
|
||||||
|
return self._connection
|
||||||
|
|
||||||
|
# Project Methods
|
||||||
|
|
||||||
|
async def create_project(self, project: Project) -> Project:
|
||||||
|
"""Create a new project."""
|
||||||
|
await self.conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO projects (id, name, description, created_at, updated_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
project.id,
|
||||||
|
project.name,
|
||||||
|
project.description,
|
||||||
|
project.created_at.isoformat(),
|
||||||
|
project.updated_at.isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
return project
|
||||||
|
|
||||||
|
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||||
|
"""Get a project by ID."""
|
||||||
|
async with self.conn.execute(
|
||||||
|
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return Project(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"] or "",
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_projects(
|
||||||
|
self, limit: int = 100, offset: int = 0
|
||||||
|
) -> list[Project]:
|
||||||
|
"""List all projects, ordered by last update."""
|
||||||
|
async with self.conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM projects
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(limit, offset),
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [
|
||||||
|
Project(
|
||||||
|
id=row["id"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"] or "",
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
async def update_project(self, project: Project) -> None:
|
||||||
|
"""Update a project."""
|
||||||
|
project.updated_at = datetime.utcnow()
|
||||||
|
await self.conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE projects SET name = ?, description = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(project.name, project.description, project.updated_at.isoformat(), project.id),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
async def delete_project(self, project_id: str) -> bool:
|
||||||
|
"""Delete a project (generations are kept but unlinked)."""
|
||||||
|
result = await self.conn.execute(
|
||||||
|
"DELETE FROM projects WHERE id = ?", (project_id,)
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
# Generation Methods
|
||||||
|
|
||||||
|
async def create_generation(self, generation: Generation) -> Generation:
|
||||||
|
"""Create a new generation record."""
|
||||||
|
await self.conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO generations (
|
||||||
|
id, project_id, model, variant, prompt, parameters,
|
||||||
|
preset_used, conditioning, audio_path, duration_seconds,
|
||||||
|
sample_rate, created_at, name, tags, notes, seed
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
generation.id,
|
||||||
|
generation.project_id,
|
||||||
|
generation.model,
|
||||||
|
generation.variant,
|
||||||
|
generation.prompt,
|
||||||
|
json.dumps(generation.parameters),
|
||||||
|
generation.preset_used,
|
||||||
|
json.dumps(generation.conditioning),
|
||||||
|
generation.audio_path,
|
||||||
|
generation.duration_seconds,
|
||||||
|
generation.sample_rate,
|
||||||
|
generation.created_at.isoformat(),
|
||||||
|
generation.name,
|
||||||
|
json.dumps(generation.tags),
|
||||||
|
generation.notes,
|
||||||
|
generation.seed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
# Update project's updated_at if linked
|
||||||
|
if generation.project_id:
|
||||||
|
await self.conn.execute(
|
||||||
|
"UPDATE projects SET updated_at = ? WHERE id = ?",
|
||||||
|
(datetime.utcnow().isoformat(), generation.project_id),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
return generation
|
||||||
|
|
||||||
|
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||||
|
"""Get a generation by ID."""
|
||||||
|
async with self.conn.execute(
|
||||||
|
"SELECT * FROM generations WHERE id = ?", (generation_id,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return self._row_to_generation(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_generations(
|
||||||
|
self,
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
search: Optional[str] = None,
|
||||||
|
) -> list[Generation]:
|
||||||
|
"""List generations with optional filters."""
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
conditions.append("project_id = ?")
|
||||||
|
params.append(project_id)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
conditions.append("model = ?")
|
||||||
|
params.append(model)
|
||||||
|
|
||||||
|
if search:
|
||||||
|
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
|
||||||
|
search_pattern = f"%{search}%"
|
||||||
|
params.extend([search_pattern, search_pattern, search_pattern])
|
||||||
|
|
||||||
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||||
|
|
||||||
|
async with self.conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT * FROM generations
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
""",
|
||||||
|
(*params, limit, offset),
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [self._row_to_generation(row) for row in rows]
|
||||||
|
|
||||||
|
async def update_generation(self, generation: Generation) -> None:
|
||||||
|
"""Update a generation record."""
|
||||||
|
await self.conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE generations SET
|
||||||
|
project_id = ?, name = ?, tags = ?, notes = ?,
|
||||||
|
audio_path = ?, duration_seconds = ?, sample_rate = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
generation.project_id,
|
||||||
|
generation.name,
|
||||||
|
json.dumps(generation.tags),
|
||||||
|
generation.notes,
|
||||||
|
generation.audio_path,
|
||||||
|
generation.duration_seconds,
|
||||||
|
generation.sample_rate,
|
||||||
|
generation.id,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
async def delete_generation(self, generation_id: str) -> bool:
|
||||||
|
"""Delete a generation record."""
|
||||||
|
result = await self.conn.execute(
|
||||||
|
"DELETE FROM generations WHERE id = ?", (generation_id,)
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
async def count_generations(
|
||||||
|
self, project_id: Optional[str] = None, model: Optional[str] = None
|
||||||
|
) -> int:
|
||||||
|
"""Count generations with optional filters."""
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
conditions.append("project_id = ?")
|
||||||
|
params.append(project_id)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
conditions.append("model = ?")
|
||||||
|
params.append(model)
|
||||||
|
|
||||||
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||||
|
|
||||||
|
async with self.conn.execute(
|
||||||
|
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
|
||||||
|
params,
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
|
||||||
|
"""Convert database row to Generation object."""
|
||||||
|
return Generation(
|
||||||
|
id=row["id"],
|
||||||
|
project_id=row["project_id"],
|
||||||
|
model=row["model"],
|
||||||
|
variant=row["variant"],
|
||||||
|
prompt=row["prompt"],
|
||||||
|
parameters=json.loads(row["parameters"]),
|
||||||
|
preset_used=row["preset_used"],
|
||||||
|
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
|
||||||
|
audio_path=row["audio_path"],
|
||||||
|
duration_seconds=row["duration_seconds"],
|
||||||
|
sample_rate=row["sample_rate"],
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
name=row["name"],
|
||||||
|
tags=json.loads(row["tags"]) if row["tags"] else [],
|
||||||
|
notes=row["notes"],
|
||||||
|
seed=row["seed"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preset Methods
|
||||||
|
|
||||||
|
async def create_preset(self, preset: Preset) -> Preset:
|
||||||
|
"""Create a new preset."""
|
||||||
|
await self.conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
preset.id,
|
||||||
|
preset.model,
|
||||||
|
preset.name,
|
||||||
|
preset.description,
|
||||||
|
json.dumps(preset.parameters),
|
||||||
|
preset.is_builtin,
|
||||||
|
preset.created_at.isoformat(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
return preset
|
||||||
|
|
||||||
|
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
||||||
|
"""Get a preset by ID."""
|
||||||
|
async with self.conn.execute(
|
||||||
|
"SELECT * FROM presets WHERE id = ?", (preset_id,)
|
||||||
|
) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return self._row_to_preset(row)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def list_presets(
|
||||||
|
self, model: Optional[str] = None, include_builtin: bool = True
|
||||||
|
) -> list[Preset]:
|
||||||
|
"""List presets with optional model filter."""
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if model:
|
||||||
|
conditions.append("model = ?")
|
||||||
|
params.append(model)
|
||||||
|
|
||||||
|
if not include_builtin:
|
||||||
|
conditions.append("is_builtin = FALSE")
|
||||||
|
|
||||||
|
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||||
|
|
||||||
|
async with self.conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT * FROM presets
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY is_builtin DESC, name ASC
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return [self._row_to_preset(row) for row in rows]
|
||||||
|
|
||||||
|
async def delete_preset(self, preset_id: str) -> bool:
|
||||||
|
"""Delete a preset (only custom presets can be deleted)."""
|
||||||
|
result = await self.conn.execute(
|
||||||
|
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
|
||||||
|
(preset_id,),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
return result.rowcount > 0
|
||||||
|
|
||||||
|
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
|
||||||
|
"""Convert database row to Preset object."""
|
||||||
|
return Preset(
|
||||||
|
id=row["id"],
|
||||||
|
model=row["model"],
|
||||||
|
name=row["name"],
|
||||||
|
description=row["description"] or "",
|
||||||
|
parameters=json.loads(row["parameters"]),
|
||||||
|
is_builtin=bool(row["is_builtin"]),
|
||||||
|
created_at=datetime.fromisoformat(row["created_at"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Utility Methods
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get database statistics."""
|
||||||
|
stats = {}
|
||||||
|
|
||||||
|
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats["projects"] = row[0] if row else 0
|
||||||
|
|
||||||
|
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats["generations"] = row[0] if row else 0
|
||||||
|
|
||||||
|
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
stats["presets"] = row[0] if row else 0
|
||||||
|
|
||||||
|
async with self.conn.execute(
|
||||||
|
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
|
||||||
|
) as cursor:
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
|
||||||
|
|
||||||
|
return stats
|
||||||
5
src/ui/__init__.py
Normal file
5
src/ui/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Gradio UI for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.ui.app import create_app
|
||||||
|
|
||||||
|
__all__ = ["create_app"]
|
||||||
355
src/ui/app.py
Normal file
355
src/ui/app.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""Main Gradio application for AudioCraft Studio."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from src.ui.theme import create_audiocraft_theme, get_custom_css
|
||||||
|
from src.ui.state import UIState, DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.vram_monitor import create_vram_monitor
|
||||||
|
from src.ui.tabs import (
|
||||||
|
create_dashboard_tab,
|
||||||
|
create_musicgen_tab,
|
||||||
|
create_audiogen_tab,
|
||||||
|
create_magnet_tab,
|
||||||
|
create_style_tab,
|
||||||
|
create_jasco_tab,
|
||||||
|
)
|
||||||
|
from src.ui.pages import create_projects_page, create_settings_page
|
||||||
|
|
||||||
|
from config.settings import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class AudioCraftApp:
|
||||||
|
"""Main AudioCraft Studio Gradio application."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
generation_service: Any = None,
|
||||||
|
batch_processor: Any = None,
|
||||||
|
project_service: Any = None,
|
||||||
|
gpu_manager: Any = None,
|
||||||
|
model_registry: Any = None,
|
||||||
|
):
|
||||||
|
"""Initialize the application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_service: Service for handling generations
|
||||||
|
batch_processor: Service for batch/queue processing
|
||||||
|
project_service: Service for project management
|
||||||
|
gpu_manager: GPU memory manager
|
||||||
|
model_registry: Model registry for loading/unloading
|
||||||
|
"""
|
||||||
|
self.settings = get_settings()
|
||||||
|
self.generation_service = generation_service
|
||||||
|
self.batch_processor = batch_processor
|
||||||
|
self.project_service = project_service
|
||||||
|
self.gpu_manager = gpu_manager
|
||||||
|
self.model_registry = model_registry
|
||||||
|
|
||||||
|
self.ui_state = UIState()
|
||||||
|
self.app: Optional[gr.Blocks] = None
|
||||||
|
|
||||||
|
def _get_queue_status(self) -> dict[str, Any]:
|
||||||
|
"""Get current queue status."""
|
||||||
|
if self.batch_processor:
|
||||||
|
return {
|
||||||
|
"queue_size": len(self.batch_processor.queue),
|
||||||
|
"active_jobs": self.batch_processor.active_count,
|
||||||
|
"completed_today": self.batch_processor.completed_count,
|
||||||
|
}
|
||||||
|
return {"queue_size": 0, "active_jobs": 0, "completed_today": 0}
|
||||||
|
|
||||||
|
def _get_recent_generations(self, limit: int = 5) -> list[dict[str, Any]]:
|
||||||
|
"""Get recent generations."""
|
||||||
|
if self.project_service:
|
||||||
|
try:
|
||||||
|
return asyncio.run(self.project_service.get_recent_generations(limit))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_gpu_status(self) -> dict[str, Any]:
|
||||||
|
"""Get GPU memory status."""
|
||||||
|
if self.gpu_manager:
|
||||||
|
return {
|
||||||
|
"used_gb": self.gpu_manager.get_used_memory() / 1024**3,
|
||||||
|
"total_gb": self.gpu_manager.total_memory / 1024**3,
|
||||||
|
"utilization_percent": self.gpu_manager.get_utilization(),
|
||||||
|
"available_gb": self.gpu_manager.get_available_memory() / 1024**3,
|
||||||
|
}
|
||||||
|
return {"used_gb": 0, "total_gb": 24, "utilization_percent": 0, "available_gb": 24}
|
||||||
|
|
||||||
|
async def _generate(self, **kwargs) -> tuple[Any, Any]:
|
||||||
|
"""Generate audio using the generation service."""
|
||||||
|
if self.generation_service:
|
||||||
|
return await self.generation_service.generate(**kwargs)
|
||||||
|
raise RuntimeError("Generation service not configured")
|
||||||
|
|
||||||
|
def _add_to_queue(self, **kwargs) -> Any:
|
||||||
|
"""Add generation job to queue."""
|
||||||
|
if self.batch_processor:
|
||||||
|
return self.batch_processor.add_job(**kwargs)
|
||||||
|
raise RuntimeError("Batch processor not configured")
|
||||||
|
|
||||||
|
def _get_projects(self) -> list[dict]:
|
||||||
|
"""Get all projects."""
|
||||||
|
if self.project_service:
|
||||||
|
try:
|
||||||
|
return asyncio.run(self.project_service.list_projects())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _get_generations(self, project_id: str, limit: int, offset: int) -> list[dict]:
|
||||||
|
"""Get generations for a project."""
|
||||||
|
if self.project_service:
|
||||||
|
try:
|
||||||
|
return asyncio.run(
|
||||||
|
self.project_service.list_generations(project_id, limit, offset)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _delete_generation(self, generation_id: str) -> bool:
|
||||||
|
"""Delete a generation."""
|
||||||
|
if self.project_service:
|
||||||
|
try:
|
||||||
|
asyncio.run(self.project_service.delete_generation(generation_id))
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _export_project(self, project_id: str) -> str:
|
||||||
|
"""Export project as ZIP."""
|
||||||
|
if self.project_service:
|
||||||
|
return asyncio.run(self.project_service.export_project_zip(project_id))
|
||||||
|
raise RuntimeError("Project service not configured")
|
||||||
|
|
||||||
|
def _create_project(self, name: str, description: str) -> dict:
|
||||||
|
"""Create a new project."""
|
||||||
|
if self.project_service:
|
||||||
|
return asyncio.run(self.project_service.create_project(name, description))
|
||||||
|
raise RuntimeError("Project service not configured")
|
||||||
|
|
||||||
|
def _get_app_settings(self) -> dict:
|
||||||
|
"""Get application settings."""
|
||||||
|
return {
|
||||||
|
"output_dir": str(self.settings.output_dir),
|
||||||
|
"default_format": self.settings.default_format,
|
||||||
|
"sample_rate": self.settings.sample_rate,
|
||||||
|
"normalize_audio": self.settings.normalize_audio,
|
||||||
|
"theme_mode": "Dark",
|
||||||
|
"show_advanced": False,
|
||||||
|
"auto_play": True,
|
||||||
|
"comfyui_reserve_gb": self.settings.comfyui_reserve_gb,
|
||||||
|
"idle_timeout_minutes": self.settings.idle_unload_minutes,
|
||||||
|
"max_loaded_models": self.settings.max_loaded_models,
|
||||||
|
"musicgen_variant": "medium",
|
||||||
|
"musicgen_duration": 10,
|
||||||
|
"audiogen_duration": 5,
|
||||||
|
"magnet_variant": "medium",
|
||||||
|
"magnet_decoding_steps": 20,
|
||||||
|
"api_enabled": self.settings.api_enabled,
|
||||||
|
"api_port": self.settings.api_port,
|
||||||
|
"rate_limit": self.settings.api_rate_limit,
|
||||||
|
"max_batch_size": self.settings.max_batch_size,
|
||||||
|
"max_queue_size": self.settings.max_queue_size,
|
||||||
|
"max_workers": 1,
|
||||||
|
"priority_queue": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _update_app_settings(self, settings: dict) -> bool:
|
||||||
|
"""Update application settings."""
|
||||||
|
# In a real implementation, this would persist settings
|
||||||
|
# For now, just return success
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _clear_cache(self) -> bool:
|
||||||
|
"""Clear model cache."""
|
||||||
|
if self.model_registry:
|
||||||
|
try:
|
||||||
|
self.model_registry.clear_cache()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _unload_all_models(self) -> bool:
|
||||||
|
"""Unload all models from memory."""
|
||||||
|
if self.model_registry:
|
||||||
|
try:
|
||||||
|
asyncio.run(self.model_registry.unload_all())
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def build(self) -> gr.Blocks:
|
||||||
|
"""Build the Gradio application."""
|
||||||
|
theme = create_audiocraft_theme()
|
||||||
|
css = get_custom_css()
|
||||||
|
|
||||||
|
with gr.Blocks(
|
||||||
|
theme=theme,
|
||||||
|
css=css,
|
||||||
|
title="AudioCraft Studio",
|
||||||
|
analytics_enabled=False,
|
||||||
|
) as app:
|
||||||
|
# Header with VRAM monitor
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=4):
|
||||||
|
gr.Markdown("# AudioCraft Studio")
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
vram_monitor = create_vram_monitor(
|
||||||
|
get_status_fn=self._get_gpu_status,
|
||||||
|
update_interval=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Main tabs
|
||||||
|
with gr.Tabs() as main_tabs:
|
||||||
|
# Dashboard
|
||||||
|
with gr.TabItem("Dashboard", id="dashboard"):
|
||||||
|
dashboard = create_dashboard_tab(
|
||||||
|
get_queue_status=self._get_queue_status,
|
||||||
|
get_recent_generations=self._get_recent_generations,
|
||||||
|
get_gpu_status=self._get_gpu_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model tabs
|
||||||
|
with gr.TabItem("MusicGen", id="musicgen"):
|
||||||
|
musicgen = create_musicgen_tab(
|
||||||
|
generate_fn=self._generate,
|
||||||
|
add_to_queue_fn=self._add_to_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("AudioGen", id="audiogen"):
|
||||||
|
audiogen = create_audiogen_tab(
|
||||||
|
generate_fn=self._generate,
|
||||||
|
add_to_queue_fn=self._add_to_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("MAGNeT", id="magnet"):
|
||||||
|
magnet = create_magnet_tab(
|
||||||
|
generate_fn=self._generate,
|
||||||
|
add_to_queue_fn=self._add_to_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Style", id="style"):
|
||||||
|
style = create_style_tab(
|
||||||
|
generate_fn=self._generate,
|
||||||
|
add_to_queue_fn=self._add_to_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("JASCO", id="jasco"):
|
||||||
|
jasco = create_jasco_tab(
|
||||||
|
generate_fn=self._generate,
|
||||||
|
add_to_queue_fn=self._add_to_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Projects
|
||||||
|
with gr.TabItem("Projects", id="projects"):
|
||||||
|
projects = create_projects_page(
|
||||||
|
get_projects=self._get_projects,
|
||||||
|
get_generations=self._get_generations,
|
||||||
|
delete_generation=self._delete_generation,
|
||||||
|
export_project=self._export_project,
|
||||||
|
create_project=self._create_project,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Settings
|
||||||
|
with gr.TabItem("Settings", id="settings"):
|
||||||
|
settings = create_settings_page(
|
||||||
|
get_settings=self._get_app_settings,
|
||||||
|
update_settings=self._update_app_settings,
|
||||||
|
get_gpu_info=self._get_gpu_status,
|
||||||
|
clear_cache=self._clear_cache,
|
||||||
|
unload_all_models=self._unload_all_models,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Footer
|
||||||
|
gr.Markdown("---")
|
||||||
|
gr.Markdown(
|
||||||
|
"AudioCraft Studio | "
|
||||||
|
"[Documentation](https://github.com/facebookresearch/audiocraft) | "
|
||||||
|
"Powered by Meta AudioCraft"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store component references
|
||||||
|
self.components = {
|
||||||
|
"vram_monitor": vram_monitor,
|
||||||
|
"dashboard": dashboard,
|
||||||
|
"musicgen": musicgen,
|
||||||
|
"audiogen": audiogen,
|
||||||
|
"magnet": magnet,
|
||||||
|
"style": style,
|
||||||
|
"jasco": jasco,
|
||||||
|
"projects": projects,
|
||||||
|
"settings": settings,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.app = app
|
||||||
|
return app
|
||||||
|
|
||||||
|
def launch(
|
||||||
|
self,
|
||||||
|
server_name: Optional[str] = None,
|
||||||
|
server_port: Optional[int] = None,
|
||||||
|
share: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Launch the Gradio application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Server hostname
|
||||||
|
server_port: Server port
|
||||||
|
share: Whether to create a public share link
|
||||||
|
**kwargs: Additional arguments for gr.Blocks.launch()
|
||||||
|
"""
|
||||||
|
if self.app is None:
|
||||||
|
self.build()
|
||||||
|
|
||||||
|
self.app.launch(
|
||||||
|
server_name=server_name or self.settings.host,
|
||||||
|
server_port=server_port or self.settings.gradio_port,
|
||||||
|
share=share,
|
||||||
|
show_error=True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(
|
||||||
|
generation_service: Any = None,
|
||||||
|
batch_processor: Any = None,
|
||||||
|
project_service: Any = None,
|
||||||
|
gpu_manager: Any = None,
|
||||||
|
model_registry: Any = None,
|
||||||
|
) -> AudioCraftApp:
|
||||||
|
"""Create and return the AudioCraft application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generation_service: Service for handling generations
|
||||||
|
batch_processor: Service for batch/queue processing
|
||||||
|
project_service: Service for project management
|
||||||
|
gpu_manager: GPU memory manager
|
||||||
|
model_registry: Model registry for loading/unloading
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured AudioCraftApp instance
|
||||||
|
"""
|
||||||
|
return AudioCraftApp(
|
||||||
|
generation_service=generation_service,
|
||||||
|
batch_processor=batch_processor,
|
||||||
|
project_service=project_service,
|
||||||
|
gpu_manager=gpu_manager,
|
||||||
|
model_registry=model_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Standalone launch for development/testing
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = create_app()
|
||||||
|
app.launch()
|
||||||
13
src/ui/components/__init__.py
Normal file
13
src/ui/components/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""Reusable UI components for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.ui.components.vram_monitor import create_vram_monitor
|
||||||
|
from src.ui.components.audio_player import create_audio_player
|
||||||
|
from src.ui.components.preset_selector import create_preset_selector
|
||||||
|
from src.ui.components.generation_params import create_generation_params
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_vram_monitor",
|
||||||
|
"create_audio_player",
|
||||||
|
"create_preset_selector",
|
||||||
|
"create_generation_params",
|
||||||
|
]
|
||||||
178
src/ui/components/audio_player.py
Normal file
178
src/ui/components/audio_player.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""Audio player component with waveform visualization."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Callable
|
||||||
|
|
||||||
|
|
||||||
|
def create_audio_player(
|
||||||
|
label: str = "Generated Audio",
|
||||||
|
show_waveform: bool = True,
|
||||||
|
show_download: bool = True,
|
||||||
|
show_info: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create audio player component with optional waveform.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label: Label for the audio component
|
||||||
|
show_waveform: Show waveform image
|
||||||
|
show_download: Show download buttons
|
||||||
|
show_info: Show audio info (duration, sample rate)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
# Audio player
|
||||||
|
audio_output = gr.Audio(
|
||||||
|
label=label,
|
||||||
|
type="filepath",
|
||||||
|
interactive=False,
|
||||||
|
show_download_button=show_download,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Waveform visualization
|
||||||
|
if show_waveform:
|
||||||
|
waveform_image = gr.Image(
|
||||||
|
label="Waveform",
|
||||||
|
type="filepath",
|
||||||
|
interactive=False,
|
||||||
|
height=100,
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
waveform_image = None
|
||||||
|
|
||||||
|
# Audio info
|
||||||
|
if show_info:
|
||||||
|
with gr.Row():
|
||||||
|
duration_text = gr.Textbox(
|
||||||
|
label="Duration",
|
||||||
|
value="",
|
||||||
|
interactive=False,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
sample_rate_text = gr.Textbox(
|
||||||
|
label="Sample Rate",
|
||||||
|
value="",
|
||||||
|
interactive=False,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
seed_text = gr.Textbox(
|
||||||
|
label="Seed",
|
||||||
|
value="",
|
||||||
|
interactive=False,
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
duration_text = None
|
||||||
|
sample_rate_text = None
|
||||||
|
seed_text = None
|
||||||
|
|
||||||
|
# Download buttons
|
||||||
|
if show_download:
|
||||||
|
with gr.Row():
|
||||||
|
download_wav = gr.Button("Download WAV", size="sm")
|
||||||
|
download_mp3 = gr.Button("Download MP3", size="sm")
|
||||||
|
download_flac = gr.Button("Download FLAC", size="sm")
|
||||||
|
else:
|
||||||
|
download_wav = download_mp3 = download_flac = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"audio": audio_output,
|
||||||
|
"waveform": waveform_image,
|
||||||
|
"duration": duration_text,
|
||||||
|
"sample_rate": sample_rate_text,
|
||||||
|
"seed": seed_text,
|
||||||
|
"download_wav": download_wav,
|
||||||
|
"download_mp3": download_mp3,
|
||||||
|
"download_flac": download_flac,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def update_audio_player(
|
||||||
|
audio_path: Optional[str],
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
sample_rate: Optional[int] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
waveform_path: Optional[str] = None,
|
||||||
|
) -> tuple:
|
||||||
|
"""Update audio player with new audio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_path: Path to audio file
|
||||||
|
duration: Audio duration in seconds
|
||||||
|
sample_rate: Audio sample rate
|
||||||
|
seed: Generation seed
|
||||||
|
waveform_path: Path to waveform image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of update values for components
|
||||||
|
"""
|
||||||
|
duration_str = f"{duration:.2f}s" if duration else ""
|
||||||
|
sample_rate_str = f"{sample_rate} Hz" if sample_rate else ""
|
||||||
|
seed_str = str(seed) if seed is not None else ""
|
||||||
|
|
||||||
|
waveform_update = gr.update(value=waveform_path, visible=waveform_path is not None)
|
||||||
|
|
||||||
|
return (
|
||||||
|
audio_path,
|
||||||
|
waveform_update,
|
||||||
|
duration_str,
|
||||||
|
sample_rate_str,
|
||||||
|
seed_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_generation_output() -> dict[str, Any]:
|
||||||
|
"""Create generation output section with audio player and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Output")
|
||||||
|
|
||||||
|
# Status/progress
|
||||||
|
with gr.Row():
|
||||||
|
status_text = gr.Markdown("Ready to generate")
|
||||||
|
progress_bar = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=100,
|
||||||
|
value=0,
|
||||||
|
label="Progress",
|
||||||
|
interactive=False,
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio player
|
||||||
|
player = create_audio_player(
|
||||||
|
label="Generated Audio",
|
||||||
|
show_waveform=True,
|
||||||
|
show_download=True,
|
||||||
|
show_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generation metadata
|
||||||
|
with gr.Accordion("Generation Details", open=False):
|
||||||
|
generation_info = gr.JSON(
|
||||||
|
label="Parameters",
|
||||||
|
value={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actions
|
||||||
|
with gr.Row():
|
||||||
|
save_btn = gr.Button("Save to Project", variant="secondary")
|
||||||
|
regenerate_btn = gr.Button("Regenerate", variant="secondary")
|
||||||
|
add_queue_btn = gr.Button("Add to Queue", variant="secondary")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status_text,
|
||||||
|
"progress": progress_bar,
|
||||||
|
"player": player,
|
||||||
|
"info": generation_info,
|
||||||
|
"save_btn": save_btn,
|
||||||
|
"regenerate_btn": regenerate_btn,
|
||||||
|
"add_queue_btn": add_queue_btn,
|
||||||
|
}
|
||||||
199
src/ui/components/generation_params.py
Normal file
199
src/ui/components/generation_params.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""Generation parameters component."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def create_generation_params(
|
||||||
|
model_id: str,
|
||||||
|
show_advanced: bool = False,
|
||||||
|
max_duration: float = 30.0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create generation parameters panel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family for customizing available options
|
||||||
|
show_advanced: Whether to show advanced parameters by default
|
||||||
|
max_duration: Maximum allowed duration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
# Model-specific defaults
|
||||||
|
defaults = {
|
||||||
|
"musicgen": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||||
|
"audiogen": {"duration": 5, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||||
|
"magnet": {"duration": 10, "temperature": 3.0, "top_k": 0, "top_p": 0.9, "cfg_coef": 3.0},
|
||||||
|
"musicgen-style": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||||
|
"jasco": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||||
|
}
|
||||||
|
|
||||||
|
d = defaults.get(model_id, defaults["musicgen"])
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
# Basic parameters (always visible)
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=max_duration,
|
||||||
|
value=d["duration"],
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
info="Length of audio to generate",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advanced parameters (expandable)
|
||||||
|
with gr.Accordion("Advanced Parameters", open=show_advanced):
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
value=d["temperature"],
|
||||||
|
step=0.05,
|
||||||
|
label="Temperature",
|
||||||
|
info="Higher = more random, lower = more deterministic",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=d["cfg_coef"],
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
info="Classifier-free guidance strength",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=d["top_k"],
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
info="Token selection limit (0 = disabled)",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=d["top_p"],
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P (Nucleus)",
|
||||||
|
info="Cumulative probability threshold (0 = disabled)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed",
|
||||||
|
info="Random seed for reproducibility (leave empty for random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
use_random_seed = gr.Checkbox(
|
||||||
|
value=True,
|
||||||
|
label="Random Seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset button
|
||||||
|
reset_btn = gr.Button("Reset to Defaults", size="sm", variant="secondary")
|
||||||
|
|
||||||
|
def reset_params():
|
||||||
|
"""Reset all parameters to defaults."""
|
||||||
|
return (
|
||||||
|
d["duration"],
|
||||||
|
d["temperature"],
|
||||||
|
d["cfg_coef"],
|
||||||
|
d["top_k"],
|
||||||
|
d["top_p"],
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
reset_btn.click(
|
||||||
|
fn=reset_params,
|
||||||
|
outputs=[
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
use_random_seed,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Link random seed checkbox to seed input
|
||||||
|
def toggle_seed(use_random: bool, current_seed: Optional[int]):
|
||||||
|
if use_random:
|
||||||
|
return gr.update(value=None, interactive=False)
|
||||||
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
use_random_seed.change(
|
||||||
|
fn=toggle_seed,
|
||||||
|
inputs=[use_random_seed, seed_input],
|
||||||
|
outputs=[seed_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"duration": duration_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"seed": seed_input,
|
||||||
|
"use_random_seed": use_random_seed,
|
||||||
|
"reset_btn": reset_btn,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_variant_selector(
|
||||||
|
model_id: str,
|
||||||
|
variants: list[dict[str, Any]],
|
||||||
|
default_variant: str = "medium",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create model variant selector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
variants: List of variant configurations
|
||||||
|
default_variant: Default variant to select
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
# Build choices with descriptions
|
||||||
|
choices = []
|
||||||
|
for v in variants:
|
||||||
|
name = v.get("name", v.get("id", "unknown"))
|
||||||
|
vram = v.get("vram_mb", 0)
|
||||||
|
desc = v.get("description", "")
|
||||||
|
label = f"{name} ({vram/1024:.1f}GB)"
|
||||||
|
choices.append((label, name))
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=choices,
|
||||||
|
value=default_variant,
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
variant_info = gr.Markdown(
|
||||||
|
value="",
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_info(variant_name: str):
|
||||||
|
for v in variants:
|
||||||
|
if v.get("name", v.get("id")) == variant_name:
|
||||||
|
return v.get("description", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
variant_dropdown.change(
|
||||||
|
fn=update_info,
|
||||||
|
inputs=[variant_dropdown],
|
||||||
|
outputs=[variant_info],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dropdown": variant_dropdown,
|
||||||
|
"info": variant_info,
|
||||||
|
"variants": variants,
|
||||||
|
}
|
||||||
103
src/ui/components/preset_selector.py
Normal file
103
src/ui/components/preset_selector.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""Preset selector component."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS
|
||||||
|
|
||||||
|
|
||||||
|
def create_preset_selector(
|
||||||
|
model_id: str,
|
||||||
|
on_preset_select: Optional[Callable[[dict], None]] = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create preset selector component for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
on_preset_select: Callback when preset is selected
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get(model_id, [])
|
||||||
|
|
||||||
|
# Create preset choices
|
||||||
|
choices = [(p["name"], p["id"]) for p in presets]
|
||||||
|
choices.append(("Custom", "custom"))
|
||||||
|
|
||||||
|
def get_preset_by_id(preset_id: str) -> Optional[dict]:
|
||||||
|
"""Get preset data by ID."""
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
return p
|
||||||
|
return None
|
||||||
|
|
||||||
|
def on_change(preset_id: str):
|
||||||
|
"""Handle preset selection change."""
|
||||||
|
if preset_id == "custom":
|
||||||
|
return gr.update(visible=True), {}
|
||||||
|
|
||||||
|
preset = get_preset_by_id(preset_id)
|
||||||
|
if preset:
|
||||||
|
return gr.update(visible=False), preset.get("parameters", {})
|
||||||
|
|
||||||
|
return gr.update(visible=True), {}
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
preset_description = gr.Markdown(
|
||||||
|
value=presets[0]["description"] if presets else "",
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dropdown": preset_dropdown,
|
||||||
|
"description": preset_description,
|
||||||
|
"presets": presets,
|
||||||
|
"get_preset": get_preset_by_id,
|
||||||
|
"on_change": on_change,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_preset_chips(
|
||||||
|
model_id: str,
|
||||||
|
on_select: Callable[[str], None],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create preset selector as clickable chips/buttons.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model family ID
|
||||||
|
on_select: Callback when preset is clicked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get(model_id, [])
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
buttons = []
|
||||||
|
for preset in presets:
|
||||||
|
btn = gr.Button(
|
||||||
|
preset["name"],
|
||||||
|
size="sm",
|
||||||
|
variant="secondary",
|
||||||
|
)
|
||||||
|
buttons.append((btn, preset))
|
||||||
|
|
||||||
|
custom_btn = gr.Button(
|
||||||
|
"Custom",
|
||||||
|
size="sm",
|
||||||
|
variant="secondary",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"buttons": buttons,
|
||||||
|
"custom_btn": custom_btn,
|
||||||
|
"presets": presets,
|
||||||
|
}
|
||||||
151
src/ui/components/vram_monitor.py
Normal file
151
src/ui/components/vram_monitor.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""VRAM monitor component for GPU memory tracking."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def create_vram_monitor(
|
||||||
|
get_gpu_status: Callable[[], dict[str, Any]],
|
||||||
|
get_loaded_models: Callable[[], list[dict[str, Any]]],
|
||||||
|
unload_model: Callable[[str, str], bool],
|
||||||
|
load_model: Callable[[str, str], bool],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create VRAM monitor component.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
get_gpu_status: Function to get GPU status dict
|
||||||
|
get_loaded_models: Function to get list of loaded models
|
||||||
|
unload_model: Function to unload a model (model_id, variant)
|
||||||
|
load_model: Function to load a model (model_id, variant)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
|
||||||
|
def refresh_status():
|
||||||
|
"""Refresh GPU status display."""
|
||||||
|
status = get_gpu_status()
|
||||||
|
loaded = get_loaded_models()
|
||||||
|
|
||||||
|
# Format VRAM bar
|
||||||
|
used_gb = status.get("used_gb", 0)
|
||||||
|
total_gb = status.get("total_gb", 24)
|
||||||
|
util_pct = status.get("utilization_percent", 0)
|
||||||
|
|
||||||
|
vram_text = f"{used_gb:.1f} / {total_gb:.1f} GB ({util_pct:.0f}%)"
|
||||||
|
|
||||||
|
# Format loaded models list
|
||||||
|
if loaded:
|
||||||
|
models_text = "\n".join([
|
||||||
|
f"• {m['model_id']}/{m['variant']} "
|
||||||
|
f"(idle: {m['idle_seconds']:.0f}s)"
|
||||||
|
for m in loaded
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
models_text = "No models loaded"
|
||||||
|
|
||||||
|
# Determine status color
|
||||||
|
if util_pct > 90:
|
||||||
|
status_color = "🔴"
|
||||||
|
elif util_pct > 75:
|
||||||
|
status_color = "🟡"
|
||||||
|
else:
|
||||||
|
status_color = "🟢"
|
||||||
|
|
||||||
|
status_text = f"{status_color} GPU: {status.get('device', 'N/A')}"
|
||||||
|
|
||||||
|
return vram_text, util_pct, models_text, status_text
|
||||||
|
|
||||||
|
def handle_unload(model_selection: str):
|
||||||
|
"""Handle model unload."""
|
||||||
|
if not model_selection or "/" not in model_selection:
|
||||||
|
return "Select a model to unload", *refresh_status()
|
||||||
|
|
||||||
|
parts = model_selection.split("/")
|
||||||
|
model_id, variant = parts[0], parts[1]
|
||||||
|
|
||||||
|
success = unload_model(model_id, variant)
|
||||||
|
if success:
|
||||||
|
msg = f"Unloaded {model_id}/{variant}"
|
||||||
|
else:
|
||||||
|
msg = f"Failed to unload {model_id}/{variant}"
|
||||||
|
|
||||||
|
return msg, *refresh_status()
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### GPU Memory")
|
||||||
|
|
||||||
|
status_text = gr.Markdown("🟢 GPU: Checking...")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
vram_display = gr.Textbox(
|
||||||
|
label="VRAM Usage",
|
||||||
|
value="Loading...",
|
||||||
|
interactive=False,
|
||||||
|
scale=3,
|
||||||
|
)
|
||||||
|
refresh_btn = gr.Button("🔄", scale=1, min_width=50)
|
||||||
|
|
||||||
|
vram_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=100,
|
||||||
|
value=0,
|
||||||
|
label="",
|
||||||
|
interactive=False,
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Markdown("### Loaded Models")
|
||||||
|
|
||||||
|
models_display = gr.Textbox(
|
||||||
|
label="",
|
||||||
|
value="No models loaded",
|
||||||
|
interactive=False,
|
||||||
|
lines=4,
|
||||||
|
max_lines=6,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
model_selector = gr.Dropdown(
|
||||||
|
label="Select Model",
|
||||||
|
choices=[],
|
||||||
|
interactive=True,
|
||||||
|
scale=3,
|
||||||
|
)
|
||||||
|
unload_btn = gr.Button("Unload", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
unload_status = gr.Markdown("")
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
def update_model_choices():
|
||||||
|
loaded = get_loaded_models()
|
||||||
|
choices = [f"{m['model_id']}/{m['variant']}" for m in loaded]
|
||||||
|
return gr.update(choices=choices, value=None)
|
||||||
|
|
||||||
|
refresh_btn.click(
|
||||||
|
fn=refresh_status,
|
||||||
|
outputs=[vram_display, vram_slider, models_display, status_text],
|
||||||
|
).then(
|
||||||
|
fn=update_model_choices,
|
||||||
|
outputs=[model_selector],
|
||||||
|
)
|
||||||
|
|
||||||
|
unload_btn.click(
|
||||||
|
fn=handle_unload,
|
||||||
|
inputs=[model_selector],
|
||||||
|
outputs=[unload_status, vram_display, vram_slider, models_display, status_text],
|
||||||
|
).then(
|
||||||
|
fn=update_model_choices,
|
||||||
|
outputs=[model_selector],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vram_display": vram_display,
|
||||||
|
"vram_slider": vram_slider,
|
||||||
|
"models_display": models_display,
|
||||||
|
"status_text": status_text,
|
||||||
|
"model_selector": model_selector,
|
||||||
|
"refresh_btn": refresh_btn,
|
||||||
|
"unload_btn": unload_btn,
|
||||||
|
"refresh_fn": refresh_status,
|
||||||
|
}
|
||||||
9
src/ui/pages/__init__.py
Normal file
9
src/ui/pages/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""UI pages for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.ui.pages.projects_page import create_projects_page
|
||||||
|
from src.ui.pages.settings_page import create_settings_page
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_projects_page",
|
||||||
|
"create_settings_page",
|
||||||
|
]
|
||||||
374
src/ui/pages/projects_page.py
Normal file
374
src/ui/pages/projects_page.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""Projects page for managing generations and history."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def create_projects_page(
|
||||||
|
get_projects: Callable[[], list[dict]],
|
||||||
|
get_generations: Callable[[str, int, int], list[dict]],
|
||||||
|
delete_generation: Callable[[str], bool],
|
||||||
|
export_project: Callable[[str], str],
|
||||||
|
create_project: Callable[[str, str], dict],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create projects management page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
get_projects: Function to get all projects
|
||||||
|
get_generations: Function to get generations (project_id, limit, offset)
|
||||||
|
delete_generation: Function to delete a generation
|
||||||
|
export_project: Function to export project as ZIP
|
||||||
|
create_project: Function to create new project
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("# Projects")
|
||||||
|
gr.Markdown("Browse and manage your generations")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left sidebar - project list
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
gr.Markdown("### Projects")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
new_project_name = gr.Textbox(
|
||||||
|
placeholder="New project name...",
|
||||||
|
show_label=False,
|
||||||
|
scale=3,
|
||||||
|
)
|
||||||
|
new_project_btn = gr.Button("+", size="sm", scale=1)
|
||||||
|
|
||||||
|
project_list = gr.Dataframe(
|
||||||
|
headers=["ID", "Name", "Count"],
|
||||||
|
datatype=["str", "str", "number"],
|
||||||
|
col_count=(3, "fixed"),
|
||||||
|
interactive=False,
|
||||||
|
height=400,
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_projects_btn = gr.Button("Refresh Projects", size="sm")
|
||||||
|
|
||||||
|
# Main content - generations
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
# Selected project info
|
||||||
|
selected_project_id = gr.State(value=None)
|
||||||
|
selected_project_name = gr.Markdown("### Select a project")
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
with gr.Row():
|
||||||
|
model_filter = gr.Dropdown(
|
||||||
|
label="Model",
|
||||||
|
choices=[
|
||||||
|
("All", "all"),
|
||||||
|
("MusicGen", "musicgen"),
|
||||||
|
("AudioGen", "audiogen"),
|
||||||
|
("MAGNeT", "magnet"),
|
||||||
|
("Style", "musicgen-style"),
|
||||||
|
("JASCO", "jasco"),
|
||||||
|
],
|
||||||
|
value="all",
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
sort_by = gr.Dropdown(
|
||||||
|
label="Sort By",
|
||||||
|
choices=[
|
||||||
|
("Newest First", "newest"),
|
||||||
|
("Oldest First", "oldest"),
|
||||||
|
("Duration (Long)", "duration_desc"),
|
||||||
|
("Duration (Short)", "duration_asc"),
|
||||||
|
],
|
||||||
|
value="newest",
|
||||||
|
scale=1,
|
||||||
|
)
|
||||||
|
search_input = gr.Textbox(
|
||||||
|
label="Search Prompts",
|
||||||
|
placeholder="Search...",
|
||||||
|
scale=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generations grid
|
||||||
|
generations_gallery = gr.Gallery(
|
||||||
|
label="Generations",
|
||||||
|
columns=3,
|
||||||
|
rows=3,
|
||||||
|
height=400,
|
||||||
|
object_fit="contain",
|
||||||
|
show_label=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
with gr.Row():
|
||||||
|
prev_page_btn = gr.Button("← Previous", size="sm")
|
||||||
|
page_info = gr.Markdown("Page 1 of 1")
|
||||||
|
next_page_btn = gr.Button("Next →", size="sm")
|
||||||
|
|
||||||
|
current_page = gr.State(value=1)
|
||||||
|
total_pages = gr.State(value=1)
|
||||||
|
|
||||||
|
# Selected generation details
|
||||||
|
gr.Markdown("---")
|
||||||
|
gr.Markdown("### Generation Details")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
selected_audio = gr.Audio(
|
||||||
|
label="Audio",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
selected_prompt = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
interactive=False,
|
||||||
|
lines=2,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
selected_model = gr.Textbox(
|
||||||
|
label="Model",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
selected_duration = gr.Textbox(
|
||||||
|
label="Duration",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
selected_seed = gr.Textbox(
|
||||||
|
label="Seed",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
selected_date = gr.Textbox(
|
||||||
|
label="Created",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Action buttons
|
||||||
|
with gr.Row():
|
||||||
|
regenerate_btn = gr.Button("Regenerate", variant="secondary")
|
||||||
|
download_btn = gr.Button("Download", variant="secondary")
|
||||||
|
delete_btn = gr.Button("Delete", variant="stop")
|
||||||
|
export_project_btn = gr.Button("Export Project", variant="secondary")
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
def load_projects():
|
||||||
|
"""Load all projects into the list."""
|
||||||
|
projects = get_projects()
|
||||||
|
data = []
|
||||||
|
for p in projects:
|
||||||
|
data.append([
|
||||||
|
p.get("id", ""),
|
||||||
|
p.get("name", "Untitled"),
|
||||||
|
p.get("generation_count", 0),
|
||||||
|
])
|
||||||
|
return data
|
||||||
|
|
||||||
|
def on_project_select(evt: gr.SelectData, df):
|
||||||
|
"""Handle project selection from dataframe."""
|
||||||
|
if evt.index is None:
|
||||||
|
return None, "### Select a project"
|
||||||
|
|
||||||
|
row = evt.index[0]
|
||||||
|
if row < len(df):
|
||||||
|
project_id = df[row][0]
|
||||||
|
project_name = df[row][1]
|
||||||
|
return project_id, f"### {project_name}"
|
||||||
|
|
||||||
|
return None, "### Select a project"
|
||||||
|
|
||||||
|
def load_generations(project_id, page, model, sort, search):
|
||||||
|
"""Load generations for selected project."""
|
||||||
|
if not project_id:
|
||||||
|
return [], "Page 0 of 0", 1, 1
|
||||||
|
|
||||||
|
limit = 9 # 3x3 grid
|
||||||
|
offset = (page - 1) * limit
|
||||||
|
|
||||||
|
gens = get_generations(project_id, limit + 1, offset)
|
||||||
|
|
||||||
|
# Check if there are more pages
|
||||||
|
has_more = len(gens) > limit
|
||||||
|
gens = gens[:limit]
|
||||||
|
|
||||||
|
# Filter by model if needed
|
||||||
|
if model != "all":
|
||||||
|
gens = [g for g in gens if g.get("model") == model]
|
||||||
|
|
||||||
|
# Filter by search
|
||||||
|
if search:
|
||||||
|
search_lower = search.lower()
|
||||||
|
gens = [g for g in gens if search_lower in g.get("prompt", "").lower()]
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
if sort == "oldest":
|
||||||
|
gens = sorted(gens, key=lambda x: x.get("created_at", ""))
|
||||||
|
elif sort == "duration_desc":
|
||||||
|
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0), reverse=True)
|
||||||
|
elif sort == "duration_asc":
|
||||||
|
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0))
|
||||||
|
# Default is newest first (already sorted from DB)
|
||||||
|
|
||||||
|
# Build gallery items (using waveform images if available)
|
||||||
|
gallery_items = []
|
||||||
|
for g in gens:
|
||||||
|
waveform = g.get("waveform_path")
|
||||||
|
if waveform:
|
||||||
|
gallery_items.append((waveform, g.get("prompt", "")[:50]))
|
||||||
|
else:
|
||||||
|
# Placeholder
|
||||||
|
gallery_items.append((None, g.get("prompt", "")[:50]))
|
||||||
|
|
||||||
|
# Calculate total pages (estimate)
|
||||||
|
total = offset + len(gens) + (1 if has_more else 0)
|
||||||
|
total_p = max(1, (total + limit - 1) // limit)
|
||||||
|
|
||||||
|
return gallery_items, f"Page {page} of {total_p}", page, total_p
|
||||||
|
|
||||||
|
def on_generation_select(evt: gr.SelectData, project_id):
|
||||||
|
"""Handle generation selection from gallery."""
|
||||||
|
if evt.index is None or not project_id:
|
||||||
|
return None, "", "", "", "", ""
|
||||||
|
|
||||||
|
# Get generations again to find the selected one
|
||||||
|
gens = get_generations(project_id, 100, 0)
|
||||||
|
if evt.index < len(gens):
|
||||||
|
gen = gens[evt.index]
|
||||||
|
return (
|
||||||
|
gen.get("audio_path"),
|
||||||
|
gen.get("prompt", ""),
|
||||||
|
gen.get("model", ""),
|
||||||
|
f"{gen.get('duration_seconds', 0):.1f}s",
|
||||||
|
str(gen.get("seed", "")),
|
||||||
|
gen.get("created_at", "")[:19] if gen.get("created_at") else "",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None, "", "", "", "", ""
|
||||||
|
|
||||||
|
def do_create_project(name):
|
||||||
|
"""Create a new project."""
|
||||||
|
if not name.strip():
|
||||||
|
return gr.update(), "Please enter a project name"
|
||||||
|
|
||||||
|
project = create_project(name.strip(), "")
|
||||||
|
projects_data = load_projects()
|
||||||
|
return projects_data, f"Created project: {name}"
|
||||||
|
|
||||||
|
def do_delete_generation(project_id, audio_path):
|
||||||
|
"""Delete selected generation."""
|
||||||
|
if not audio_path:
|
||||||
|
return "No generation selected"
|
||||||
|
|
||||||
|
# Find generation by audio path
|
||||||
|
gens = get_generations(project_id, 100, 0)
|
||||||
|
for g in gens:
|
||||||
|
if g.get("audio_path") == audio_path:
|
||||||
|
if delete_generation(g.get("id")):
|
||||||
|
return "Generation deleted"
|
||||||
|
else:
|
||||||
|
return "Failed to delete"
|
||||||
|
|
||||||
|
return "Generation not found"
|
||||||
|
|
||||||
|
def do_export_project(project_id):
|
||||||
|
"""Export project as ZIP."""
|
||||||
|
if not project_id:
|
||||||
|
return "No project selected"
|
||||||
|
|
||||||
|
try:
|
||||||
|
zip_path = export_project(project_id)
|
||||||
|
return f"Exported to: {zip_path}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Export failed: {str(e)}"
|
||||||
|
|
||||||
|
# Wire up events
|
||||||
|
|
||||||
|
refresh_projects_btn.click(
|
||||||
|
fn=load_projects,
|
||||||
|
outputs=[project_list],
|
||||||
|
)
|
||||||
|
|
||||||
|
project_list.select(
|
||||||
|
fn=on_project_select,
|
||||||
|
inputs=[project_list],
|
||||||
|
outputs=[selected_project_id, selected_project_name],
|
||||||
|
).then(
|
||||||
|
fn=load_generations,
|
||||||
|
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||||
|
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter changes reload generations
|
||||||
|
for component in [model_filter, sort_by, search_input]:
|
||||||
|
component.change(
|
||||||
|
fn=load_generations,
|
||||||
|
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||||
|
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pagination
|
||||||
|
def go_prev(page, total):
|
||||||
|
return max(1, page - 1)
|
||||||
|
|
||||||
|
def go_next(page, total):
|
||||||
|
return min(total, page + 1)
|
||||||
|
|
||||||
|
prev_page_btn.click(
|
||||||
|
fn=go_prev,
|
||||||
|
inputs=[current_page, total_pages],
|
||||||
|
outputs=[current_page],
|
||||||
|
).then(
|
||||||
|
fn=load_generations,
|
||||||
|
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||||
|
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||||
|
)
|
||||||
|
|
||||||
|
next_page_btn.click(
|
||||||
|
fn=go_next,
|
||||||
|
inputs=[current_page, total_pages],
|
||||||
|
outputs=[current_page],
|
||||||
|
).then(
|
||||||
|
fn=load_generations,
|
||||||
|
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||||
|
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generation selection
|
||||||
|
generations_gallery.select(
|
||||||
|
fn=on_generation_select,
|
||||||
|
inputs=[selected_project_id],
|
||||||
|
outputs=[selected_audio, selected_prompt, selected_model, selected_duration, selected_seed, selected_date],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actions
|
||||||
|
new_project_btn.click(
|
||||||
|
fn=do_create_project,
|
||||||
|
inputs=[new_project_name],
|
||||||
|
outputs=[project_list, selected_project_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
delete_btn.click(
|
||||||
|
fn=do_delete_generation,
|
||||||
|
inputs=[selected_project_id, selected_audio],
|
||||||
|
outputs=[selected_project_name],
|
||||||
|
).then(
|
||||||
|
fn=load_generations,
|
||||||
|
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||||
|
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||||
|
)
|
||||||
|
|
||||||
|
export_project_btn.click(
|
||||||
|
fn=do_export_project,
|
||||||
|
inputs=[selected_project_id],
|
||||||
|
outputs=[selected_project_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"project_list": project_list,
|
||||||
|
"generations_gallery": generations_gallery,
|
||||||
|
"selected_audio": selected_audio,
|
||||||
|
"selected_project_id": selected_project_id,
|
||||||
|
"refresh_fn": load_projects,
|
||||||
|
}
|
||||||
397
src/ui/pages/settings_page.py
Normal file
397
src/ui/pages/settings_page.py
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
"""Settings page for application configuration."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def create_settings_page(
|
||||||
|
get_settings: Callable[[], dict],
|
||||||
|
update_settings: Callable[[dict], bool],
|
||||||
|
get_gpu_info: Callable[[], dict],
|
||||||
|
clear_cache: Callable[[], bool],
|
||||||
|
unload_all_models: Callable[[], bool],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create settings management page.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
get_settings: Function to get current settings
|
||||||
|
update_settings: Function to update settings
|
||||||
|
get_gpu_info: Function to get GPU information
|
||||||
|
clear_cache: Function to clear model cache
|
||||||
|
unload_all_models: Function to unload all models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("# Settings")
|
||||||
|
gr.Markdown("Configure AudioCraft Studio")
|
||||||
|
|
||||||
|
with gr.Tabs():
|
||||||
|
# General Settings Tab
|
||||||
|
with gr.TabItem("General"):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Output Settings")
|
||||||
|
|
||||||
|
output_dir = gr.Textbox(
|
||||||
|
label="Output Directory",
|
||||||
|
placeholder="/path/to/output",
|
||||||
|
info="Where generated audio files are saved",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
default_format = gr.Dropdown(
|
||||||
|
label="Default Audio Format",
|
||||||
|
choices=[("WAV", "wav"), ("MP3", "mp3"), ("FLAC", "flac"), ("OGG", "ogg")],
|
||||||
|
value="wav",
|
||||||
|
)
|
||||||
|
sample_rate = gr.Dropdown(
|
||||||
|
label="Sample Rate",
|
||||||
|
choices=[
|
||||||
|
("32000 Hz (AudioCraft default)", 32000),
|
||||||
|
("44100 Hz (CD quality)", 44100),
|
||||||
|
("48000 Hz (Video standard)", 48000),
|
||||||
|
],
|
||||||
|
value=32000,
|
||||||
|
)
|
||||||
|
|
||||||
|
normalize_audio = gr.Checkbox(
|
||||||
|
label="Normalize audio output",
|
||||||
|
value=True,
|
||||||
|
info="Normalize audio levels to prevent clipping",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Interface Settings")
|
||||||
|
|
||||||
|
theme_mode = gr.Radio(
|
||||||
|
label="Theme",
|
||||||
|
choices=["Dark", "Light", "System"],
|
||||||
|
value="Dark",
|
||||||
|
)
|
||||||
|
|
||||||
|
show_advanced = gr.Checkbox(
|
||||||
|
label="Show advanced parameters by default",
|
||||||
|
value=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
auto_play = gr.Checkbox(
|
||||||
|
label="Auto-play generated audio",
|
||||||
|
value=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GPU & Memory Tab
|
||||||
|
with gr.TabItem("GPU & Memory"):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### GPU Information")
|
||||||
|
|
||||||
|
gpu_info_display = gr.JSON(
|
||||||
|
label="GPU Status",
|
||||||
|
value={},
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_gpu_btn = gr.Button("Refresh GPU Info", size="sm")
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Memory Management")
|
||||||
|
|
||||||
|
comfyui_reserve = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=16,
|
||||||
|
value=10,
|
||||||
|
step=0.5,
|
||||||
|
label="ComfyUI VRAM Reserve (GB)",
|
||||||
|
info="VRAM to reserve for ComfyUI when running alongside",
|
||||||
|
)
|
||||||
|
|
||||||
|
idle_timeout = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=60,
|
||||||
|
value=15,
|
||||||
|
step=1,
|
||||||
|
label="Idle Model Timeout (minutes)",
|
||||||
|
info="Unload models after this period of inactivity",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_loaded = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=5,
|
||||||
|
value=2,
|
||||||
|
step=1,
|
||||||
|
label="Maximum Loaded Models",
|
||||||
|
info="Maximum number of models to keep in memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Cache Management")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
clear_cache_btn = gr.Button("Clear Model Cache", variant="secondary")
|
||||||
|
unload_models_btn = gr.Button("Unload All Models", variant="stop")
|
||||||
|
|
||||||
|
cache_status = gr.Markdown("Cache status: Ready")
|
||||||
|
|
||||||
|
# Model Defaults Tab
|
||||||
|
with gr.TabItem("Model Defaults"):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### MusicGen Defaults")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
musicgen_variant = gr.Dropdown(
|
||||||
|
label="Default Variant",
|
||||||
|
choices=[
|
||||||
|
("Small", "small"),
|
||||||
|
("Medium", "medium"),
|
||||||
|
("Large", "large"),
|
||||||
|
("Melody", "melody"),
|
||||||
|
],
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
musicgen_duration = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=30,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Default Duration (s)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### AudioGen Defaults")
|
||||||
|
|
||||||
|
audiogen_duration = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
value=5,
|
||||||
|
step=1,
|
||||||
|
label="Default Duration (s)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### MAGNeT Defaults")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
magnet_variant = gr.Dropdown(
|
||||||
|
label="Default Variant",
|
||||||
|
choices=[
|
||||||
|
("Small Music", "small"),
|
||||||
|
("Medium Music", "medium"),
|
||||||
|
("Small Audio", "audio-small"),
|
||||||
|
("Medium Audio", "audio-medium"),
|
||||||
|
],
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
magnet_decoding_steps = gr.Slider(
|
||||||
|
minimum=10,
|
||||||
|
maximum=100,
|
||||||
|
value=20,
|
||||||
|
step=5,
|
||||||
|
label="Decoding Steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
# API Settings Tab
|
||||||
|
with gr.TabItem("API"):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### REST API Configuration")
|
||||||
|
|
||||||
|
api_enabled = gr.Checkbox(
|
||||||
|
label="Enable REST API",
|
||||||
|
value=True,
|
||||||
|
info="Enable FastAPI endpoints for programmatic access",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_port = gr.Number(
|
||||||
|
value=8000,
|
||||||
|
label="API Port",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
api_key_display = gr.Textbox(
|
||||||
|
label="API Key",
|
||||||
|
value="••••••••",
|
||||||
|
interactive=False,
|
||||||
|
)
|
||||||
|
regenerate_key_btn = gr.Button("Regenerate", size="sm")
|
||||||
|
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Rate Limiting")
|
||||||
|
|
||||||
|
rate_limit = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=100,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Requests per minute",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_batch_size = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
value=4,
|
||||||
|
step=1,
|
||||||
|
label="Maximum batch size",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Queue Settings Tab
|
||||||
|
with gr.TabItem("Queue"):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown("### Batch Processing")
|
||||||
|
|
||||||
|
max_queue_size = gr.Slider(
|
||||||
|
minimum=10,
|
||||||
|
maximum=500,
|
||||||
|
value=100,
|
||||||
|
step=10,
|
||||||
|
label="Maximum Queue Size",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_workers = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=4,
|
||||||
|
value=1,
|
||||||
|
step=1,
|
||||||
|
label="Concurrent Workers",
|
||||||
|
info="Number of parallel generation workers",
|
||||||
|
)
|
||||||
|
|
||||||
|
priority_queue = gr.Checkbox(
|
||||||
|
label="Enable priority queue",
|
||||||
|
value=False,
|
||||||
|
info="Allow high-priority jobs to skip the queue",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save button
|
||||||
|
gr.Markdown("---")
|
||||||
|
with gr.Row():
|
||||||
|
save_btn = gr.Button("Save Settings", variant="primary", scale=2)
|
||||||
|
reset_btn = gr.Button("Reset to Defaults", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
settings_status = gr.Markdown("")
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
def load_settings():
|
||||||
|
"""Load current settings into form."""
|
||||||
|
settings = get_settings()
|
||||||
|
return (
|
||||||
|
settings.get("output_dir", ""),
|
||||||
|
settings.get("default_format", "wav"),
|
||||||
|
settings.get("sample_rate", 32000),
|
||||||
|
settings.get("normalize_audio", True),
|
||||||
|
settings.get("theme_mode", "Dark"),
|
||||||
|
settings.get("show_advanced", False),
|
||||||
|
settings.get("auto_play", True),
|
||||||
|
settings.get("comfyui_reserve_gb", 10),
|
||||||
|
settings.get("idle_timeout_minutes", 15),
|
||||||
|
settings.get("max_loaded_models", 2),
|
||||||
|
settings.get("musicgen_variant", "medium"),
|
||||||
|
settings.get("musicgen_duration", 10),
|
||||||
|
settings.get("audiogen_duration", 5),
|
||||||
|
settings.get("magnet_variant", "medium"),
|
||||||
|
settings.get("magnet_decoding_steps", 20),
|
||||||
|
settings.get("api_enabled", True),
|
||||||
|
settings.get("api_port", 8000),
|
||||||
|
settings.get("rate_limit", 10),
|
||||||
|
settings.get("max_batch_size", 4),
|
||||||
|
settings.get("max_queue_size", 100),
|
||||||
|
settings.get("max_workers", 1),
|
||||||
|
settings.get("priority_queue", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_settings(
|
||||||
|
out_dir, fmt, sr, norm, theme, adv, play,
|
||||||
|
comfyui_res, idle_to, max_load,
|
||||||
|
mg_var, mg_dur, ag_dur, mn_var, mn_steps,
|
||||||
|
api_en, api_p, rate, batch, queue_sz, workers, priority
|
||||||
|
):
|
||||||
|
"""Save settings from form."""
|
||||||
|
settings = {
|
||||||
|
"output_dir": out_dir,
|
||||||
|
"default_format": fmt,
|
||||||
|
"sample_rate": sr,
|
||||||
|
"normalize_audio": norm,
|
||||||
|
"theme_mode": theme,
|
||||||
|
"show_advanced": adv,
|
||||||
|
"auto_play": play,
|
||||||
|
"comfyui_reserve_gb": comfyui_res,
|
||||||
|
"idle_timeout_minutes": idle_to,
|
||||||
|
"max_loaded_models": max_load,
|
||||||
|
"musicgen_variant": mg_var,
|
||||||
|
"musicgen_duration": mg_dur,
|
||||||
|
"audiogen_duration": ag_dur,
|
||||||
|
"magnet_variant": mn_var,
|
||||||
|
"magnet_decoding_steps": mn_steps,
|
||||||
|
"api_enabled": api_en,
|
||||||
|
"api_port": int(api_p),
|
||||||
|
"rate_limit": rate,
|
||||||
|
"max_batch_size": batch,
|
||||||
|
"max_queue_size": queue_sz,
|
||||||
|
"max_workers": workers,
|
||||||
|
"priority_queue": priority,
|
||||||
|
}
|
||||||
|
|
||||||
|
if update_settings(settings):
|
||||||
|
return "✅ Settings saved successfully"
|
||||||
|
else:
|
||||||
|
return "❌ Failed to save settings"
|
||||||
|
|
||||||
|
def do_refresh_gpu():
|
||||||
|
"""Refresh GPU info display."""
|
||||||
|
return get_gpu_info()
|
||||||
|
|
||||||
|
def do_clear_cache():
|
||||||
|
"""Clear model cache."""
|
||||||
|
if clear_cache():
|
||||||
|
return "✅ Cache cleared"
|
||||||
|
return "❌ Failed to clear cache"
|
||||||
|
|
||||||
|
def do_unload_models():
|
||||||
|
"""Unload all models."""
|
||||||
|
if unload_all_models():
|
||||||
|
return "✅ All models unloaded"
|
||||||
|
return "❌ Failed to unload models"
|
||||||
|
|
||||||
|
# Wire up events
|
||||||
|
|
||||||
|
refresh_gpu_btn.click(
|
||||||
|
fn=do_refresh_gpu,
|
||||||
|
outputs=[gpu_info_display],
|
||||||
|
)
|
||||||
|
|
||||||
|
clear_cache_btn.click(
|
||||||
|
fn=do_clear_cache,
|
||||||
|
outputs=[cache_status],
|
||||||
|
)
|
||||||
|
|
||||||
|
unload_models_btn.click(
|
||||||
|
fn=do_unload_models,
|
||||||
|
outputs=[cache_status],
|
||||||
|
)
|
||||||
|
|
||||||
|
save_btn.click(
|
||||||
|
fn=save_settings,
|
||||||
|
inputs=[
|
||||||
|
output_dir, default_format, sample_rate, normalize_audio,
|
||||||
|
theme_mode, show_advanced, auto_play,
|
||||||
|
comfyui_reserve, idle_timeout, max_loaded,
|
||||||
|
musicgen_variant, musicgen_duration, audiogen_duration,
|
||||||
|
magnet_variant, magnet_decoding_steps,
|
||||||
|
api_enabled, api_port, rate_limit, max_batch_size,
|
||||||
|
max_queue_size, max_workers, priority_queue,
|
||||||
|
],
|
||||||
|
outputs=[settings_status],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"output_dir": output_dir,
|
||||||
|
"default_format": default_format,
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"comfyui_reserve": comfyui_reserve,
|
||||||
|
"idle_timeout": idle_timeout,
|
||||||
|
"api_enabled": api_enabled,
|
||||||
|
"save_btn": save_btn,
|
||||||
|
"settings_status": settings_status,
|
||||||
|
"load_fn": load_settings,
|
||||||
|
}
|
||||||
294
src/ui/state.py
Normal file
294
src/ui/state.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""State management for Gradio UI."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UIState:
|
||||||
|
"""Global UI state container."""
|
||||||
|
|
||||||
|
# Current view
|
||||||
|
current_tab: str = "dashboard"
|
||||||
|
|
||||||
|
# Generation state
|
||||||
|
is_generating: bool = False
|
||||||
|
current_job_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Selected items
|
||||||
|
selected_project_id: Optional[str] = None
|
||||||
|
selected_generation_id: Optional[str] = None
|
||||||
|
selected_preset_id: Optional[str] = None
|
||||||
|
|
||||||
|
# Model state
|
||||||
|
selected_model: str = "musicgen"
|
||||||
|
selected_variant: str = "medium"
|
||||||
|
|
||||||
|
# Generation parameters (current values)
|
||||||
|
prompt: str = ""
|
||||||
|
duration: float = 10.0
|
||||||
|
temperature: float = 1.0
|
||||||
|
top_k: int = 250
|
||||||
|
top_p: float = 0.0
|
||||||
|
cfg_coef: float = 3.0
|
||||||
|
seed: Optional[int] = None
|
||||||
|
|
||||||
|
# Conditioning
|
||||||
|
melody_audio: Optional[str] = None
|
||||||
|
style_audio: Optional[str] = None
|
||||||
|
chords: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
drums_pattern: str = ""
|
||||||
|
bpm: float = 120.0
|
||||||
|
|
||||||
|
# UI preferences
|
||||||
|
show_advanced: bool = False
|
||||||
|
auto_play: bool = True
|
||||||
|
|
||||||
|
def reset_generation_params(self) -> None:
|
||||||
|
"""Reset generation parameters to defaults."""
|
||||||
|
self.prompt = ""
|
||||||
|
self.duration = 10.0
|
||||||
|
self.temperature = 1.0
|
||||||
|
self.top_k = 250
|
||||||
|
self.top_p = 0.0
|
||||||
|
self.cfg_coef = 3.0
|
||||||
|
self.seed = None
|
||||||
|
self.melody_audio = None
|
||||||
|
self.style_audio = None
|
||||||
|
self.chords = []
|
||||||
|
self.drums_pattern = ""
|
||||||
|
|
||||||
|
def apply_preset(self, preset: dict[str, Any]) -> None:
|
||||||
|
"""Apply preset parameters."""
|
||||||
|
params = preset.get("parameters", {})
|
||||||
|
self.duration = params.get("duration", self.duration)
|
||||||
|
self.temperature = params.get("temperature", self.temperature)
|
||||||
|
self.top_k = params.get("top_k", self.top_k)
|
||||||
|
self.top_p = params.get("top_p", self.top_p)
|
||||||
|
self.cfg_coef = params.get("cfg_coef", self.cfg_coef)
|
||||||
|
|
||||||
|
def to_generation_params(self) -> dict[str, Any]:
|
||||||
|
"""Convert current state to generation parameters."""
|
||||||
|
return {
|
||||||
|
"duration": self.duration,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"cfg_coef": self.cfg_coef,
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Default presets for each model
|
||||||
|
DEFAULT_PRESETS = {
|
||||||
|
"musicgen": [
|
||||||
|
{
|
||||||
|
"id": "cinematic",
|
||||||
|
"name": "Cinematic",
|
||||||
|
"description": "Epic orchestral soundscapes",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 30,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "electronic",
|
||||||
|
"name": "Electronic",
|
||||||
|
"description": "Synthesizers and beats",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 15,
|
||||||
|
"temperature": 1.1,
|
||||||
|
"top_k": 200,
|
||||||
|
"cfg_coef": 3.5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ambient",
|
||||||
|
"name": "Ambient",
|
||||||
|
"description": "Atmospheric and calm",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 30,
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_k": 300,
|
||||||
|
"cfg_coef": 2.5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "rock",
|
||||||
|
"name": "Rock",
|
||||||
|
"description": "Guitar-driven energy",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 20,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "jazz",
|
||||||
|
"name": "Jazz",
|
||||||
|
"description": "Smooth and improvisational",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 20,
|
||||||
|
"temperature": 1.2,
|
||||||
|
"top_k": 200,
|
||||||
|
"cfg_coef": 2.5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"audiogen": [
|
||||||
|
{
|
||||||
|
"id": "nature",
|
||||||
|
"name": "Nature",
|
||||||
|
"description": "Natural environments",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "urban",
|
||||||
|
"name": "Urban",
|
||||||
|
"description": "City sounds",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "mechanical",
|
||||||
|
"name": "Mechanical",
|
||||||
|
"description": "Machines and tools",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 5,
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_k": 200,
|
||||||
|
"cfg_coef": 3.5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "weather",
|
||||||
|
"name": "Weather",
|
||||||
|
"description": "Rain, thunder, wind",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"magnet": [
|
||||||
|
{
|
||||||
|
"id": "fast",
|
||||||
|
"name": "Fast",
|
||||||
|
"description": "Quick generation",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 3.0,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "quality",
|
||||||
|
"name": "Quality",
|
||||||
|
"description": "Higher quality output",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 2.5,
|
||||||
|
"top_p": 0.85,
|
||||||
|
"cfg_coef": 4.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"musicgen-style": [
|
||||||
|
{
|
||||||
|
"id": "style_transfer",
|
||||||
|
"name": "Style Transfer",
|
||||||
|
"description": "Copy style from reference",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 15,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
"eval_q": 3,
|
||||||
|
"excerpt_length": 3.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"jasco": [
|
||||||
|
{
|
||||||
|
"id": "pop",
|
||||||
|
"name": "Pop",
|
||||||
|
"description": "Pop chord progressions",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
"bpm": 120,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "blues",
|
||||||
|
"name": "Blues",
|
||||||
|
"description": "12-bar blues",
|
||||||
|
"parameters": {
|
||||||
|
"duration": 10,
|
||||||
|
"temperature": 1.0,
|
||||||
|
"top_k": 250,
|
||||||
|
"cfg_coef": 3.0,
|
||||||
|
"bpm": 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Prompt suggestions for each model
|
||||||
|
PROMPT_SUGGESTIONS = {
|
||||||
|
"musicgen": [
|
||||||
|
"Epic orchestral music with dramatic strings and powerful brass",
|
||||||
|
"Upbeat electronic dance music with synthesizers and heavy bass",
|
||||||
|
"Calm acoustic guitar melody with soft piano accompaniment",
|
||||||
|
"Energetic rock song with electric guitars and driving drums",
|
||||||
|
"Smooth jazz with saxophone solo and walking bass",
|
||||||
|
"Ambient soundscape with ethereal pads and gentle textures",
|
||||||
|
"Cinematic trailer music building to an epic climax",
|
||||||
|
"Lo-fi hip hop beats with vinyl crackle and mellow keys",
|
||||||
|
],
|
||||||
|
"audiogen": [
|
||||||
|
"Thunder and heavy rain with occasional lightning strikes",
|
||||||
|
"Busy city street with traffic, horns, and distant sirens",
|
||||||
|
"Forest ambience with birds singing and wind in trees",
|
||||||
|
"Ocean waves crashing on a rocky shore",
|
||||||
|
"Crackling fireplace with wood popping",
|
||||||
|
"Coffee shop atmosphere with murmuring voices and clinking cups",
|
||||||
|
"Construction site with hammering and machinery",
|
||||||
|
"Spaceship engine humming with occasional beeps",
|
||||||
|
],
|
||||||
|
"magnet": [
|
||||||
|
"Energetic pop music with catchy melody",
|
||||||
|
"Dark electronic music with deep bass",
|
||||||
|
"Cheerful ukulele tune with whistling",
|
||||||
|
"Dramatic piano piece with building intensity",
|
||||||
|
],
|
||||||
|
"musicgen-style": [
|
||||||
|
"Generate music in the style of the uploaded reference",
|
||||||
|
"Create a variation with similar instrumentation",
|
||||||
|
"Compose a piece matching the mood of the reference",
|
||||||
|
],
|
||||||
|
"jasco": [
|
||||||
|
"Upbeat pop song with the specified chord progression",
|
||||||
|
"Mellow jazz piece following the chord changes",
|
||||||
|
"Rock anthem with powerful drum pattern",
|
||||||
|
"Electronic track with syncopated rhythms",
|
||||||
|
],
|
||||||
|
}
|
||||||
17
src/ui/tabs/__init__.py
Normal file
17
src/ui/tabs/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
"""Model tabs for AudioCraft Studio."""
|
||||||
|
|
||||||
|
from src.ui.tabs.dashboard_tab import create_dashboard_tab
|
||||||
|
from src.ui.tabs.musicgen_tab import create_musicgen_tab
|
||||||
|
from src.ui.tabs.audiogen_tab import create_audiogen_tab
|
||||||
|
from src.ui.tabs.magnet_tab import create_magnet_tab
|
||||||
|
from src.ui.tabs.style_tab import create_style_tab
|
||||||
|
from src.ui.tabs.jasco_tab import create_jasco_tab
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_dashboard_tab",
|
||||||
|
"create_musicgen_tab",
|
||||||
|
"create_audiogen_tab",
|
||||||
|
"create_magnet_tab",
|
||||||
|
"create_style_tab",
|
||||||
|
"create_jasco_tab",
|
||||||
|
]
|
||||||
283
src/ui/tabs/audiogen_tab.py
Normal file
283
src/ui/tabs/audiogen_tab.py
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
"""AudioGen tab for text-to-sound generation."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.audio_player import create_generation_output
|
||||||
|
|
||||||
|
|
||||||
|
AUDIOGEN_VARIANTS = [
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, balanced quality/speed"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_audiogen_tab(
|
||||||
|
generate_fn: Callable[..., Any],
|
||||||
|
add_to_queue_fn: Callable[..., Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create AudioGen generation tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to call for generation
|
||||||
|
add_to_queue_fn: Function to add to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get("audiogen", [])
|
||||||
|
suggestions = PROMPT_SUGGESTIONS.get("audiogen", [])
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## 🔊 AudioGen")
|
||||||
|
gr.Markdown("Generate sound effects and environmental audio from text")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - inputs
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
# Preset selector
|
||||||
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=preset_choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model variant (AudioGen only has medium)
|
||||||
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in AUDIOGEN_VARIANTS]
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=variant_choices,
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt input
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
placeholder="Describe the sound you want to generate...",
|
||||||
|
lines=3,
|
||||||
|
max_lines=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
with gr.Accordion("Prompt Suggestions", open=False):
|
||||||
|
suggestion_btns = []
|
||||||
|
for i, suggestion in enumerate(suggestions[:6]):
|
||||||
|
btn = gr.Button(suggestion[:50] + "...", size="sm", variant="secondary")
|
||||||
|
suggestion_btns.append((btn, suggestion))
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
gr.Markdown("### Parameters")
|
||||||
|
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=10,
|
||||||
|
value=5,
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
info="AudioGen works best with shorter clips",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced Parameters", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=250,
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed (empty = random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buttons
|
||||||
|
with gr.Row():
|
||||||
|
generate_btn = gr.Button("🔊 Generate", variant="primary", scale=2)
|
||||||
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
# Right column - output
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
output = create_generation_output()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
# Preset change
|
||||||
|
def apply_preset(preset_id: str):
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
params = p["parameters"]
|
||||||
|
return (
|
||||||
|
params.get("duration", 5),
|
||||||
|
params.get("temperature", 1.0),
|
||||||
|
params.get("cfg_coef", 3.0),
|
||||||
|
params.get("top_k", 250),
|
||||||
|
params.get("top_p", 0.0),
|
||||||
|
)
|
||||||
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||||
|
|
||||||
|
preset_dropdown.change(
|
||||||
|
fn=apply_preset,
|
||||||
|
inputs=[preset_dropdown],
|
||||||
|
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
for btn, suggestion in suggestion_btns:
|
||||||
|
btn.click(
|
||||||
|
fn=lambda s=suggestion: s,
|
||||||
|
outputs=[prompt_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
async def do_generate(
|
||||||
|
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||||
|
):
|
||||||
|
if not prompt:
|
||||||
|
return (
|
||||||
|
gr.update(value="Please enter a prompt"),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="🔄 Generating..."),
|
||||||
|
gr.update(visible=True, value=0),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result, generation = await generate_fn(
|
||||||
|
model_id="audiogen",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="✅ Generation complete!"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(value=generation.audio_path),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(value=f"{result.duration:.2f}s"),
|
||||||
|
gr.update(value=str(result.seed)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield (
|
||||||
|
gr.update(value=f"❌ Error: {str(e)}"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=do_generate,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
output["status"],
|
||||||
|
output["progress"],
|
||||||
|
output["player"]["audio"],
|
||||||
|
output["player"]["waveform"],
|
||||||
|
output["player"]["duration"],
|
||||||
|
output["player"]["seed"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue
|
||||||
|
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||||
|
if not prompt:
|
||||||
|
return "Please enter a prompt"
|
||||||
|
|
||||||
|
job = add_to_queue_fn(
|
||||||
|
model_id="audiogen",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"✅ Added to queue: {job.id}"
|
||||||
|
|
||||||
|
queue_btn.click(
|
||||||
|
fn=do_add_queue,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[output["status"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preset": preset_dropdown,
|
||||||
|
"variant": variant_dropdown,
|
||||||
|
"prompt": prompt_input,
|
||||||
|
"duration": duration_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"seed": seed_input,
|
||||||
|
"generate_btn": generate_btn,
|
||||||
|
"queue_btn": queue_btn,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
166
src/ui/tabs/dashboard_tab.py
Normal file
166
src/ui/tabs/dashboard_tab.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""Dashboard tab - home page with model overview and quick actions."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_INFO = {
|
||||||
|
"musicgen": {
|
||||||
|
"name": "MusicGen",
|
||||||
|
"icon": "🎵",
|
||||||
|
"description": "Text-to-music generation with optional melody conditioning",
|
||||||
|
"capabilities": ["Text prompts", "Melody conditioning", "Stereo output"],
|
||||||
|
},
|
||||||
|
"audiogen": {
|
||||||
|
"name": "AudioGen",
|
||||||
|
"icon": "🔊",
|
||||||
|
"description": "Text-to-sound effects and environmental audio",
|
||||||
|
"capabilities": ["Sound effects", "Ambiences", "Foley"],
|
||||||
|
},
|
||||||
|
"magnet": {
|
||||||
|
"name": "MAGNeT",
|
||||||
|
"icon": "⚡",
|
||||||
|
"description": "Fast non-autoregressive music generation",
|
||||||
|
"capabilities": ["Fast generation", "Music", "Sound effects"],
|
||||||
|
},
|
||||||
|
"musicgen-style": {
|
||||||
|
"name": "MusicGen Style",
|
||||||
|
"icon": "🎨",
|
||||||
|
"description": "Style-conditioned music from reference audio",
|
||||||
|
"capabilities": ["Style transfer", "Reference audio", "Text prompts"],
|
||||||
|
},
|
||||||
|
"jasco": {
|
||||||
|
"name": "JASCO",
|
||||||
|
"icon": "🎹",
|
||||||
|
"description": "Chord and drum-conditioned music generation",
|
||||||
|
"capabilities": ["Chord control", "Drum patterns", "Symbolic conditioning"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_dashboard_tab(
|
||||||
|
get_queue_status: Callable[[], dict[str, Any]],
|
||||||
|
get_recent_generations: Callable[[int], list[dict[str, Any]]],
|
||||||
|
get_gpu_status: Callable[[], dict[str, Any]],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create dashboard tab with model overview and status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
get_queue_status: Function to get generation queue status
|
||||||
|
get_recent_generations: Function to get recent generations
|
||||||
|
get_gpu_status: Function to get GPU status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
|
||||||
|
def refresh_dashboard():
|
||||||
|
"""Refresh all dashboard data."""
|
||||||
|
queue = get_queue_status()
|
||||||
|
recent = get_recent_generations(5)
|
||||||
|
gpu = get_gpu_status()
|
||||||
|
|
||||||
|
# Format queue status
|
||||||
|
queue_size = queue.get("queue_size", 0)
|
||||||
|
queue_text = f"**Queue:** {queue_size} job(s) pending"
|
||||||
|
|
||||||
|
# Format recent generations
|
||||||
|
if recent:
|
||||||
|
recent_items = []
|
||||||
|
for gen in recent[:5]:
|
||||||
|
model = gen.get("model", "unknown")
|
||||||
|
prompt = gen.get("prompt", "")[:50]
|
||||||
|
duration = gen.get("duration_seconds", 0)
|
||||||
|
recent_items.append(f"• **{model}** ({duration:.0f}s): {prompt}...")
|
||||||
|
recent_text = "\n".join(recent_items)
|
||||||
|
else:
|
||||||
|
recent_text = "No recent generations"
|
||||||
|
|
||||||
|
# Format GPU status
|
||||||
|
used_gb = gpu.get("used_gb", 0)
|
||||||
|
total_gb = gpu.get("total_gb", 24)
|
||||||
|
util = gpu.get("utilization_percent", 0)
|
||||||
|
gpu_text = f"**GPU:** {used_gb:.1f}/{total_gb:.1f} GB ({util:.0f}%)"
|
||||||
|
|
||||||
|
return queue_text, recent_text, gpu_text
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
# Header
|
||||||
|
gr.Markdown("# AudioCraft Studio")
|
||||||
|
gr.Markdown("AI-powered music and sound generation")
|
||||||
|
|
||||||
|
# Status bar
|
||||||
|
with gr.Row():
|
||||||
|
queue_status = gr.Markdown("**Queue:** Loading...")
|
||||||
|
gpu_status = gr.Markdown("**GPU:** Loading...")
|
||||||
|
refresh_btn = gr.Button("🔄 Refresh", size="sm")
|
||||||
|
|
||||||
|
gr.Markdown("---")
|
||||||
|
|
||||||
|
# Model cards
|
||||||
|
gr.Markdown("## Models")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# First row of cards
|
||||||
|
for model_id in ["musicgen", "audiogen", "magnet"]:
|
||||||
|
info = MODEL_INFO[model_id]
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||||
|
gr.Markdown(info["description"])
|
||||||
|
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||||
|
gr.Button(
|
||||||
|
f"Open {info['name']}",
|
||||||
|
variant="primary",
|
||||||
|
size="sm",
|
||||||
|
elem_id=f"btn_{model_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Second row of cards
|
||||||
|
for model_id in ["musicgen-style", "jasco"]:
|
||||||
|
info = MODEL_INFO[model_id]
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
with gr.Group():
|
||||||
|
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||||
|
gr.Markdown(info["description"])
|
||||||
|
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||||
|
gr.Button(
|
||||||
|
f"Open {info['name']}",
|
||||||
|
variant="primary",
|
||||||
|
size="sm",
|
||||||
|
elem_id=f"btn_{model_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty column for balance
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
pass
|
||||||
|
|
||||||
|
gr.Markdown("---")
|
||||||
|
|
||||||
|
# Recent generations and queue
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
gr.Markdown("## Recent Generations")
|
||||||
|
recent_list = gr.Markdown("Loading...")
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
gr.Markdown("## Quick Actions")
|
||||||
|
with gr.Group():
|
||||||
|
gr.Button("📁 Browse Projects", variant="secondary")
|
||||||
|
gr.Button("⚙️ Settings", variant="secondary")
|
||||||
|
gr.Button("📖 API Documentation", variant="secondary")
|
||||||
|
|
||||||
|
# Refresh handler
|
||||||
|
refresh_btn.click(
|
||||||
|
fn=refresh_dashboard,
|
||||||
|
outputs=[queue_status, recent_list, gpu_status],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"queue_status": queue_status,
|
||||||
|
"gpu_status": gpu_status,
|
||||||
|
"recent_list": recent_list,
|
||||||
|
"refresh_btn": refresh_btn,
|
||||||
|
"refresh_fn": refresh_dashboard,
|
||||||
|
}
|
||||||
364
src/ui/tabs/jasco_tab.py
Normal file
364
src/ui/tabs/jasco_tab.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
"""JASCO tab for chord and drum-conditioned generation."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.audio_player import create_generation_output
|
||||||
|
|
||||||
|
|
||||||
|
JASCO_VARIANTS = [
|
||||||
|
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation"},
|
||||||
|
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Common chord progressions
|
||||||
|
CHORD_PRESETS = [
|
||||||
|
{"name": "Pop I-V-vi-IV", "chords": "C G Am F"},
|
||||||
|
{"name": "Jazz ii-V-I", "chords": "Dm7 G7 Cmaj7"},
|
||||||
|
{"name": "Blues I-IV-V", "chords": "A7 D7 E7"},
|
||||||
|
{"name": "Rock I-bVII-IV", "chords": "E D A"},
|
||||||
|
{"name": "Minor i-VI-III-VII", "chords": "Am F C G"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_jasco_tab(
|
||||||
|
generate_fn: Callable[..., Any],
|
||||||
|
add_to_queue_fn: Callable[..., Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create JASCO generation tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to call for generation
|
||||||
|
add_to_queue_fn: Function to add to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get("jasco", [])
|
||||||
|
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## 🎹 JASCO")
|
||||||
|
gr.Markdown("Generate music conditioned on chords and drum patterns")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - inputs
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
# Preset selector
|
||||||
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=preset_choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model variant
|
||||||
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in JASCO_VARIANTS]
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=variant_choices,
|
||||||
|
value="chords-drums",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt input
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Text Prompt",
|
||||||
|
placeholder="Describe the music style, mood, instruments...",
|
||||||
|
lines=2,
|
||||||
|
max_lines=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chord conditioning
|
||||||
|
gr.Markdown("### Chord Progression")
|
||||||
|
|
||||||
|
chord_input = gr.Textbox(
|
||||||
|
label="Chords",
|
||||||
|
placeholder="C G Am F or Cmaj7 Dm7 G7 Cmaj7",
|
||||||
|
lines=1,
|
||||||
|
info="Space-separated chord symbols",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chord presets
|
||||||
|
with gr.Accordion("Chord Presets", open=False):
|
||||||
|
chord_preset_btns = []
|
||||||
|
with gr.Row():
|
||||||
|
for cp in CHORD_PRESETS[:3]:
|
||||||
|
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||||
|
chord_preset_btns.append((btn, cp["chords"]))
|
||||||
|
with gr.Row():
|
||||||
|
for cp in CHORD_PRESETS[3:]:
|
||||||
|
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||||
|
chord_preset_btns.append((btn, cp["chords"]))
|
||||||
|
|
||||||
|
# Drum conditioning (for chords-drums variant)
|
||||||
|
with gr.Group(visible=True) as drum_section:
|
||||||
|
gr.Markdown("### Drum Pattern")
|
||||||
|
|
||||||
|
drum_input = gr.Audio(
|
||||||
|
label="Drum Reference",
|
||||||
|
type="filepath",
|
||||||
|
sources=["upload"],
|
||||||
|
)
|
||||||
|
gr.Markdown("*Upload a drum loop to condition the rhythm*")
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
gr.Markdown("### Parameters")
|
||||||
|
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=30,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
)
|
||||||
|
|
||||||
|
bpm_slider = gr.Slider(
|
||||||
|
minimum=60,
|
||||||
|
maximum=180,
|
||||||
|
value=120,
|
||||||
|
step=1,
|
||||||
|
label="BPM",
|
||||||
|
info="Tempo for chord timing",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced Parameters", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=250,
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed (empty = random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buttons
|
||||||
|
with gr.Row():
|
||||||
|
generate_btn = gr.Button("🎹 Generate", variant="primary", scale=2)
|
||||||
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
# Right column - output
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
output = create_generation_output()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
# Preset change
|
||||||
|
def apply_preset(preset_id: str):
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
params = p["parameters"]
|
||||||
|
return (
|
||||||
|
params.get("duration", 10),
|
||||||
|
params.get("bpm", 120),
|
||||||
|
params.get("temperature", 1.0),
|
||||||
|
params.get("cfg_coef", 3.0),
|
||||||
|
params.get("top_k", 250),
|
||||||
|
params.get("top_p", 0.0),
|
||||||
|
)
|
||||||
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||||
|
|
||||||
|
preset_dropdown.change(
|
||||||
|
fn=apply_preset,
|
||||||
|
inputs=[preset_dropdown],
|
||||||
|
outputs=[duration_slider, bpm_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Variant change - show/hide drum section
|
||||||
|
def on_variant_change(variant: str):
|
||||||
|
show_drums = "drums" in variant.lower()
|
||||||
|
return gr.update(visible=show_drums)
|
||||||
|
|
||||||
|
variant_dropdown.change(
|
||||||
|
fn=on_variant_change,
|
||||||
|
inputs=[variant_dropdown],
|
||||||
|
outputs=[drum_section],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Chord presets
|
||||||
|
for btn, chords in chord_preset_btns:
|
||||||
|
btn.click(
|
||||||
|
fn=lambda c=chords: c,
|
||||||
|
outputs=[chord_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
async def do_generate(
|
||||||
|
prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed
|
||||||
|
):
|
||||||
|
if not chords:
|
||||||
|
return (
|
||||||
|
gr.update(value="Please enter a chord progression"),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="🔄 Generating..."),
|
||||||
|
gr.update(visible=True, value=0),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conditioning = {
|
||||||
|
"chords": chords,
|
||||||
|
"bpm": bpm,
|
||||||
|
}
|
||||||
|
if drums and "drums" in variant.lower():
|
||||||
|
conditioning["drums"] = drums
|
||||||
|
|
||||||
|
result, generation = await generate_fn(
|
||||||
|
model_id="jasco",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt] if prompt else [""],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="✅ Generation complete!"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(value=generation.audio_path),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(value=f"{result.duration:.2f}s"),
|
||||||
|
gr.update(value=str(result.seed)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield (
|
||||||
|
gr.update(value=f"❌ Error: {str(e)}"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=do_generate,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
chord_input,
|
||||||
|
drum_input,
|
||||||
|
duration_slider,
|
||||||
|
bpm_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
output["status"],
|
||||||
|
output["progress"],
|
||||||
|
output["player"]["audio"],
|
||||||
|
output["player"]["waveform"],
|
||||||
|
output["player"]["duration"],
|
||||||
|
output["player"]["seed"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue
|
||||||
|
def do_add_queue(prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed):
|
||||||
|
if not chords:
|
||||||
|
return "Please enter a chord progression"
|
||||||
|
|
||||||
|
conditioning = {
|
||||||
|
"chords": chords,
|
||||||
|
"bpm": bpm,
|
||||||
|
}
|
||||||
|
if drums and "drums" in variant.lower():
|
||||||
|
conditioning["drums"] = drums
|
||||||
|
|
||||||
|
job = add_to_queue_fn(
|
||||||
|
model_id="jasco",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt] if prompt else [""],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"✅ Added to queue: {job.id}"
|
||||||
|
|
||||||
|
queue_btn.click(
|
||||||
|
fn=do_add_queue,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
chord_input,
|
||||||
|
drum_input,
|
||||||
|
duration_slider,
|
||||||
|
bpm_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[output["status"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preset": preset_dropdown,
|
||||||
|
"variant": variant_dropdown,
|
||||||
|
"prompt": prompt_input,
|
||||||
|
"chords": chord_input,
|
||||||
|
"drums": drum_input,
|
||||||
|
"duration": duration_slider,
|
||||||
|
"bpm": bpm_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"seed": seed_input,
|
||||||
|
"generate_btn": generate_btn,
|
||||||
|
"queue_btn": queue_btn,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
316
src/ui/tabs/magnet_tab.py
Normal file
316
src/ui/tabs/magnet_tab.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""MAGNeT tab for fast non-autoregressive generation."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.audio_player import create_generation_output
|
||||||
|
|
||||||
|
|
||||||
|
MAGNET_VARIANTS = [
|
||||||
|
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params"},
|
||||||
|
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params"},
|
||||||
|
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects"},
|
||||||
|
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_magnet_tab(
|
||||||
|
generate_fn: Callable[..., Any],
|
||||||
|
add_to_queue_fn: Callable[..., Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create MAGNeT generation tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to call for generation
|
||||||
|
add_to_queue_fn: Function to add to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get("magnet", [])
|
||||||
|
suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) # Reuse music suggestions
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## ⚡ MAGNeT")
|
||||||
|
gr.Markdown("Fast non-autoregressive music and sound generation")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - inputs
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
# Preset selector
|
||||||
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=preset_choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model variant
|
||||||
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MAGNET_VARIANTS]
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=variant_choices,
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt input
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
placeholder="Describe the music or sound you want to generate...",
|
||||||
|
lines=3,
|
||||||
|
max_lines=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
with gr.Accordion("Prompt Suggestions", open=False):
|
||||||
|
suggestion_btns = []
|
||||||
|
for i, suggestion in enumerate(suggestions[:4]):
|
||||||
|
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||||
|
suggestion_btns.append((btn, suggestion))
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
gr.Markdown("### Parameters")
|
||||||
|
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=30,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced Parameters", open=False):
|
||||||
|
gr.Markdown("*MAGNeT uses different sampling compared to MusicGen*")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=5.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.1,
|
||||||
|
label="Temperature",
|
||||||
|
info="Higher values recommended (3.0 default)",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=0,
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
info="0 recommended for MAGNeT",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.9,
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P",
|
||||||
|
info="0.9 recommended for MAGNeT",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
decoding_steps_slider = gr.Slider(
|
||||||
|
minimum=10,
|
||||||
|
maximum=100,
|
||||||
|
value=20,
|
||||||
|
step=5,
|
||||||
|
label="Decoding Steps",
|
||||||
|
info="More steps = better quality, slower",
|
||||||
|
)
|
||||||
|
span_arrangement = gr.Dropdown(
|
||||||
|
label="Span Arrangement",
|
||||||
|
choices=[("No Overlap", "nonoverlap"), ("Overlap", "stride1")],
|
||||||
|
value="nonoverlap",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed (empty = random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buttons
|
||||||
|
with gr.Row():
|
||||||
|
generate_btn = gr.Button("⚡ Generate", variant="primary", scale=2)
|
||||||
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
# Right column - output
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
output = create_generation_output()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
# Preset change
|
||||||
|
def apply_preset(preset_id: str):
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
params = p["parameters"]
|
||||||
|
return (
|
||||||
|
params.get("duration", 10),
|
||||||
|
params.get("temperature", 3.0),
|
||||||
|
params.get("cfg_coef", 3.0),
|
||||||
|
params.get("top_k", 0),
|
||||||
|
params.get("top_p", 0.9),
|
||||||
|
params.get("decoding_steps", 20),
|
||||||
|
)
|
||||||
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||||
|
|
||||||
|
preset_dropdown.change(
|
||||||
|
fn=apply_preset,
|
||||||
|
inputs=[preset_dropdown],
|
||||||
|
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, decoding_steps_slider],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
for btn, suggestion in suggestion_btns:
|
||||||
|
btn.click(
|
||||||
|
fn=lambda s=suggestion: s,
|
||||||
|
outputs=[prompt_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
async def do_generate(
|
||||||
|
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed
|
||||||
|
):
|
||||||
|
if not prompt:
|
||||||
|
return (
|
||||||
|
gr.update(value="Please enter a prompt"),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="🔄 Generating..."),
|
||||||
|
gr.update(visible=True, value=0),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result, generation = await generate_fn(
|
||||||
|
model_id="magnet",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
decoding_steps=int(decoding_steps),
|
||||||
|
span_arrangement=span_arr,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="✅ Generation complete!"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(value=generation.audio_path),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(value=f"{result.duration:.2f}s"),
|
||||||
|
gr.update(value=str(result.seed)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield (
|
||||||
|
gr.update(value=f"❌ Error: {str(e)}"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=do_generate,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
decoding_steps_slider,
|
||||||
|
span_arrangement,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
output["status"],
|
||||||
|
output["progress"],
|
||||||
|
output["player"]["audio"],
|
||||||
|
output["player"]["waveform"],
|
||||||
|
output["player"]["duration"],
|
||||||
|
output["player"]["seed"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue
|
||||||
|
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed):
|
||||||
|
if not prompt:
|
||||||
|
return "Please enter a prompt"
|
||||||
|
|
||||||
|
job = add_to_queue_fn(
|
||||||
|
model_id="magnet",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
decoding_steps=int(decoding_steps),
|
||||||
|
span_arrangement=span_arr,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"✅ Added to queue: {job.id}"
|
||||||
|
|
||||||
|
queue_btn.click(
|
||||||
|
fn=do_add_queue,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
decoding_steps_slider,
|
||||||
|
span_arrangement,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[output["status"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preset": preset_dropdown,
|
||||||
|
"variant": variant_dropdown,
|
||||||
|
"prompt": prompt_input,
|
||||||
|
"duration": duration_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"decoding_steps": decoding_steps_slider,
|
||||||
|
"span_arrangement": span_arrangement,
|
||||||
|
"seed": seed_input,
|
||||||
|
"generate_btn": generate_btn,
|
||||||
|
"queue_btn": queue_btn,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
325
src/ui/tabs/musicgen_tab.py
Normal file
325
src/ui/tabs/musicgen_tab.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
"""MusicGen tab for text-to-music generation."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.audio_player import create_generation_output
|
||||||
|
|
||||||
|
|
||||||
|
MUSICGEN_VARIANTS = [
|
||||||
|
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"},
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"},
|
||||||
|
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"},
|
||||||
|
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"},
|
||||||
|
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"},
|
||||||
|
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"},
|
||||||
|
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"},
|
||||||
|
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_musicgen_tab(
|
||||||
|
generate_fn: Callable[..., Any],
|
||||||
|
add_to_queue_fn: Callable[..., Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create MusicGen generation tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to call for generation
|
||||||
|
add_to_queue_fn: Function to add to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get("musicgen", [])
|
||||||
|
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## 🎵 MusicGen")
|
||||||
|
gr.Markdown("Generate music from text descriptions")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - inputs
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
# Preset selector
|
||||||
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=preset_choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model variant
|
||||||
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS]
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=variant_choices,
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt input
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Prompt",
|
||||||
|
placeholder="Describe the music you want to generate...",
|
||||||
|
lines=3,
|
||||||
|
max_lines=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
with gr.Accordion("Prompt Suggestions", open=False):
|
||||||
|
suggestion_btns = []
|
||||||
|
for i, suggestion in enumerate(suggestions[:4]):
|
||||||
|
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||||
|
suggestion_btns.append((btn, suggestion))
|
||||||
|
|
||||||
|
# Melody conditioning (for melody variants)
|
||||||
|
with gr.Group(visible=False) as melody_section:
|
||||||
|
gr.Markdown("### Melody Conditioning")
|
||||||
|
melody_input = gr.Audio(
|
||||||
|
label="Reference Melody",
|
||||||
|
type="filepath",
|
||||||
|
sources=["upload", "microphone"],
|
||||||
|
)
|
||||||
|
gr.Markdown("*Upload audio to condition generation on its melody*")
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
gr.Markdown("### Parameters")
|
||||||
|
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=30,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced Parameters", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=250,
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed (empty = random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buttons
|
||||||
|
with gr.Row():
|
||||||
|
generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2)
|
||||||
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
# Right column - output
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
output = create_generation_output()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
# Preset change
|
||||||
|
def apply_preset(preset_id: str):
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
params = p["parameters"]
|
||||||
|
return (
|
||||||
|
params.get("duration", 10),
|
||||||
|
params.get("temperature", 1.0),
|
||||||
|
params.get("cfg_coef", 3.0),
|
||||||
|
params.get("top_k", 250),
|
||||||
|
params.get("top_p", 0.0),
|
||||||
|
)
|
||||||
|
# Custom preset - don't change values
|
||||||
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||||
|
|
||||||
|
preset_dropdown.change(
|
||||||
|
fn=apply_preset,
|
||||||
|
inputs=[preset_dropdown],
|
||||||
|
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Variant change - show/hide melody section
|
||||||
|
def on_variant_change(variant: str):
|
||||||
|
show_melody = "melody" in variant.lower()
|
||||||
|
return gr.update(visible=show_melody)
|
||||||
|
|
||||||
|
variant_dropdown.change(
|
||||||
|
fn=on_variant_change,
|
||||||
|
inputs=[variant_dropdown],
|
||||||
|
outputs=[melody_section],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt suggestions
|
||||||
|
for btn, suggestion in suggestion_btns:
|
||||||
|
btn.click(
|
||||||
|
fn=lambda s=suggestion: s,
|
||||||
|
outputs=[prompt_input],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
async def do_generate(
|
||||||
|
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
|
||||||
|
):
|
||||||
|
if not prompt:
|
||||||
|
return (
|
||||||
|
gr.update(value="Please enter a prompt"),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update status
|
||||||
|
yield (
|
||||||
|
gr.update(value="🔄 Generating..."),
|
||||||
|
gr.update(visible=True, value=0),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conditioning = {}
|
||||||
|
if melody:
|
||||||
|
conditioning["melody"] = melody
|
||||||
|
|
||||||
|
result, generation = await generate_fn(
|
||||||
|
model_id="musicgen",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="✅ Generation complete!"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(value=generation.audio_path),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(value=f"{result.duration:.2f}s"),
|
||||||
|
gr.update(value=str(result.seed)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield (
|
||||||
|
gr.update(value=f"❌ Error: {str(e)}"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=do_generate,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
melody_input,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
output["status"],
|
||||||
|
output["progress"],
|
||||||
|
output["player"]["audio"],
|
||||||
|
output["player"]["waveform"],
|
||||||
|
output["player"]["duration"],
|
||||||
|
output["player"]["seed"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue
|
||||||
|
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
|
||||||
|
if not prompt:
|
||||||
|
return "Please enter a prompt"
|
||||||
|
|
||||||
|
conditioning = {}
|
||||||
|
if melody:
|
||||||
|
conditioning["melody"] = melody
|
||||||
|
|
||||||
|
job = add_to_queue_fn(
|
||||||
|
model_id="musicgen",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"✅ Added to queue: {job.id}"
|
||||||
|
|
||||||
|
queue_btn.click(
|
||||||
|
fn=do_add_queue,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
melody_input,
|
||||||
|
],
|
||||||
|
outputs=[output["status"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preset": preset_dropdown,
|
||||||
|
"variant": variant_dropdown,
|
||||||
|
"prompt": prompt_input,
|
||||||
|
"melody": melody_input,
|
||||||
|
"duration": duration_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"seed": seed_input,
|
||||||
|
"generate_btn": generate_btn,
|
||||||
|
"queue_btn": queue_btn,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
292
src/ui/tabs/style_tab.py
Normal file
292
src/ui/tabs/style_tab.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""MusicGen Style tab for style-conditioned generation."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||||
|
from src.ui.components.audio_player import create_generation_output
|
||||||
|
|
||||||
|
|
||||||
|
STYLE_VARIANTS = [
|
||||||
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_style_tab(
|
||||||
|
generate_fn: Callable[..., Any],
|
||||||
|
add_to_queue_fn: Callable[..., Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Create MusicGen Style generation tab.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generate_fn: Function to call for generation
|
||||||
|
add_to_queue_fn: Function to add to queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with component references
|
||||||
|
"""
|
||||||
|
presets = DEFAULT_PRESETS.get("musicgen-style", [])
|
||||||
|
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||||
|
|
||||||
|
with gr.Column():
|
||||||
|
gr.Markdown("## 🎨 MusicGen Style")
|
||||||
|
gr.Markdown("Generate music conditioned on the style of reference audio")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left column - inputs
|
||||||
|
with gr.Column(scale=2):
|
||||||
|
# Preset selector
|
||||||
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||||
|
preset_dropdown = gr.Dropdown(
|
||||||
|
label="Preset",
|
||||||
|
choices=preset_choices,
|
||||||
|
value=presets[0]["id"] if presets else "custom",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model variant
|
||||||
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
|
||||||
|
variant_dropdown = gr.Dropdown(
|
||||||
|
label="Model Variant",
|
||||||
|
choices=variant_choices,
|
||||||
|
value="medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt input
|
||||||
|
prompt_input = gr.Textbox(
|
||||||
|
label="Text Prompt",
|
||||||
|
placeholder="Describe additional characteristics for the music...",
|
||||||
|
lines=3,
|
||||||
|
max_lines=5,
|
||||||
|
info="Optional: combine with style conditioning",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Style conditioning (required)
|
||||||
|
gr.Markdown("### Style Conditioning")
|
||||||
|
gr.Markdown("*Upload reference audio to extract musical style*")
|
||||||
|
|
||||||
|
style_input = gr.Audio(
|
||||||
|
label="Style Reference",
|
||||||
|
type="filepath",
|
||||||
|
sources=["upload", "microphone"],
|
||||||
|
)
|
||||||
|
|
||||||
|
style_info = gr.Markdown(
|
||||||
|
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
gr.Markdown("### Parameters")
|
||||||
|
|
||||||
|
duration_slider = gr.Slider(
|
||||||
|
minimum=1,
|
||||||
|
maximum=30,
|
||||||
|
value=10,
|
||||||
|
step=1,
|
||||||
|
label="Duration (seconds)",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Accordion("Advanced Parameters", open=False):
|
||||||
|
with gr.Row():
|
||||||
|
temperature_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=2.0,
|
||||||
|
value=1.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Temperature",
|
||||||
|
)
|
||||||
|
cfg_slider = gr.Slider(
|
||||||
|
minimum=1.0,
|
||||||
|
maximum=10.0,
|
||||||
|
value=3.0,
|
||||||
|
step=0.5,
|
||||||
|
label="CFG Coefficient",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
top_k_slider = gr.Slider(
|
||||||
|
minimum=0,
|
||||||
|
maximum=500,
|
||||||
|
value=250,
|
||||||
|
step=10,
|
||||||
|
label="Top-K",
|
||||||
|
)
|
||||||
|
top_p_slider = gr.Slider(
|
||||||
|
minimum=0.0,
|
||||||
|
maximum=1.0,
|
||||||
|
value=0.0,
|
||||||
|
step=0.05,
|
||||||
|
label="Top-P",
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
seed_input = gr.Number(
|
||||||
|
value=None,
|
||||||
|
label="Seed (empty = random)",
|
||||||
|
precision=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buttons
|
||||||
|
with gr.Row():
|
||||||
|
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
|
||||||
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
# Right column - output
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
output = create_generation_output()
|
||||||
|
|
||||||
|
# Event handlers
|
||||||
|
|
||||||
|
# Preset change
|
||||||
|
def apply_preset(preset_id: str):
|
||||||
|
for p in presets:
|
||||||
|
if p["id"] == preset_id:
|
||||||
|
params = p["parameters"]
|
||||||
|
return (
|
||||||
|
params.get("duration", 10),
|
||||||
|
params.get("temperature", 1.0),
|
||||||
|
params.get("cfg_coef", 3.0),
|
||||||
|
params.get("top_k", 250),
|
||||||
|
params.get("top_p", 0.0),
|
||||||
|
)
|
||||||
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||||
|
|
||||||
|
preset_dropdown.change(
|
||||||
|
fn=apply_preset,
|
||||||
|
inputs=[preset_dropdown],
|
||||||
|
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate
|
||||||
|
async def do_generate(
|
||||||
|
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||||
|
):
|
||||||
|
if not style_audio:
|
||||||
|
return (
|
||||||
|
gr.update(value="Please upload a style reference audio"),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="🔄 Generating..."),
|
||||||
|
gr.update(visible=True, value=0),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
conditioning = {"style": style_audio}
|
||||||
|
|
||||||
|
result, generation = await generate_fn(
|
||||||
|
model_id="musicgen-style",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt] if prompt else [""],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield (
|
||||||
|
gr.update(value="✅ Generation complete!"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(value=generation.audio_path),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(value=f"{result.duration:.2f}s"),
|
||||||
|
gr.update(value=str(result.seed)),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield (
|
||||||
|
gr.update(value=f"❌ Error: {str(e)}"),
|
||||||
|
gr.update(visible=False),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
gr.update(),
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_btn.click(
|
||||||
|
fn=do_generate,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
style_input,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
output["status"],
|
||||||
|
output["progress"],
|
||||||
|
output["player"]["audio"],
|
||||||
|
output["player"]["waveform"],
|
||||||
|
output["player"]["duration"],
|
||||||
|
output["player"]["seed"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add to queue
|
||||||
|
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||||
|
if not style_audio:
|
||||||
|
return "Please upload a style reference audio"
|
||||||
|
|
||||||
|
conditioning = {"style": style_audio}
|
||||||
|
|
||||||
|
job = add_to_queue_fn(
|
||||||
|
model_id="musicgen-style",
|
||||||
|
variant=variant,
|
||||||
|
prompts=[prompt] if prompt else [""],
|
||||||
|
duration=duration,
|
||||||
|
temperature=temperature,
|
||||||
|
top_k=int(top_k),
|
||||||
|
top_p=top_p,
|
||||||
|
cfg_coef=cfg_coef,
|
||||||
|
seed=int(seed) if seed else None,
|
||||||
|
conditioning=conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"✅ Added to queue: {job.id}"
|
||||||
|
|
||||||
|
queue_btn.click(
|
||||||
|
fn=do_add_queue,
|
||||||
|
inputs=[
|
||||||
|
prompt_input,
|
||||||
|
variant_dropdown,
|
||||||
|
style_input,
|
||||||
|
duration_slider,
|
||||||
|
temperature_slider,
|
||||||
|
cfg_slider,
|
||||||
|
top_k_slider,
|
||||||
|
top_p_slider,
|
||||||
|
seed_input,
|
||||||
|
],
|
||||||
|
outputs=[output["status"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preset": preset_dropdown,
|
||||||
|
"variant": variant_dropdown,
|
||||||
|
"prompt": prompt_input,
|
||||||
|
"style": style_input,
|
||||||
|
"duration": duration_slider,
|
||||||
|
"temperature": temperature_slider,
|
||||||
|
"cfg_coef": cfg_slider,
|
||||||
|
"top_k": top_k_slider,
|
||||||
|
"top_p": top_p_slider,
|
||||||
|
"seed": seed_input,
|
||||||
|
"generate_btn": generate_btn,
|
||||||
|
"queue_btn": queue_btn,
|
||||||
|
"output": output,
|
||||||
|
}
|
||||||
303
src/ui/theme.py
Normal file
303
src/ui/theme.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
"""Custom Gradio theme for AudioCraft Studio."""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
|
def create_theme() -> gr.themes.Base:
|
||||||
|
"""Create custom theme for AudioCraft Studio.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Gradio theme instance
|
||||||
|
"""
|
||||||
|
return gr.themes.Soft(
|
||||||
|
primary_hue=gr.themes.colors.blue,
|
||||||
|
secondary_hue=gr.themes.colors.slate,
|
||||||
|
neutral_hue=gr.themes.colors.gray,
|
||||||
|
font=[
|
||||||
|
gr.themes.GoogleFont("Inter"),
|
||||||
|
"ui-sans-serif",
|
||||||
|
"system-ui",
|
||||||
|
"sans-serif",
|
||||||
|
],
|
||||||
|
font_mono=[
|
||||||
|
gr.themes.GoogleFont("JetBrains Mono"),
|
||||||
|
"ui-monospace",
|
||||||
|
"monospace",
|
||||||
|
],
|
||||||
|
).set(
|
||||||
|
# Colors
|
||||||
|
body_background_fill="#0f172a",
|
||||||
|
body_background_fill_dark="#0f172a",
|
||||||
|
background_fill_primary="#1e293b",
|
||||||
|
background_fill_primary_dark="#1e293b",
|
||||||
|
background_fill_secondary="#334155",
|
||||||
|
background_fill_secondary_dark="#334155",
|
||||||
|
border_color_primary="#475569",
|
||||||
|
border_color_primary_dark="#475569",
|
||||||
|
|
||||||
|
# Text
|
||||||
|
body_text_color="#e2e8f0",
|
||||||
|
body_text_color_dark="#e2e8f0",
|
||||||
|
body_text_color_subdued="#94a3b8",
|
||||||
|
body_text_color_subdued_dark="#94a3b8",
|
||||||
|
|
||||||
|
# Buttons
|
||||||
|
button_primary_background_fill="#3b82f6",
|
||||||
|
button_primary_background_fill_dark="#3b82f6",
|
||||||
|
button_primary_background_fill_hover="#2563eb",
|
||||||
|
button_primary_background_fill_hover_dark="#2563eb",
|
||||||
|
button_primary_text_color="#ffffff",
|
||||||
|
button_primary_text_color_dark="#ffffff",
|
||||||
|
|
||||||
|
button_secondary_background_fill="#475569",
|
||||||
|
button_secondary_background_fill_dark="#475569",
|
||||||
|
button_secondary_background_fill_hover="#64748b",
|
||||||
|
button_secondary_background_fill_hover_dark="#64748b",
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
input_background_fill="#1e293b",
|
||||||
|
input_background_fill_dark="#1e293b",
|
||||||
|
input_border_color="#475569",
|
||||||
|
input_border_color_dark="#475569",
|
||||||
|
input_border_color_focus="#3b82f6",
|
||||||
|
input_border_color_focus_dark="#3b82f6",
|
||||||
|
|
||||||
|
# Blocks
|
||||||
|
block_background_fill="#1e293b",
|
||||||
|
block_background_fill_dark="#1e293b",
|
||||||
|
block_border_color="#334155",
|
||||||
|
block_border_color_dark="#334155",
|
||||||
|
block_label_background_fill="#334155",
|
||||||
|
block_label_background_fill_dark="#334155",
|
||||||
|
block_label_text_color="#e2e8f0",
|
||||||
|
block_label_text_color_dark="#e2e8f0",
|
||||||
|
block_title_text_color="#f1f5f9",
|
||||||
|
block_title_text_color_dark="#f1f5f9",
|
||||||
|
|
||||||
|
# Tabs
|
||||||
|
tab_nav_background_fill="#1e293b",
|
||||||
|
|
||||||
|
# Sliders
|
||||||
|
slider_color="#3b82f6",
|
||||||
|
slider_color_dark="#3b82f6",
|
||||||
|
|
||||||
|
# Shadows
|
||||||
|
shadow_spread="4px",
|
||||||
|
block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3)",
|
||||||
|
|
||||||
|
# Spacing
|
||||||
|
layout_gap="16px",
|
||||||
|
block_padding="16px",
|
||||||
|
panel_border_width="1px",
|
||||||
|
|
||||||
|
# Radius
|
||||||
|
radius_sm="6px",
|
||||||
|
radius_md="8px",
|
||||||
|
radius_lg="12px",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# CSS overrides for additional customization
|
||||||
|
CUSTOM_CSS = """
|
||||||
|
/* Global styles */
|
||||||
|
.gradio-container {
|
||||||
|
max-width: 100% !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Header styling */
|
||||||
|
.header-title {
|
||||||
|
font-size: 1.5rem;
|
||||||
|
font-weight: 700;
|
||||||
|
color: #f1f5f9;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Sidebar styling */
|
||||||
|
.sidebar {
|
||||||
|
background: #1e293b;
|
||||||
|
border-right: 1px solid #334155;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.sidebar-nav-btn {
|
||||||
|
width: 100%;
|
||||||
|
justify-content: flex-start;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Model cards */
|
||||||
|
.model-card {
|
||||||
|
background: #334155;
|
||||||
|
border-radius: 12px;
|
||||||
|
padding: 1rem;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.model-card:hover {
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Audio player */
|
||||||
|
.audio-player {
|
||||||
|
background: #1e293b;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Progress bar */
|
||||||
|
.progress-bar {
|
||||||
|
background: #334155;
|
||||||
|
border-radius: 4px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.progress-fill {
|
||||||
|
background: linear-gradient(90deg, #3b82f6, #8b5cf6);
|
||||||
|
height: 100%;
|
||||||
|
transition: width 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* VRAM monitor */
|
||||||
|
.vram-bar {
|
||||||
|
background: #334155;
|
||||||
|
border-radius: 4px;
|
||||||
|
height: 24px;
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.vram-fill {
|
||||||
|
position: absolute;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
height: 100%;
|
||||||
|
background: linear-gradient(90deg, #22c55e, #eab308, #ef4444);
|
||||||
|
transition: width 0.5s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.vram-text {
|
||||||
|
position: absolute;
|
||||||
|
width: 100%;
|
||||||
|
text-align: center;
|
||||||
|
line-height: 24px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
color: white;
|
||||||
|
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.5);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Queue badge */
|
||||||
|
.queue-badge {
|
||||||
|
background: #3b82f6;
|
||||||
|
color: white;
|
||||||
|
padding: 0.25rem 0.75rem;
|
||||||
|
border-radius: 9999px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Generation card */
|
||||||
|
.generation-card {
|
||||||
|
background: #334155;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 1rem;
|
||||||
|
margin-bottom: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Preset chips */
|
||||||
|
.preset-chip {
|
||||||
|
display: inline-block;
|
||||||
|
background: #475569;
|
||||||
|
color: #e2e8f0;
|
||||||
|
padding: 0.25rem 0.75rem;
|
||||||
|
border-radius: 9999px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
margin: 0.25rem;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preset-chip:hover {
|
||||||
|
background: #3b82f6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.preset-chip.active {
|
||||||
|
background: #3b82f6;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tag input */
|
||||||
|
.tag {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
background: #475569;
|
||||||
|
color: #e2e8f0;
|
||||||
|
padding: 0.25rem 0.5rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
margin: 0.125rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Accordion tweaks */
|
||||||
|
.accordion-header {
|
||||||
|
font-weight: 600;
|
||||||
|
color: #f1f5f9;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Status indicators */
|
||||||
|
.status-dot {
|
||||||
|
width: 8px;
|
||||||
|
height: 8px;
|
||||||
|
border-radius: 50%;
|
||||||
|
display: inline-block;
|
||||||
|
margin-right: 0.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.loaded {
|
||||||
|
background: #22c55e;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.unloaded {
|
||||||
|
background: #64748b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.status-dot.loading {
|
||||||
|
background: #eab308;
|
||||||
|
animation: pulse 1s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { opacity: 1; }
|
||||||
|
50% { opacity: 0.5; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Tooltip */
|
||||||
|
.tooltip {
|
||||||
|
position: relative;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tooltip:hover::after {
|
||||||
|
content: attr(data-tooltip);
|
||||||
|
position: absolute;
|
||||||
|
bottom: 100%;
|
||||||
|
left: 50%;
|
||||||
|
transform: translateX(-50%);
|
||||||
|
background: #1e293b;
|
||||||
|
color: #e2e8f0;
|
||||||
|
padding: 0.5rem;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
white-space: nowrap;
|
||||||
|
z-index: 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Responsive adjustments */
|
||||||
|
@media (max-width: 768px) {
|
||||||
|
.sidebar {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mobile-nav {
|
||||||
|
display: flex !important;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
Reference in New Issue
Block a user