Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
42
.env.example
Normal file
42
.env.example
Normal file
@@ -0,0 +1,42 @@
|
||||
# AudioCraft Studio Configuration
|
||||
# Copy this file to .env and customize as needed
|
||||
|
||||
# Server Configuration
|
||||
AUDIOCRAFT_HOST=0.0.0.0
|
||||
AUDIOCRAFT_GRADIO_PORT=7860
|
||||
AUDIOCRAFT_API_PORT=8000
|
||||
|
||||
# Paths (relative to project root)
|
||||
AUDIOCRAFT_DATA_DIR=./data
|
||||
AUDIOCRAFT_OUTPUT_DIR=./outputs
|
||||
AUDIOCRAFT_CACHE_DIR=./cache
|
||||
|
||||
# VRAM Management
|
||||
# Reserve this much VRAM for ComfyUI (GB)
|
||||
AUDIOCRAFT_COMFYUI_RESERVE_GB=10
|
||||
# Safety buffer to prevent OOM (GB)
|
||||
AUDIOCRAFT_SAFETY_BUFFER_GB=1
|
||||
# Unload idle models after this many minutes
|
||||
AUDIOCRAFT_IDLE_UNLOAD_MINUTES=15
|
||||
# Maximum number of models to keep loaded
|
||||
AUDIOCRAFT_MAX_CACHED_MODELS=2
|
||||
|
||||
# API Authentication
|
||||
# Generate a secure random key for production
|
||||
AUDIOCRAFT_API_KEY=your-secret-api-key-here
|
||||
|
||||
# Generation Defaults
|
||||
AUDIOCRAFT_DEFAULT_DURATION=10.0
|
||||
AUDIOCRAFT_MAX_DURATION=300.0
|
||||
AUDIOCRAFT_DEFAULT_BATCH_SIZE=1
|
||||
AUDIOCRAFT_MAX_BATCH_SIZE=8
|
||||
AUDIOCRAFT_MAX_QUEUE_SIZE=100
|
||||
|
||||
# Database
|
||||
AUDIOCRAFT_DATABASE_URL=sqlite+aiosqlite:///./data/audiocraft.db
|
||||
|
||||
# Logging
|
||||
AUDIOCRAFT_LOG_LEVEL=INFO
|
||||
|
||||
# PyTorch Optimization (recommended)
|
||||
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
76
.gitignore
vendored
Normal file
76
.gitignore
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
# Type checking
|
||||
.mypy_cache/
|
||||
|
||||
# Project specific
|
||||
data/
|
||||
outputs/
|
||||
cache/
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Environment
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Model weights (downloaded from HuggingFace)
|
||||
*.bin
|
||||
*.safetensors
|
||||
*.pt
|
||||
*.pth
|
||||
|
||||
# Audio files (generated)
|
||||
*.wav
|
||||
*.mp3
|
||||
*.flac
|
||||
*.ogg
|
||||
|
||||
# Temp files
|
||||
/tmp/
|
||||
*.tmp
|
||||
83
Dockerfile
Normal file
83
Dockerfile
Normal file
@@ -0,0 +1,83 @@
|
||||
# AudioCraft Studio Dockerfile for RunPod
|
||||
# Optimized for NVIDIA RTX 4090 (24GB VRAM)
|
||||
|
||||
FROM nvidia/cuda:12.1-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
# Set environment variables
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PIP_NO_CACHE_DIR=1
|
||||
ENV PIP_DISABLE_PIP_VERSION_CHECK=1
|
||||
|
||||
# CUDA settings
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH="${CUDA_HOME}/bin:${PATH}"
|
||||
ENV LD_LIBRARY_PATH="${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}"
|
||||
|
||||
# AudioCraft settings
|
||||
ENV AUDIOCRAFT_OUTPUT_DIR=/workspace/outputs
|
||||
ENV AUDIOCRAFT_DATA_DIR=/workspace/data
|
||||
ENV AUDIOCRAFT_MODEL_CACHE=/workspace/models
|
||||
ENV AUDIOCRAFT_HOST=0.0.0.0
|
||||
ENV AUDIOCRAFT_GRADIO_PORT=7860
|
||||
ENV AUDIOCRAFT_API_PORT=8000
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
curl \
|
||||
wget \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
libsox-dev \
|
||||
sox \
|
||||
build-essential \
|
||||
python3.10 \
|
||||
python3.10-venv \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Set Python 3.10 as default
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1 \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
|
||||
|
||||
# Upgrade pip
|
||||
RUN pip install --upgrade pip setuptools wheel
|
||||
|
||||
# Create workspace directory
|
||||
WORKDIR /workspace
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /workspace/outputs /workspace/data /workspace/models /workspace/app
|
||||
|
||||
# Copy requirements first for caching
|
||||
COPY requirements.txt /workspace/app/
|
||||
WORKDIR /workspace/app
|
||||
|
||||
# Install PyTorch with CUDA support
|
||||
RUN pip install torch==2.1.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Install other requirements
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
# Install AudioCraft from source for latest features
|
||||
RUN pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Copy application code
|
||||
COPY . /workspace/app/
|
||||
|
||||
# Create non-root user for security (optional, RunPod often uses root)
|
||||
# RUN useradd -m -u 1000 audiocraft && chown -R audiocraft:audiocraft /workspace
|
||||
# USER audiocraft
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 7860 8000
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:7860/ || exit 1
|
||||
|
||||
# Default command
|
||||
CMD ["python", "main.py"]
|
||||
197
README.md
Normal file
197
README.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# AudioCraft Studio
|
||||
|
||||
A comprehensive web interface for Meta's AudioCraft AI audio generation models, optimized for RunPod deployment with NVIDIA RTX 4090 GPUs.
|
||||
|
||||
## Features
|
||||
|
||||
### Models Supported
|
||||
- **MusicGen** - Text-to-music generation with melody conditioning
|
||||
- **AudioGen** - Text-to-sound effects and environmental audio
|
||||
- **MAGNeT** - Fast non-autoregressive music generation
|
||||
- **MusicGen Style** - Style-conditioned music from reference audio
|
||||
- **JASCO** - Chord and drum-conditioned music generation
|
||||
|
||||
### Core Capabilities
|
||||
- **Gradio Web UI** - Intuitive interface with real-time generation
|
||||
- **REST API** - Full-featured API with OpenAPI documentation
|
||||
- **Batch Processing** - Queue system for multiple generations
|
||||
- **Project Management** - Organize and browse generation history
|
||||
- **VRAM Management** - Smart model loading/unloading, ComfyUI coexistence
|
||||
- **Waveform Visualization** - Visual audio feedback
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Local Development
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/your-username/audiocraft-ui.git
|
||||
cd audiocraft-ui
|
||||
|
||||
# Create virtual environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Linux/Mac
|
||||
# or: venv\Scripts\activate # Windows
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Run application
|
||||
python main.py
|
||||
```
|
||||
|
||||
Access the UI at `http://localhost:7860`
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
# Build and run
|
||||
docker-compose up --build
|
||||
|
||||
# Or build manually
|
||||
docker build -t audiocraft-studio .
|
||||
docker run --gpus all -p 7860:7860 -p 8000:8000 audiocraft-studio
|
||||
```
|
||||
|
||||
### RunPod Deployment
|
||||
|
||||
1. Build and push Docker image:
|
||||
```bash
|
||||
docker build -t your-dockerhub/audiocraft-studio:latest .
|
||||
docker push your-dockerhub/audiocraft-studio:latest
|
||||
```
|
||||
|
||||
2. Create RunPod template using `runpod.yaml` as reference
|
||||
|
||||
3. Deploy with RTX 4090 or equivalent GPU
|
||||
|
||||
## Configuration
|
||||
|
||||
Configuration via environment variables:
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `AUDIOCRAFT_HOST` | `0.0.0.0` | Server bind address |
|
||||
| `AUDIOCRAFT_GRADIO_PORT` | `7860` | Gradio UI port |
|
||||
| `AUDIOCRAFT_API_PORT` | `8000` | REST API port |
|
||||
| `AUDIOCRAFT_OUTPUT_DIR` | `./outputs` | Generated audio output |
|
||||
| `AUDIOCRAFT_DATA_DIR` | `./data` | Database and config |
|
||||
| `AUDIOCRAFT_COMFYUI_RESERVE_GB` | `10` | VRAM reserved for ComfyUI |
|
||||
| `AUDIOCRAFT_MAX_LOADED_MODELS` | `2` | Max models in memory |
|
||||
| `AUDIOCRAFT_IDLE_UNLOAD_MINUTES` | `15` | Auto-unload idle models |
|
||||
|
||||
See `.env.example` for full configuration options.
|
||||
|
||||
## API Usage
|
||||
|
||||
### Authentication
|
||||
|
||||
```bash
|
||||
# Get API key from Settings page or generate via CLI
|
||||
curl -X POST http://localhost:8000/api/v1/system/api-key/regenerate \
|
||||
-H "X-API-Key: YOUR_CURRENT_KEY"
|
||||
```
|
||||
|
||||
### Generate Audio
|
||||
|
||||
```bash
|
||||
# Synchronous generation
|
||||
curl -X POST http://localhost:8000/api/v1/generate \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "musicgen",
|
||||
"variant": "medium",
|
||||
"prompts": ["upbeat electronic dance music with synth leads"],
|
||||
"duration": 10
|
||||
}'
|
||||
|
||||
# Async (queue) generation
|
||||
curl -X POST http://localhost:8000/api/v1/generate/async \
|
||||
-H "X-API-Key: YOUR_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"request": {
|
||||
"model": "musicgen",
|
||||
"prompts": ["ambient soundscape"],
|
||||
"duration": 30
|
||||
},
|
||||
"priority": 5
|
||||
}'
|
||||
```
|
||||
|
||||
### Check Job Status
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/api/v1/generate/jobs/{job_id} \
|
||||
-H "X-API-Key: YOUR_API_KEY"
|
||||
```
|
||||
|
||||
Full API documentation available at `http://localhost:8000/api/docs`
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
audiocraft-ui/
|
||||
├── config/
|
||||
│ ├── settings.py # Pydantic settings
|
||||
│ └── models.yaml # Model registry
|
||||
├── src/
|
||||
│ ├── core/
|
||||
│ │ ├── base_model.py # Abstract model interface
|
||||
│ │ ├── gpu_manager.py # VRAM management
|
||||
│ │ ├── model_registry.py # Model loading/caching
|
||||
│ │ └── oom_handler.py # OOM recovery
|
||||
│ ├── models/
|
||||
│ │ ├── musicgen/ # MusicGen adapter
|
||||
│ │ ├── audiogen/ # AudioGen adapter
|
||||
│ │ ├── magnet/ # MAGNeT adapter
|
||||
│ │ ├── musicgen_style/ # Style adapter
|
||||
│ │ └── jasco/ # JASCO adapter
|
||||
│ ├── services/
|
||||
│ │ ├── generation_service.py
|
||||
│ │ ├── batch_processor.py
|
||||
│ │ └── project_service.py
|
||||
│ ├── storage/
|
||||
│ │ └── database.py # SQLite storage
|
||||
│ ├── api/
|
||||
│ │ ├── app.py # FastAPI app
|
||||
│ │ └── routes/ # API endpoints
|
||||
│ └── ui/
|
||||
│ ├── app.py # Gradio app
|
||||
│ ├── components/ # Reusable UI components
|
||||
│ ├── tabs/ # Model generation tabs
|
||||
│ └── pages/ # Projects, Settings
|
||||
├── main.py # Entry point
|
||||
├── Dockerfile
|
||||
└── docker-compose.yml
|
||||
```
|
||||
|
||||
## ComfyUI Coexistence
|
||||
|
||||
AudioCraft Studio is designed to run alongside ComfyUI on the same GPU:
|
||||
|
||||
1. Set `AUDIOCRAFT_COMFYUI_RESERVE_GB` to reserve VRAM for ComfyUI
|
||||
2. Models are automatically unloaded when idle
|
||||
3. Coordination file at `/tmp/audiocraft_comfyui_coord.json` prevents conflicts
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
# Install dev dependencies
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Format code
|
||||
black src/ config/
|
||||
ruff check src/ config/
|
||||
|
||||
# Type checking
|
||||
mypy src/
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This project uses Meta's AudioCraft library. See [AudioCraft License](https://github.com/facebookresearch/audiocraft/blob/main/LICENSE).
|
||||
5
config/__init__.py
Normal file
5
config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Configuration module for AudioCraft Studio."""
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
151
config/models.yaml
Normal file
151
config/models.yaml
Normal file
@@ -0,0 +1,151 @@
|
||||
# AudioCraft Model Registry Configuration
|
||||
# This file defines all available models and their configurations
|
||||
|
||||
models:
|
||||
musicgen:
|
||||
enabled: true
|
||||
display_name: "MusicGen"
|
||||
description: "Text-to-music generation with optional melody conditioning"
|
||||
default_variant: medium
|
||||
variants:
|
||||
small:
|
||||
hf_id: facebook/musicgen-small
|
||||
vram_mb: 1500
|
||||
max_duration: 30
|
||||
description: "Fast, lightweight model (300M params)"
|
||||
medium:
|
||||
hf_id: facebook/musicgen-medium
|
||||
vram_mb: 5000
|
||||
max_duration: 30
|
||||
description: "Balanced quality and speed (1.5B params)"
|
||||
large:
|
||||
hf_id: facebook/musicgen-large
|
||||
vram_mb: 10000
|
||||
max_duration: 30
|
||||
description: "Highest quality, slower (3.3B params)"
|
||||
melody:
|
||||
hf_id: facebook/musicgen-melody
|
||||
vram_mb: 5000
|
||||
max_duration: 30
|
||||
conditioning:
|
||||
- melody
|
||||
description: "Melody-conditioned generation (1.5B params)"
|
||||
stereo-small:
|
||||
hf_id: facebook/musicgen-stereo-small
|
||||
vram_mb: 1800
|
||||
max_duration: 30
|
||||
channels: 2
|
||||
description: "Stereo output, fast (300M params)"
|
||||
stereo-medium:
|
||||
hf_id: facebook/musicgen-stereo-medium
|
||||
vram_mb: 6000
|
||||
max_duration: 30
|
||||
channels: 2
|
||||
description: "Stereo output, balanced (1.5B params)"
|
||||
stereo-large:
|
||||
hf_id: facebook/musicgen-stereo-large
|
||||
vram_mb: 12000
|
||||
max_duration: 30
|
||||
channels: 2
|
||||
description: "Stereo output, highest quality (3.3B params)"
|
||||
stereo-melody:
|
||||
hf_id: facebook/musicgen-stereo-melody
|
||||
vram_mb: 6000
|
||||
max_duration: 30
|
||||
channels: 2
|
||||
conditioning:
|
||||
- melody
|
||||
description: "Stereo melody-conditioned (1.5B params)"
|
||||
|
||||
audiogen:
|
||||
enabled: true
|
||||
display_name: "AudioGen"
|
||||
description: "Text-to-sound effects generation"
|
||||
default_variant: medium
|
||||
variants:
|
||||
medium:
|
||||
hf_id: facebook/audiogen-medium
|
||||
vram_mb: 5000
|
||||
max_duration: 10
|
||||
description: "Sound effects generator (1.5B params)"
|
||||
|
||||
magnet:
|
||||
enabled: true
|
||||
display_name: "MAGNeT"
|
||||
description: "Fast non-autoregressive music generation"
|
||||
default_variant: medium-10secs
|
||||
variants:
|
||||
small-10secs:
|
||||
hf_id: facebook/magnet-small-10secs
|
||||
vram_mb: 1500
|
||||
max_duration: 10
|
||||
description: "Fast 10-second clips (300M params)"
|
||||
medium-10secs:
|
||||
hf_id: facebook/magnet-medium-10secs
|
||||
vram_mb: 5000
|
||||
max_duration: 10
|
||||
description: "Quality 10-second clips (1.5B params)"
|
||||
small-30secs:
|
||||
hf_id: facebook/magnet-small-30secs
|
||||
vram_mb: 1800
|
||||
max_duration: 30
|
||||
description: "Fast 30-second clips (300M params)"
|
||||
medium-30secs:
|
||||
hf_id: facebook/magnet-medium-30secs
|
||||
vram_mb: 6000
|
||||
max_duration: 30
|
||||
description: "Quality 30-second clips (1.5B params)"
|
||||
|
||||
musicgen-style:
|
||||
enabled: true
|
||||
display_name: "MusicGen Style"
|
||||
description: "Style-conditioned music generation from reference audio"
|
||||
default_variant: medium
|
||||
variants:
|
||||
medium:
|
||||
hf_id: facebook/musicgen-style
|
||||
vram_mb: 5000
|
||||
max_duration: 30
|
||||
conditioning:
|
||||
- style
|
||||
description: "Style transfer from reference audio (1.5B params)"
|
||||
|
||||
jasco:
|
||||
enabled: true
|
||||
display_name: "JASCO"
|
||||
description: "Chord and drum-conditioned music generation"
|
||||
default_variant: chords-drums-400M
|
||||
variants:
|
||||
chords-drums-400M:
|
||||
hf_id: facebook/jasco-chords-drums-400M
|
||||
vram_mb: 2000
|
||||
max_duration: 10
|
||||
conditioning:
|
||||
- chords
|
||||
- drums
|
||||
description: "Chord/drum control, fast (400M params)"
|
||||
chords-drums-1B:
|
||||
hf_id: facebook/jasco-chords-drums-1B
|
||||
vram_mb: 4000
|
||||
max_duration: 10
|
||||
conditioning:
|
||||
- chords
|
||||
- drums
|
||||
description: "Chord/drum control, higher quality (1B params)"
|
||||
|
||||
# Default generation parameters
|
||||
defaults:
|
||||
generation:
|
||||
duration: 10
|
||||
temperature: 1.0
|
||||
top_k: 250
|
||||
top_p: 0.0
|
||||
cfg_coef: 3.0
|
||||
|
||||
# VRAM thresholds for warnings
|
||||
vram:
|
||||
warning_threshold: 0.85 # 85% utilization warning
|
||||
critical_threshold: 0.95 # 95% utilization critical
|
||||
|
||||
# Presets are loaded from data/presets/*.yaml
|
||||
presets_dir: "./data/presets"
|
||||
94
config/settings.py
Normal file
94
config/settings.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Application settings with environment variable support."""
|
||||
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application configuration with environment variable support.
|
||||
|
||||
All settings can be overridden via environment variables prefixed with AUDIOCRAFT_.
|
||||
Example: AUDIOCRAFT_API_PORT=8080
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="AUDIOCRAFT_",
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Server Configuration
|
||||
host: str = Field(default="0.0.0.0", description="Server bind host")
|
||||
gradio_port: int = Field(default=7860, description="Gradio UI port")
|
||||
api_port: int = Field(default=8000, description="FastAPI port")
|
||||
|
||||
# Paths
|
||||
data_dir: Path = Field(default=Path("./data"), description="Data directory")
|
||||
output_dir: Path = Field(default=Path("./outputs"), description="Generated audio output")
|
||||
cache_dir: Path = Field(default=Path("./cache"), description="Model cache directory")
|
||||
models_config: Path = Field(
|
||||
default=Path("./config/models.yaml"), description="Model registry config"
|
||||
)
|
||||
|
||||
# VRAM Management
|
||||
comfyui_reserve_gb: float = Field(
|
||||
default=10.0, description="VRAM reserved for ComfyUI (GB)"
|
||||
)
|
||||
safety_buffer_gb: float = Field(
|
||||
default=1.0, description="Safety buffer to prevent OOM (GB)"
|
||||
)
|
||||
idle_unload_minutes: int = Field(
|
||||
default=15, description="Unload models after idle time (minutes)"
|
||||
)
|
||||
max_cached_models: int = Field(
|
||||
default=2, description="Maximum number of models to keep loaded"
|
||||
)
|
||||
|
||||
# API Authentication
|
||||
api_key: Optional[str] = Field(default=None, description="API key for authentication")
|
||||
cors_origins: list[str] = Field(
|
||||
default=["*"], description="Allowed CORS origins"
|
||||
)
|
||||
|
||||
# Generation Defaults
|
||||
default_duration: float = Field(default=10.0, description="Default generation duration")
|
||||
max_duration: float = Field(default=300.0, description="Maximum generation duration")
|
||||
default_batch_size: int = Field(default=1, description="Default batch size")
|
||||
max_batch_size: int = Field(default=8, description="Maximum batch size")
|
||||
max_queue_size: int = Field(default=100, description="Maximum generation queue size")
|
||||
|
||||
# Database
|
||||
database_url: str = Field(
|
||||
default="sqlite+aiosqlite:///./data/audiocraft.db",
|
||||
description="Database connection URL",
|
||||
)
|
||||
|
||||
# Logging
|
||||
log_level: str = Field(default="INFO", description="Logging level")
|
||||
|
||||
def ensure_directories(self) -> None:
|
||||
"""Create required directories if they don't exist."""
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
(self.data_dir / "presets").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@property
|
||||
def database_path(self) -> Path:
|
||||
"""Extract database file path from URL."""
|
||||
if self.database_url.startswith("sqlite"):
|
||||
# Handle both sqlite:/// and sqlite+aiosqlite:///
|
||||
path = self.database_url.split("///")[-1]
|
||||
return Path(path)
|
||||
raise ValueError("Only SQLite databases are supported")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
64
docker-compose.yml
Normal file
64
docker-compose.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
# Docker Compose for local development and testing
|
||||
# For RunPod deployment, use the Dockerfile directly
|
||||
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
audiocraft:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: audiocraft-studio
|
||||
ports:
|
||||
- "7860:7860" # Gradio UI
|
||||
- "8000:8000" # REST API
|
||||
volumes:
|
||||
# Persistent storage
|
||||
- audiocraft-outputs:/workspace/outputs
|
||||
- audiocraft-data:/workspace/data
|
||||
- audiocraft-models:/workspace/models
|
||||
# Development: mount source code
|
||||
- ./src:/workspace/app/src:ro
|
||||
- ./config:/workspace/app/config:ro
|
||||
environment:
|
||||
- AUDIOCRAFT_HOST=0.0.0.0
|
||||
- AUDIOCRAFT_GRADIO_PORT=7860
|
||||
- AUDIOCRAFT_API_PORT=8000
|
||||
- AUDIOCRAFT_DEBUG=false
|
||||
- AUDIOCRAFT_COMFYUI_RESERVE_GB=0 # No ComfyUI in this compose
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:7860/"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 60s
|
||||
|
||||
# Optional: Run alongside ComfyUI
|
||||
# comfyui:
|
||||
# image: your-comfyui-image
|
||||
# container_name: comfyui
|
||||
# ports:
|
||||
# - "8188:8188"
|
||||
# volumes:
|
||||
# - comfyui-data:/workspace
|
||||
# deploy:
|
||||
# resources:
|
||||
# reservations:
|
||||
# devices:
|
||||
# - driver: nvidia
|
||||
# count: 1
|
||||
# capabilities: [gpu]
|
||||
|
||||
volumes:
|
||||
audiocraft-outputs:
|
||||
audiocraft-data:
|
||||
audiocraft-models:
|
||||
147
main.py
Normal file
147
main.py
Normal file
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Main entry point for AudioCraft Studio."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from config.settings import get_settings
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
from src.core.model_registry import ModelRegistry
|
||||
from src.services.generation_service import GenerationService
|
||||
from src.services.batch_processor import BatchProcessor
|
||||
from src.services.project_service import ProjectService
|
||||
from src.storage.database import Database
|
||||
from src.ui.app import create_app
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler("audiocraft.log"),
|
||||
],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def initialize_services():
|
||||
"""Initialize all application services."""
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize database
|
||||
logger.info("Initializing database...")
|
||||
db = Database(settings.database_path)
|
||||
await db.initialize()
|
||||
|
||||
# Initialize GPU manager
|
||||
logger.info("Initializing GPU manager...")
|
||||
gpu_manager = GPUMemoryManager(
|
||||
device_id=0,
|
||||
comfyui_reserve_bytes=int(settings.comfyui_reserve_gb * 1024**3),
|
||||
)
|
||||
|
||||
# Initialize model registry
|
||||
logger.info("Initializing model registry...")
|
||||
model_registry = ModelRegistry(
|
||||
gpu_manager=gpu_manager,
|
||||
max_loaded=settings.max_loaded_models,
|
||||
idle_timeout_seconds=settings.idle_unload_minutes * 60,
|
||||
)
|
||||
|
||||
# Initialize services
|
||||
logger.info("Initializing services...")
|
||||
generation_service = GenerationService(
|
||||
model_registry=model_registry,
|
||||
gpu_manager=gpu_manager,
|
||||
output_dir=settings.output_dir,
|
||||
)
|
||||
|
||||
batch_processor = BatchProcessor(
|
||||
generation_service=generation_service,
|
||||
max_queue_size=settings.max_queue_size,
|
||||
)
|
||||
|
||||
project_service = ProjectService(
|
||||
db=db,
|
||||
output_dir=settings.output_dir,
|
||||
)
|
||||
|
||||
return {
|
||||
"db": db,
|
||||
"gpu_manager": gpu_manager,
|
||||
"model_registry": model_registry,
|
||||
"generation_service": generation_service,
|
||||
"batch_processor": batch_processor,
|
||||
"project_service": project_service,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
settings = get_settings()
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("AudioCraft Studio")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Host: {settings.host}")
|
||||
logger.info(f"Gradio Port: {settings.gradio_port}")
|
||||
logger.info(f"API Port: {settings.api_port}")
|
||||
logger.info(f"Output Dir: {settings.output_dir}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Initialize services
|
||||
logger.info("Initializing services...")
|
||||
|
||||
try:
|
||||
services = asyncio.run(initialize_services())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
logger.warning("Starting in demo mode without backend services")
|
||||
services = {}
|
||||
|
||||
# Create and launch app
|
||||
logger.info("Creating Gradio application...")
|
||||
app = create_app(
|
||||
generation_service=services.get("generation_service"),
|
||||
batch_processor=services.get("batch_processor"),
|
||||
project_service=services.get("project_service"),
|
||||
gpu_manager=services.get("gpu_manager"),
|
||||
model_registry=services.get("model_registry"),
|
||||
)
|
||||
|
||||
# Start batch processor if available
|
||||
batch_processor = services.get("batch_processor")
|
||||
if batch_processor:
|
||||
logger.info("Starting batch processor...")
|
||||
asyncio.run(batch_processor.start())
|
||||
|
||||
# Launch the app
|
||||
logger.info("Launching application...")
|
||||
try:
|
||||
app.launch(
|
||||
server_name=settings.host,
|
||||
server_port=settings.gradio_port,
|
||||
share=False,
|
||||
show_api=settings.api_enabled,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down...")
|
||||
finally:
|
||||
# Cleanup
|
||||
if batch_processor:
|
||||
asyncio.run(batch_processor.stop())
|
||||
if "db" in services:
|
||||
asyncio.run(services["db"].close())
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
pyproject.toml
Normal file
89
pyproject.toml
Normal file
@@ -0,0 +1,89 @@
|
||||
[project]
|
||||
name = "audiocraft-ui"
|
||||
version = "0.1.0"
|
||||
description = "Sophisticated AI audio web application based on Facebook's AudioCraft"
|
||||
readme = "README.md"
|
||||
license = { text = "MIT" }
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "AudioCraft UI Team" }]
|
||||
keywords = ["audio", "music", "generation", "ai", "audiocraft", "gradio"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Topic :: Multimedia :: Sound/Audio",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
# Core ML
|
||||
"torch>=2.1.0",
|
||||
"torchaudio>=2.1.0",
|
||||
"audiocraft>=1.3.0",
|
||||
"xformers>=0.0.22",
|
||||
|
||||
# UI
|
||||
"gradio>=4.0.0",
|
||||
|
||||
# API
|
||||
"fastapi>=0.104.0",
|
||||
"uvicorn[standard]>=0.24.0",
|
||||
"python-multipart>=0.0.6",
|
||||
|
||||
# GPU Monitoring
|
||||
"pynvml>=11.5.0",
|
||||
|
||||
# Storage
|
||||
"aiosqlite>=0.19.0",
|
||||
|
||||
# Configuration
|
||||
"pydantic>=2.5.0",
|
||||
"pydantic-settings>=2.1.0",
|
||||
"pyyaml>=6.0",
|
||||
|
||||
# Audio Processing
|
||||
"numpy>=1.24.0",
|
||||
"scipy>=1.11.0",
|
||||
"librosa>=0.10.0",
|
||||
"soundfile>=0.12.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.4.0",
|
||||
"pytest-asyncio>=0.21.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"ruff>=0.1.0",
|
||||
"mypy>=1.6.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
audiocraft-ui = "src.main:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP"]
|
||||
ignore = ["E501"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
30
requirements.txt
Normal file
30
requirements.txt
Normal file
@@ -0,0 +1,30 @@
|
||||
# Core ML
|
||||
torch>=2.1.0
|
||||
torchaudio>=2.1.0
|
||||
audiocraft>=1.3.0
|
||||
xformers>=0.0.22
|
||||
|
||||
# UI
|
||||
gradio>=4.0.0
|
||||
|
||||
# API
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
|
||||
# GPU Monitoring
|
||||
pynvml>=11.5.0
|
||||
|
||||
# Storage
|
||||
aiosqlite>=0.19.0
|
||||
|
||||
# Configuration
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
pyyaml>=6.0
|
||||
|
||||
# Audio Processing
|
||||
numpy>=1.24.0
|
||||
scipy>=1.11.0
|
||||
librosa>=0.10.0
|
||||
soundfile>=0.12.0
|
||||
77
runpod.yaml
Normal file
77
runpod.yaml
Normal file
@@ -0,0 +1,77 @@
|
||||
# RunPod Template Configuration
|
||||
# Use this as reference when creating a RunPod template
|
||||
|
||||
name: AudioCraft Studio
|
||||
description: AI-powered music and sound generation using Meta's AudioCraft
|
||||
|
||||
# Container settings
|
||||
container:
|
||||
image: your-dockerhub-username/audiocraft-studio:latest
|
||||
|
||||
# Or build from GitHub
|
||||
# dockerfile: Dockerfile
|
||||
# context: https://github.com/your-username/audiocraft-ui.git
|
||||
|
||||
# GPU requirements
|
||||
gpu:
|
||||
type: RTX 4090 # Recommended: RTX 4090, RTX 3090, A100
|
||||
count: 1
|
||||
minVram: 24 # GB
|
||||
|
||||
# Resource limits
|
||||
resources:
|
||||
cpu: 8
|
||||
memory: 32 # GB
|
||||
disk: 100 # GB (for model cache and outputs)
|
||||
|
||||
# Port mappings
|
||||
ports:
|
||||
- name: Gradio UI
|
||||
internal: 7860
|
||||
external: 7860
|
||||
protocol: http
|
||||
- name: REST API
|
||||
internal: 8000
|
||||
external: 8000
|
||||
protocol: http
|
||||
|
||||
# Volume mounts
|
||||
volumes:
|
||||
- name: outputs
|
||||
mountPath: /workspace/outputs
|
||||
size: 50 # GB
|
||||
- name: models
|
||||
mountPath: /workspace/models
|
||||
size: 30 # GB (model cache)
|
||||
- name: data
|
||||
mountPath: /workspace/data
|
||||
size: 10 # GB
|
||||
|
||||
# Environment variables
|
||||
env:
|
||||
- name: AUDIOCRAFT_HOST
|
||||
value: "0.0.0.0"
|
||||
- name: AUDIOCRAFT_GRADIO_PORT
|
||||
value: "7860"
|
||||
- name: AUDIOCRAFT_API_PORT
|
||||
value: "8000"
|
||||
- name: AUDIOCRAFT_COMFYUI_RESERVE_GB
|
||||
value: "10" # Reserve VRAM for ComfyUI if running alongside
|
||||
- name: AUDIOCRAFT_MAX_LOADED_MODELS
|
||||
value: "2"
|
||||
- name: AUDIOCRAFT_IDLE_UNLOAD_MINUTES
|
||||
value: "15"
|
||||
- name: HF_HOME
|
||||
value: "/workspace/models/huggingface"
|
||||
|
||||
# Startup command
|
||||
command: ["python", "main.py"]
|
||||
|
||||
# Health check
|
||||
healthCheck:
|
||||
path: /
|
||||
port: 7860
|
||||
initialDelaySeconds: 120
|
||||
periodSeconds: 30
|
||||
timeoutSeconds: 10
|
||||
failureThreshold: 3
|
||||
116
scripts/download_models.py
Executable file
116
scripts/download_models.py
Executable file
@@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Pre-download AudioCraft models for faster startup."""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def download_musicgen_models(variants: list[str] = None):
|
||||
"""Download MusicGen models."""
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
variants = variants or ["small", "medium", "large", "melody"]
|
||||
|
||||
for variant in variants:
|
||||
print(f"Downloading MusicGen {variant}...")
|
||||
try:
|
||||
model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}")
|
||||
del model
|
||||
print(f" ✓ MusicGen {variant} downloaded")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to download MusicGen {variant}: {e}")
|
||||
|
||||
|
||||
def download_audiogen_models():
|
||||
"""Download AudioGen models."""
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
print("Downloading AudioGen medium...")
|
||||
try:
|
||||
model = AudioGen.get_pretrained("facebook/audiogen-medium")
|
||||
del model
|
||||
print(" ✓ AudioGen medium downloaded")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to download AudioGen: {e}")
|
||||
|
||||
|
||||
def download_magnet_models(variants: list[str] = None):
|
||||
"""Download MAGNeT models."""
|
||||
from audiocraft.models import MAGNeT
|
||||
|
||||
variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"]
|
||||
|
||||
for variant in variants:
|
||||
print(f"Downloading MAGNeT {variant}...")
|
||||
try:
|
||||
model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}")
|
||||
del model
|
||||
print(f" ✓ MAGNeT {variant} downloaded")
|
||||
except Exception as e:
|
||||
print(f" ✗ Failed to download MAGNeT {variant}: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Pre-download AudioCraft models")
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
choices=["musicgen", "audiogen", "magnet", "all"],
|
||||
default=["all"],
|
||||
help="Models to download",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--musicgen-variants",
|
||||
nargs="+",
|
||||
default=["small", "medium"],
|
||||
help="MusicGen variants to download",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--magnet-variants",
|
||||
nargs="+",
|
||||
default=["small", "medium"],
|
||||
help="MAGNeT variants to download",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache-dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model cache directory",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set cache directory
|
||||
if args.cache_dir:
|
||||
os.environ["HF_HOME"] = args.cache_dir
|
||||
os.environ["TORCH_HOME"] = args.cache_dir
|
||||
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = args.models
|
||||
if "all" in models:
|
||||
models = ["musicgen", "audiogen", "magnet"]
|
||||
|
||||
print("=" * 50)
|
||||
print("AudioCraft Model Downloader")
|
||||
print("=" * 50)
|
||||
print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
|
||||
print(f"Models to download: {models}")
|
||||
print("=" * 50)
|
||||
|
||||
if "musicgen" in models:
|
||||
download_musicgen_models(args.musicgen_variants)
|
||||
|
||||
if "audiogen" in models:
|
||||
download_audiogen_models()
|
||||
|
||||
if "magnet" in models:
|
||||
download_magnet_models(args.magnet_variants)
|
||||
|
||||
print("=" * 50)
|
||||
print("Download complete!")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
scripts/start.sh
Executable file
55
scripts/start.sh
Executable file
@@ -0,0 +1,55 @@
|
||||
#!/bin/bash
|
||||
# Startup script for AudioCraft Studio
|
||||
# Used in Docker container and RunPod
|
||||
|
||||
set -e
|
||||
|
||||
echo "=========================================="
|
||||
echo " AudioCraft Studio"
|
||||
echo "=========================================="
|
||||
|
||||
# Create directories if they don't exist
|
||||
mkdir -p "${AUDIOCRAFT_OUTPUT_DIR:-/workspace/outputs}"
|
||||
mkdir -p "${AUDIOCRAFT_DATA_DIR:-/workspace/data}"
|
||||
mkdir -p "${AUDIOCRAFT_MODEL_CACHE:-/workspace/models}"
|
||||
|
||||
# Check GPU availability
|
||||
echo "Checking GPU..."
|
||||
if command -v nvidia-smi &> /dev/null; then
|
||||
nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv
|
||||
else
|
||||
echo "Warning: nvidia-smi not found"
|
||||
fi
|
||||
|
||||
# Check Python and dependencies
|
||||
echo "Python version:"
|
||||
python --version
|
||||
|
||||
echo "PyTorch version:"
|
||||
python -c "import torch; print(f'PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}')"
|
||||
|
||||
# Check AudioCraft installation
|
||||
echo "AudioCraft version:"
|
||||
python -c "import audiocraft; print(audiocraft.__version__)" 2>/dev/null || echo "AudioCraft installed from source"
|
||||
|
||||
# Generate API key if not exists
|
||||
if [ ! -f "${AUDIOCRAFT_DATA_DIR:-/workspace/data}/.api_key" ]; then
|
||||
echo "Generating API key..."
|
||||
python -c "
|
||||
from src.api.auth import get_key_manager
|
||||
km = get_key_manager()
|
||||
if not km.has_key():
|
||||
key = km.generate_new_key()
|
||||
print(f'Generated API key: {key}')
|
||||
print('Store this key securely - it will not be shown again!')
|
||||
"
|
||||
fi
|
||||
|
||||
# Start the application
|
||||
echo "Starting AudioCraft Studio..."
|
||||
echo "Gradio UI: http://0.0.0.0:${AUDIOCRAFT_GRADIO_PORT:-7860}"
|
||||
echo "REST API: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}"
|
||||
echo "API Docs: http://0.0.0.0:${AUDIOCRAFT_API_PORT:-8000}/api/docs"
|
||||
echo "=========================================="
|
||||
|
||||
exec python main.py "$@"
|
||||
3
src/__init__.py
Normal file
3
src/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""AudioCraft Studio - AI Audio Generation Web Application."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
5
src/api/__init__.py
Normal file
5
src/api/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""REST API for AudioCraft Studio."""
|
||||
|
||||
from src.api.app import create_api_app
|
||||
|
||||
__all__ = ["create_api_app"]
|
||||
150
src/api/app.py
Normal file
150
src/api/app.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""FastAPI application for AudioCraft Studio REST API."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
import time
|
||||
|
||||
from config.settings import get_settings
|
||||
from src.api.routes import (
|
||||
generation_router,
|
||||
projects_router,
|
||||
models_router,
|
||||
system_router,
|
||||
)
|
||||
from src.api.routes.generation import set_services as set_generation_services
|
||||
from src.api.routes.projects import set_services as set_project_services
|
||||
from src.api.routes.models import set_services as set_model_services
|
||||
from src.api.routes.system import set_services as set_system_services
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan handler."""
|
||||
# Startup
|
||||
yield
|
||||
# Shutdown
|
||||
|
||||
|
||||
def create_api_app(
|
||||
generation_service: Any = None,
|
||||
batch_processor: Any = None,
|
||||
project_service: Any = None,
|
||||
gpu_manager: Any = None,
|
||||
model_registry: Any = None,
|
||||
) -> FastAPI:
|
||||
"""Create and configure the FastAPI application.
|
||||
|
||||
Args:
|
||||
generation_service: Service for handling generations
|
||||
batch_processor: Service for batch/queue processing
|
||||
project_service: Service for project management
|
||||
gpu_manager: GPU memory manager
|
||||
model_registry: Model registry for loading/unloading
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="AudioCraft Studio API",
|
||||
description="REST API for AI-powered music and sound generation",
|
||||
version="1.0.0",
|
||||
docs_url="/api/docs" if settings.api_enabled else None,
|
||||
redoc_url="/api/redoc" if settings.api_enabled else None,
|
||||
openapi_url="/api/openapi.json" if settings.api_enabled else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Request timing middleware
|
||||
@app.middleware("http")
|
||||
async def add_process_time_header(request: Request, call_next):
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
return response
|
||||
|
||||
# Global exception handler
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": "Internal server error",
|
||||
"detail": str(exc) if settings.debug else "An unexpected error occurred",
|
||||
},
|
||||
)
|
||||
|
||||
# Inject service dependencies
|
||||
set_generation_services(generation_service, batch_processor)
|
||||
set_project_services(project_service)
|
||||
set_model_services(model_registry)
|
||||
set_system_services(gpu_manager, batch_processor, model_registry)
|
||||
|
||||
# Register routers
|
||||
app.include_router(generation_router, prefix="/api/v1")
|
||||
app.include_router(projects_router, prefix="/api/v1")
|
||||
app.include_router(models_router, prefix="/api/v1")
|
||||
app.include_router(system_router, prefix="/api/v1")
|
||||
|
||||
# Root endpoint
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"name": "AudioCraft Studio API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/api/docs",
|
||||
}
|
||||
|
||||
# API info endpoint
|
||||
@app.get("/api/v1")
|
||||
async def api_info():
|
||||
return {
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"generation": "/api/v1/generate",
|
||||
"projects": "/api/v1/projects",
|
||||
"models": "/api/v1/models",
|
||||
"system": "/api/v1/system",
|
||||
},
|
||||
}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def run_api_server(
|
||||
app: FastAPI,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Run the API server.
|
||||
|
||||
Args:
|
||||
app: FastAPI application
|
||||
host: Server hostname
|
||||
port: Server port
|
||||
"""
|
||||
import uvicorn
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host or settings.host,
|
||||
port=port or settings.api_port,
|
||||
log_level="info",
|
||||
)
|
||||
133
src/api/auth.py
Normal file
133
src/api/auth.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""API authentication middleware."""
|
||||
|
||||
import secrets
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import HTTPException, Security, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
|
||||
# API key header
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""Generate a new API key."""
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
def hash_api_key(key: str) -> str:
|
||||
"""Hash an API key for storage."""
|
||||
return hashlib.sha256(key.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_api_key(key: str, hashed: str) -> bool:
|
||||
"""Verify an API key against its hash."""
|
||||
return secrets.compare_digest(hash_api_key(key), hashed)
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""Manage API keys for authentication."""
|
||||
|
||||
def __init__(self, key_file: Optional[Path] = None):
|
||||
"""Initialize the key manager.
|
||||
|
||||
Args:
|
||||
key_file: Path to store API key hash
|
||||
"""
|
||||
self.settings = get_settings()
|
||||
self.key_file = key_file or Path(self.settings.data_dir) / ".api_key"
|
||||
self._key_hash: Optional[str] = None
|
||||
self._load_key()
|
||||
|
||||
def _load_key(self) -> None:
|
||||
"""Load API key hash from file."""
|
||||
if self.key_file.exists():
|
||||
self._key_hash = self.key_file.read_text().strip()
|
||||
|
||||
def _save_key(self, key_hash: str) -> None:
|
||||
"""Save API key hash to file."""
|
||||
self.key_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.key_file.write_text(key_hash)
|
||||
self._key_hash = key_hash
|
||||
|
||||
def generate_new_key(self) -> str:
|
||||
"""Generate and store a new API key.
|
||||
|
||||
Returns:
|
||||
The new API key (only shown once)
|
||||
"""
|
||||
key = generate_api_key()
|
||||
self._save_key(hash_api_key(key))
|
||||
return key
|
||||
|
||||
def verify(self, key: str) -> bool:
|
||||
"""Verify an API key.
|
||||
|
||||
Args:
|
||||
key: API key to verify
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not self._key_hash:
|
||||
return False
|
||||
return verify_api_key(key, self._key_hash)
|
||||
|
||||
def has_key(self) -> bool:
|
||||
"""Check if an API key has been generated."""
|
||||
return self._key_hash is not None
|
||||
|
||||
|
||||
# Global key manager instance
|
||||
_key_manager: Optional[APIKeyManager] = None
|
||||
|
||||
|
||||
def get_key_manager() -> APIKeyManager:
|
||||
"""Get the global key manager instance."""
|
||||
global _key_manager
|
||||
if _key_manager is None:
|
||||
_key_manager = APIKeyManager()
|
||||
return _key_manager
|
||||
|
||||
|
||||
async def verify_api_key_dependency(
|
||||
api_key: Optional[str] = Security(api_key_header),
|
||||
) -> str:
|
||||
"""FastAPI dependency to verify API key.
|
||||
|
||||
Args:
|
||||
api_key: API key from header
|
||||
|
||||
Returns:
|
||||
The verified API key
|
||||
|
||||
Raises:
|
||||
HTTPException: If key is missing or invalid
|
||||
"""
|
||||
settings = get_settings()
|
||||
|
||||
# Skip auth if disabled
|
||||
if not settings.api_key_required:
|
||||
return "anonymous"
|
||||
|
||||
if api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="API key required",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
key_manager = get_key_manager()
|
||||
|
||||
if not key_manager.verify(api_key):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
|
||||
return api_key
|
||||
166
src/api/models.py
Normal file
166
src/api/models.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Pydantic models for API requests and responses."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ModelFamily(str, Enum):
|
||||
"""Available model families."""
|
||||
MUSICGEN = "musicgen"
|
||||
AUDIOGEN = "audiogen"
|
||||
MAGNET = "magnet"
|
||||
MUSICGEN_STYLE = "musicgen-style"
|
||||
JASCO = "jasco"
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""Generation job status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
# Generation requests
|
||||
|
||||
class GenerationRequest(BaseModel):
|
||||
"""Request to generate audio."""
|
||||
model: ModelFamily = Field(..., description="Model family to use")
|
||||
variant: str = Field("medium", description="Model variant")
|
||||
prompts: list[str] = Field(..., min_length=1, max_length=10, description="Text prompts")
|
||||
duration: float = Field(10.0, ge=1, le=30, description="Duration in seconds")
|
||||
temperature: float = Field(1.0, ge=0, le=2, description="Sampling temperature")
|
||||
top_k: int = Field(250, ge=0, le=500, description="Top-K sampling")
|
||||
top_p: float = Field(0.0, ge=0, le=1, description="Top-P (nucleus) sampling")
|
||||
cfg_coef: float = Field(3.0, ge=1, le=10, description="CFG coefficient")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
conditioning: Optional[dict[str, Any]] = Field(None, description="Model-specific conditioning")
|
||||
project_id: Optional[str] = Field(None, description="Project to save to")
|
||||
|
||||
|
||||
class BatchGenerationRequest(BaseModel):
|
||||
"""Request to add generation to queue."""
|
||||
request: GenerationRequest
|
||||
priority: int = Field(0, ge=0, le=10, description="Job priority (higher = sooner)")
|
||||
|
||||
|
||||
# Generation responses
|
||||
|
||||
class GenerationResult(BaseModel):
|
||||
"""Result of a completed generation."""
|
||||
id: str = Field(..., description="Generation ID")
|
||||
audio_url: str = Field(..., description="URL to download audio")
|
||||
waveform_url: Optional[str] = Field(None, description="URL to waveform image")
|
||||
duration: float = Field(..., description="Actual duration in seconds")
|
||||
seed: int = Field(..., description="Seed used for generation")
|
||||
model: str = Field(..., description="Model used")
|
||||
variant: str = Field(..., description="Variant used")
|
||||
prompt: str = Field(..., description="Prompt used")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
"""Response for a queued job."""
|
||||
job_id: str = Field(..., description="Job ID for tracking")
|
||||
status: JobStatus = Field(..., description="Current status")
|
||||
position: Optional[int] = Field(None, description="Queue position if pending")
|
||||
progress: Optional[float] = Field(None, description="Progress 0-1 if running")
|
||||
result: Optional[GenerationResult] = Field(None, description="Result if completed")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
# Project models
|
||||
|
||||
class ProjectCreate(BaseModel):
|
||||
"""Request to create a project."""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
description: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
|
||||
class ProjectResponse(BaseModel):
|
||||
"""Project information."""
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
generation_count: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class GenerationResponse(BaseModel):
|
||||
"""Generation record from database."""
|
||||
id: str
|
||||
project_id: str
|
||||
model: str
|
||||
variant: str
|
||||
prompt: str
|
||||
duration_seconds: float
|
||||
seed: int
|
||||
audio_path: str
|
||||
waveform_path: Optional[str]
|
||||
parameters: dict[str, Any]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
# Model info
|
||||
|
||||
class ModelVariantInfo(BaseModel):
|
||||
"""Information about a model variant."""
|
||||
id: str
|
||||
name: str
|
||||
vram_mb: int
|
||||
description: str
|
||||
capabilities: list[str]
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about a model family."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
variants: list[ModelVariantInfo]
|
||||
loaded: bool
|
||||
current_variant: Optional[str]
|
||||
|
||||
|
||||
# System info
|
||||
|
||||
class GPUStatus(BaseModel):
|
||||
"""GPU memory status."""
|
||||
device_name: str
|
||||
total_gb: float
|
||||
used_gb: float
|
||||
available_gb: float
|
||||
utilization_percent: float
|
||||
temperature_c: Optional[float]
|
||||
|
||||
|
||||
class QueueStatus(BaseModel):
|
||||
"""Generation queue status."""
|
||||
queue_size: int
|
||||
active_jobs: int
|
||||
completed_today: int
|
||||
failed_today: int
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
"""Overall system status."""
|
||||
gpu: GPUStatus
|
||||
queue: QueueStatus
|
||||
loaded_models: list[str]
|
||||
uptime_seconds: float
|
||||
|
||||
|
||||
# Pagination
|
||||
|
||||
class PaginatedResponse(BaseModel):
|
||||
"""Paginated list response."""
|
||||
items: list[Any]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
13
src/api/routes/__init__.py
Normal file
13
src/api/routes/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""API route modules."""
|
||||
|
||||
from src.api.routes.generation import router as generation_router
|
||||
from src.api.routes.projects import router as projects_router
|
||||
from src.api.routes.models import router as models_router
|
||||
from src.api.routes.system import router as system_router
|
||||
|
||||
__all__ = [
|
||||
"generation_router",
|
||||
"projects_router",
|
||||
"models_router",
|
||||
"system_router",
|
||||
]
|
||||
234
src/api/routes/generation.py
Normal file
234
src/api/routes/generation.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Generation API endpoints."""
|
||||
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.api.auth import verify_api_key_dependency
|
||||
from src.api.models import (
|
||||
GenerationRequest,
|
||||
BatchGenerationRequest,
|
||||
GenerationResult,
|
||||
JobResponse,
|
||||
JobStatus,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/generate", tags=["generation"])
|
||||
|
||||
|
||||
# Service dependencies (injected at app startup)
|
||||
_generation_service = None
|
||||
_batch_processor = None
|
||||
|
||||
|
||||
def set_services(generation_service: Any, batch_processor: Any) -> None:
|
||||
"""Set service dependencies."""
|
||||
global _generation_service, _batch_processor
|
||||
_generation_service = generation_service
|
||||
_batch_processor = batch_processor
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=GenerationResult,
|
||||
summary="Generate audio synchronously",
|
||||
description="Generate audio and wait for completion. For long generations, consider using the async endpoint.",
|
||||
)
|
||||
async def generate_sync(
|
||||
request: GenerationRequest,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> GenerationResult:
|
||||
"""Generate audio synchronously."""
|
||||
if _generation_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Generation service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
result, generation = await _generation_service.generate(
|
||||
model_id=request.model.value,
|
||||
variant=request.variant,
|
||||
prompts=request.prompts,
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
seed=request.seed,
|
||||
conditioning=request.conditioning,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
id=generation.id,
|
||||
audio_url=f"/api/v1/audio/{generation.id}",
|
||||
waveform_url=f"/api/v1/audio/{generation.id}/waveform" if generation.waveform_path else None,
|
||||
duration=result.duration,
|
||||
seed=result.seed,
|
||||
model=request.model.value,
|
||||
variant=request.variant,
|
||||
prompt=request.prompts[0],
|
||||
created_at=generation.created_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/async",
|
||||
response_model=JobResponse,
|
||||
summary="Queue generation job",
|
||||
description="Add a generation to the queue for async processing.",
|
||||
)
|
||||
async def generate_async(
|
||||
request: BatchGenerationRequest,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> JobResponse:
|
||||
"""Add generation to queue."""
|
||||
if _batch_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Batch processor not available",
|
||||
)
|
||||
|
||||
try:
|
||||
job = _batch_processor.add_job(
|
||||
model_id=request.request.model.value,
|
||||
variant=request.request.variant,
|
||||
prompts=request.request.prompts,
|
||||
duration=request.request.duration,
|
||||
temperature=request.request.temperature,
|
||||
top_k=request.request.top_k,
|
||||
top_p=request.request.top_p,
|
||||
cfg_coef=request.request.cfg_coef,
|
||||
seed=request.request.seed,
|
||||
conditioning=request.request.conditioning,
|
||||
project_id=request.request.project_id,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
return JobResponse(
|
||||
job_id=job.id,
|
||||
status=JobStatus.PENDING,
|
||||
position=_batch_processor.get_position(job.id),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/jobs/{job_id}",
|
||||
response_model=JobResponse,
|
||||
summary="Get job status",
|
||||
description="Check the status of a queued generation job.",
|
||||
)
|
||||
async def get_job_status(
|
||||
job_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> JobResponse:
|
||||
"""Get status of a queued job."""
|
||||
if _batch_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Batch processor not available",
|
||||
)
|
||||
|
||||
job = _batch_processor.get_job(job_id)
|
||||
if job is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Job {job_id} not found",
|
||||
)
|
||||
|
||||
response = JobResponse(
|
||||
job_id=job.id,
|
||||
status=JobStatus(job.status.value),
|
||||
)
|
||||
|
||||
if job.status.value == "pending":
|
||||
response.position = _batch_processor.get_position(job_id)
|
||||
elif job.status.value == "running":
|
||||
response.progress = job.progress
|
||||
elif job.status.value == "completed" and job.result:
|
||||
response.result = GenerationResult(
|
||||
id=job.result.id,
|
||||
audio_url=f"/api/v1/audio/{job.result.id}",
|
||||
waveform_url=f"/api/v1/audio/{job.result.id}/waveform",
|
||||
duration=job.result.duration,
|
||||
seed=job.result.seed,
|
||||
model=job.model_id,
|
||||
variant=job.variant,
|
||||
prompt=job.prompts[0],
|
||||
created_at=job.completed_at,
|
||||
)
|
||||
elif job.status.value == "failed":
|
||||
response.error = job.error
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/jobs/{job_id}",
|
||||
summary="Cancel job",
|
||||
description="Cancel a pending or running job.",
|
||||
)
|
||||
async def cancel_job(
|
||||
job_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Cancel a queued job."""
|
||||
if _batch_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Batch processor not available",
|
||||
)
|
||||
|
||||
success = _batch_processor.cancel_job(job_id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Job {job_id} not found or cannot be cancelled",
|
||||
)
|
||||
|
||||
return {"message": f"Job {job_id} cancelled"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/jobs",
|
||||
response_model=list[JobResponse],
|
||||
summary="List jobs",
|
||||
description="List all jobs in the queue.",
|
||||
)
|
||||
async def list_jobs(
|
||||
status_filter: str = None,
|
||||
limit: int = 50,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> list[JobResponse]:
|
||||
"""List queued jobs."""
|
||||
if _batch_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Batch processor not available",
|
||||
)
|
||||
|
||||
jobs = _batch_processor.list_jobs(status_filter=status_filter, limit=limit)
|
||||
|
||||
return [
|
||||
JobResponse(
|
||||
job_id=job.id,
|
||||
status=JobStatus(job.status.value),
|
||||
position=_batch_processor.get_position(job.id) if job.status.value == "pending" else None,
|
||||
progress=job.progress if job.status.value == "running" else None,
|
||||
)
|
||||
for job in jobs
|
||||
]
|
||||
228
src/api/routes/models.py
Normal file
228
src/api/routes/models.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Models API endpoints."""
|
||||
|
||||
from typing import Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from src.api.auth import verify_api_key_dependency
|
||||
from src.api.models import ModelInfo, ModelVariantInfo
|
||||
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
# Service dependency (injected at app startup)
|
||||
_model_registry = None
|
||||
|
||||
|
||||
def set_services(model_registry: Any) -> None:
|
||||
"""Set service dependencies."""
|
||||
global _model_registry
|
||||
_model_registry = model_registry
|
||||
|
||||
|
||||
# Static model information
|
||||
MODEL_CATALOG = {
|
||||
"musicgen": {
|
||||
"id": "musicgen",
|
||||
"name": "MusicGen",
|
||||
"description": "Text-to-music generation with optional melody conditioning",
|
||||
"variants": [
|
||||
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params", "capabilities": ["text"]},
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params", "capabilities": ["text"]},
|
||||
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params", "capabilities": ["text"]},
|
||||
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning", "capabilities": ["text", "melody"]},
|
||||
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params", "capabilities": ["text", "stereo"]},
|
||||
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params", "capabilities": ["text", "stereo"]},
|
||||
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params", "capabilities": ["text", "stereo"]},
|
||||
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody", "capabilities": ["text", "melody", "stereo"]},
|
||||
],
|
||||
},
|
||||
"audiogen": {
|
||||
"id": "audiogen",
|
||||
"name": "AudioGen",
|
||||
"description": "Text-to-sound effects and environmental audio",
|
||||
"variants": [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params", "capabilities": ["text", "sfx"]},
|
||||
],
|
||||
},
|
||||
"magnet": {
|
||||
"id": "magnet",
|
||||
"name": "MAGNeT",
|
||||
"description": "Fast non-autoregressive music generation",
|
||||
"variants": [
|
||||
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params", "capabilities": ["text", "music"]},
|
||||
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params", "capabilities": ["text", "music"]},
|
||||
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects", "capabilities": ["text", "sfx"]},
|
||||
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects", "capabilities": ["text", "sfx"]},
|
||||
],
|
||||
},
|
||||
"musicgen-style": {
|
||||
"id": "musicgen-style",
|
||||
"name": "MusicGen Style",
|
||||
"description": "Style-conditioned music from reference audio",
|
||||
"variants": [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning", "capabilities": ["text", "style"]},
|
||||
],
|
||||
},
|
||||
"jasco": {
|
||||
"id": "jasco",
|
||||
"name": "JASCO",
|
||||
"description": "Chord and drum-conditioned music generation",
|
||||
"variants": [
|
||||
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation", "capabilities": ["text", "chords"]},
|
||||
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning", "capabilities": ["text", "chords", "drums"]},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/",
|
||||
response_model=list[ModelInfo],
|
||||
summary="List models",
|
||||
description="Get information about all available models.",
|
||||
)
|
||||
async def list_models(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> list[ModelInfo]:
|
||||
"""List all available models."""
|
||||
models = []
|
||||
|
||||
for model_id, info in MODEL_CATALOG.items():
|
||||
loaded = False
|
||||
current_variant = None
|
||||
|
||||
if _model_registry:
|
||||
loaded = _model_registry.is_loaded(model_id)
|
||||
if loaded:
|
||||
current_variant = _model_registry.get_current_variant(model_id)
|
||||
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=info["id"],
|
||||
name=info["name"],
|
||||
description=info["description"],
|
||||
variants=[ModelVariantInfo(**v) for v in info["variants"]],
|
||||
loaded=loaded,
|
||||
current_variant=current_variant,
|
||||
)
|
||||
)
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{model_id}",
|
||||
response_model=ModelInfo,
|
||||
summary="Get model info",
|
||||
description="Get detailed information about a specific model.",
|
||||
)
|
||||
async def get_model(
|
||||
model_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> ModelInfo:
|
||||
"""Get model information by ID."""
|
||||
if model_id not in MODEL_CATALOG:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found",
|
||||
)
|
||||
|
||||
info = MODEL_CATALOG[model_id]
|
||||
loaded = False
|
||||
current_variant = None
|
||||
|
||||
if _model_registry:
|
||||
loaded = _model_registry.is_loaded(model_id)
|
||||
if loaded:
|
||||
current_variant = _model_registry.get_current_variant(model_id)
|
||||
|
||||
return ModelInfo(
|
||||
id=info["id"],
|
||||
name=info["name"],
|
||||
description=info["description"],
|
||||
variants=[ModelVariantInfo(**v) for v in info["variants"]],
|
||||
loaded=loaded,
|
||||
current_variant=current_variant,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{model_id}/load",
|
||||
summary="Load model",
|
||||
description="Load a model into GPU memory.",
|
||||
)
|
||||
async def load_model(
|
||||
model_id: str,
|
||||
variant: str = "medium",
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Load a model into memory."""
|
||||
if model_id not in MODEL_CATALOG:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found",
|
||||
)
|
||||
|
||||
if _model_registry is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Model registry not available",
|
||||
)
|
||||
|
||||
try:
|
||||
await _model_registry.load_model(model_id, variant)
|
||||
return {"message": f"Model {model_id} ({variant}) loaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{model_id}/unload",
|
||||
summary="Unload model",
|
||||
description="Unload a model from GPU memory.",
|
||||
)
|
||||
async def unload_model(
|
||||
model_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Unload a model from memory."""
|
||||
if model_id not in MODEL_CATALOG:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Model {model_id} not found",
|
||||
)
|
||||
|
||||
if _model_registry is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Model registry not available",
|
||||
)
|
||||
|
||||
try:
|
||||
await _model_registry.unload_model(model_id)
|
||||
return {"message": f"Model {model_id} unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/loaded",
|
||||
response_model=list[str],
|
||||
summary="List loaded models",
|
||||
description="Get list of currently loaded models.",
|
||||
)
|
||||
async def list_loaded_models(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> list[str]:
|
||||
"""List currently loaded models."""
|
||||
if _model_registry is None:
|
||||
return []
|
||||
|
||||
return _model_registry.get_loaded_models()
|
||||
250
src/api/routes/projects.py
Normal file
250
src/api/routes/projects.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""Projects API endpoints."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.api.auth import verify_api_key_dependency
|
||||
from src.api.models import (
|
||||
ProjectCreate,
|
||||
ProjectResponse,
|
||||
GenerationResponse,
|
||||
PaginatedResponse,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["projects"])
|
||||
|
||||
|
||||
# Service dependency (injected at app startup)
|
||||
_project_service = None
|
||||
|
||||
|
||||
def set_services(project_service: Any) -> None:
|
||||
"""Set service dependencies."""
|
||||
global _project_service
|
||||
_project_service = project_service
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=ProjectResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
summary="Create project",
|
||||
description="Create a new project for organizing generations.",
|
||||
)
|
||||
async def create_project(
|
||||
request: ProjectCreate,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> ProjectResponse:
|
||||
"""Create a new project."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
project = await _project_service.create_project(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
)
|
||||
return ProjectResponse(**project)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/",
|
||||
response_model=list[ProjectResponse],
|
||||
summary="List projects",
|
||||
description="Get all projects.",
|
||||
)
|
||||
async def list_projects(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> list[ProjectResponse]:
|
||||
"""List all projects."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
projects = await _project_service.list_projects()
|
||||
return [ProjectResponse(**p) for p in projects]
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}",
|
||||
response_model=ProjectResponse,
|
||||
summary="Get project",
|
||||
description="Get project details by ID.",
|
||||
)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> ProjectResponse:
|
||||
"""Get a project by ID."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
project = await _project_service.get_project(project_id)
|
||||
if project is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Project {project_id} not found",
|
||||
)
|
||||
return ProjectResponse(**project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete project",
|
||||
description="Delete a project and all its generations.",
|
||||
)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> None:
|
||||
"""Delete a project."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
await _project_service.delete_project(project_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/generations",
|
||||
response_model=PaginatedResponse,
|
||||
summary="List generations",
|
||||
description="Get generations for a project with pagination.",
|
||||
)
|
||||
async def list_generations(
|
||||
project_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
model: Optional[str] = Query(None, description="Filter by model"),
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> PaginatedResponse:
|
||||
"""List generations for a project."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
offset = (page - 1) * page_size
|
||||
generations = await _project_service.list_generations(
|
||||
project_id=project_id,
|
||||
limit=page_size + 1, # +1 to check if more pages
|
||||
offset=offset,
|
||||
model_filter=model,
|
||||
)
|
||||
|
||||
has_more = len(generations) > page_size
|
||||
generations = generations[:page_size]
|
||||
|
||||
# Estimate total (could be improved with actual count query)
|
||||
total = offset + len(generations) + (1 if has_more else 0)
|
||||
pages = (total + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=[GenerationResponse(**g) for g in generations],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
pages=pages,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{project_id}/export",
|
||||
summary="Export project",
|
||||
description="Export project as ZIP file with all audio and metadata.",
|
||||
)
|
||||
async def export_project(
|
||||
project_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> FileResponse:
|
||||
"""Export project as ZIP."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
zip_path = await _project_service.export_project_zip(project_id)
|
||||
return FileResponse(
|
||||
path=zip_path,
|
||||
filename=f"project_{project_id}.zip",
|
||||
media_type="application/zip",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{project_id}/generations/{generation_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete generation",
|
||||
description="Delete a specific generation.",
|
||||
)
|
||||
async def delete_generation(
|
||||
project_id: str,
|
||||
generation_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> None:
|
||||
"""Delete a generation."""
|
||||
if _project_service is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Project service not available",
|
||||
)
|
||||
|
||||
try:
|
||||
await _project_service.delete_generation(generation_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
263
src/api/routes/system.py
Normal file
263
src/api/routes/system.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""System API endpoints."""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from src.api.auth import verify_api_key_dependency, get_key_manager
|
||||
from src.api.models import GPUStatus, QueueStatus, SystemStatus
|
||||
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
|
||||
|
||||
# Service dependencies (injected at app startup)
|
||||
_gpu_manager = None
|
||||
_batch_processor = None
|
||||
_model_registry = None
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
def set_services(
|
||||
gpu_manager: Any,
|
||||
batch_processor: Any,
|
||||
model_registry: Any,
|
||||
) -> None:
|
||||
"""Set service dependencies."""
|
||||
global _gpu_manager, _batch_processor, _model_registry
|
||||
_gpu_manager = gpu_manager
|
||||
_batch_processor = batch_processor
|
||||
_model_registry = model_registry
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
response_model=SystemStatus,
|
||||
summary="System status",
|
||||
description="Get overall system status including GPU, queue, and loaded models.",
|
||||
)
|
||||
async def get_status(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> SystemStatus:
|
||||
"""Get system status."""
|
||||
# GPU status
|
||||
if _gpu_manager:
|
||||
gpu = GPUStatus(
|
||||
device_name=_gpu_manager.device_name,
|
||||
total_gb=_gpu_manager.total_memory / 1024**3,
|
||||
used_gb=_gpu_manager.get_used_memory() / 1024**3,
|
||||
available_gb=_gpu_manager.get_available_memory() / 1024**3,
|
||||
utilization_percent=_gpu_manager.get_utilization(),
|
||||
temperature_c=_gpu_manager.get_temperature(),
|
||||
)
|
||||
else:
|
||||
gpu = GPUStatus(
|
||||
device_name="Unknown",
|
||||
total_gb=0,
|
||||
used_gb=0,
|
||||
available_gb=0,
|
||||
utilization_percent=0,
|
||||
temperature_c=None,
|
||||
)
|
||||
|
||||
# Queue status
|
||||
if _batch_processor:
|
||||
queue = QueueStatus(
|
||||
queue_size=len(_batch_processor.queue),
|
||||
active_jobs=_batch_processor.active_count,
|
||||
completed_today=_batch_processor.completed_count,
|
||||
failed_today=_batch_processor.failed_count,
|
||||
)
|
||||
else:
|
||||
queue = QueueStatus(
|
||||
queue_size=0,
|
||||
active_jobs=0,
|
||||
completed_today=0,
|
||||
failed_today=0,
|
||||
)
|
||||
|
||||
# Loaded models
|
||||
loaded_models = []
|
||||
if _model_registry:
|
||||
loaded_models = _model_registry.get_loaded_models()
|
||||
|
||||
return SystemStatus(
|
||||
gpu=gpu,
|
||||
queue=queue,
|
||||
loaded_models=loaded_models,
|
||||
uptime_seconds=time.time() - _start_time,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/gpu",
|
||||
response_model=GPUStatus,
|
||||
summary="GPU status",
|
||||
description="Get detailed GPU memory and utilization status.",
|
||||
)
|
||||
async def get_gpu_status(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> GPUStatus:
|
||||
"""Get GPU status."""
|
||||
if _gpu_manager is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="GPU manager not available",
|
||||
)
|
||||
|
||||
return GPUStatus(
|
||||
device_name=_gpu_manager.device_name,
|
||||
total_gb=_gpu_manager.total_memory / 1024**3,
|
||||
used_gb=_gpu_manager.get_used_memory() / 1024**3,
|
||||
available_gb=_gpu_manager.get_available_memory() / 1024**3,
|
||||
utilization_percent=_gpu_manager.get_utilization(),
|
||||
temperature_c=_gpu_manager.get_temperature(),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/clear-cache",
|
||||
summary="Clear cache",
|
||||
description="Clear model cache and free GPU memory.",
|
||||
)
|
||||
async def clear_cache(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Clear model cache."""
|
||||
if _model_registry is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Model registry not available",
|
||||
)
|
||||
|
||||
try:
|
||||
_model_registry.clear_cache()
|
||||
return {"message": "Cache cleared successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/unload-all",
|
||||
summary="Unload all models",
|
||||
description="Unload all models from GPU memory.",
|
||||
)
|
||||
async def unload_all_models(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Unload all models."""
|
||||
if _model_registry is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Model registry not available",
|
||||
)
|
||||
|
||||
try:
|
||||
await _model_registry.unload_all()
|
||||
return {"message": "All models unloaded successfully"}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health",
|
||||
summary="Health check",
|
||||
description="Simple health check endpoint.",
|
||||
)
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint (no auth required)."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"uptime_seconds": time.time() - _start_time,
|
||||
}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/api-key/regenerate",
|
||||
summary="Regenerate API key",
|
||||
description="Generate a new API key. The old key will be invalidated.",
|
||||
)
|
||||
async def regenerate_api_key(
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> dict:
|
||||
"""Regenerate API key."""
|
||||
key_manager = get_key_manager()
|
||||
new_key = key_manager.generate_new_key()
|
||||
|
||||
return {
|
||||
"api_key": new_key,
|
||||
"message": "New API key generated. Store it securely - it won't be shown again.",
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/audio/{generation_id}",
|
||||
summary="Download audio",
|
||||
description="Download generated audio file.",
|
||||
)
|
||||
async def download_audio(
|
||||
generation_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> FileResponse:
|
||||
"""Download audio file for a generation."""
|
||||
# This would look up the actual file path from the database
|
||||
# For now, construct expected path
|
||||
from config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
# Find the audio file
|
||||
audio_dir = Path(settings.output_dir)
|
||||
possible_paths = [
|
||||
audio_dir / f"{generation_id}.wav",
|
||||
audio_dir / f"{generation_id}.mp3",
|
||||
audio_dir / f"{generation_id}.flac",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path.exists():
|
||||
return FileResponse(
|
||||
path=path,
|
||||
filename=path.name,
|
||||
media_type="audio/wav" if path.suffix == ".wav" else f"audio/{path.suffix[1:]}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Audio file for generation {generation_id} not found",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/audio/{generation_id}/waveform",
|
||||
summary="Download waveform",
|
||||
description="Download waveform visualization image.",
|
||||
)
|
||||
async def download_waveform(
|
||||
generation_id: str,
|
||||
api_key: str = Depends(verify_api_key_dependency),
|
||||
) -> FileResponse:
|
||||
"""Download waveform image for a generation."""
|
||||
from config.settings import get_settings
|
||||
settings = get_settings()
|
||||
|
||||
waveform_path = Path(settings.output_dir) / f"{generation_id}_waveform.png"
|
||||
|
||||
if not waveform_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Waveform for generation {generation_id} not found",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=waveform_path,
|
||||
filename=waveform_path.name,
|
||||
media_type="image/png",
|
||||
)
|
||||
24
src/core/__init__.py
Normal file
24
src/core/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Core infrastructure for AudioCraft Studio."""
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
ConditioningType,
|
||||
)
|
||||
from src.core.gpu_manager import GPUMemoryManager, VRAMBudget
|
||||
from src.core.model_registry import ModelRegistry
|
||||
from src.core.oom_handler import OOMHandler, OOMRecoveryError, oom_safe
|
||||
|
||||
__all__ = [
|
||||
"BaseAudioModel",
|
||||
"GenerationRequest",
|
||||
"GenerationResult",
|
||||
"ConditioningType",
|
||||
"GPUMemoryManager",
|
||||
"VRAMBudget",
|
||||
"ModelRegistry",
|
||||
"OOMHandler",
|
||||
"OOMRecoveryError",
|
||||
"oom_safe",
|
||||
]
|
||||
535
src/core/audio_utils.py
Normal file
535
src/core/audio_utils.py
Normal file
@@ -0,0 +1,535 @@
|
||||
"""Audio utilities for processing, visualization, and export."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def normalize_audio(
|
||||
audio: Union[torch.Tensor, np.ndarray],
|
||||
target_db: float = -14.0,
|
||||
peak_normalize: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Normalize audio to target loudness.
|
||||
|
||||
Args:
|
||||
audio: Audio tensor or array [channels, samples] or [samples]
|
||||
target_db: Target loudness in dB (LUFS-like)
|
||||
peak_normalize: If True, normalize to peak instead of RMS
|
||||
|
||||
Returns:
|
||||
Normalized audio as numpy array
|
||||
"""
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.numpy()
|
||||
|
||||
# Ensure float32
|
||||
audio = audio.astype(np.float32)
|
||||
|
||||
# Handle batch dimension
|
||||
if audio.ndim == 3:
|
||||
audio = audio[0] # Take first sample if batched
|
||||
|
||||
if peak_normalize:
|
||||
# Peak normalization
|
||||
peak = np.abs(audio).max()
|
||||
if peak > 0:
|
||||
target_linear = 10 ** (target_db / 20)
|
||||
audio = audio * (target_linear / peak)
|
||||
else:
|
||||
# RMS normalization (approximating LUFS)
|
||||
rms = np.sqrt(np.mean(audio ** 2))
|
||||
if rms > 0:
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
audio = audio * (target_rms / rms)
|
||||
|
||||
# Clip to prevent clipping
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def convert_sample_rate(
|
||||
audio: np.ndarray,
|
||||
orig_sr: int,
|
||||
target_sr: int,
|
||||
) -> np.ndarray:
|
||||
"""Convert audio sample rate.
|
||||
|
||||
Args:
|
||||
audio: Audio array [channels, samples] or [samples]
|
||||
orig_sr: Original sample rate
|
||||
target_sr: Target sample rate
|
||||
|
||||
Returns:
|
||||
Resampled audio
|
||||
"""
|
||||
if orig_sr == target_sr:
|
||||
return audio
|
||||
|
||||
try:
|
||||
import librosa
|
||||
|
||||
# Handle multi-channel
|
||||
if audio.ndim == 2:
|
||||
resampled = np.array([
|
||||
librosa.resample(ch, orig_sr=orig_sr, target_sr=target_sr)
|
||||
for ch in audio
|
||||
])
|
||||
else:
|
||||
resampled = librosa.resample(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
|
||||
return resampled
|
||||
|
||||
except ImportError:
|
||||
logger.warning("librosa not available, using scipy for resampling")
|
||||
from scipy import signal
|
||||
|
||||
ratio = target_sr / orig_sr
|
||||
new_length = int(audio.shape[-1] * ratio)
|
||||
|
||||
if audio.ndim == 2:
|
||||
resampled = np.array([
|
||||
signal.resample(ch, new_length) for ch in audio
|
||||
])
|
||||
else:
|
||||
resampled = signal.resample(audio, new_length)
|
||||
|
||||
return resampled
|
||||
|
||||
|
||||
def generate_waveform(
|
||||
audio: Union[torch.Tensor, np.ndarray],
|
||||
sample_rate: int,
|
||||
width: int = 800,
|
||||
height: int = 200,
|
||||
color: str = "#3b82f6",
|
||||
background: str = "#1f2937",
|
||||
) -> bytes:
|
||||
"""Generate waveform image as PNG bytes.
|
||||
|
||||
Args:
|
||||
audio: Audio data [channels, samples] or [samples]
|
||||
sample_rate: Sample rate in Hz
|
||||
width: Image width in pixels
|
||||
height: Image height in pixels
|
||||
color: Waveform color (hex)
|
||||
background: Background color (hex)
|
||||
|
||||
Returns:
|
||||
PNG image as bytes
|
||||
"""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
logger.warning("matplotlib not available for waveform generation")
|
||||
return b""
|
||||
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.numpy()
|
||||
|
||||
# Handle dimensions
|
||||
if audio.ndim == 3:
|
||||
audio = audio[0]
|
||||
if audio.ndim == 2:
|
||||
audio = audio.mean(axis=0) # Mix to mono for visualization
|
||||
|
||||
# Downsample for visualization
|
||||
samples_per_pixel = max(1, len(audio) // width)
|
||||
num_chunks = len(audio) // samples_per_pixel
|
||||
|
||||
if num_chunks > 0:
|
||||
audio_chunks = audio[:num_chunks * samples_per_pixel].reshape(
|
||||
num_chunks, samples_per_pixel
|
||||
)
|
||||
# Get min/max for each chunk
|
||||
mins = audio_chunks.min(axis=1)
|
||||
maxs = audio_chunks.max(axis=1)
|
||||
else:
|
||||
mins = maxs = audio
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
|
||||
fig.patch.set_facecolor(background)
|
||||
ax.set_facecolor(background)
|
||||
|
||||
# Plot waveform
|
||||
x = np.arange(len(mins))
|
||||
ax.fill_between(x, mins, maxs, color=color, alpha=0.7)
|
||||
ax.axhline(y=0, color=color, alpha=0.3, linewidth=0.5)
|
||||
|
||||
# Style
|
||||
ax.set_xlim(0, len(mins))
|
||||
ax.set_ylim(-1, 1)
|
||||
ax.axis('off')
|
||||
plt.tight_layout(pad=0)
|
||||
|
||||
# Save to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', facecolor=background, edgecolor='none')
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
|
||||
return buf.read()
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio: Union[torch.Tensor, np.ndarray],
|
||||
sample_rate: int,
|
||||
width: int = 800,
|
||||
height: int = 200,
|
||||
colormap: str = "magma",
|
||||
) -> bytes:
|
||||
"""Generate spectrogram image as PNG bytes.
|
||||
|
||||
Args:
|
||||
audio: Audio data
|
||||
sample_rate: Sample rate in Hz
|
||||
width: Image width
|
||||
height: Image height
|
||||
colormap: Matplotlib colormap name
|
||||
|
||||
Returns:
|
||||
PNG image as bytes
|
||||
"""
|
||||
try:
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
import librosa
|
||||
import librosa.display
|
||||
except ImportError:
|
||||
logger.warning("matplotlib/librosa not available for spectrogram")
|
||||
return b""
|
||||
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.numpy()
|
||||
|
||||
# Handle dimensions
|
||||
if audio.ndim == 3:
|
||||
audio = audio[0]
|
||||
if audio.ndim == 2:
|
||||
audio = audio.mean(axis=0)
|
||||
|
||||
# Compute mel spectrogram
|
||||
S = librosa.feature.melspectrogram(
|
||||
y=audio,
|
||||
sr=sample_rate,
|
||||
n_mels=128,
|
||||
fmax=sample_rate // 2,
|
||||
)
|
||||
S_db = librosa.power_to_db(S, ref=np.max)
|
||||
|
||||
# Create figure
|
||||
fig, ax = plt.subplots(figsize=(width / 100, height / 100), dpi=100)
|
||||
|
||||
librosa.display.specshow(
|
||||
S_db,
|
||||
sr=sample_rate,
|
||||
x_axis='time',
|
||||
y_axis='mel',
|
||||
cmap=colormap,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
ax.axis('off')
|
||||
plt.tight_layout(pad=0)
|
||||
|
||||
# Save to bytes
|
||||
buf = io.BytesIO()
|
||||
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
|
||||
plt.close(fig)
|
||||
buf.seek(0)
|
||||
|
||||
return buf.read()
|
||||
|
||||
|
||||
def save_audio(
|
||||
audio: Union[torch.Tensor, np.ndarray],
|
||||
sample_rate: int,
|
||||
path: Path,
|
||||
format: str = "wav",
|
||||
normalize: bool = True,
|
||||
target_db: float = -14.0,
|
||||
) -> Path:
|
||||
"""Save audio to file with optional normalization.
|
||||
|
||||
Args:
|
||||
audio: Audio data
|
||||
sample_rate: Sample rate
|
||||
path: Output path (extension will be added if needed)
|
||||
format: Output format (wav, mp3, flac, ogg)
|
||||
normalize: Whether to normalize audio
|
||||
target_db: Normalization target
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
import soundfile as sf
|
||||
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.numpy()
|
||||
|
||||
# Handle batch dimension
|
||||
if audio.ndim == 3:
|
||||
audio = audio[0]
|
||||
|
||||
# Normalize if requested
|
||||
if normalize:
|
||||
audio = normalize_audio(audio, target_db=target_db)
|
||||
|
||||
# Transpose for soundfile [samples, channels]
|
||||
if audio.ndim == 2:
|
||||
audio = audio.T
|
||||
|
||||
# Ensure correct extension
|
||||
path = Path(path)
|
||||
if not path.suffix:
|
||||
path = path.with_suffix(f".{format}")
|
||||
|
||||
# Save based on format
|
||||
if format in ("wav", "flac"):
|
||||
sf.write(path, audio, sample_rate)
|
||||
elif format == "mp3":
|
||||
# Use scipy.io.wavfile then convert with pydub if available
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
|
||||
# Save as WAV first
|
||||
wav_path = path.with_suffix(".wav")
|
||||
sf.write(wav_path, audio, sample_rate)
|
||||
|
||||
# Convert to MP3
|
||||
sound = AudioSegment.from_wav(wav_path)
|
||||
sound.export(path, format="mp3", bitrate="320k")
|
||||
|
||||
# Remove temp WAV
|
||||
wav_path.unlink()
|
||||
except ImportError:
|
||||
logger.warning("pydub not available, saving as WAV instead")
|
||||
path = path.with_suffix(".wav")
|
||||
sf.write(path, audio, sample_rate)
|
||||
elif format == "ogg":
|
||||
sf.write(path, audio, sample_rate, format="ogg", subtype="vorbis")
|
||||
else:
|
||||
# Default to WAV
|
||||
path = path.with_suffix(".wav")
|
||||
sf.write(path, audio, sample_rate)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: Path,
|
||||
target_sr: Optional[int] = None,
|
||||
mono: bool = False,
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Load audio from file.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
target_sr: Target sample rate (None to keep original)
|
||||
mono: Convert to mono
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_array, sample_rate)
|
||||
"""
|
||||
import soundfile as sf
|
||||
|
||||
audio, sr = sf.read(path)
|
||||
|
||||
# Convert to [channels, samples] format
|
||||
if audio.ndim == 1:
|
||||
audio = audio[np.newaxis, :]
|
||||
else:
|
||||
audio = audio.T
|
||||
|
||||
# Convert to mono
|
||||
if mono and audio.shape[0] > 1:
|
||||
audio = audio.mean(axis=0, keepdims=True)
|
||||
|
||||
# Resample if needed
|
||||
if target_sr and target_sr != sr:
|
||||
audio = convert_sample_rate(audio, sr, target_sr)
|
||||
sr = target_sr
|
||||
|
||||
return audio, sr
|
||||
|
||||
|
||||
def get_audio_info(path: Path) -> dict:
|
||||
"""Get audio file information.
|
||||
|
||||
Args:
|
||||
path: Path to audio file
|
||||
|
||||
Returns:
|
||||
Dictionary with audio info
|
||||
"""
|
||||
import soundfile as sf
|
||||
|
||||
info = sf.info(path)
|
||||
|
||||
return {
|
||||
"path": str(path),
|
||||
"duration": info.duration,
|
||||
"sample_rate": info.samplerate,
|
||||
"channels": info.channels,
|
||||
"format": info.format,
|
||||
"subtype": info.subtype,
|
||||
"frames": info.frames,
|
||||
}
|
||||
|
||||
|
||||
def trim_silence(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
threshold_db: float = -40.0,
|
||||
min_silence_ms: int = 100,
|
||||
) -> np.ndarray:
|
||||
"""Trim silence from start and end of audio.
|
||||
|
||||
Args:
|
||||
audio: Audio array
|
||||
sample_rate: Sample rate
|
||||
threshold_db: Silence threshold in dB
|
||||
min_silence_ms: Minimum silence duration to trim
|
||||
|
||||
Returns:
|
||||
Trimmed audio
|
||||
"""
|
||||
try:
|
||||
import librosa
|
||||
|
||||
if audio.ndim == 2:
|
||||
# Process mono for trimming
|
||||
mono = audio.mean(axis=0)
|
||||
else:
|
||||
mono = audio
|
||||
|
||||
# Get non-silent intervals
|
||||
intervals = librosa.effects.split(
|
||||
mono,
|
||||
top_db=abs(threshold_db),
|
||||
frame_length=int(sample_rate * min_silence_ms / 1000),
|
||||
)
|
||||
|
||||
if len(intervals) == 0:
|
||||
return audio
|
||||
|
||||
start = intervals[0][0]
|
||||
end = intervals[-1][1]
|
||||
|
||||
if audio.ndim == 2:
|
||||
return audio[:, start:end]
|
||||
return audio[start:end]
|
||||
|
||||
except ImportError:
|
||||
logger.warning("librosa not available for silence trimming")
|
||||
return audio
|
||||
|
||||
|
||||
def apply_fade(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
fade_in_ms: float = 0,
|
||||
fade_out_ms: float = 0,
|
||||
) -> np.ndarray:
|
||||
"""Apply fade in/out to audio.
|
||||
|
||||
Args:
|
||||
audio: Audio array [channels, samples] or [samples]
|
||||
sample_rate: Sample rate
|
||||
fade_in_ms: Fade in duration in milliseconds
|
||||
fade_out_ms: Fade out duration in milliseconds
|
||||
|
||||
Returns:
|
||||
Audio with fades applied
|
||||
"""
|
||||
audio = audio.copy()
|
||||
|
||||
if fade_in_ms > 0:
|
||||
fade_in_samples = int(sample_rate * fade_in_ms / 1000)
|
||||
fade_in_samples = min(fade_in_samples, audio.shape[-1])
|
||||
fade_in_curve = np.linspace(0, 1, fade_in_samples)
|
||||
|
||||
if audio.ndim == 2:
|
||||
audio[:, :fade_in_samples] *= fade_in_curve
|
||||
else:
|
||||
audio[:fade_in_samples] *= fade_in_curve
|
||||
|
||||
if fade_out_ms > 0:
|
||||
fade_out_samples = int(sample_rate * fade_out_ms / 1000)
|
||||
fade_out_samples = min(fade_out_samples, audio.shape[-1])
|
||||
fade_out_curve = np.linspace(1, 0, fade_out_samples)
|
||||
|
||||
if audio.ndim == 2:
|
||||
audio[:, -fade_out_samples:] *= fade_out_curve
|
||||
else:
|
||||
audio[-fade_out_samples:] *= fade_out_curve
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def concatenate_audio(
|
||||
audio_list: list[np.ndarray],
|
||||
sample_rate: int,
|
||||
crossfade_ms: float = 0,
|
||||
) -> np.ndarray:
|
||||
"""Concatenate multiple audio segments.
|
||||
|
||||
Args:
|
||||
audio_list: List of audio arrays
|
||||
sample_rate: Sample rate (must be same for all)
|
||||
crossfade_ms: Crossfade duration between segments
|
||||
|
||||
Returns:
|
||||
Concatenated audio
|
||||
"""
|
||||
if not audio_list:
|
||||
return np.array([])
|
||||
|
||||
if len(audio_list) == 1:
|
||||
return audio_list[0]
|
||||
|
||||
crossfade_samples = int(sample_rate * crossfade_ms / 1000)
|
||||
|
||||
result = audio_list[0]
|
||||
|
||||
for audio in audio_list[1:]:
|
||||
if crossfade_samples > 0 and crossfade_samples < min(
|
||||
result.shape[-1], audio.shape[-1]
|
||||
):
|
||||
# Apply crossfade
|
||||
fade_out = np.linspace(1, 0, crossfade_samples)
|
||||
fade_in = np.linspace(0, 1, crossfade_samples)
|
||||
|
||||
if result.ndim == 2:
|
||||
# Overlap region
|
||||
result[:, -crossfade_samples:] *= fade_out
|
||||
overlap = result[:, -crossfade_samples:] + audio[:, :crossfade_samples] * fade_in
|
||||
result = np.concatenate([
|
||||
result[:, :-crossfade_samples],
|
||||
overlap,
|
||||
audio[:, crossfade_samples:]
|
||||
], axis=1)
|
||||
else:
|
||||
result[-crossfade_samples:] *= fade_out
|
||||
overlap = result[-crossfade_samples:] + audio[:crossfade_samples] * fade_in
|
||||
result = np.concatenate([
|
||||
result[:-crossfade_samples],
|
||||
overlap,
|
||||
audio[crossfade_samples:]
|
||||
])
|
||||
else:
|
||||
# Simple concatenation
|
||||
result = np.concatenate([result, audio], axis=-1)
|
||||
|
||||
return result
|
||||
247
src/core/base_model.py
Normal file
247
src/core/base_model.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""Abstract base classes for AudioCraft model adapters."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ConditioningType(str, Enum):
|
||||
"""Types of conditioning supported by models."""
|
||||
|
||||
TEXT = "text"
|
||||
MELODY = "melody"
|
||||
STYLE = "style"
|
||||
CHORDS = "chords"
|
||||
DRUMS = "drums"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationRequest:
|
||||
"""Request parameters for audio generation.
|
||||
|
||||
Attributes:
|
||||
prompts: List of text prompts for generation
|
||||
duration: Target duration in seconds
|
||||
temperature: Sampling temperature (higher = more random)
|
||||
top_k: Top-k sampling parameter
|
||||
top_p: Nucleus sampling parameter (0 = disabled)
|
||||
cfg_coef: Classifier-free guidance coefficient
|
||||
batch_size: Number of samples to generate per prompt
|
||||
seed: Random seed for reproducibility
|
||||
conditioning: Optional conditioning data
|
||||
"""
|
||||
|
||||
prompts: list[str]
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
top_k: int = 250
|
||||
top_p: float = 0.0
|
||||
cfg_coef: float = 3.0
|
||||
batch_size: int = 1
|
||||
seed: Optional[int] = None
|
||||
conditioning: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate request parameters."""
|
||||
if not self.prompts:
|
||||
raise ValueError("At least one prompt is required")
|
||||
if self.duration <= 0:
|
||||
raise ValueError("Duration must be positive")
|
||||
if self.temperature < 0:
|
||||
raise ValueError("Temperature must be non-negative")
|
||||
if self.top_k < 0:
|
||||
raise ValueError("top_k must be non-negative")
|
||||
if not 0 <= self.top_p <= 1:
|
||||
raise ValueError("top_p must be between 0 and 1")
|
||||
if self.cfg_coef < 1:
|
||||
raise ValueError("cfg_coef must be >= 1")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationResult:
|
||||
"""Result of audio generation.
|
||||
|
||||
Attributes:
|
||||
audio: Generated audio tensor (shape: [batch, channels, samples])
|
||||
sample_rate: Audio sample rate in Hz
|
||||
duration: Actual duration in seconds
|
||||
model_id: ID of the model used
|
||||
variant: Model variant used
|
||||
parameters: Generation parameters used
|
||||
seed: Actual seed used (for reproducibility)
|
||||
"""
|
||||
|
||||
audio: torch.Tensor
|
||||
sample_rate: int
|
||||
duration: float
|
||||
model_id: str
|
||||
variant: str
|
||||
parameters: dict[str, Any]
|
||||
seed: int
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of audio samples generated."""
|
||||
return self.audio.shape[0]
|
||||
|
||||
@property
|
||||
def num_channels(self) -> int:
|
||||
"""Number of audio channels."""
|
||||
return self.audio.shape[1]
|
||||
|
||||
@property
|
||||
def num_frames(self) -> int:
|
||||
"""Number of audio frames."""
|
||||
return self.audio.shape[2]
|
||||
|
||||
|
||||
class BaseAudioModel(ABC):
|
||||
"""Abstract base class for AudioCraft model adapters.
|
||||
|
||||
All model adapters must implement this interface to integrate with
|
||||
the model registry and generation service.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_id(self) -> str:
|
||||
"""Unique identifier for this model family (e.g., 'musicgen')."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def variant(self) -> str:
|
||||
"""Current model variant (e.g., 'medium', 'large')."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def display_name(self) -> str:
|
||||
"""Human-readable name for UI display."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Brief description of the model's capabilities."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vram_estimate_mb(self) -> int:
|
||||
"""Estimated VRAM usage when loaded (in megabytes)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_duration(self) -> float:
|
||||
"""Maximum supported generation duration in seconds."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sample_rate(self) -> int:
|
||||
"""Output audio sample rate in Hz."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
"""List of conditioning types supported by this model."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_loaded(self) -> bool:
|
||||
"""Whether the model is currently loaded in memory."""
|
||||
...
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
"""Device the model is loaded on, or None if not loaded."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the model into memory.
|
||||
|
||||
Args:
|
||||
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If loading fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory.
|
||||
|
||||
Should be idempotent - safe to call even if not loaded.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate audio based on the request.
|
||||
|
||||
Args:
|
||||
request: Generation parameters and prompts
|
||||
|
||||
Returns:
|
||||
GenerationResult containing audio and metadata
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model is not loaded
|
||||
ValueError: If request parameters are invalid for this model
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters for this model.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameter names to default values
|
||||
"""
|
||||
...
|
||||
|
||||
def validate_request(self, request: GenerationRequest) -> None:
|
||||
"""Validate a generation request for this model.
|
||||
|
||||
Args:
|
||||
request: Request to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If request is invalid for this model
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
|
||||
|
||||
if request.duration > self.max_duration:
|
||||
raise ValueError(
|
||||
f"Duration {request.duration}s exceeds maximum {self.max_duration}s "
|
||||
f"for {self.model_id}/{self.variant}"
|
||||
)
|
||||
|
||||
# Check conditioning requirements
|
||||
for cond_type, cond_data in request.conditioning.items():
|
||||
if cond_data is not None:
|
||||
try:
|
||||
cond_enum = ConditioningType(cond_type)
|
||||
except ValueError:
|
||||
raise ValueError(f"Unknown conditioning type: {cond_type}")
|
||||
|
||||
if cond_enum not in self.supports_conditioning:
|
||||
raise ValueError(
|
||||
f"Model {self.model_id}/{self.variant} does not support "
|
||||
f"{cond_type} conditioning"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation."""
|
||||
loaded = "loaded" if self.is_loaded else "not loaded"
|
||||
return f"<{self.__class__.__name__} {self.model_id}/{self.variant} ({loaded})>"
|
||||
433
src/core/gpu_manager.py
Normal file
433
src/core/gpu_manager.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""GPU memory management for AudioCraft models."""
|
||||
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VRAMBudget:
|
||||
"""VRAM budget allocation information.
|
||||
|
||||
Attributes:
|
||||
total_mb: Total VRAM in megabytes
|
||||
used_mb: Currently used VRAM
|
||||
free_mb: Free VRAM
|
||||
reserved_comfyui_mb: VRAM reserved for ComfyUI
|
||||
safety_buffer_mb: Safety buffer to prevent OOM
|
||||
available_mb: VRAM available for AudioCraft models
|
||||
"""
|
||||
|
||||
total_mb: int
|
||||
used_mb: int
|
||||
free_mb: int
|
||||
reserved_comfyui_mb: int
|
||||
safety_buffer_mb: int
|
||||
available_mb: int
|
||||
|
||||
@property
|
||||
def utilization(self) -> float:
|
||||
"""Current VRAM utilization as a fraction (0-1)."""
|
||||
return self.used_mb / self.total_mb if self.total_mb > 0 else 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPUState:
|
||||
"""State information for inter-service coordination."""
|
||||
|
||||
timestamp: float
|
||||
service: str # "audiocraft" or "comfyui"
|
||||
vram_used_mb: int
|
||||
vram_requested_mb: int
|
||||
status: str # "idle", "working", "requesting_priority", "yielded"
|
||||
|
||||
|
||||
class GPUMemoryManager:
|
||||
"""Manages GPU memory allocation and coordination with ComfyUI.
|
||||
|
||||
Uses pynvml for accurate system-wide VRAM tracking and file-based
|
||||
IPC for coordination with ComfyUI running on the same system.
|
||||
"""
|
||||
|
||||
COORDINATION_FILE = Path("/tmp/audiocraft_comfyui_coord.json")
|
||||
LOCK_FILE = Path("/tmp/audiocraft_comfyui_coord.lock")
|
||||
STALE_THRESHOLD = 30.0 # seconds
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_id: int = 0,
|
||||
comfyui_reserve_gb: float = 10.0,
|
||||
safety_buffer_gb: float = 1.0,
|
||||
):
|
||||
"""Initialize GPU memory manager.
|
||||
|
||||
Args:
|
||||
device_id: CUDA device index
|
||||
comfyui_reserve_gb: VRAM to reserve for ComfyUI (gigabytes)
|
||||
safety_buffer_gb: Safety buffer to prevent OOM (gigabytes)
|
||||
"""
|
||||
self.device_id = device_id
|
||||
self.device = torch.device(f"cuda:{device_id}")
|
||||
self.comfyui_reserve_mb = int(comfyui_reserve_gb * 1024)
|
||||
self.safety_buffer_mb = int(safety_buffer_gb * 1024)
|
||||
|
||||
# Initialize NVML for direct GPU monitoring
|
||||
self._nvml_initialized = False
|
||||
self._nvml_handle = None
|
||||
self._init_nvml()
|
||||
|
||||
# Threading
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Callbacks for memory events
|
||||
self._low_memory_callbacks: list[Callable[[VRAMBudget], None]] = []
|
||||
self._oom_callbacks: list[Callable[[], None]] = []
|
||||
|
||||
# Initialize coordination file
|
||||
self._ensure_coordination_file()
|
||||
|
||||
def _init_nvml(self) -> None:
|
||||
"""Initialize NVML for GPU monitoring."""
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlInit()
|
||||
self._nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
|
||||
self._nvml_initialized = True
|
||||
logger.info("NVML initialized successfully")
|
||||
except ImportError:
|
||||
logger.warning("pynvml not available, falling back to torch.cuda")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize NVML: {e}, falling back to torch.cuda")
|
||||
|
||||
def _ensure_coordination_file(self) -> None:
|
||||
"""Create coordination file if it doesn't exist."""
|
||||
if not self.COORDINATION_FILE.exists():
|
||||
initial_state = {
|
||||
"audiocraft": None,
|
||||
"comfyui": None,
|
||||
"priority": None,
|
||||
"last_update": time.time(),
|
||||
}
|
||||
self._write_coordination_state(initial_state)
|
||||
|
||||
def get_memory_info(self) -> dict[str, int]:
|
||||
"""Get current GPU memory status.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory values in megabytes:
|
||||
- total: Total VRAM
|
||||
- used: Used VRAM (system-wide)
|
||||
- free: Free VRAM
|
||||
- torch_allocated: PyTorch allocated memory
|
||||
- torch_reserved: PyTorch reserved memory
|
||||
- torch_cached: PyTorch cached memory
|
||||
"""
|
||||
with self._lock:
|
||||
if self._nvml_initialized:
|
||||
return self._get_memory_info_nvml()
|
||||
return self._get_memory_info_torch()
|
||||
|
||||
def _get_memory_info_nvml(self) -> dict[str, int]:
|
||||
"""Get memory info using NVML (more accurate)."""
|
||||
import pynvml
|
||||
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(self._nvml_handle)
|
||||
torch_allocated = torch.cuda.memory_allocated(self.device)
|
||||
torch_reserved = torch.cuda.memory_reserved(self.device)
|
||||
|
||||
return {
|
||||
"total": info.total // (1024 * 1024),
|
||||
"used": info.used // (1024 * 1024),
|
||||
"free": info.free // (1024 * 1024),
|
||||
"torch_allocated": torch_allocated // (1024 * 1024),
|
||||
"torch_reserved": torch_reserved // (1024 * 1024),
|
||||
"torch_cached": (torch_reserved - torch_allocated) // (1024 * 1024),
|
||||
}
|
||||
|
||||
def _get_memory_info_torch(self) -> dict[str, int]:
|
||||
"""Get memory info using torch.cuda (fallback)."""
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
allocated = torch.cuda.memory_allocated(self.device)
|
||||
reserved = torch.cuda.memory_reserved(self.device)
|
||||
|
||||
# Note: This is less accurate for system-wide usage
|
||||
return {
|
||||
"total": props.total_memory // (1024 * 1024),
|
||||
"used": reserved // (1024 * 1024),
|
||||
"free": (props.total_memory - reserved) // (1024 * 1024),
|
||||
"torch_allocated": allocated // (1024 * 1024),
|
||||
"torch_reserved": reserved // (1024 * 1024),
|
||||
"torch_cached": (reserved - allocated) // (1024 * 1024),
|
||||
}
|
||||
|
||||
def get_available_budget(self) -> VRAMBudget:
|
||||
"""Calculate available VRAM budget considering ComfyUI.
|
||||
|
||||
Returns:
|
||||
VRAMBudget with current allocation information
|
||||
"""
|
||||
mem = self.get_memory_info()
|
||||
|
||||
# Check ComfyUI's actual usage via coordination file
|
||||
comfyui_state = self.get_comfyui_status()
|
||||
if comfyui_state and comfyui_state.status != "yielded":
|
||||
# Use actual ComfyUI usage + buffer, or reserve, whichever is higher
|
||||
effective_comfyui_reserve = max(
|
||||
self.comfyui_reserve_mb,
|
||||
comfyui_state.vram_used_mb + 2048, # 2GB headroom
|
||||
)
|
||||
else:
|
||||
effective_comfyui_reserve = self.comfyui_reserve_mb
|
||||
|
||||
available = max(
|
||||
0,
|
||||
mem["total"]
|
||||
- mem["used"]
|
||||
+ mem["torch_allocated"] # Our own usage doesn't count against us
|
||||
- effective_comfyui_reserve
|
||||
- self.safety_buffer_mb,
|
||||
)
|
||||
|
||||
return VRAMBudget(
|
||||
total_mb=mem["total"],
|
||||
used_mb=mem["used"],
|
||||
free_mb=mem["free"],
|
||||
reserved_comfyui_mb=effective_comfyui_reserve,
|
||||
safety_buffer_mb=self.safety_buffer_mb,
|
||||
available_mb=available,
|
||||
)
|
||||
|
||||
def can_load_model(self, vram_required_mb: int) -> tuple[bool, str]:
|
||||
"""Check if a model can fit in available VRAM.
|
||||
|
||||
Args:
|
||||
vram_required_mb: VRAM needed by the model
|
||||
|
||||
Returns:
|
||||
Tuple of (can_load, reason_message)
|
||||
"""
|
||||
budget = self.get_available_budget()
|
||||
|
||||
if vram_required_mb <= budget.available_mb:
|
||||
return True, "Sufficient VRAM available"
|
||||
|
||||
deficit = vram_required_mb - budget.available_mb
|
||||
return False, (
|
||||
f"Insufficient VRAM: need {vram_required_mb}MB, "
|
||||
f"available {budget.available_mb}MB (deficit: {deficit}MB)"
|
||||
)
|
||||
|
||||
def force_cleanup(self) -> int:
|
||||
"""Force GPU memory cleanup.
|
||||
|
||||
Returns:
|
||||
Freed memory in megabytes (approximate)
|
||||
"""
|
||||
with self._lock:
|
||||
before = self.get_memory_info()
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize(self.device)
|
||||
|
||||
after = self.get_memory_info()
|
||||
freed = before["torch_reserved"] - after["torch_reserved"]
|
||||
|
||||
if freed > 0:
|
||||
logger.info(f"Freed {freed}MB of GPU memory")
|
||||
|
||||
return freed
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get detailed GPU status for UI display.
|
||||
|
||||
Returns:
|
||||
Dictionary with status information
|
||||
"""
|
||||
mem = self.get_memory_info()
|
||||
budget = self.get_available_budget()
|
||||
|
||||
return {
|
||||
"device": str(self.device),
|
||||
"total_gb": round(mem["total"] / 1024, 2),
|
||||
"used_gb": round(mem["used"] / 1024, 2),
|
||||
"free_gb": round(mem["free"] / 1024, 2),
|
||||
"utilization_percent": round(budget.utilization * 100, 1),
|
||||
"available_for_models_gb": round(budget.available_mb / 1024, 2),
|
||||
"comfyui_reserve_gb": round(budget.reserved_comfyui_mb / 1024, 2),
|
||||
"torch_allocated_gb": round(mem["torch_allocated"] / 1024, 2),
|
||||
"torch_cached_gb": round(mem["torch_cached"] / 1024, 2),
|
||||
}
|
||||
|
||||
# ComfyUI Coordination Methods
|
||||
|
||||
def _read_coordination_state(self) -> dict[str, Any]:
|
||||
"""Read coordination state from file."""
|
||||
try:
|
||||
if self.COORDINATION_FILE.exists():
|
||||
return json.loads(self.COORDINATION_FILE.read_text())
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logger.warning(f"Failed to read coordination file: {e}")
|
||||
return {}
|
||||
|
||||
def _write_coordination_state(self, state: dict[str, Any]) -> None:
|
||||
"""Write coordination state to file with locking."""
|
||||
import fcntl
|
||||
|
||||
try:
|
||||
self.LOCK_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.LOCK_FILE, "w") as lock:
|
||||
fcntl.flock(lock, fcntl.LOCK_EX)
|
||||
try:
|
||||
self.COORDINATION_FILE.write_text(json.dumps(state, indent=2))
|
||||
finally:
|
||||
fcntl.flock(lock, fcntl.LOCK_UN)
|
||||
except IOError as e:
|
||||
logger.warning(f"Failed to write coordination file: {e}")
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
vram_used_mb: int,
|
||||
vram_requested_mb: int = 0,
|
||||
status: str = "idle",
|
||||
) -> None:
|
||||
"""Update AudioCraft's status in coordination file.
|
||||
|
||||
Args:
|
||||
vram_used_mb: Current VRAM usage
|
||||
vram_requested_mb: VRAM needed for pending operation
|
||||
status: Current status ("idle", "working", "requesting_priority")
|
||||
"""
|
||||
state = self._read_coordination_state()
|
||||
state["audiocraft"] = {
|
||||
"timestamp": time.time(),
|
||||
"service": "audiocraft",
|
||||
"vram_used_mb": vram_used_mb,
|
||||
"vram_requested_mb": vram_requested_mb,
|
||||
"status": status,
|
||||
}
|
||||
state["last_update"] = time.time()
|
||||
self._write_coordination_state(state)
|
||||
|
||||
def get_comfyui_status(self) -> Optional[GPUState]:
|
||||
"""Get ComfyUI's current status.
|
||||
|
||||
Returns:
|
||||
GPUState if ComfyUI is active and status is fresh, None otherwise
|
||||
"""
|
||||
state = self._read_coordination_state()
|
||||
comfyui_data = state.get("comfyui")
|
||||
|
||||
if not comfyui_data:
|
||||
return None
|
||||
|
||||
# Check if stale
|
||||
if time.time() - comfyui_data.get("timestamp", 0) > self.STALE_THRESHOLD:
|
||||
return None
|
||||
|
||||
return GPUState(
|
||||
timestamp=comfyui_data["timestamp"],
|
||||
service="comfyui",
|
||||
vram_used_mb=comfyui_data.get("vram_used_mb", 0),
|
||||
vram_requested_mb=comfyui_data.get("vram_requested_mb", 0),
|
||||
status=comfyui_data.get("status", "unknown"),
|
||||
)
|
||||
|
||||
def request_priority(self, vram_needed_mb: int, timeout: float = 30.0) -> bool:
|
||||
"""Request VRAM priority from ComfyUI.
|
||||
|
||||
Signals ComfyUI to release VRAM if possible.
|
||||
|
||||
Args:
|
||||
vram_needed_mb: Amount of VRAM needed
|
||||
timeout: Seconds to wait for ComfyUI to yield
|
||||
|
||||
Returns:
|
||||
True if ComfyUI acknowledged and yielded, False otherwise
|
||||
"""
|
||||
state = self._read_coordination_state()
|
||||
state["priority"] = {
|
||||
"requester": "audiocraft",
|
||||
"vram_needed_mb": vram_needed_mb,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
self._write_coordination_state(state)
|
||||
|
||||
logger.info(f"Requesting {vram_needed_mb}MB VRAM from ComfyUI...")
|
||||
|
||||
# Wait for ComfyUI to respond
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
comfyui = self.get_comfyui_status()
|
||||
if comfyui and comfyui.status == "yielded":
|
||||
logger.info("ComfyUI yielded VRAM")
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
|
||||
logger.warning("ComfyUI did not yield VRAM within timeout")
|
||||
return False
|
||||
|
||||
def is_comfyui_busy(self) -> bool:
|
||||
"""Check if ComfyUI is actively processing.
|
||||
|
||||
Returns:
|
||||
True if ComfyUI is working, False otherwise
|
||||
"""
|
||||
status = self.get_comfyui_status()
|
||||
return status is not None and status.status == "working"
|
||||
|
||||
# Callback Registration
|
||||
|
||||
def on_low_memory(self, callback: Callable[[VRAMBudget], None]) -> None:
|
||||
"""Register callback for low memory warnings.
|
||||
|
||||
Args:
|
||||
callback: Function to call with budget info when memory is low
|
||||
"""
|
||||
self._low_memory_callbacks.append(callback)
|
||||
|
||||
def on_oom(self, callback: Callable[[], None]) -> None:
|
||||
"""Register callback for OOM events.
|
||||
|
||||
Args:
|
||||
callback: Function to call when OOM occurs
|
||||
"""
|
||||
self._oom_callbacks.append(callback)
|
||||
|
||||
def check_memory_pressure(self, warning_threshold: float = 0.85) -> None:
|
||||
"""Check memory pressure and trigger callbacks if needed.
|
||||
|
||||
Args:
|
||||
warning_threshold: Utilization threshold for warnings (0-1)
|
||||
"""
|
||||
budget = self.get_available_budget()
|
||||
|
||||
if budget.utilization >= warning_threshold:
|
||||
logger.warning(
|
||||
f"High GPU memory pressure: {budget.utilization*100:.1f}% utilized"
|
||||
)
|
||||
for callback in self._low_memory_callbacks:
|
||||
try:
|
||||
callback(budget)
|
||||
except Exception as e:
|
||||
logger.error(f"Low memory callback failed: {e}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup NVML on destruction."""
|
||||
if self._nvml_initialized:
|
||||
try:
|
||||
import pynvml
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
except Exception:
|
||||
pass
|
||||
487
src/core/model_registry.py
Normal file
487
src/core/model_registry.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""Model registry for discovering and managing AudioCraft model adapters."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, Optional, Type
|
||||
|
||||
import yaml
|
||||
|
||||
from src.core.base_model import BaseAudioModel, ConditioningType
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelVariantConfig:
|
||||
"""Configuration for a model variant."""
|
||||
|
||||
hf_id: str
|
||||
vram_mb: int
|
||||
max_duration: float = 30.0
|
||||
channels: int = 1
|
||||
conditioning: list[str] = field(default_factory=list)
|
||||
description: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelFamilyConfig:
|
||||
"""Configuration for a model family."""
|
||||
|
||||
enabled: bool
|
||||
display_name: str
|
||||
description: str
|
||||
default_variant: str
|
||||
variants: dict[str, ModelVariantConfig]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelHandle:
|
||||
"""Handle for a loaded model with reference counting."""
|
||||
|
||||
model: BaseAudioModel
|
||||
model_id: str
|
||||
variant: str
|
||||
loaded_at: float
|
||||
last_accessed: float
|
||||
ref_count: int = 0
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update last accessed time."""
|
||||
self.last_accessed = time.time()
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""Central registry for discovering and managing model adapters.
|
||||
|
||||
Handles:
|
||||
- Loading model configurations from YAML
|
||||
- Lazy loading models on demand
|
||||
- LRU eviction when VRAM is constrained
|
||||
- Reference counting to prevent unloading during use
|
||||
- Automatic idle timeout for unused models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: Path,
|
||||
gpu_manager: GPUMemoryManager,
|
||||
max_cached_models: int = 2,
|
||||
idle_timeout_minutes: int = 15,
|
||||
):
|
||||
"""Initialize the model registry.
|
||||
|
||||
Args:
|
||||
config_path: Path to models.yaml configuration
|
||||
gpu_manager: GPU memory manager instance
|
||||
max_cached_models: Maximum models to keep loaded
|
||||
idle_timeout_minutes: Unload models after this idle time
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.gpu_manager = gpu_manager
|
||||
self.max_cached_models = max_cached_models
|
||||
self.idle_timeout_seconds = idle_timeout_minutes * 60
|
||||
|
||||
# Model configurations
|
||||
self._model_configs: dict[str, ModelFamilyConfig] = {}
|
||||
self._default_params: dict[str, Any] = {}
|
||||
|
||||
# Loaded model handles
|
||||
self._handles: dict[str, ModelHandle] = {} # Key: "model_id/variant"
|
||||
self._access_order: list[str] = [] # LRU tracking
|
||||
|
||||
# Registered adapter classes
|
||||
self._adapter_classes: dict[str, Type[BaseAudioModel]] = {}
|
||||
|
||||
# Threading
|
||||
self._lock = threading.RLock()
|
||||
self._cleanup_thread: Optional[threading.Thread] = None
|
||||
self._stop_cleanup = threading.Event()
|
||||
|
||||
# Load configuration
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""Load model configurations from YAML file."""
|
||||
if not self.config_path.exists():
|
||||
logger.warning(f"Model config not found: {self.config_path}")
|
||||
return
|
||||
|
||||
with open(self.config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Parse model families
|
||||
for model_id, model_config in config.get("models", {}).items():
|
||||
if not model_config.get("enabled", True):
|
||||
continue
|
||||
|
||||
variants = {}
|
||||
for variant_name, variant_config in model_config.get("variants", {}).items():
|
||||
variants[variant_name] = ModelVariantConfig(
|
||||
hf_id=variant_config["hf_id"],
|
||||
vram_mb=variant_config["vram_mb"],
|
||||
max_duration=variant_config.get("max_duration", 30.0),
|
||||
channels=variant_config.get("channels", 1),
|
||||
conditioning=variant_config.get("conditioning", []),
|
||||
description=variant_config.get("description", ""),
|
||||
)
|
||||
|
||||
self._model_configs[model_id] = ModelFamilyConfig(
|
||||
enabled=model_config.get("enabled", True),
|
||||
display_name=model_config.get("display_name", model_id),
|
||||
description=model_config.get("description", ""),
|
||||
default_variant=model_config.get("default_variant", "medium"),
|
||||
variants=variants,
|
||||
)
|
||||
|
||||
# Parse default generation parameters
|
||||
self._default_params = config.get("defaults", {}).get("generation", {})
|
||||
|
||||
logger.info(f"Loaded {len(self._model_configs)} model families from config")
|
||||
|
||||
def register_adapter(
|
||||
self, model_id: str, adapter_class: Type[BaseAudioModel]
|
||||
) -> None:
|
||||
"""Register a model adapter class.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID (e.g., 'musicgen')
|
||||
adapter_class: Adapter class implementing BaseAudioModel
|
||||
"""
|
||||
self._adapter_classes[model_id] = adapter_class
|
||||
logger.debug(f"Registered adapter for {model_id}: {adapter_class.__name__}")
|
||||
|
||||
def list_models(self) -> list[dict[str, Any]]:
|
||||
"""List all available models with their configurations.
|
||||
|
||||
Returns:
|
||||
List of model information dictionaries
|
||||
"""
|
||||
models = []
|
||||
|
||||
for model_id, config in self._model_configs.items():
|
||||
for variant_name, variant in config.variants.items():
|
||||
key = f"{model_id}/{variant_name}"
|
||||
handle = self._handles.get(key)
|
||||
|
||||
can_load, reason = self.gpu_manager.can_load_model(variant.vram_mb)
|
||||
|
||||
models.append({
|
||||
"model_id": model_id,
|
||||
"variant": variant_name,
|
||||
"display_name": config.display_name,
|
||||
"description": variant.description or config.description,
|
||||
"hf_id": variant.hf_id,
|
||||
"vram_mb": variant.vram_mb,
|
||||
"max_duration": variant.max_duration,
|
||||
"channels": variant.channels,
|
||||
"conditioning": variant.conditioning,
|
||||
"is_default": variant_name == config.default_variant,
|
||||
"is_loaded": handle is not None,
|
||||
"can_load": can_load,
|
||||
"load_reason": reason,
|
||||
"has_adapter": model_id in self._adapter_classes,
|
||||
})
|
||||
|
||||
return models
|
||||
|
||||
def get_model_config(
|
||||
self, model_id: str, variant: Optional[str] = None
|
||||
) -> tuple[ModelFamilyConfig, ModelVariantConfig]:
|
||||
"""Get configuration for a model.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variant: Specific variant, or None for default
|
||||
|
||||
Returns:
|
||||
Tuple of (family_config, variant_config)
|
||||
|
||||
Raises:
|
||||
ValueError: If model or variant not found
|
||||
"""
|
||||
if model_id not in self._model_configs:
|
||||
raise ValueError(f"Unknown model: {model_id}")
|
||||
|
||||
family = self._model_configs[model_id]
|
||||
variant = variant or family.default_variant
|
||||
|
||||
if variant not in family.variants:
|
||||
raise ValueError(f"Unknown variant {variant} for {model_id}")
|
||||
|
||||
return family, family.variants[variant]
|
||||
|
||||
def get_loaded_models(self) -> list[dict[str, Any]]:
|
||||
"""Get information about currently loaded models.
|
||||
|
||||
Returns:
|
||||
List of loaded model information
|
||||
"""
|
||||
with self._lock:
|
||||
return [
|
||||
{
|
||||
"model_id": handle.model_id,
|
||||
"variant": handle.variant,
|
||||
"loaded_at": handle.loaded_at,
|
||||
"last_accessed": handle.last_accessed,
|
||||
"ref_count": handle.ref_count,
|
||||
"idle_seconds": time.time() - handle.last_accessed,
|
||||
}
|
||||
for handle in self._handles.values()
|
||||
]
|
||||
|
||||
@contextmanager
|
||||
def get_model(
|
||||
self, model_id: str, variant: Optional[str] = None
|
||||
) -> Generator[BaseAudioModel, None, None]:
|
||||
"""Get a model, loading it if necessary.
|
||||
|
||||
Context manager that handles reference counting to prevent
|
||||
unloading during use.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variant: Specific variant, or None for default
|
||||
|
||||
Yields:
|
||||
Loaded model instance
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found or cannot be loaded
|
||||
RuntimeError: If VRAM insufficient
|
||||
"""
|
||||
family, variant_config = self.get_model_config(model_id, variant)
|
||||
variant = variant or family.default_variant
|
||||
key = f"{model_id}/{variant}"
|
||||
|
||||
with self._lock:
|
||||
# Get or load model
|
||||
if key not in self._handles:
|
||||
self._load_model(model_id, variant)
|
||||
|
||||
handle = self._handles[key]
|
||||
handle.ref_count += 1
|
||||
handle.touch()
|
||||
|
||||
# Update LRU order
|
||||
if key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
self._access_order.append(key)
|
||||
|
||||
try:
|
||||
yield handle.model
|
||||
finally:
|
||||
with self._lock:
|
||||
handle.ref_count -= 1
|
||||
|
||||
def _load_model(self, model_id: str, variant: str) -> None:
|
||||
"""Load a model into memory.
|
||||
|
||||
Must be called with self._lock held.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variant: Variant to load
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter registered
|
||||
RuntimeError: If VRAM insufficient
|
||||
"""
|
||||
key = f"{model_id}/{variant}"
|
||||
family, variant_config = self.get_model_config(model_id, variant)
|
||||
|
||||
# Check for adapter
|
||||
if model_id not in self._adapter_classes:
|
||||
raise ValueError(f"No adapter registered for {model_id}")
|
||||
|
||||
# Check VRAM
|
||||
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||
if not can_load:
|
||||
# Try to free memory by evicting models
|
||||
self._evict_for_space(variant_config.vram_mb)
|
||||
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||
if not can_load:
|
||||
raise RuntimeError(reason)
|
||||
|
||||
# Create and load model
|
||||
logger.info(f"Loading model {key}...")
|
||||
adapter_class = self._adapter_classes[model_id]
|
||||
model = adapter_class(variant=variant)
|
||||
model.load()
|
||||
|
||||
# Register handle
|
||||
self._handles[key] = ModelHandle(
|
||||
model=model,
|
||||
model_id=model_id,
|
||||
variant=variant,
|
||||
loaded_at=time.time(),
|
||||
last_accessed=time.time(),
|
||||
)
|
||||
self._access_order.append(key)
|
||||
|
||||
# Update GPU status
|
||||
mem = self.gpu_manager.get_memory_info()
|
||||
self.gpu_manager.update_status(mem["torch_allocated"], status="working")
|
||||
|
||||
logger.info(f"Model {key} loaded successfully")
|
||||
|
||||
def _evict_for_space(self, needed_mb: int) -> bool:
|
||||
"""Evict models to free up VRAM.
|
||||
|
||||
Must be called with self._lock held.
|
||||
|
||||
Args:
|
||||
needed_mb: VRAM needed
|
||||
|
||||
Returns:
|
||||
True if enough space was freed
|
||||
"""
|
||||
freed = 0
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
deficit = needed_mb - budget.available_mb
|
||||
|
||||
if deficit <= 0:
|
||||
return True
|
||||
|
||||
# Evict LRU models that have no active references
|
||||
for key in list(self._access_order):
|
||||
if deficit <= 0:
|
||||
break
|
||||
|
||||
handle = self._handles.get(key)
|
||||
if handle and handle.ref_count == 0:
|
||||
_, variant_config = self.get_model_config(
|
||||
handle.model_id, handle.variant
|
||||
)
|
||||
logger.info(f"Evicting {key} to free {variant_config.vram_mb}MB")
|
||||
self._unload_model(key)
|
||||
freed += variant_config.vram_mb
|
||||
deficit -= variant_config.vram_mb
|
||||
|
||||
self.gpu_manager.force_cleanup()
|
||||
return deficit <= 0
|
||||
|
||||
def _unload_model(self, key: str) -> None:
|
||||
"""Unload a model from memory.
|
||||
|
||||
Must be called with self._lock held.
|
||||
|
||||
Args:
|
||||
key: Model key (model_id/variant)
|
||||
"""
|
||||
if key not in self._handles:
|
||||
return
|
||||
|
||||
handle = self._handles[key]
|
||||
if handle.ref_count > 0:
|
||||
logger.warning(f"Cannot unload {key}: {handle.ref_count} active references")
|
||||
return
|
||||
|
||||
logger.info(f"Unloading model {key}")
|
||||
handle.model.unload()
|
||||
del self._handles[key]
|
||||
|
||||
if key in self._access_order:
|
||||
self._access_order.remove(key)
|
||||
|
||||
self.gpu_manager.force_cleanup()
|
||||
|
||||
def unload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
|
||||
"""Manually unload a model.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variant: Variant to unload, or None for all variants
|
||||
|
||||
Returns:
|
||||
True if model was unloaded
|
||||
"""
|
||||
with self._lock:
|
||||
if variant:
|
||||
key = f"{model_id}/{variant}"
|
||||
if key in self._handles:
|
||||
self._unload_model(key)
|
||||
return True
|
||||
else:
|
||||
# Unload all variants of this model
|
||||
keys = [k for k in self._handles if k.startswith(f"{model_id}/")]
|
||||
for key in keys:
|
||||
self._unload_model(key)
|
||||
return bool(keys)
|
||||
return False
|
||||
|
||||
def preload_model(self, model_id: str, variant: Optional[str] = None) -> bool:
|
||||
"""Preload a model into memory.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variant: Variant to load
|
||||
|
||||
Returns:
|
||||
True if model was loaded successfully
|
||||
"""
|
||||
family, _ = self.get_model_config(model_id, variant)
|
||||
variant = variant or family.default_variant
|
||||
key = f"{model_id}/{variant}"
|
||||
|
||||
with self._lock:
|
||||
if key in self._handles:
|
||||
return True # Already loaded
|
||||
|
||||
try:
|
||||
self._load_model(model_id, variant)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to preload {key}: {e}")
|
||||
return False
|
||||
|
||||
def start_cleanup_thread(self) -> None:
|
||||
"""Start background thread for idle model cleanup."""
|
||||
if self._cleanup_thread is not None:
|
||||
return
|
||||
|
||||
def cleanup_loop():
|
||||
while not self._stop_cleanup.is_set():
|
||||
self._cleanup_idle_models()
|
||||
self._stop_cleanup.wait(60) # Check every minute
|
||||
|
||||
self._cleanup_thread = threading.Thread(target=cleanup_loop, daemon=True)
|
||||
self._cleanup_thread.start()
|
||||
logger.info("Started model cleanup thread")
|
||||
|
||||
def stop_cleanup_thread(self) -> None:
|
||||
"""Stop the background cleanup thread."""
|
||||
if self._cleanup_thread is not None:
|
||||
self._stop_cleanup.set()
|
||||
self._cleanup_thread.join(timeout=5)
|
||||
self._cleanup_thread = None
|
||||
self._stop_cleanup.clear()
|
||||
|
||||
def _cleanup_idle_models(self) -> None:
|
||||
"""Unload models that have been idle too long."""
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
for key, handle in list(self._handles.items()):
|
||||
idle_time = now - handle.last_accessed
|
||||
if idle_time > self.idle_timeout_seconds and handle.ref_count == 0:
|
||||
logger.info(
|
||||
f"Unloading idle model {key} (idle for {idle_time/60:.1f} min)"
|
||||
)
|
||||
self._unload_model(key)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary of default parameter values
|
||||
"""
|
||||
return self._default_params.copy()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Cleanup on destruction."""
|
||||
self.stop_cleanup_thread()
|
||||
297
src/core/oom_handler.py
Normal file
297
src/core/oom_handler.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""OOM (Out of Memory) handling and recovery strategies."""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class OOMRecoveryError(Exception):
|
||||
"""Raised when OOM recovery fails after all strategies exhausted."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OOMHandler:
|
||||
"""Handles CUDA Out of Memory errors with multi-level recovery strategies.
|
||||
|
||||
Recovery levels:
|
||||
1. Clear PyTorch CUDA cache
|
||||
2. Evict unused models from registry
|
||||
3. Request ComfyUI to yield VRAM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gpu_manager: GPUMemoryManager,
|
||||
model_registry: Optional[Any] = None, # Avoid circular import
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 0.5,
|
||||
):
|
||||
"""Initialize OOM handler.
|
||||
|
||||
Args:
|
||||
gpu_manager: GPU memory manager instance
|
||||
model_registry: Optional model registry for eviction
|
||||
max_retries: Maximum recovery attempts
|
||||
retry_delay: Delay between retries in seconds
|
||||
"""
|
||||
self.gpu_manager = gpu_manager
|
||||
self.model_registry = model_registry
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
# Track OOM events for monitoring
|
||||
self._oom_count = 0
|
||||
self._last_oom_time: Optional[float] = None
|
||||
|
||||
@property
|
||||
def oom_count(self) -> int:
|
||||
"""Number of OOM events handled."""
|
||||
return self._oom_count
|
||||
|
||||
def set_model_registry(self, registry: Any) -> None:
|
||||
"""Set model registry (to avoid circular import at init time)."""
|
||||
self.model_registry = registry
|
||||
|
||||
def with_oom_recovery(self, func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Decorator that wraps function with OOM recovery logic.
|
||||
|
||||
Usage:
|
||||
@oom_handler.with_oom_recovery
|
||||
def generate_audio(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
func: Function to wrap
|
||||
|
||||
Returns:
|
||||
Wrapped function with OOM recovery
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f"Retry attempt {attempt}/{self.max_retries}")
|
||||
time.sleep(self.retry_delay)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
last_exception = e
|
||||
self._oom_count += 1
|
||||
self._last_oom_time = time.time()
|
||||
|
||||
logger.warning(f"CUDA OOM detected (attempt {attempt + 1}): {e}")
|
||||
|
||||
if attempt < self.max_retries:
|
||||
self._execute_recovery_strategy(attempt)
|
||||
else:
|
||||
logger.error(
|
||||
f"OOM recovery failed after {self.max_retries} attempts"
|
||||
)
|
||||
|
||||
raise OOMRecoveryError(
|
||||
f"OOM recovery failed after {self.max_retries} attempts"
|
||||
) from last_exception
|
||||
|
||||
return wrapper
|
||||
|
||||
def _execute_recovery_strategy(self, level: int) -> None:
|
||||
"""Execute recovery strategy based on severity level.
|
||||
|
||||
Args:
|
||||
level: Recovery level (0-2)
|
||||
"""
|
||||
strategies = [
|
||||
self._strategy_clear_cache,
|
||||
self._strategy_evict_models,
|
||||
self._strategy_request_comfyui_yield,
|
||||
]
|
||||
|
||||
# Execute all strategies up to and including current level
|
||||
for i in range(min(level + 1, len(strategies))):
|
||||
logger.info(f"Executing recovery strategy {i + 1}: {strategies[i].__name__}")
|
||||
strategies[i]()
|
||||
|
||||
def _strategy_clear_cache(self) -> None:
|
||||
"""Level 1: Clear PyTorch CUDA cache.
|
||||
|
||||
This is the fastest and least disruptive recovery strategy.
|
||||
Clears cached memory that PyTorch holds for future allocations.
|
||||
"""
|
||||
logger.info("Clearing CUDA cache...")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Reset peak memory stats for monitoring
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
freed = self.gpu_manager.force_cleanup()
|
||||
logger.info(f"Cache cleared, freed approximately {freed}MB")
|
||||
|
||||
def _strategy_evict_models(self) -> None:
|
||||
"""Level 2: Evict non-essential models from registry.
|
||||
|
||||
Unloads all models that don't have active references,
|
||||
freeing their VRAM for the current operation.
|
||||
"""
|
||||
if self.model_registry is None:
|
||||
logger.warning("No model registry available for eviction")
|
||||
self._strategy_clear_cache()
|
||||
return
|
||||
|
||||
logger.info("Evicting unused models...")
|
||||
|
||||
# Get list of loaded models
|
||||
loaded = self.model_registry.get_loaded_models()
|
||||
evicted = []
|
||||
|
||||
for model_info in loaded:
|
||||
# Only evict models with no active references
|
||||
if model_info["ref_count"] == 0:
|
||||
model_id = model_info["model_id"]
|
||||
variant = model_info["variant"]
|
||||
logger.info(f"Evicting {model_id}/{variant}")
|
||||
self.model_registry.unload_model(model_id, variant)
|
||||
evicted.append(f"{model_id}/{variant}")
|
||||
|
||||
# Clear cache after eviction
|
||||
self._strategy_clear_cache()
|
||||
|
||||
logger.info(f"Evicted {len(evicted)} model(s): {evicted}")
|
||||
|
||||
def _strategy_request_comfyui_yield(self) -> None:
|
||||
"""Level 3: Request ComfyUI to yield VRAM.
|
||||
|
||||
Uses the coordination protocol to ask ComfyUI to
|
||||
temporarily release GPU memory.
|
||||
"""
|
||||
logger.info("Requesting ComfyUI to yield VRAM...")
|
||||
|
||||
# First, evict our own models
|
||||
self._strategy_evict_models()
|
||||
|
||||
# Calculate how much VRAM we need
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
needed = max(4096, budget.total_mb // 4) # Request at least 4GB or 25% of total
|
||||
|
||||
# Request priority from ComfyUI
|
||||
success = self.gpu_manager.request_priority(needed, timeout=15.0)
|
||||
|
||||
if success:
|
||||
logger.info("ComfyUI yielded VRAM successfully")
|
||||
else:
|
||||
logger.warning("ComfyUI did not yield VRAM within timeout")
|
||||
|
||||
# Final cache clear
|
||||
self._strategy_clear_cache()
|
||||
|
||||
def recover_from_oom(self, level: int = 0) -> bool:
|
||||
"""Manually trigger OOM recovery.
|
||||
|
||||
Args:
|
||||
level: Recovery level to execute (0-2)
|
||||
|
||||
Returns:
|
||||
True if recovery was successful (memory was freed)
|
||||
"""
|
||||
before = self.gpu_manager.get_memory_info()
|
||||
|
||||
self._execute_recovery_strategy(level)
|
||||
|
||||
after = self.gpu_manager.get_memory_info()
|
||||
freed = before["used"] - after["used"]
|
||||
|
||||
logger.info(f"Manual recovery freed {freed}MB")
|
||||
return freed > 0
|
||||
|
||||
def check_memory_for_operation(self, required_mb: int) -> bool:
|
||||
"""Check if there's enough memory for an operation.
|
||||
|
||||
If not enough, attempts recovery strategies.
|
||||
|
||||
Args:
|
||||
required_mb: Memory required in megabytes
|
||||
|
||||
Returns:
|
||||
True if enough memory is available (possibly after recovery)
|
||||
"""
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
|
||||
if budget.available_mb >= required_mb:
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
f"Need {required_mb}MB but only {budget.available_mb}MB available. "
|
||||
"Attempting recovery..."
|
||||
)
|
||||
|
||||
# Try progressively more aggressive recovery
|
||||
for level in range(3):
|
||||
self._execute_recovery_strategy(level)
|
||||
budget = self.gpu_manager.get_available_budget()
|
||||
|
||||
if budget.available_mb >= required_mb:
|
||||
logger.info(f"Recovery successful at level {level + 1}")
|
||||
return True
|
||||
|
||||
logger.error(
|
||||
f"Could not free enough memory. Need {required_mb}MB, "
|
||||
f"have {budget.available_mb}MB"
|
||||
)
|
||||
return False
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get OOM handling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with OOM stats
|
||||
"""
|
||||
return {
|
||||
"oom_count": self._oom_count,
|
||||
"last_oom_time": self._last_oom_time,
|
||||
"max_retries": self.max_retries,
|
||||
"has_registry": self.model_registry is not None,
|
||||
}
|
||||
|
||||
|
||||
# Module-level convenience function
|
||||
def oom_safe(
|
||||
gpu_manager: GPUMemoryManager,
|
||||
model_registry: Optional[Any] = None,
|
||||
max_retries: int = 3,
|
||||
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""Decorator factory for OOM-safe functions.
|
||||
|
||||
Usage:
|
||||
@oom_safe(gpu_manager, model_registry)
|
||||
def generate_audio(...):
|
||||
...
|
||||
|
||||
Args:
|
||||
gpu_manager: GPU memory manager
|
||||
model_registry: Optional model registry for eviction
|
||||
max_retries: Maximum recovery attempts
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
"""
|
||||
handler = OOMHandler(gpu_manager, model_registry, max_retries)
|
||||
return handler.with_oom_recovery
|
||||
84
src/main.py
Normal file
84
src/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""AudioCraft Studio - Main Application Entry Point."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from config.settings import get_settings
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
from src.core.model_registry import ModelRegistry
|
||||
from src.storage.database import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def init_app():
|
||||
"""Initialize application components."""
|
||||
settings = get_settings()
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
# Ensure directories exist
|
||||
settings.ensure_directories()
|
||||
|
||||
# Initialize GPU manager
|
||||
gpu_manager = GPUMemoryManager(
|
||||
comfyui_reserve_gb=settings.comfyui_reserve_gb,
|
||||
safety_buffer_gb=settings.safety_buffer_gb,
|
||||
)
|
||||
|
||||
# Initialize model registry
|
||||
registry = ModelRegistry(
|
||||
config_path=settings.models_config,
|
||||
gpu_manager=gpu_manager,
|
||||
max_cached_models=settings.max_cached_models,
|
||||
idle_timeout_minutes=settings.idle_unload_minutes,
|
||||
)
|
||||
|
||||
# Initialize database
|
||||
db = Database(settings.database_path)
|
||||
await db.connect()
|
||||
|
||||
logger.info("AudioCraft Studio initialized")
|
||||
logger.info(f"GPU Status: {gpu_manager.get_status()}")
|
||||
logger.info(f"Available models: {len(registry.list_models())}")
|
||||
|
||||
return {
|
||||
"settings": settings,
|
||||
"gpu_manager": gpu_manager,
|
||||
"registry": registry,
|
||||
"database": db,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
print("AudioCraft Studio - Starting...")
|
||||
print("Phase 1 core infrastructure is complete.")
|
||||
print("\nTo continue implementation:")
|
||||
print(" - Phase 2: Model adapters (musicgen, audiogen, magnet, style, jasco)")
|
||||
print(" - Phase 3: Services layer (generation, batch, project)")
|
||||
print(" - Phase 4: Gradio UI")
|
||||
print(" - Phase 5: REST API")
|
||||
print(" - Phase 6: Deployment")
|
||||
|
||||
# Quick initialization test
|
||||
async def test_init():
|
||||
components = await init_app()
|
||||
print(f"\nDatabase path: {components['settings'].database_path}")
|
||||
print(f"GPU status: {components['gpu_manager'].get_status()}")
|
||||
await components["database"].close()
|
||||
|
||||
asyncio.run(test_init())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
32
src/models/__init__.py
Normal file
32
src/models/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""AudioCraft model adapters.
|
||||
|
||||
This module contains adapters that wrap AudioCraft's models with a
|
||||
consistent interface for the application.
|
||||
"""
|
||||
|
||||
from src.models.musicgen.adapter import MusicGenAdapter
|
||||
from src.models.audiogen.adapter import AudioGenAdapter
|
||||
from src.models.magnet.adapter import MAGNeTAdapter
|
||||
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
|
||||
from src.models.jasco.adapter import JASCOAdapter
|
||||
|
||||
__all__ = [
|
||||
"MusicGenAdapter",
|
||||
"AudioGenAdapter",
|
||||
"MAGNeTAdapter",
|
||||
"MusicGenStyleAdapter",
|
||||
"JASCOAdapter",
|
||||
]
|
||||
|
||||
|
||||
def register_all_adapters(registry) -> None:
|
||||
"""Register all model adapters with the registry.
|
||||
|
||||
Args:
|
||||
registry: ModelRegistry instance to register adapters with
|
||||
"""
|
||||
registry.register_adapter("musicgen", MusicGenAdapter)
|
||||
registry.register_adapter("audiogen", AudioGenAdapter)
|
||||
registry.register_adapter("magnet", MAGNeTAdapter)
|
||||
registry.register_adapter("musicgen-style", MusicGenStyleAdapter)
|
||||
registry.register_adapter("jasco", JASCOAdapter)
|
||||
5
src/models/audiogen/__init__.py
Normal file
5
src/models/audiogen/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""AudioGen model adapter."""
|
||||
|
||||
from src.models.audiogen.adapter import AudioGenAdapter
|
||||
|
||||
__all__ = ["AudioGenAdapter"]
|
||||
203
src/models/audiogen/adapter.py
Normal file
203
src/models/audiogen/adapter.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""AudioGen model adapter for text-to-sound effects generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioGenAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's AudioGen model.
|
||||
|
||||
Generates sound effects and environmental audio from text descriptions.
|
||||
Optimized for non-musical audio like sound effects, ambiences, and foley.
|
||||
"""
|
||||
|
||||
VARIANTS = {
|
||||
"medium": {
|
||||
"hf_id": "facebook/audiogen-medium",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, variant: str = "medium"):
|
||||
"""Initialize AudioGen adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant (currently only 'medium' available)
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown AudioGen variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "audiogen"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"AudioGen ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text-to-sound effects generation"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 16000 # AudioGen default sample rate
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the AudioGen model."""
|
||||
if self._model is not None:
|
||||
logger.warning(f"AudioGen {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading AudioGen {self._variant} from {self._config['hf_id']}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = AudioGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"AudioGen {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load AudioGen {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load AudioGen: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading AudioGen {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate sound effects from text prompts.
|
||||
|
||||
Args:
|
||||
request: Generation parameters including prompts
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Configure generation
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sound effect(s) with AudioGen "
|
||||
f"(duration={request.duration}s)"
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
with torch.inference_mode():
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters."""
|
||||
return {
|
||||
"duration": 5.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"top_p": 0.0,
|
||||
"cfg_coef": 3.0,
|
||||
}
|
||||
5
src/models/jasco/__init__.py
Normal file
5
src/models/jasco/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""JASCO model adapter."""
|
||||
|
||||
from src.models.jasco.adapter import JASCOAdapter
|
||||
|
||||
__all__ = ["JASCOAdapter"]
|
||||
348
src/models/jasco/adapter.py
Normal file
348
src/models/jasco/adapter.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""JASCO model adapter for chord and drum-conditioned music generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JASCOAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's JASCO model.
|
||||
|
||||
JASCO (Joint Audio and Symbolic Conditioning) enables music generation
|
||||
with control over chord progressions and drum patterns alongside text.
|
||||
"""
|
||||
|
||||
VARIANTS = {
|
||||
"chords-drums-400M": {
|
||||
"hf_id": "facebook/jasco-chords-drums-400M",
|
||||
"vram_mb": 2000,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
},
|
||||
"chords-drums-1B": {
|
||||
"hf_id": "facebook/jasco-chords-drums-1B",
|
||||
"vram_mb": 4000,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
},
|
||||
}
|
||||
|
||||
# Common chord types for validation
|
||||
VALID_CHORD_TYPES = [
|
||||
"maj", "min", "dim", "aug", "7", "maj7", "min7", "dim7",
|
||||
"sus2", "sus4", "add9", "6", "min6", "9", "min9", "maj9",
|
||||
]
|
||||
|
||||
def __init__(self, variant: str = "chords-drums-400M"):
|
||||
"""Initialize JASCO adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant to use
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown JASCO variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "jasco"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"JASCO ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Chord and drum-conditioned music generation"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 32000
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT, ConditioningType.CHORDS, ConditioningType.DRUMS]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the JASCO model."""
|
||||
if self._model is not None:
|
||||
logger.warning(f"JASCO {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading JASCO {self._variant} from {self._config['hf_id']}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import JASCO
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = JASCO.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"JASCO {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load JASCO {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load JASCO: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading JASCO {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@staticmethod
|
||||
def parse_chord_progression(
|
||||
chords: list[dict[str, Any]], duration: float
|
||||
) -> list[tuple[float, float, str]]:
|
||||
"""Parse chord progression from user input format.
|
||||
|
||||
Args:
|
||||
chords: List of chord dictionaries with keys:
|
||||
- time: Start time in seconds
|
||||
- chord: Chord name (e.g., "C", "Am", "G7")
|
||||
duration: Total duration for calculating end times
|
||||
|
||||
Returns:
|
||||
List of (start_time, end_time, chord_name) tuples
|
||||
|
||||
Example input:
|
||||
[
|
||||
{"time": 0.0, "chord": "C"},
|
||||
{"time": 2.0, "chord": "Am"},
|
||||
{"time": 4.0, "chord": "F"},
|
||||
{"time": 6.0, "chord": "G"},
|
||||
]
|
||||
"""
|
||||
if not chords:
|
||||
return []
|
||||
|
||||
# Sort by time
|
||||
sorted_chords = sorted(chords, key=lambda x: x["time"])
|
||||
|
||||
# Build (start, end, chord) tuples
|
||||
result = []
|
||||
for i, chord_info in enumerate(sorted_chords):
|
||||
start = chord_info["time"]
|
||||
# End time is either next chord's start or total duration
|
||||
if i + 1 < len(sorted_chords):
|
||||
end = sorted_chords[i + 1]["time"]
|
||||
else:
|
||||
end = duration
|
||||
result.append((start, end, chord_info["chord"]))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def create_drum_pattern(
|
||||
pattern: str, duration: float, bpm: float = 120.0
|
||||
) -> list[tuple[float, str]]:
|
||||
"""Create drum events from a pattern string.
|
||||
|
||||
Args:
|
||||
pattern: Pattern string (e.g., "kick,snare,kick,snare")
|
||||
or "4/4" for common time signature
|
||||
duration: Total duration in seconds
|
||||
bpm: Beats per minute
|
||||
|
||||
Returns:
|
||||
List of (time, drum_type) tuples
|
||||
"""
|
||||
beat_duration = 60.0 / bpm
|
||||
events = []
|
||||
|
||||
if pattern in ["4/4", "common"]:
|
||||
# Standard 4/4 rock pattern
|
||||
time = 0.0
|
||||
beat = 0
|
||||
while time < duration:
|
||||
if beat % 4 == 0:
|
||||
events.append((time, "kick"))
|
||||
elif beat % 4 == 2:
|
||||
events.append((time, "snare"))
|
||||
if beat % 2 == 0:
|
||||
events.append((time, "hihat"))
|
||||
time += beat_duration / 2
|
||||
beat += 1
|
||||
else:
|
||||
# Parse comma-separated pattern
|
||||
drum_types = pattern.split(",")
|
||||
time = 0.0
|
||||
idx = 0
|
||||
while time < duration:
|
||||
drum = drum_types[idx % len(drum_types)].strip()
|
||||
if drum:
|
||||
events.append((time, drum))
|
||||
time += beat_duration
|
||||
idx += 1
|
||||
|
||||
return events
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate music with chord and drum conditioning.
|
||||
|
||||
Args:
|
||||
request: Generation parameters with optional conditioning:
|
||||
- chords: List of {"time": float, "chord": str} dicts
|
||||
- drums: Drum pattern string or list of (time, drum_type)
|
||||
- bpm: Beats per minute for drum pattern
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Configure generation parameters
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
# Process chord conditioning
|
||||
chords_input = request.conditioning.get("chords")
|
||||
chords_formatted = None
|
||||
if chords_input:
|
||||
if isinstance(chords_input, list) and len(chords_input) > 0:
|
||||
if isinstance(chords_input[0], dict):
|
||||
chords_formatted = self.parse_chord_progression(
|
||||
chords_input, request.duration
|
||||
)
|
||||
else:
|
||||
# Already in (start, end, chord) format
|
||||
chords_formatted = chords_input
|
||||
|
||||
# Process drum conditioning
|
||||
drums_input = request.conditioning.get("drums")
|
||||
bpm = request.conditioning.get("bpm", 120.0)
|
||||
drums_formatted = None
|
||||
if drums_input:
|
||||
if isinstance(drums_input, str):
|
||||
drums_formatted = self.create_drum_pattern(
|
||||
drums_input, request.duration, bpm
|
||||
)
|
||||
else:
|
||||
drums_formatted = drums_input
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sample(s) with JASCO "
|
||||
f"(duration={request.duration}s, chords={chords_formatted is not None}, "
|
||||
f"drums={drums_formatted is not None})"
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
# Build conditioning dict for JASCO
|
||||
conditioning = {}
|
||||
if chords_formatted:
|
||||
conditioning["chords"] = chords_formatted
|
||||
if drums_formatted:
|
||||
conditioning["drums"] = drums_formatted
|
||||
|
||||
if conditioning:
|
||||
audio = self._model.generate(
|
||||
descriptions=request.prompts,
|
||||
**conditioning,
|
||||
)
|
||||
else:
|
||||
# Generate without symbolic conditioning
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
"chords": chords_formatted,
|
||||
"drums": drums_formatted,
|
||||
"bpm": bpm,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters for JASCO."""
|
||||
return {
|
||||
"duration": 10.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"top_p": 0.0,
|
||||
"cfg_coef": 3.0,
|
||||
"bpm": 120.0,
|
||||
}
|
||||
5
src/models/magnet/__init__.py
Normal file
5
src/models/magnet/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""MAGNeT model adapter."""
|
||||
|
||||
from src.models.magnet.adapter import MAGNeTAdapter
|
||||
|
||||
__all__ = ["MAGNeTAdapter"]
|
||||
253
src/models/magnet/adapter.py
Normal file
253
src/models/magnet/adapter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""MAGNeT model adapter for fast non-autoregressive audio generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MAGNeTAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's MAGNeT model.
|
||||
|
||||
MAGNeT (Masked Audio Generation using Non-autoregressive Transformers)
|
||||
provides faster generation than autoregressive models like MusicGen.
|
||||
Supports both music and sound effect generation.
|
||||
"""
|
||||
|
||||
VARIANTS = {
|
||||
"small-10secs": {
|
||||
"hf_id": "facebook/magnet-small-10secs",
|
||||
"vram_mb": 1500,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
"audio_type": "music",
|
||||
},
|
||||
"medium-10secs": {
|
||||
"hf_id": "facebook/magnet-medium-10secs",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
"audio_type": "music",
|
||||
},
|
||||
"small-30secs": {
|
||||
"hf_id": "facebook/magnet-small-30secs",
|
||||
"vram_mb": 1800,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"audio_type": "music",
|
||||
},
|
||||
"medium-30secs": {
|
||||
"hf_id": "facebook/magnet-medium-30secs",
|
||||
"vram_mb": 6000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"audio_type": "music",
|
||||
},
|
||||
"audio-small-10secs": {
|
||||
"hf_id": "facebook/audio-magnet-small",
|
||||
"vram_mb": 1500,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
"audio_type": "sound",
|
||||
},
|
||||
"audio-medium-10secs": {
|
||||
"hf_id": "facebook/audio-magnet-medium",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 10,
|
||||
"channels": 1,
|
||||
"audio_type": "sound",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, variant: str = "medium-10secs"):
|
||||
"""Initialize MAGNeT adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant to use
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown MAGNeT variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "magnet"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"MAGNeT ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
audio_type = self._config.get("audio_type", "music")
|
||||
return f"Fast non-autoregressive {audio_type} generation"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 32000
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the MAGNeT model."""
|
||||
if self._model is not None:
|
||||
logger.warning(f"MAGNeT {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading MAGNeT {self._variant} from {self._config['hf_id']}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import MAGNeT
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"MAGNeT {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load MAGNeT {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load MAGNeT: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading MAGNeT {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate audio from text prompts using MAGNeT.
|
||||
|
||||
MAGNeT uses a non-autoregressive approach with iterative decoding,
|
||||
which is significantly faster than autoregressive models.
|
||||
|
||||
Args:
|
||||
request: Generation parameters including prompts
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Configure generation parameters
|
||||
# MAGNeT has different parameters than MusicGen
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
# MAGNeT-specific parameters
|
||||
decoding_steps=[
|
||||
int(request.conditioning.get("decoding_steps_1", 20)),
|
||||
int(request.conditioning.get("decoding_steps_2", 10)),
|
||||
int(request.conditioning.get("decoding_steps_3", 10)),
|
||||
int(request.conditioning.get("decoding_steps_4", 10)),
|
||||
],
|
||||
span_arrangement=request.conditioning.get("span_arrangement", "nonoverlap"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sample(s) with MAGNeT {self._variant} "
|
||||
f"(duration={request.duration}s)"
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
with torch.inference_mode():
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters for MAGNeT."""
|
||||
return {
|
||||
"duration": 10.0,
|
||||
"temperature": 3.0, # MAGNeT works better with higher temperature
|
||||
"top_k": 0, # Use top_p instead for MAGNeT
|
||||
"top_p": 0.9,
|
||||
"cfg_coef": 3.0,
|
||||
}
|
||||
5
src/models/musicgen/__init__.py
Normal file
5
src/models/musicgen/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""MusicGen model adapter."""
|
||||
|
||||
from src.models.musicgen.adapter import MusicGenAdapter
|
||||
|
||||
__all__ = ["MusicGenAdapter"]
|
||||
290
src/models/musicgen/adapter.py
Normal file
290
src/models/musicgen/adapter.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""MusicGen model adapter for text-to-music generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MusicGenAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's MusicGen model.
|
||||
|
||||
Supports text-to-music generation with optional melody conditioning.
|
||||
Available variants: small, medium, large, melody, and stereo versions.
|
||||
"""
|
||||
|
||||
# Variant configurations
|
||||
VARIANTS = {
|
||||
"small": {
|
||||
"hf_id": "facebook/musicgen-small",
|
||||
"vram_mb": 1500,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"conditioning": [],
|
||||
},
|
||||
"medium": {
|
||||
"hf_id": "facebook/musicgen-medium",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"conditioning": [],
|
||||
},
|
||||
"large": {
|
||||
"hf_id": "facebook/musicgen-large",
|
||||
"vram_mb": 10000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"conditioning": [],
|
||||
},
|
||||
"melody": {
|
||||
"hf_id": "facebook/musicgen-melody",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
"conditioning": [ConditioningType.MELODY],
|
||||
},
|
||||
"stereo-small": {
|
||||
"hf_id": "facebook/musicgen-stereo-small",
|
||||
"vram_mb": 1800,
|
||||
"max_duration": 30,
|
||||
"channels": 2,
|
||||
"conditioning": [],
|
||||
},
|
||||
"stereo-medium": {
|
||||
"hf_id": "facebook/musicgen-stereo-medium",
|
||||
"vram_mb": 6000,
|
||||
"max_duration": 30,
|
||||
"channels": 2,
|
||||
"conditioning": [],
|
||||
},
|
||||
"stereo-large": {
|
||||
"hf_id": "facebook/musicgen-stereo-large",
|
||||
"vram_mb": 12000,
|
||||
"max_duration": 30,
|
||||
"channels": 2,
|
||||
"conditioning": [],
|
||||
},
|
||||
"stereo-melody": {
|
||||
"hf_id": "facebook/musicgen-stereo-melody",
|
||||
"vram_mb": 6000,
|
||||
"max_duration": 30,
|
||||
"channels": 2,
|
||||
"conditioning": [ConditioningType.MELODY],
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, variant: str = "medium"):
|
||||
"""Initialize MusicGen adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant to use (small, medium, large, melody, etc.)
|
||||
|
||||
Raises:
|
||||
ValueError: If variant is not recognized
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown MusicGen variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "musicgen"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"MusicGen ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if "melody" in self._variant:
|
||||
return "Text-to-music with melody conditioning"
|
||||
elif "stereo" in self._variant:
|
||||
return "Stereo text-to-music generation"
|
||||
return "Text-to-music generation"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 32000 # Default MusicGen sample rate
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT] + self._config["conditioning"]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the MusicGen model.
|
||||
|
||||
Args:
|
||||
device: Target device ('cuda', 'cuda:0', 'cpu', etc.)
|
||||
"""
|
||||
if self._model is not None:
|
||||
logger.warning(f"MusicGen {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading MusicGen {self._variant} from {self._config['hf_id']}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"MusicGen {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load MusicGen {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load MusicGen: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading MusicGen {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate music from text prompts.
|
||||
|
||||
Args:
|
||||
request: Generation parameters including prompts
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not loaded
|
||||
ValueError: If request is invalid
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed for reproducibility
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Configure generation parameters
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sample(s) with MusicGen {self._variant} "
|
||||
f"(duration={request.duration}s, temp={request.temperature})"
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
with torch.inference_mode():
|
||||
melody_audio = request.conditioning.get("melody")
|
||||
melody_sr = request.conditioning.get("melody_sr", self.sample_rate)
|
||||
|
||||
if melody_audio is not None and ConditioningType.MELODY in self.supports_conditioning:
|
||||
# Melody-conditioned generation
|
||||
if isinstance(melody_audio, str):
|
||||
# Load from file path
|
||||
import torchaudio
|
||||
melody_tensor, melody_sr = torchaudio.load(melody_audio)
|
||||
melody_tensor = melody_tensor.to(self._device)
|
||||
else:
|
||||
melody_tensor = torch.tensor(melody_audio).to(self._device)
|
||||
|
||||
audio = self._model.generate_with_chroma(
|
||||
descriptions=request.prompts,
|
||||
melody_wavs=melody_tensor.unsqueeze(0) if melody_tensor.dim() == 1 else melody_tensor,
|
||||
melody_sample_rate=melody_sr,
|
||||
)
|
||||
else:
|
||||
# Standard text-to-music generation
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
# audio shape: [batch, channels, samples]
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s, shape={audio.shape}"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters."""
|
||||
return {
|
||||
"duration": 10.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"top_p": 0.0,
|
||||
"cfg_coef": 3.0,
|
||||
}
|
||||
5
src/models/musicgen_style/__init__.py
Normal file
5
src/models/musicgen_style/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""MusicGen Style model adapter."""
|
||||
|
||||
from src.models.musicgen_style.adapter import MusicGenStyleAdapter
|
||||
|
||||
__all__ = ["MusicGenStyleAdapter"]
|
||||
277
src/models/musicgen_style/adapter.py
Normal file
277
src/models/musicgen_style/adapter.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""MusicGen Style model adapter for style-conditioned music generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MusicGenStyleAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's MusicGen Style model.
|
||||
|
||||
Generates music conditioned on both text and a style reference audio.
|
||||
Extracts style features from the reference and applies them to new generations.
|
||||
"""
|
||||
|
||||
VARIANTS = {
|
||||
"medium": {
|
||||
"hf_id": "facebook/musicgen-style",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, variant: str = "medium"):
|
||||
"""Initialize MusicGen Style adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant (currently only 'medium' available)
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown MusicGen Style variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "musicgen-style"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"MusicGen Style ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Style-conditioned music generation from reference audio"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 32000
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT, ConditioningType.STYLE]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the MusicGen Style model."""
|
||||
if self._model is not None:
|
||||
logger.warning(f"MusicGen Style {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading MusicGen Style {self._variant}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"MusicGen Style {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading MusicGen Style {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _load_style_audio(
|
||||
self, style_input: Any, target_sr: int
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Load and prepare style reference audio.
|
||||
|
||||
Args:
|
||||
style_input: File path, tensor, or numpy array
|
||||
target_sr: Target sample rate
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_tensor, sample_rate)
|
||||
"""
|
||||
if isinstance(style_input, str):
|
||||
# Load from file
|
||||
audio, sr = torchaudio.load(style_input)
|
||||
if sr != target_sr:
|
||||
audio = torchaudio.functional.resample(audio, sr, target_sr)
|
||||
return audio.to(self._device), target_sr
|
||||
elif isinstance(style_input, torch.Tensor):
|
||||
return style_input.to(self._device), target_sr
|
||||
else:
|
||||
# Assume numpy array
|
||||
return torch.tensor(style_input).to(self._device), target_sr
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate music conditioned on text and style reference.
|
||||
|
||||
Args:
|
||||
request: Generation parameters including prompts and style conditioning
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
|
||||
Note:
|
||||
Style conditioning requires 'style' in request.conditioning with either:
|
||||
- File path to audio
|
||||
- Audio tensor
|
||||
- Numpy array
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Get style conditioning parameters
|
||||
style_audio = request.conditioning.get("style")
|
||||
eval_q = request.conditioning.get("eval_q", 3)
|
||||
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
|
||||
|
||||
# Configure generation parameters
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
|
||||
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
if style_audio is not None:
|
||||
# Load style reference
|
||||
style_tensor, style_sr = self._load_style_audio(
|
||||
style_audio, self.sample_rate
|
||||
)
|
||||
|
||||
# Ensure proper shape [batch, channels, samples]
|
||||
if style_tensor.dim() == 1:
|
||||
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
|
||||
elif style_tensor.dim() == 2:
|
||||
style_tensor = style_tensor.unsqueeze(0)
|
||||
|
||||
# Set style conditioner parameters
|
||||
if hasattr(self._model, 'set_style_conditioner_params'):
|
||||
self._model.set_style_conditioner_params(
|
||||
eval_q=eval_q,
|
||||
excerpt_length=excerpt_length,
|
||||
)
|
||||
|
||||
# Generate with style conditioning
|
||||
# Expand style to match number of prompts if needed
|
||||
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
|
||||
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
|
||||
|
||||
audio = self._model.generate_with_chroma(
|
||||
descriptions=request.prompts,
|
||||
melody_wavs=style_tensor,
|
||||
melody_sample_rate=style_sr,
|
||||
)
|
||||
else:
|
||||
# Generate without style (falls back to standard MusicGen behavior)
|
||||
logger.warning(
|
||||
"No style reference provided, generating without style conditioning"
|
||||
)
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
"style_conditioned": style_audio is not None,
|
||||
"eval_q": eval_q,
|
||||
"excerpt_length": excerpt_length,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters for MusicGen Style."""
|
||||
return {
|
||||
"duration": 10.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"top_p": 0.0,
|
||||
"cfg_coef": 3.0,
|
||||
"eval_q": 3,
|
||||
"excerpt_length": 3.0,
|
||||
}
|
||||
13
src/services/__init__.py
Normal file
13
src/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Services layer for AudioCraft Studio."""
|
||||
|
||||
from src.services.generation_service import GenerationService
|
||||
from src.services.batch_processor import BatchProcessor, GenerationJob, JobStatus
|
||||
from src.services.project_service import ProjectService
|
||||
|
||||
__all__ = [
|
||||
"GenerationService",
|
||||
"BatchProcessor",
|
||||
"GenerationJob",
|
||||
"JobStatus",
|
||||
"ProjectService",
|
||||
]
|
||||
397
src/services/batch_processor.py
Normal file
397
src/services/batch_processor.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Batch processor for queued audio generation jobs."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
"""Status of a generation job."""
|
||||
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationJob:
|
||||
"""A queued generation job."""
|
||||
|
||||
id: str
|
||||
model_id: str
|
||||
variant: Optional[str]
|
||||
prompts: list[str]
|
||||
parameters: dict[str, Any]
|
||||
conditioning: dict[str, Any]
|
||||
project_id: Optional[str]
|
||||
preset_used: Optional[str]
|
||||
tags: list[str]
|
||||
|
||||
# Status tracking
|
||||
status: JobStatus = JobStatus.PENDING
|
||||
progress: float = 0.0
|
||||
progress_message: str = ""
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
# Results
|
||||
result_id: Optional[str] = None # Generation ID if completed
|
||||
audio_path: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model_id: str,
|
||||
variant: Optional[str],
|
||||
prompts: list[str],
|
||||
duration: float = 10.0,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 250,
|
||||
top_p: float = 0.0,
|
||||
cfg_coef: float = 3.0,
|
||||
seed: Optional[int] = None,
|
||||
conditioning: Optional[dict[str, Any]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
preset_used: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
) -> "GenerationJob":
|
||||
"""Create a new generation job."""
|
||||
return cls(
|
||||
id=f"job_{uuid.uuid4().hex[:12]}",
|
||||
model_id=model_id,
|
||||
variant=variant,
|
||||
prompts=prompts,
|
||||
parameters={
|
||||
"duration": duration,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"cfg_coef": cfg_coef,
|
||||
"seed": seed,
|
||||
},
|
||||
conditioning=conditioning or {},
|
||||
project_id=project_id,
|
||||
preset_used=preset_used,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert job to dictionary for API responses."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"model_id": self.model_id,
|
||||
"variant": self.variant,
|
||||
"prompts": self.prompts,
|
||||
"parameters": self.parameters,
|
||||
"status": self.status.value,
|
||||
"progress": self.progress,
|
||||
"progress_message": self.progress_message,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"result_id": self.result_id,
|
||||
"audio_path": self.audio_path,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Manages a queue of generation jobs.
|
||||
|
||||
Features:
|
||||
- Async job queue with configurable concurrency
|
||||
- Progress tracking and callbacks
|
||||
- Job cancellation
|
||||
- Priority support (future enhancement)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generation_service: Any, # Avoid circular import
|
||||
max_queue_size: int = 100,
|
||||
max_concurrent: int = 1, # GPU operations should be serialized
|
||||
):
|
||||
"""Initialize batch processor.
|
||||
|
||||
Args:
|
||||
generation_service: GenerationService instance
|
||||
max_queue_size: Maximum jobs in queue
|
||||
max_concurrent: Maximum concurrent generations (usually 1 for GPU)
|
||||
"""
|
||||
self.generation_service = generation_service
|
||||
self.max_queue_size = max_queue_size
|
||||
self.max_concurrent = max_concurrent
|
||||
|
||||
# Job tracking
|
||||
self._jobs: dict[str, GenerationJob] = {}
|
||||
self._queue: asyncio.Queue[str] = asyncio.Queue(maxsize=max_queue_size)
|
||||
|
||||
# Processing control
|
||||
self._workers: list[asyncio.Task] = []
|
||||
self._running = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Callbacks
|
||||
self._on_job_complete: list[Callable[[GenerationJob], None]] = []
|
||||
self._on_job_failed: list[Callable[[GenerationJob], None]] = []
|
||||
self._on_progress: list[Callable[[GenerationJob], None]] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the batch processor workers."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Start worker tasks
|
||||
for i in range(self.max_concurrent):
|
||||
worker = asyncio.create_task(self._worker_loop(i))
|
||||
self._workers.append(worker)
|
||||
|
||||
logger.info(f"Batch processor started with {self.max_concurrent} worker(s)")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the batch processor and wait for pending jobs."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
# Cancel workers
|
||||
for worker in self._workers:
|
||||
worker.cancel()
|
||||
|
||||
# Wait for workers to finish
|
||||
await asyncio.gather(*self._workers, return_exceptions=True)
|
||||
self._workers.clear()
|
||||
|
||||
logger.info("Batch processor stopped")
|
||||
|
||||
async def submit(self, job: GenerationJob) -> GenerationJob:
|
||||
"""Submit a job to the queue.
|
||||
|
||||
Args:
|
||||
job: Job to submit
|
||||
|
||||
Returns:
|
||||
The submitted job with ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If queue is full
|
||||
"""
|
||||
async with self._lock:
|
||||
if len(self._jobs) >= self.max_queue_size:
|
||||
raise RuntimeError(
|
||||
f"Queue full (max {self.max_queue_size} jobs). "
|
||||
"Please wait for jobs to complete."
|
||||
)
|
||||
|
||||
self._jobs[job.id] = job
|
||||
await self._queue.put(job.id)
|
||||
|
||||
logger.info(f"Job {job.id} submitted to queue (position: {self._queue.qsize()})")
|
||||
return job
|
||||
|
||||
async def cancel(self, job_id: str) -> bool:
|
||||
"""Cancel a pending job.
|
||||
|
||||
Args:
|
||||
job_id: ID of job to cancel
|
||||
|
||||
Returns:
|
||||
True if job was cancelled, False if not found or already processing
|
||||
"""
|
||||
async with self._lock:
|
||||
job = self._jobs.get(job_id)
|
||||
if job is None:
|
||||
return False
|
||||
|
||||
if job.status != JobStatus.PENDING:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
||||
return False
|
||||
|
||||
job.status = JobStatus.CANCELLED
|
||||
job.completed_at = datetime.utcnow()
|
||||
|
||||
logger.info(f"Job {job_id} cancelled")
|
||||
return True
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[GenerationJob]:
|
||||
"""Get a job by ID."""
|
||||
return self._jobs.get(job_id)
|
||||
|
||||
def get_queue_status(self) -> dict[str, Any]:
|
||||
"""Get current queue status."""
|
||||
jobs_by_status = {}
|
||||
for job in self._jobs.values():
|
||||
status = job.status.value
|
||||
jobs_by_status[status] = jobs_by_status.get(status, 0) + 1
|
||||
|
||||
return {
|
||||
"queue_size": self._queue.qsize(),
|
||||
"total_jobs": len(self._jobs),
|
||||
"jobs_by_status": jobs_by_status,
|
||||
"running": self._running,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
}
|
||||
|
||||
def list_jobs(
|
||||
self,
|
||||
status: Optional[JobStatus] = None,
|
||||
limit: int = 50,
|
||||
) -> list[GenerationJob]:
|
||||
"""List jobs with optional status filter.
|
||||
|
||||
Args:
|
||||
status: Filter by status
|
||||
limit: Maximum jobs to return
|
||||
|
||||
Returns:
|
||||
List of jobs ordered by creation time (newest first)
|
||||
"""
|
||||
jobs = list(self._jobs.values())
|
||||
|
||||
if status:
|
||||
jobs = [j for j in jobs if j.status == status]
|
||||
|
||||
# Sort by created_at descending
|
||||
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
||||
|
||||
return jobs[:limit]
|
||||
|
||||
def cleanup_completed(self, max_age_hours: float = 24.0) -> int:
|
||||
"""Remove old completed/failed jobs from memory.
|
||||
|
||||
Args:
|
||||
max_age_hours: Remove jobs older than this
|
||||
|
||||
Returns:
|
||||
Number of jobs removed
|
||||
"""
|
||||
cutoff = datetime.utcnow().timestamp() - (max_age_hours * 3600)
|
||||
removed = 0
|
||||
|
||||
for job_id, job in list(self._jobs.items()):
|
||||
if job.status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
|
||||
if job.completed_at and job.completed_at.timestamp() < cutoff:
|
||||
del self._jobs[job_id]
|
||||
removed += 1
|
||||
|
||||
if removed:
|
||||
logger.info(f"Cleaned up {removed} old jobs")
|
||||
|
||||
return removed
|
||||
|
||||
async def _worker_loop(self, worker_id: int) -> None:
|
||||
"""Worker loop that processes jobs from queue."""
|
||||
logger.debug(f"Worker {worker_id} started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Wait for job with timeout
|
||||
try:
|
||||
job_id = await asyncio.wait_for(
|
||||
self._queue.get(), timeout=1.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
job = self._jobs.get(job_id)
|
||||
if job is None or job.status == JobStatus.CANCELLED:
|
||||
continue
|
||||
|
||||
await self._process_job(job)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Worker {worker_id} error: {e}")
|
||||
|
||||
logger.debug(f"Worker {worker_id} stopped")
|
||||
|
||||
async def _process_job(self, job: GenerationJob) -> None:
|
||||
"""Process a single generation job."""
|
||||
logger.info(f"Processing job {job.id}: {job.model_id}/{job.variant}")
|
||||
|
||||
job.status = JobStatus.PROCESSING
|
||||
job.started_at = datetime.utcnow()
|
||||
|
||||
def progress_callback(progress: float, message: str) -> None:
|
||||
job.progress = progress
|
||||
job.progress_message = message
|
||||
for callback in self._on_progress:
|
||||
try:
|
||||
callback(job)
|
||||
except Exception as e:
|
||||
logger.error(f"Progress callback error: {e}")
|
||||
|
||||
try:
|
||||
result, generation = await self.generation_service.generate(
|
||||
model_id=job.model_id,
|
||||
variant=job.variant,
|
||||
prompts=job.prompts,
|
||||
duration=job.parameters.get("duration", 10.0),
|
||||
temperature=job.parameters.get("temperature", 1.0),
|
||||
top_k=job.parameters.get("top_k", 250),
|
||||
top_p=job.parameters.get("top_p", 0.0),
|
||||
cfg_coef=job.parameters.get("cfg_coef", 3.0),
|
||||
seed=job.parameters.get("seed"),
|
||||
conditioning=job.conditioning,
|
||||
project_id=job.project_id,
|
||||
preset_used=job.preset_used,
|
||||
tags=job.tags,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
job.result_id = generation.id
|
||||
job.audio_path = generation.audio_path
|
||||
job.completed_at = datetime.utcnow()
|
||||
job.progress = 1.0
|
||||
job.progress_message = "Complete"
|
||||
|
||||
logger.info(f"Job {job.id} completed: {generation.id}")
|
||||
|
||||
for callback in self._on_job_complete:
|
||||
try:
|
||||
callback(job)
|
||||
except Exception as e:
|
||||
logger.error(f"Completion callback error: {e}")
|
||||
|
||||
except Exception as e:
|
||||
job.status = JobStatus.FAILED
|
||||
job.error = str(e)
|
||||
job.completed_at = datetime.utcnow()
|
||||
|
||||
logger.error(f"Job {job.id} failed: {e}")
|
||||
|
||||
for callback in self._on_job_failed:
|
||||
try:
|
||||
callback(job)
|
||||
except Exception as e2:
|
||||
logger.error(f"Failure callback error: {e2}")
|
||||
|
||||
# Callback registration
|
||||
|
||||
def on_job_complete(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||
"""Register callback for job completion."""
|
||||
self._on_job_complete.append(callback)
|
||||
|
||||
def on_job_failed(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||
"""Register callback for job failure."""
|
||||
self._on_job_failed.append(callback)
|
||||
|
||||
def on_progress(self, callback: Callable[[GenerationJob], None]) -> None:
|
||||
"""Register callback for progress updates."""
|
||||
self._on_progress.append(callback)
|
||||
322
src/services/generation_service.py
Normal file
322
src/services/generation_service.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Generation service for orchestrating audio generation."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from src.core.base_model import GenerationRequest, GenerationResult
|
||||
from src.core.gpu_manager import GPUMemoryManager
|
||||
from src.core.model_registry import ModelRegistry
|
||||
from src.core.oom_handler import OOMHandler
|
||||
from src.storage.database import Database, Generation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerationService:
|
||||
"""Orchestrates audio generation across all models.
|
||||
|
||||
Handles:
|
||||
- Model selection and loading
|
||||
- Generation execution with OOM recovery
|
||||
- Result saving and database recording
|
||||
- Progress callbacks for UI updates
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
registry: ModelRegistry,
|
||||
gpu_manager: GPUMemoryManager,
|
||||
database: Database,
|
||||
output_dir: Path,
|
||||
):
|
||||
"""Initialize generation service.
|
||||
|
||||
Args:
|
||||
registry: Model registry for model access
|
||||
gpu_manager: GPU memory manager
|
||||
database: Database for storing generation records
|
||||
output_dir: Directory for saving generated audio
|
||||
"""
|
||||
self.registry = registry
|
||||
self.gpu_manager = gpu_manager
|
||||
self.database = database
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# OOM handler
|
||||
self.oom_handler = OOMHandler(gpu_manager, registry)
|
||||
|
||||
# Statistics
|
||||
self._generation_count = 0
|
||||
self._total_duration_generated = 0.0
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
model_id: str,
|
||||
variant: Optional[str],
|
||||
prompts: list[str],
|
||||
duration: float = 10.0,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = 250,
|
||||
top_p: float = 0.0,
|
||||
cfg_coef: float = 3.0,
|
||||
seed: Optional[int] = None,
|
||||
conditioning: Optional[dict[str, Any]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
preset_used: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> tuple[GenerationResult, Generation]:
|
||||
"""Generate audio and save to database.
|
||||
|
||||
Args:
|
||||
model_id: Model family to use
|
||||
variant: Model variant (None for default)
|
||||
prompts: Text prompts for generation
|
||||
duration: Target duration in seconds
|
||||
temperature: Sampling temperature
|
||||
top_k: Top-k sampling parameter
|
||||
top_p: Nucleus sampling parameter
|
||||
cfg_coef: Classifier-free guidance coefficient
|
||||
seed: Random seed for reproducibility
|
||||
conditioning: Optional conditioning data (melody, style, chords, etc.)
|
||||
project_id: Optional project to associate with
|
||||
preset_used: Name of preset used (for metadata)
|
||||
tags: Optional tags for organization
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
Tuple of (GenerationResult, Generation database record)
|
||||
|
||||
Raises:
|
||||
ValueError: If model not found or parameters invalid
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Report progress
|
||||
if progress_callback:
|
||||
progress_callback(0.0, "Preparing generation...")
|
||||
|
||||
# Build generation request
|
||||
request = GenerationRequest(
|
||||
prompts=prompts,
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=seed,
|
||||
conditioning=conditioning or {},
|
||||
)
|
||||
|
||||
# Get model configuration
|
||||
family_config, variant_config = self.registry.get_model_config(model_id, variant)
|
||||
actual_variant = variant or family_config.default_variant
|
||||
|
||||
# Check VRAM availability
|
||||
if progress_callback:
|
||||
progress_callback(0.1, "Checking GPU memory...")
|
||||
|
||||
can_load, reason = self.gpu_manager.can_load_model(variant_config.vram_mb)
|
||||
if not can_load:
|
||||
# Try OOM recovery
|
||||
if not self.oom_handler.check_memory_for_operation(variant_config.vram_mb):
|
||||
raise RuntimeError(f"Insufficient GPU memory: {reason}")
|
||||
|
||||
# Generate with OOM recovery wrapper
|
||||
if progress_callback:
|
||||
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
|
||||
|
||||
@self.oom_handler.with_oom_recovery
|
||||
def do_generation() -> GenerationResult:
|
||||
with self.registry.get_model(model_id, actual_variant) as model:
|
||||
if progress_callback:
|
||||
progress_callback(0.4, "Generating audio...")
|
||||
return model.generate(request)
|
||||
|
||||
result = do_generation()
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0.8, "Saving audio...")
|
||||
|
||||
# Save audio file
|
||||
audio_path = self._save_audio(result)
|
||||
|
||||
# Create database record
|
||||
generation = Generation.create(
|
||||
model=model_id,
|
||||
variant=actual_variant,
|
||||
prompt=prompts[0] if len(prompts) == 1 else "\n".join(prompts),
|
||||
parameters={
|
||||
"duration": duration,
|
||||
"temperature": temperature,
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"cfg_coef": cfg_coef,
|
||||
},
|
||||
project_id=project_id,
|
||||
preset_used=preset_used,
|
||||
conditioning=conditioning,
|
||||
audio_path=str(audio_path),
|
||||
duration_seconds=result.duration,
|
||||
sample_rate=result.sample_rate,
|
||||
tags=tags or [],
|
||||
seed=result.seed,
|
||||
)
|
||||
|
||||
# Save to database
|
||||
await self.database.create_generation(generation)
|
||||
|
||||
# Update statistics
|
||||
self._generation_count += 1
|
||||
self._total_duration_generated += result.duration
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"Generation complete: {model_id}/{actual_variant}, "
|
||||
f"duration={result.duration:.1f}s, elapsed={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(1.0, "Complete!")
|
||||
|
||||
return result, generation
|
||||
|
||||
def _save_audio(self, result: GenerationResult) -> Path:
|
||||
"""Save generated audio to file.
|
||||
|
||||
Args:
|
||||
result: Generation result with audio tensor
|
||||
|
||||
Returns:
|
||||
Path to saved audio file
|
||||
"""
|
||||
# Generate unique filename
|
||||
timestamp = int(time.time() * 1000)
|
||||
filename = f"{result.model_id}_{result.variant}_{timestamp}.wav"
|
||||
filepath = self.output_dir / filename
|
||||
|
||||
# Convert tensor to numpy and save
|
||||
audio = result.audio.numpy()
|
||||
|
||||
# Handle batch dimension - save first sample if batched
|
||||
if audio.ndim == 3:
|
||||
audio = audio[0] # [channels, samples]
|
||||
|
||||
# Transpose to [samples, channels] for soundfile
|
||||
if audio.ndim == 2:
|
||||
audio = audio.T
|
||||
|
||||
sf.write(filepath, audio, result.sample_rate)
|
||||
|
||||
logger.debug(f"Saved audio to {filepath}")
|
||||
return filepath
|
||||
|
||||
async def regenerate(
|
||||
self,
|
||||
generation_id: str,
|
||||
new_seed: Optional[int] = None,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> tuple[GenerationResult, Generation]:
|
||||
"""Regenerate audio using parameters from existing generation.
|
||||
|
||||
Args:
|
||||
generation_id: ID of generation to regenerate
|
||||
new_seed: Optional new seed (uses original if None)
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
Tuple of (GenerationResult, new Generation record)
|
||||
|
||||
Raises:
|
||||
ValueError: If generation not found
|
||||
"""
|
||||
# Load original generation
|
||||
original = await self.database.get_generation(generation_id)
|
||||
if original is None:
|
||||
raise ValueError(f"Generation not found: {generation_id}")
|
||||
|
||||
# Parse prompts
|
||||
prompts = original.prompt.split("\n") if "\n" in original.prompt else [original.prompt]
|
||||
|
||||
# Regenerate with same or new seed
|
||||
return await self.generate(
|
||||
model_id=original.model,
|
||||
variant=original.variant,
|
||||
prompts=prompts,
|
||||
duration=original.parameters.get("duration", 10.0),
|
||||
temperature=original.parameters.get("temperature", 1.0),
|
||||
top_k=original.parameters.get("top_k", 250),
|
||||
top_p=original.parameters.get("top_p", 0.0),
|
||||
cfg_coef=original.parameters.get("cfg_coef", 3.0),
|
||||
seed=new_seed if new_seed is not None else original.seed,
|
||||
conditioning=original.conditioning,
|
||||
project_id=original.project_id,
|
||||
preset_used=original.preset_used,
|
||||
tags=original.tags,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get generation statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with generation stats
|
||||
"""
|
||||
return {
|
||||
"generation_count": self._generation_count,
|
||||
"total_duration_generated": self._total_duration_generated,
|
||||
"oom_stats": self.oom_handler.get_stats(),
|
||||
}
|
||||
|
||||
def estimate_generation_time(
|
||||
self, model_id: str, variant: Optional[str], duration: float
|
||||
) -> float:
|
||||
"""Estimate generation time for given parameters.
|
||||
|
||||
Args:
|
||||
model_id: Model family
|
||||
variant: Model variant
|
||||
duration: Target audio duration
|
||||
|
||||
Returns:
|
||||
Estimated generation time in seconds
|
||||
"""
|
||||
# Rough estimates based on model type and RTX 4090
|
||||
# These are approximations and vary based on many factors
|
||||
estimates = {
|
||||
"musicgen": {
|
||||
"small": 0.8, # seconds per second of audio
|
||||
"medium": 1.5,
|
||||
"large": 3.0,
|
||||
"melody": 1.8,
|
||||
},
|
||||
"audiogen": {
|
||||
"medium": 1.5,
|
||||
},
|
||||
"magnet": {
|
||||
"small-10secs": 0.3, # Non-autoregressive is faster
|
||||
"medium-10secs": 0.5,
|
||||
"small-30secs": 0.3,
|
||||
"medium-30secs": 0.5,
|
||||
},
|
||||
"musicgen-style": {
|
||||
"medium": 1.8,
|
||||
},
|
||||
"jasco": {
|
||||
"chords-drums-400M": 1.0,
|
||||
"chords-drums-1B": 1.5,
|
||||
},
|
||||
}
|
||||
|
||||
family_config, _ = self.registry.get_model_config(model_id, variant)
|
||||
actual_variant = variant or family_config.default_variant
|
||||
|
||||
ratio = estimates.get(model_id, {}).get(actual_variant, 2.0)
|
||||
return duration * ratio + 5.0 # Add 5s for model loading overhead
|
||||
395
src/services/project_service.py
Normal file
395
src/services/project_service.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""Project service for managing projects and generations."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.storage.database import Database, Generation, Project, Preset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProjectService:
|
||||
"""Service for managing projects, generations, and presets.
|
||||
|
||||
Provides a high-level API for project organization and
|
||||
generation history management.
|
||||
"""
|
||||
|
||||
def __init__(self, database: Database, output_dir: Path):
|
||||
"""Initialize project service.
|
||||
|
||||
Args:
|
||||
database: Database instance
|
||||
output_dir: Directory where audio files are stored
|
||||
"""
|
||||
self.database = database
|
||||
self.output_dir = Path(output_dir)
|
||||
|
||||
# Project Operations
|
||||
|
||||
async def create_project(
|
||||
self, name: str, description: str = ""
|
||||
) -> Project:
|
||||
"""Create a new project.
|
||||
|
||||
Args:
|
||||
name: Project name
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
Created project
|
||||
"""
|
||||
project = Project.create(name, description)
|
||||
await self.database.create_project(project)
|
||||
logger.info(f"Created project: {project.id} ({name})")
|
||||
return project
|
||||
|
||||
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||
"""Get a project by ID."""
|
||||
return await self.database.get_project(project_id)
|
||||
|
||||
async def list_projects(
|
||||
self, limit: int = 100, offset: int = 0
|
||||
) -> list[Project]:
|
||||
"""List all projects."""
|
||||
return await self.database.list_projects(limit, offset)
|
||||
|
||||
async def update_project(
|
||||
self,
|
||||
project_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
) -> Optional[Project]:
|
||||
"""Update a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
name: New name (None to keep current)
|
||||
description: New description (None to keep current)
|
||||
|
||||
Returns:
|
||||
Updated project, or None if not found
|
||||
"""
|
||||
project = await self.database.get_project(project_id)
|
||||
if project is None:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
project.name = name
|
||||
if description is not None:
|
||||
project.description = description
|
||||
|
||||
await self.database.update_project(project)
|
||||
logger.info(f"Updated project: {project_id}")
|
||||
return project
|
||||
|
||||
async def delete_project(
|
||||
self, project_id: str, delete_files: bool = False
|
||||
) -> bool:
|
||||
"""Delete a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
delete_files: If True, also delete associated audio files
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
if delete_files:
|
||||
# Get all generations and delete their files
|
||||
generations = await self.database.list_generations(project_id=project_id)
|
||||
for gen in generations:
|
||||
if gen.audio_path:
|
||||
try:
|
||||
Path(gen.audio_path).unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {gen.audio_path}: {e}")
|
||||
|
||||
result = await self.database.delete_project(project_id)
|
||||
if result:
|
||||
logger.info(f"Deleted project: {project_id}")
|
||||
return result
|
||||
|
||||
async def get_project_stats(self, project_id: str) -> dict[str, Any]:
|
||||
"""Get statistics for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
|
||||
Returns:
|
||||
Dictionary with project statistics
|
||||
"""
|
||||
generations = await self.database.list_generations(
|
||||
project_id=project_id, limit=10000
|
||||
)
|
||||
|
||||
total_duration = sum(g.duration_seconds or 0 for g in generations)
|
||||
models_used = {}
|
||||
for gen in generations:
|
||||
key = f"{gen.model}/{gen.variant}"
|
||||
models_used[key] = models_used.get(key, 0) + 1
|
||||
|
||||
return {
|
||||
"generation_count": len(generations),
|
||||
"total_duration_seconds": total_duration,
|
||||
"models_used": models_used,
|
||||
}
|
||||
|
||||
# Generation Operations
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
"""Get a generation by ID."""
|
||||
return await self.database.get_generation(generation_id)
|
||||
|
||||
async def list_generations(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Generation]:
|
||||
"""List generations with optional filters.
|
||||
|
||||
Args:
|
||||
project_id: Filter by project
|
||||
model: Filter by model family
|
||||
search: Search in prompts, names, and tags
|
||||
limit: Maximum results
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
List of generations
|
||||
"""
|
||||
return await self.database.list_generations(
|
||||
project_id=project_id,
|
||||
model=model,
|
||||
search=search,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
async def update_generation(
|
||||
self,
|
||||
generation_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
notes: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
) -> Optional[Generation]:
|
||||
"""Update a generation's metadata.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
name: New name
|
||||
tags: New tags (replaces existing)
|
||||
notes: New notes
|
||||
project_id: Move to different project
|
||||
|
||||
Returns:
|
||||
Updated generation, or None if not found
|
||||
"""
|
||||
generation = await self.database.get_generation(generation_id)
|
||||
if generation is None:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
generation.name = name
|
||||
if tags is not None:
|
||||
generation.tags = tags
|
||||
if notes is not None:
|
||||
generation.notes = notes
|
||||
if project_id is not None:
|
||||
generation.project_id = project_id
|
||||
|
||||
await self.database.update_generation(generation)
|
||||
logger.info(f"Updated generation: {generation_id}")
|
||||
return generation
|
||||
|
||||
async def delete_generation(
|
||||
self, generation_id: str, delete_file: bool = True
|
||||
) -> bool:
|
||||
"""Delete a generation.
|
||||
|
||||
Args:
|
||||
generation_id: Generation ID
|
||||
delete_file: If True, also delete audio file
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
if delete_file:
|
||||
generation = await self.database.get_generation(generation_id)
|
||||
if generation and generation.audio_path:
|
||||
try:
|
||||
Path(generation.audio_path).unlink(missing_ok=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete audio file: {e}")
|
||||
|
||||
result = await self.database.delete_generation(generation_id)
|
||||
if result:
|
||||
logger.info(f"Deleted generation: {generation_id}")
|
||||
return result
|
||||
|
||||
async def move_generations_to_project(
|
||||
self, generation_ids: list[str], project_id: Optional[str]
|
||||
) -> int:
|
||||
"""Move multiple generations to a project.
|
||||
|
||||
Args:
|
||||
generation_ids: List of generation IDs
|
||||
project_id: Target project ID (None to unlink)
|
||||
|
||||
Returns:
|
||||
Number of generations moved
|
||||
"""
|
||||
moved = 0
|
||||
for gen_id in generation_ids:
|
||||
result = await self.update_generation(gen_id, project_id=project_id)
|
||||
if result:
|
||||
moved += 1
|
||||
|
||||
logger.info(f"Moved {moved} generations to project {project_id}")
|
||||
return moved
|
||||
|
||||
# Preset Operations
|
||||
|
||||
async def create_preset(
|
||||
self,
|
||||
model: str,
|
||||
name: str,
|
||||
parameters: dict[str, Any],
|
||||
description: str = "",
|
||||
) -> Preset:
|
||||
"""Create a custom preset.
|
||||
|
||||
Args:
|
||||
model: Model family this preset is for
|
||||
name: Preset name
|
||||
parameters: Generation parameters
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
Created preset
|
||||
"""
|
||||
preset = Preset.create(model, name, parameters, description)
|
||||
await self.database.create_preset(preset)
|
||||
logger.info(f"Created preset: {preset.id} ({name}) for {model}")
|
||||
return preset
|
||||
|
||||
async def list_presets(
|
||||
self, model: Optional[str] = None, include_builtin: bool = True
|
||||
) -> list[Preset]:
|
||||
"""List presets with optional model filter.
|
||||
|
||||
Args:
|
||||
model: Filter by model family
|
||||
include_builtin: Include built-in presets
|
||||
|
||||
Returns:
|
||||
List of presets
|
||||
"""
|
||||
return await self.database.list_presets(model, include_builtin)
|
||||
|
||||
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
||||
"""Get a preset by ID."""
|
||||
return await self.database.get_preset(preset_id)
|
||||
|
||||
async def delete_preset(self, preset_id: str) -> bool:
|
||||
"""Delete a custom preset.
|
||||
|
||||
Note: Built-in presets cannot be deleted.
|
||||
|
||||
Args:
|
||||
preset_id: Preset ID
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
result = await self.database.delete_preset(preset_id)
|
||||
if result:
|
||||
logger.info(f"Deleted preset: {preset_id}")
|
||||
return result
|
||||
|
||||
# Export Operations
|
||||
|
||||
async def export_project(
|
||||
self, project_id: str, output_path: Path, include_metadata: bool = True
|
||||
) -> Path:
|
||||
"""Export a project as a ZIP archive.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
output_path: Output ZIP file path
|
||||
include_metadata: Include JSON metadata file
|
||||
|
||||
Returns:
|
||||
Path to created ZIP file
|
||||
"""
|
||||
import json
|
||||
import tempfile
|
||||
import zipfile
|
||||
|
||||
project = await self.database.get_project(project_id)
|
||||
if project is None:
|
||||
raise ValueError(f"Project not found: {project_id}")
|
||||
|
||||
generations = await self.database.list_generations(
|
||||
project_id=project_id, limit=10000
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmppath = Path(tmpdir)
|
||||
|
||||
# Copy audio files
|
||||
for gen in generations:
|
||||
if gen.audio_path and Path(gen.audio_path).exists():
|
||||
src = Path(gen.audio_path)
|
||||
dst = tmppath / src.name
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Create metadata file
|
||||
if include_metadata:
|
||||
metadata = {
|
||||
"project": {
|
||||
"id": project.id,
|
||||
"name": project.name,
|
||||
"description": project.description,
|
||||
"created_at": project.created_at.isoformat(),
|
||||
},
|
||||
"generations": [
|
||||
{
|
||||
"id": g.id,
|
||||
"model": g.model,
|
||||
"variant": g.variant,
|
||||
"prompt": g.prompt,
|
||||
"parameters": g.parameters,
|
||||
"duration": g.duration_seconds,
|
||||
"audio_file": Path(g.audio_path).name if g.audio_path else None,
|
||||
"created_at": g.created_at.isoformat(),
|
||||
"tags": g.tags,
|
||||
"seed": g.seed,
|
||||
}
|
||||
for g in generations
|
||||
],
|
||||
}
|
||||
|
||||
metadata_path = tmppath / "metadata.json"
|
||||
metadata_path.write_text(json.dumps(metadata, indent=2))
|
||||
|
||||
# Create ZIP
|
||||
output_path = Path(output_path)
|
||||
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for file in tmppath.iterdir():
|
||||
zf.write(file, file.name)
|
||||
|
||||
logger.info(f"Exported project {project_id} to {output_path}")
|
||||
return output_path
|
||||
|
||||
# Statistics
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get overall statistics."""
|
||||
return await self.database.get_stats()
|
||||
5
src/storage/__init__.py
Normal file
5
src/storage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Storage module for AudioCraft Studio."""
|
||||
|
||||
from src.storage.database import Database, Generation, Project, Preset
|
||||
|
||||
__all__ = ["Database", "Generation", "Project", "Preset"]
|
||||
550
src/storage/database.py
Normal file
550
src/storage/database.py
Normal file
@@ -0,0 +1,550 @@
|
||||
"""SQLite database for projects, generations, and presets."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import aiosqlite
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
"""Project entity for organizing generations."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
description: str = ""
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, description: str = "") -> "Project":
|
||||
"""Create a new project with generated ID."""
|
||||
now = datetime.utcnow()
|
||||
return cls(
|
||||
id=f"proj_{uuid.uuid4().hex[:12]}",
|
||||
name=name,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
description=description,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Generation:
|
||||
"""Audio generation record."""
|
||||
|
||||
id: str
|
||||
project_id: Optional[str]
|
||||
model: str
|
||||
variant: str
|
||||
prompt: str
|
||||
parameters: dict[str, Any]
|
||||
created_at: datetime
|
||||
audio_path: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
sample_rate: Optional[int] = None
|
||||
preset_used: Optional[str] = None
|
||||
conditioning: dict[str, Any] = field(default_factory=dict)
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = field(default_factory=list)
|
||||
notes: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
variant: str,
|
||||
prompt: str,
|
||||
parameters: dict[str, Any],
|
||||
project_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> "Generation":
|
||||
"""Create a new generation record."""
|
||||
return cls(
|
||||
id=f"gen_{uuid.uuid4().hex[:12]}",
|
||||
project_id=project_id,
|
||||
model=model,
|
||||
variant=variant,
|
||||
prompt=prompt,
|
||||
parameters=parameters,
|
||||
created_at=datetime.utcnow(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Preset:
|
||||
"""Generation parameter preset."""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
name: str
|
||||
parameters: dict[str, Any]
|
||||
created_at: datetime
|
||||
description: str = ""
|
||||
is_builtin: bool = False
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
model: str,
|
||||
name: str,
|
||||
parameters: dict[str, Any],
|
||||
description: str = "",
|
||||
) -> "Preset":
|
||||
"""Create a new custom preset."""
|
||||
return cls(
|
||||
id=f"preset_{uuid.uuid4().hex[:12]}",
|
||||
model=model,
|
||||
name=name,
|
||||
parameters=parameters,
|
||||
created_at=datetime.utcnow(),
|
||||
description=description,
|
||||
is_builtin=False,
|
||||
)
|
||||
|
||||
|
||||
class Database:
|
||||
"""Async SQLite database for AudioCraft Studio.
|
||||
|
||||
Handles storage of projects, generations, and presets.
|
||||
"""
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS projects (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS generations (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT REFERENCES projects(id) ON DELETE SET NULL,
|
||||
model TEXT NOT NULL,
|
||||
variant TEXT NOT NULL,
|
||||
prompt TEXT NOT NULL,
|
||||
parameters JSON NOT NULL,
|
||||
preset_used TEXT,
|
||||
conditioning JSON,
|
||||
audio_path TEXT,
|
||||
duration_seconds REAL,
|
||||
sample_rate INTEGER,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
name TEXT,
|
||||
tags JSON,
|
||||
notes TEXT,
|
||||
seed INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS presets (
|
||||
id TEXT PRIMARY KEY,
|
||||
model TEXT NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT DEFAULT '',
|
||||
parameters JSON NOT NULL,
|
||||
is_builtin BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_project ON generations(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_created ON generations(created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_generations_model ON generations(model);
|
||||
CREATE INDEX IF NOT EXISTS idx_presets_model ON presets(model);
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
"""Initialize database.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self._connection: Optional[aiosqlite.Connection] = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Open database connection and initialize schema."""
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._connection = await aiosqlite.connect(self.db_path)
|
||||
self._connection.row_factory = aiosqlite.Row
|
||||
|
||||
# Initialize schema
|
||||
await self._connection.executescript(self.SCHEMA)
|
||||
await self._connection.commit()
|
||||
|
||||
logger.info(f"Database connected: {self.db_path}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close database connection."""
|
||||
if self._connection:
|
||||
await self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def conn(self) -> aiosqlite.Connection:
|
||||
"""Get active connection."""
|
||||
if not self._connection:
|
||||
raise RuntimeError("Database not connected")
|
||||
return self._connection
|
||||
|
||||
# Project Methods
|
||||
|
||||
async def create_project(self, project: Project) -> Project:
|
||||
"""Create a new project."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO projects (id, name, description, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
project.id,
|
||||
project.name,
|
||||
project.description,
|
||||
project.created_at.isoformat(),
|
||||
project.updated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return project
|
||||
|
||||
async def get_project(self, project_id: str) -> Optional[Project]:
|
||||
"""Get a project by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM projects WHERE id = ?", (project_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return Project(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||
)
|
||||
return None
|
||||
|
||||
async def list_projects(
|
||||
self, limit: int = 100, offset: int = 0
|
||||
) -> list[Project]:
|
||||
"""List all projects, ordered by last update."""
|
||||
async with self.conn.execute(
|
||||
"""
|
||||
SELECT * FROM projects
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [
|
||||
Project(
|
||||
id=row["id"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
updated_at=datetime.fromisoformat(row["updated_at"]),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def update_project(self, project: Project) -> None:
|
||||
"""Update a project."""
|
||||
project.updated_at = datetime.utcnow()
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE projects SET name = ?, description = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(project.name, project.description, project.updated_at.isoformat(), project.id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
async def delete_project(self, project_id: str) -> bool:
|
||||
"""Delete a project (generations are kept but unlinked)."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM projects WHERE id = ?", (project_id,)
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
# Generation Methods
|
||||
|
||||
async def create_generation(self, generation: Generation) -> Generation:
|
||||
"""Create a new generation record."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO generations (
|
||||
id, project_id, model, variant, prompt, parameters,
|
||||
preset_used, conditioning, audio_path, duration_seconds,
|
||||
sample_rate, created_at, name, tags, notes, seed
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
generation.id,
|
||||
generation.project_id,
|
||||
generation.model,
|
||||
generation.variant,
|
||||
generation.prompt,
|
||||
json.dumps(generation.parameters),
|
||||
generation.preset_used,
|
||||
json.dumps(generation.conditioning),
|
||||
generation.audio_path,
|
||||
generation.duration_seconds,
|
||||
generation.sample_rate,
|
||||
generation.created_at.isoformat(),
|
||||
generation.name,
|
||||
json.dumps(generation.tags),
|
||||
generation.notes,
|
||||
generation.seed,
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
# Update project's updated_at if linked
|
||||
if generation.project_id:
|
||||
await self.conn.execute(
|
||||
"UPDATE projects SET updated_at = ? WHERE id = ?",
|
||||
(datetime.utcnow().isoformat(), generation.project_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
return generation
|
||||
|
||||
async def get_generation(self, generation_id: str) -> Optional[Generation]:
|
||||
"""Get a generation by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM generations WHERE id = ?", (generation_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_generation(row)
|
||||
return None
|
||||
|
||||
async def list_generations(
|
||||
self,
|
||||
project_id: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
search: Optional[str] = None,
|
||||
) -> list[Generation]:
|
||||
"""List generations with optional filters."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if project_id:
|
||||
conditions.append("project_id = ?")
|
||||
params.append(project_id)
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
if search:
|
||||
conditions.append("(prompt LIKE ? OR name LIKE ? OR tags LIKE ?)")
|
||||
search_pattern = f"%{search}%"
|
||||
params.extend([search_pattern, search_pattern, search_pattern])
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"""
|
||||
SELECT * FROM generations
|
||||
WHERE {where_clause}
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(*params, limit, offset),
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_generation(row) for row in rows]
|
||||
|
||||
async def update_generation(self, generation: Generation) -> None:
|
||||
"""Update a generation record."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
UPDATE generations SET
|
||||
project_id = ?, name = ?, tags = ?, notes = ?,
|
||||
audio_path = ?, duration_seconds = ?, sample_rate = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(
|
||||
generation.project_id,
|
||||
generation.name,
|
||||
json.dumps(generation.tags),
|
||||
generation.notes,
|
||||
generation.audio_path,
|
||||
generation.duration_seconds,
|
||||
generation.sample_rate,
|
||||
generation.id,
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
async def delete_generation(self, generation_id: str) -> bool:
|
||||
"""Delete a generation record."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM generations WHERE id = ?", (generation_id,)
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
async def count_generations(
|
||||
self, project_id: Optional[str] = None, model: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count generations with optional filters."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if project_id:
|
||||
conditions.append("project_id = ?")
|
||||
params.append(project_id)
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"SELECT COUNT(*) FROM generations WHERE {where_clause}",
|
||||
params,
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def _row_to_generation(self, row: aiosqlite.Row) -> Generation:
|
||||
"""Convert database row to Generation object."""
|
||||
return Generation(
|
||||
id=row["id"],
|
||||
project_id=row["project_id"],
|
||||
model=row["model"],
|
||||
variant=row["variant"],
|
||||
prompt=row["prompt"],
|
||||
parameters=json.loads(row["parameters"]),
|
||||
preset_used=row["preset_used"],
|
||||
conditioning=json.loads(row["conditioning"]) if row["conditioning"] else {},
|
||||
audio_path=row["audio_path"],
|
||||
duration_seconds=row["duration_seconds"],
|
||||
sample_rate=row["sample_rate"],
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
name=row["name"],
|
||||
tags=json.loads(row["tags"]) if row["tags"] else [],
|
||||
notes=row["notes"],
|
||||
seed=row["seed"],
|
||||
)
|
||||
|
||||
# Preset Methods
|
||||
|
||||
async def create_preset(self, preset: Preset) -> Preset:
|
||||
"""Create a new preset."""
|
||||
await self.conn.execute(
|
||||
"""
|
||||
INSERT INTO presets (id, model, name, description, parameters, is_builtin, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
preset.id,
|
||||
preset.model,
|
||||
preset.name,
|
||||
preset.description,
|
||||
json.dumps(preset.parameters),
|
||||
preset.is_builtin,
|
||||
preset.created_at.isoformat(),
|
||||
),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return preset
|
||||
|
||||
async def get_preset(self, preset_id: str) -> Optional[Preset]:
|
||||
"""Get a preset by ID."""
|
||||
async with self.conn.execute(
|
||||
"SELECT * FROM presets WHERE id = ?", (preset_id,)
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return self._row_to_preset(row)
|
||||
return None
|
||||
|
||||
async def list_presets(
|
||||
self, model: Optional[str] = None, include_builtin: bool = True
|
||||
) -> list[Preset]:
|
||||
"""List presets with optional model filter."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
if not include_builtin:
|
||||
conditions.append("is_builtin = FALSE")
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with self.conn.execute(
|
||||
f"""
|
||||
SELECT * FROM presets
|
||||
WHERE {where_clause}
|
||||
ORDER BY is_builtin DESC, name ASC
|
||||
""",
|
||||
params,
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
return [self._row_to_preset(row) for row in rows]
|
||||
|
||||
async def delete_preset(self, preset_id: str) -> bool:
|
||||
"""Delete a preset (only custom presets can be deleted)."""
|
||||
result = await self.conn.execute(
|
||||
"DELETE FROM presets WHERE id = ? AND is_builtin = FALSE",
|
||||
(preset_id,),
|
||||
)
|
||||
await self.conn.commit()
|
||||
return result.rowcount > 0
|
||||
|
||||
def _row_to_preset(self, row: aiosqlite.Row) -> Preset:
|
||||
"""Convert database row to Preset object."""
|
||||
return Preset(
|
||||
id=row["id"],
|
||||
model=row["model"],
|
||||
name=row["name"],
|
||||
description=row["description"] or "",
|
||||
parameters=json.loads(row["parameters"]),
|
||||
is_builtin=bool(row["is_builtin"]),
|
||||
created_at=datetime.fromisoformat(row["created_at"]),
|
||||
)
|
||||
|
||||
# Utility Methods
|
||||
|
||||
async def get_stats(self) -> dict[str, Any]:
|
||||
"""Get database statistics."""
|
||||
stats = {}
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM projects") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["projects"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM generations") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["generations"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute("SELECT COUNT(*) FROM presets") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
stats["presets"] = row[0] if row else 0
|
||||
|
||||
async with self.conn.execute(
|
||||
"SELECT model, COUNT(*) as count FROM generations GROUP BY model"
|
||||
) as cursor:
|
||||
rows = await cursor.fetchall()
|
||||
stats["generations_by_model"] = {row["model"]: row["count"] for row in rows}
|
||||
|
||||
return stats
|
||||
5
src/ui/__init__.py
Normal file
5
src/ui/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Gradio UI for AudioCraft Studio."""
|
||||
|
||||
from src.ui.app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
355
src/ui/app.py
Normal file
355
src/ui/app.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Main Gradio application for AudioCraft Studio."""
|
||||
|
||||
import asyncio
|
||||
import gradio as gr
|
||||
from typing import Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from src.ui.theme import create_audiocraft_theme, get_custom_css
|
||||
from src.ui.state import UIState, DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.vram_monitor import create_vram_monitor
|
||||
from src.ui.tabs import (
|
||||
create_dashboard_tab,
|
||||
create_musicgen_tab,
|
||||
create_audiogen_tab,
|
||||
create_magnet_tab,
|
||||
create_style_tab,
|
||||
create_jasco_tab,
|
||||
)
|
||||
from src.ui.pages import create_projects_page, create_settings_page
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
|
||||
class AudioCraftApp:
|
||||
"""Main AudioCraft Studio Gradio application."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generation_service: Any = None,
|
||||
batch_processor: Any = None,
|
||||
project_service: Any = None,
|
||||
gpu_manager: Any = None,
|
||||
model_registry: Any = None,
|
||||
):
|
||||
"""Initialize the application.
|
||||
|
||||
Args:
|
||||
generation_service: Service for handling generations
|
||||
batch_processor: Service for batch/queue processing
|
||||
project_service: Service for project management
|
||||
gpu_manager: GPU memory manager
|
||||
model_registry: Model registry for loading/unloading
|
||||
"""
|
||||
self.settings = get_settings()
|
||||
self.generation_service = generation_service
|
||||
self.batch_processor = batch_processor
|
||||
self.project_service = project_service
|
||||
self.gpu_manager = gpu_manager
|
||||
self.model_registry = model_registry
|
||||
|
||||
self.ui_state = UIState()
|
||||
self.app: Optional[gr.Blocks] = None
|
||||
|
||||
def _get_queue_status(self) -> dict[str, Any]:
|
||||
"""Get current queue status."""
|
||||
if self.batch_processor:
|
||||
return {
|
||||
"queue_size": len(self.batch_processor.queue),
|
||||
"active_jobs": self.batch_processor.active_count,
|
||||
"completed_today": self.batch_processor.completed_count,
|
||||
}
|
||||
return {"queue_size": 0, "active_jobs": 0, "completed_today": 0}
|
||||
|
||||
def _get_recent_generations(self, limit: int = 5) -> list[dict[str, Any]]:
|
||||
"""Get recent generations."""
|
||||
if self.project_service:
|
||||
try:
|
||||
return asyncio.run(self.project_service.get_recent_generations(limit))
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
def _get_gpu_status(self) -> dict[str, Any]:
|
||||
"""Get GPU memory status."""
|
||||
if self.gpu_manager:
|
||||
return {
|
||||
"used_gb": self.gpu_manager.get_used_memory() / 1024**3,
|
||||
"total_gb": self.gpu_manager.total_memory / 1024**3,
|
||||
"utilization_percent": self.gpu_manager.get_utilization(),
|
||||
"available_gb": self.gpu_manager.get_available_memory() / 1024**3,
|
||||
}
|
||||
return {"used_gb": 0, "total_gb": 24, "utilization_percent": 0, "available_gb": 24}
|
||||
|
||||
async def _generate(self, **kwargs) -> tuple[Any, Any]:
|
||||
"""Generate audio using the generation service."""
|
||||
if self.generation_service:
|
||||
return await self.generation_service.generate(**kwargs)
|
||||
raise RuntimeError("Generation service not configured")
|
||||
|
||||
def _add_to_queue(self, **kwargs) -> Any:
|
||||
"""Add generation job to queue."""
|
||||
if self.batch_processor:
|
||||
return self.batch_processor.add_job(**kwargs)
|
||||
raise RuntimeError("Batch processor not configured")
|
||||
|
||||
def _get_projects(self) -> list[dict]:
|
||||
"""Get all projects."""
|
||||
if self.project_service:
|
||||
try:
|
||||
return asyncio.run(self.project_service.list_projects())
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
def _get_generations(self, project_id: str, limit: int, offset: int) -> list[dict]:
|
||||
"""Get generations for a project."""
|
||||
if self.project_service:
|
||||
try:
|
||||
return asyncio.run(
|
||||
self.project_service.list_generations(project_id, limit, offset)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
def _delete_generation(self, generation_id: str) -> bool:
|
||||
"""Delete a generation."""
|
||||
if self.project_service:
|
||||
try:
|
||||
asyncio.run(self.project_service.delete_generation(generation_id))
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _export_project(self, project_id: str) -> str:
|
||||
"""Export project as ZIP."""
|
||||
if self.project_service:
|
||||
return asyncio.run(self.project_service.export_project_zip(project_id))
|
||||
raise RuntimeError("Project service not configured")
|
||||
|
||||
def _create_project(self, name: str, description: str) -> dict:
|
||||
"""Create a new project."""
|
||||
if self.project_service:
|
||||
return asyncio.run(self.project_service.create_project(name, description))
|
||||
raise RuntimeError("Project service not configured")
|
||||
|
||||
def _get_app_settings(self) -> dict:
|
||||
"""Get application settings."""
|
||||
return {
|
||||
"output_dir": str(self.settings.output_dir),
|
||||
"default_format": self.settings.default_format,
|
||||
"sample_rate": self.settings.sample_rate,
|
||||
"normalize_audio": self.settings.normalize_audio,
|
||||
"theme_mode": "Dark",
|
||||
"show_advanced": False,
|
||||
"auto_play": True,
|
||||
"comfyui_reserve_gb": self.settings.comfyui_reserve_gb,
|
||||
"idle_timeout_minutes": self.settings.idle_unload_minutes,
|
||||
"max_loaded_models": self.settings.max_loaded_models,
|
||||
"musicgen_variant": "medium",
|
||||
"musicgen_duration": 10,
|
||||
"audiogen_duration": 5,
|
||||
"magnet_variant": "medium",
|
||||
"magnet_decoding_steps": 20,
|
||||
"api_enabled": self.settings.api_enabled,
|
||||
"api_port": self.settings.api_port,
|
||||
"rate_limit": self.settings.api_rate_limit,
|
||||
"max_batch_size": self.settings.max_batch_size,
|
||||
"max_queue_size": self.settings.max_queue_size,
|
||||
"max_workers": 1,
|
||||
"priority_queue": False,
|
||||
}
|
||||
|
||||
def _update_app_settings(self, settings: dict) -> bool:
|
||||
"""Update application settings."""
|
||||
# In a real implementation, this would persist settings
|
||||
# For now, just return success
|
||||
return True
|
||||
|
||||
def _clear_cache(self) -> bool:
|
||||
"""Clear model cache."""
|
||||
if self.model_registry:
|
||||
try:
|
||||
self.model_registry.clear_cache()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _unload_all_models(self) -> bool:
|
||||
"""Unload all models from memory."""
|
||||
if self.model_registry:
|
||||
try:
|
||||
asyncio.run(self.model_registry.unload_all())
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def build(self) -> gr.Blocks:
|
||||
"""Build the Gradio application."""
|
||||
theme = create_audiocraft_theme()
|
||||
css = get_custom_css()
|
||||
|
||||
with gr.Blocks(
|
||||
theme=theme,
|
||||
css=css,
|
||||
title="AudioCraft Studio",
|
||||
analytics_enabled=False,
|
||||
) as app:
|
||||
# Header with VRAM monitor
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
gr.Markdown("# AudioCraft Studio")
|
||||
with gr.Column(scale=1):
|
||||
vram_monitor = create_vram_monitor(
|
||||
get_status_fn=self._get_gpu_status,
|
||||
update_interval=5,
|
||||
)
|
||||
|
||||
# Main tabs
|
||||
with gr.Tabs() as main_tabs:
|
||||
# Dashboard
|
||||
with gr.TabItem("Dashboard", id="dashboard"):
|
||||
dashboard = create_dashboard_tab(
|
||||
get_queue_status=self._get_queue_status,
|
||||
get_recent_generations=self._get_recent_generations,
|
||||
get_gpu_status=self._get_gpu_status,
|
||||
)
|
||||
|
||||
# Model tabs
|
||||
with gr.TabItem("MusicGen", id="musicgen"):
|
||||
musicgen = create_musicgen_tab(
|
||||
generate_fn=self._generate,
|
||||
add_to_queue_fn=self._add_to_queue,
|
||||
)
|
||||
|
||||
with gr.TabItem("AudioGen", id="audiogen"):
|
||||
audiogen = create_audiogen_tab(
|
||||
generate_fn=self._generate,
|
||||
add_to_queue_fn=self._add_to_queue,
|
||||
)
|
||||
|
||||
with gr.TabItem("MAGNeT", id="magnet"):
|
||||
magnet = create_magnet_tab(
|
||||
generate_fn=self._generate,
|
||||
add_to_queue_fn=self._add_to_queue,
|
||||
)
|
||||
|
||||
with gr.TabItem("Style", id="style"):
|
||||
style = create_style_tab(
|
||||
generate_fn=self._generate,
|
||||
add_to_queue_fn=self._add_to_queue,
|
||||
)
|
||||
|
||||
with gr.TabItem("JASCO", id="jasco"):
|
||||
jasco = create_jasco_tab(
|
||||
generate_fn=self._generate,
|
||||
add_to_queue_fn=self._add_to_queue,
|
||||
)
|
||||
|
||||
# Projects
|
||||
with gr.TabItem("Projects", id="projects"):
|
||||
projects = create_projects_page(
|
||||
get_projects=self._get_projects,
|
||||
get_generations=self._get_generations,
|
||||
delete_generation=self._delete_generation,
|
||||
export_project=self._export_project,
|
||||
create_project=self._create_project,
|
||||
)
|
||||
|
||||
# Settings
|
||||
with gr.TabItem("Settings", id="settings"):
|
||||
settings = create_settings_page(
|
||||
get_settings=self._get_app_settings,
|
||||
update_settings=self._update_app_settings,
|
||||
get_gpu_info=self._get_gpu_status,
|
||||
clear_cache=self._clear_cache,
|
||||
unload_all_models=self._unload_all_models,
|
||||
)
|
||||
|
||||
# Footer
|
||||
gr.Markdown("---")
|
||||
gr.Markdown(
|
||||
"AudioCraft Studio | "
|
||||
"[Documentation](https://github.com/facebookresearch/audiocraft) | "
|
||||
"Powered by Meta AudioCraft"
|
||||
)
|
||||
|
||||
# Store component references
|
||||
self.components = {
|
||||
"vram_monitor": vram_monitor,
|
||||
"dashboard": dashboard,
|
||||
"musicgen": musicgen,
|
||||
"audiogen": audiogen,
|
||||
"magnet": magnet,
|
||||
"style": style,
|
||||
"jasco": jasco,
|
||||
"projects": projects,
|
||||
"settings": settings,
|
||||
}
|
||||
|
||||
self.app = app
|
||||
return app
|
||||
|
||||
def launch(
|
||||
self,
|
||||
server_name: Optional[str] = None,
|
||||
server_port: Optional[int] = None,
|
||||
share: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Launch the Gradio application.
|
||||
|
||||
Args:
|
||||
server_name: Server hostname
|
||||
server_port: Server port
|
||||
share: Whether to create a public share link
|
||||
**kwargs: Additional arguments for gr.Blocks.launch()
|
||||
"""
|
||||
if self.app is None:
|
||||
self.build()
|
||||
|
||||
self.app.launch(
|
||||
server_name=server_name or self.settings.host,
|
||||
server_port=server_port or self.settings.gradio_port,
|
||||
share=share,
|
||||
show_error=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_app(
|
||||
generation_service: Any = None,
|
||||
batch_processor: Any = None,
|
||||
project_service: Any = None,
|
||||
gpu_manager: Any = None,
|
||||
model_registry: Any = None,
|
||||
) -> AudioCraftApp:
|
||||
"""Create and return the AudioCraft application.
|
||||
|
||||
Args:
|
||||
generation_service: Service for handling generations
|
||||
batch_processor: Service for batch/queue processing
|
||||
project_service: Service for project management
|
||||
gpu_manager: GPU memory manager
|
||||
model_registry: Model registry for loading/unloading
|
||||
|
||||
Returns:
|
||||
Configured AudioCraftApp instance
|
||||
"""
|
||||
return AudioCraftApp(
|
||||
generation_service=generation_service,
|
||||
batch_processor=batch_processor,
|
||||
project_service=project_service,
|
||||
gpu_manager=gpu_manager,
|
||||
model_registry=model_registry,
|
||||
)
|
||||
|
||||
|
||||
# Standalone launch for development/testing
|
||||
if __name__ == "__main__":
|
||||
app = create_app()
|
||||
app.launch()
|
||||
13
src/ui/components/__init__.py
Normal file
13
src/ui/components/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Reusable UI components for AudioCraft Studio."""
|
||||
|
||||
from src.ui.components.vram_monitor import create_vram_monitor
|
||||
from src.ui.components.audio_player import create_audio_player
|
||||
from src.ui.components.preset_selector import create_preset_selector
|
||||
from src.ui.components.generation_params import create_generation_params
|
||||
|
||||
__all__ = [
|
||||
"create_vram_monitor",
|
||||
"create_audio_player",
|
||||
"create_preset_selector",
|
||||
"create_generation_params",
|
||||
]
|
||||
178
src/ui/components/audio_player.py
Normal file
178
src/ui/components/audio_player.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""Audio player component with waveform visualization."""
|
||||
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Callable
|
||||
|
||||
|
||||
def create_audio_player(
|
||||
label: str = "Generated Audio",
|
||||
show_waveform: bool = True,
|
||||
show_download: bool = True,
|
||||
show_info: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create audio player component with optional waveform.
|
||||
|
||||
Args:
|
||||
label: Label for the audio component
|
||||
show_waveform: Show waveform image
|
||||
show_download: Show download buttons
|
||||
show_info: Show audio info (duration, sample rate)
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
with gr.Group():
|
||||
# Audio player
|
||||
audio_output = gr.Audio(
|
||||
label=label,
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
show_download_button=show_download,
|
||||
)
|
||||
|
||||
# Waveform visualization
|
||||
if show_waveform:
|
||||
waveform_image = gr.Image(
|
||||
label="Waveform",
|
||||
type="filepath",
|
||||
interactive=False,
|
||||
height=100,
|
||||
visible=False,
|
||||
)
|
||||
else:
|
||||
waveform_image = None
|
||||
|
||||
# Audio info
|
||||
if show_info:
|
||||
with gr.Row():
|
||||
duration_text = gr.Textbox(
|
||||
label="Duration",
|
||||
value="",
|
||||
interactive=False,
|
||||
scale=1,
|
||||
)
|
||||
sample_rate_text = gr.Textbox(
|
||||
label="Sample Rate",
|
||||
value="",
|
||||
interactive=False,
|
||||
scale=1,
|
||||
)
|
||||
seed_text = gr.Textbox(
|
||||
label="Seed",
|
||||
value="",
|
||||
interactive=False,
|
||||
scale=1,
|
||||
)
|
||||
else:
|
||||
duration_text = None
|
||||
sample_rate_text = None
|
||||
seed_text = None
|
||||
|
||||
# Download buttons
|
||||
if show_download:
|
||||
with gr.Row():
|
||||
download_wav = gr.Button("Download WAV", size="sm")
|
||||
download_mp3 = gr.Button("Download MP3", size="sm")
|
||||
download_flac = gr.Button("Download FLAC", size="sm")
|
||||
else:
|
||||
download_wav = download_mp3 = download_flac = None
|
||||
|
||||
return {
|
||||
"audio": audio_output,
|
||||
"waveform": waveform_image,
|
||||
"duration": duration_text,
|
||||
"sample_rate": sample_rate_text,
|
||||
"seed": seed_text,
|
||||
"download_wav": download_wav,
|
||||
"download_mp3": download_mp3,
|
||||
"download_flac": download_flac,
|
||||
}
|
||||
|
||||
|
||||
def update_audio_player(
|
||||
audio_path: Optional[str],
|
||||
duration: Optional[float] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
waveform_path: Optional[str] = None,
|
||||
) -> tuple:
|
||||
"""Update audio player with new audio.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
duration: Audio duration in seconds
|
||||
sample_rate: Audio sample rate
|
||||
seed: Generation seed
|
||||
waveform_path: Path to waveform image
|
||||
|
||||
Returns:
|
||||
Tuple of update values for components
|
||||
"""
|
||||
duration_str = f"{duration:.2f}s" if duration else ""
|
||||
sample_rate_str = f"{sample_rate} Hz" if sample_rate else ""
|
||||
seed_str = str(seed) if seed is not None else ""
|
||||
|
||||
waveform_update = gr.update(value=waveform_path, visible=waveform_path is not None)
|
||||
|
||||
return (
|
||||
audio_path,
|
||||
waveform_update,
|
||||
duration_str,
|
||||
sample_rate_str,
|
||||
seed_str,
|
||||
)
|
||||
|
||||
|
||||
def create_generation_output() -> dict[str, Any]:
|
||||
"""Create generation output section with audio player and metadata.
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
with gr.Group():
|
||||
gr.Markdown("### Output")
|
||||
|
||||
# Status/progress
|
||||
with gr.Row():
|
||||
status_text = gr.Markdown("Ready to generate")
|
||||
progress_bar = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=0,
|
||||
label="Progress",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
|
||||
# Audio player
|
||||
player = create_audio_player(
|
||||
label="Generated Audio",
|
||||
show_waveform=True,
|
||||
show_download=True,
|
||||
show_info=True,
|
||||
)
|
||||
|
||||
# Generation metadata
|
||||
with gr.Accordion("Generation Details", open=False):
|
||||
generation_info = gr.JSON(
|
||||
label="Parameters",
|
||||
value={},
|
||||
)
|
||||
|
||||
# Actions
|
||||
with gr.Row():
|
||||
save_btn = gr.Button("Save to Project", variant="secondary")
|
||||
regenerate_btn = gr.Button("Regenerate", variant="secondary")
|
||||
add_queue_btn = gr.Button("Add to Queue", variant="secondary")
|
||||
|
||||
return {
|
||||
"status": status_text,
|
||||
"progress": progress_bar,
|
||||
"player": player,
|
||||
"info": generation_info,
|
||||
"save_btn": save_btn,
|
||||
"regenerate_btn": regenerate_btn,
|
||||
"add_queue_btn": add_queue_btn,
|
||||
}
|
||||
199
src/ui/components/generation_params.py
Normal file
199
src/ui/components/generation_params.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Generation parameters component."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
def create_generation_params(
|
||||
model_id: str,
|
||||
show_advanced: bool = False,
|
||||
max_duration: float = 30.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Create generation parameters panel.
|
||||
|
||||
Args:
|
||||
model_id: Model family for customizing available options
|
||||
show_advanced: Whether to show advanced parameters by default
|
||||
max_duration: Maximum allowed duration
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
# Model-specific defaults
|
||||
defaults = {
|
||||
"musicgen": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||
"audiogen": {"duration": 5, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||
"magnet": {"duration": 10, "temperature": 3.0, "top_k": 0, "top_p": 0.9, "cfg_coef": 3.0},
|
||||
"musicgen-style": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||
"jasco": {"duration": 10, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0},
|
||||
}
|
||||
|
||||
d = defaults.get(model_id, defaults["musicgen"])
|
||||
|
||||
with gr.Group():
|
||||
# Basic parameters (always visible)
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=max_duration,
|
||||
value=d["duration"],
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
info="Length of audio to generate",
|
||||
)
|
||||
|
||||
# Advanced parameters (expandable)
|
||||
with gr.Accordion("Advanced Parameters", open=show_advanced):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=d["temperature"],
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
info="Higher = more random, lower = more deterministic",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=d["cfg_coef"],
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
info="Classifier-free guidance strength",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=d["top_k"],
|
||||
step=10,
|
||||
label="Top-K",
|
||||
info="Token selection limit (0 = disabled)",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=d["top_p"],
|
||||
step=0.05,
|
||||
label="Top-P (Nucleus)",
|
||||
info="Cumulative probability threshold (0 = disabled)",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed",
|
||||
info="Random seed for reproducibility (leave empty for random)",
|
||||
precision=0,
|
||||
)
|
||||
use_random_seed = gr.Checkbox(
|
||||
value=True,
|
||||
label="Random Seed",
|
||||
)
|
||||
|
||||
# Reset button
|
||||
reset_btn = gr.Button("Reset to Defaults", size="sm", variant="secondary")
|
||||
|
||||
def reset_params():
|
||||
"""Reset all parameters to defaults."""
|
||||
return (
|
||||
d["duration"],
|
||||
d["temperature"],
|
||||
d["cfg_coef"],
|
||||
d["top_k"],
|
||||
d["top_p"],
|
||||
None,
|
||||
True,
|
||||
)
|
||||
|
||||
reset_btn.click(
|
||||
fn=reset_params,
|
||||
outputs=[
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
use_random_seed,
|
||||
],
|
||||
)
|
||||
|
||||
# Link random seed checkbox to seed input
|
||||
def toggle_seed(use_random: bool, current_seed: Optional[int]):
|
||||
if use_random:
|
||||
return gr.update(value=None, interactive=False)
|
||||
return gr.update(interactive=True)
|
||||
|
||||
use_random_seed.change(
|
||||
fn=toggle_seed,
|
||||
inputs=[use_random_seed, seed_input],
|
||||
outputs=[seed_input],
|
||||
)
|
||||
|
||||
return {
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"use_random_seed": use_random_seed,
|
||||
"reset_btn": reset_btn,
|
||||
}
|
||||
|
||||
|
||||
def create_model_variant_selector(
|
||||
model_id: str,
|
||||
variants: list[dict[str, Any]],
|
||||
default_variant: str = "medium",
|
||||
) -> dict[str, Any]:
|
||||
"""Create model variant selector.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
variants: List of variant configurations
|
||||
default_variant: Default variant to select
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
# Build choices with descriptions
|
||||
choices = []
|
||||
for v in variants:
|
||||
name = v.get("name", v.get("id", "unknown"))
|
||||
vram = v.get("vram_mb", 0)
|
||||
desc = v.get("description", "")
|
||||
label = f"{name} ({vram/1024:.1f}GB)"
|
||||
choices.append((label, name))
|
||||
|
||||
with gr.Group():
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=choices,
|
||||
value=default_variant,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
variant_info = gr.Markdown(
|
||||
value="",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
def update_info(variant_name: str):
|
||||
for v in variants:
|
||||
if v.get("name", v.get("id")) == variant_name:
|
||||
return v.get("description", "")
|
||||
return ""
|
||||
|
||||
variant_dropdown.change(
|
||||
fn=update_info,
|
||||
inputs=[variant_dropdown],
|
||||
outputs=[variant_info],
|
||||
)
|
||||
|
||||
return {
|
||||
"dropdown": variant_dropdown,
|
||||
"info": variant_info,
|
||||
"variants": variants,
|
||||
}
|
||||
103
src/ui/components/preset_selector.py
Normal file
103
src/ui/components/preset_selector.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Preset selector component."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS
|
||||
|
||||
|
||||
def create_preset_selector(
|
||||
model_id: str,
|
||||
on_preset_select: Optional[Callable[[dict], None]] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create preset selector component for a model.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
on_preset_select: Callback when preset is selected
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get(model_id, [])
|
||||
|
||||
# Create preset choices
|
||||
choices = [(p["name"], p["id"]) for p in presets]
|
||||
choices.append(("Custom", "custom"))
|
||||
|
||||
def get_preset_by_id(preset_id: str) -> Optional[dict]:
|
||||
"""Get preset data by ID."""
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
return p
|
||||
return None
|
||||
|
||||
def on_change(preset_id: str):
|
||||
"""Handle preset selection change."""
|
||||
if preset_id == "custom":
|
||||
return gr.update(visible=True), {}
|
||||
|
||||
preset = get_preset_by_id(preset_id)
|
||||
if preset:
|
||||
return gr.update(visible=False), preset.get("parameters", {})
|
||||
|
||||
return gr.update(visible=True), {}
|
||||
|
||||
with gr.Group():
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
preset_description = gr.Markdown(
|
||||
value=presets[0]["description"] if presets else "",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"dropdown": preset_dropdown,
|
||||
"description": preset_description,
|
||||
"presets": presets,
|
||||
"get_preset": get_preset_by_id,
|
||||
"on_change": on_change,
|
||||
}
|
||||
|
||||
|
||||
def create_preset_chips(
|
||||
model_id: str,
|
||||
on_select: Callable[[str], None],
|
||||
) -> dict[str, Any]:
|
||||
"""Create preset selector as clickable chips/buttons.
|
||||
|
||||
Args:
|
||||
model_id: Model family ID
|
||||
on_select: Callback when preset is clicked
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get(model_id, [])
|
||||
|
||||
with gr.Row():
|
||||
buttons = []
|
||||
for preset in presets:
|
||||
btn = gr.Button(
|
||||
preset["name"],
|
||||
size="sm",
|
||||
variant="secondary",
|
||||
)
|
||||
buttons.append((btn, preset))
|
||||
|
||||
custom_btn = gr.Button(
|
||||
"Custom",
|
||||
size="sm",
|
||||
variant="secondary",
|
||||
)
|
||||
|
||||
return {
|
||||
"buttons": buttons,
|
||||
"custom_btn": custom_btn,
|
||||
"presets": presets,
|
||||
}
|
||||
151
src/ui/components/vram_monitor.py
Normal file
151
src/ui/components/vram_monitor.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""VRAM monitor component for GPU memory tracking."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
def create_vram_monitor(
|
||||
get_gpu_status: Callable[[], dict[str, Any]],
|
||||
get_loaded_models: Callable[[], list[dict[str, Any]]],
|
||||
unload_model: Callable[[str, str], bool],
|
||||
load_model: Callable[[str, str], bool],
|
||||
) -> dict[str, Any]:
|
||||
"""Create VRAM monitor component.
|
||||
|
||||
Args:
|
||||
get_gpu_status: Function to get GPU status dict
|
||||
get_loaded_models: Function to get list of loaded models
|
||||
unload_model: Function to unload a model (model_id, variant)
|
||||
load_model: Function to load a model (model_id, variant)
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
def refresh_status():
|
||||
"""Refresh GPU status display."""
|
||||
status = get_gpu_status()
|
||||
loaded = get_loaded_models()
|
||||
|
||||
# Format VRAM bar
|
||||
used_gb = status.get("used_gb", 0)
|
||||
total_gb = status.get("total_gb", 24)
|
||||
util_pct = status.get("utilization_percent", 0)
|
||||
|
||||
vram_text = f"{used_gb:.1f} / {total_gb:.1f} GB ({util_pct:.0f}%)"
|
||||
|
||||
# Format loaded models list
|
||||
if loaded:
|
||||
models_text = "\n".join([
|
||||
f"• {m['model_id']}/{m['variant']} "
|
||||
f"(idle: {m['idle_seconds']:.0f}s)"
|
||||
for m in loaded
|
||||
])
|
||||
else:
|
||||
models_text = "No models loaded"
|
||||
|
||||
# Determine status color
|
||||
if util_pct > 90:
|
||||
status_color = "🔴"
|
||||
elif util_pct > 75:
|
||||
status_color = "🟡"
|
||||
else:
|
||||
status_color = "🟢"
|
||||
|
||||
status_text = f"{status_color} GPU: {status.get('device', 'N/A')}"
|
||||
|
||||
return vram_text, util_pct, models_text, status_text
|
||||
|
||||
def handle_unload(model_selection: str):
|
||||
"""Handle model unload."""
|
||||
if not model_selection or "/" not in model_selection:
|
||||
return "Select a model to unload", *refresh_status()
|
||||
|
||||
parts = model_selection.split("/")
|
||||
model_id, variant = parts[0], parts[1]
|
||||
|
||||
success = unload_model(model_id, variant)
|
||||
if success:
|
||||
msg = f"Unloaded {model_id}/{variant}"
|
||||
else:
|
||||
msg = f"Failed to unload {model_id}/{variant}"
|
||||
|
||||
return msg, *refresh_status()
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### GPU Memory")
|
||||
|
||||
status_text = gr.Markdown("🟢 GPU: Checking...")
|
||||
|
||||
with gr.Row():
|
||||
vram_display = gr.Textbox(
|
||||
label="VRAM Usage",
|
||||
value="Loading...",
|
||||
interactive=False,
|
||||
scale=3,
|
||||
)
|
||||
refresh_btn = gr.Button("🔄", scale=1, min_width=50)
|
||||
|
||||
vram_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=100,
|
||||
value=0,
|
||||
label="",
|
||||
interactive=False,
|
||||
visible=True,
|
||||
)
|
||||
|
||||
gr.Markdown("### Loaded Models")
|
||||
|
||||
models_display = gr.Textbox(
|
||||
label="",
|
||||
value="No models loaded",
|
||||
interactive=False,
|
||||
lines=4,
|
||||
max_lines=6,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
model_selector = gr.Dropdown(
|
||||
label="Select Model",
|
||||
choices=[],
|
||||
interactive=True,
|
||||
scale=3,
|
||||
)
|
||||
unload_btn = gr.Button("Unload", variant="secondary", scale=1)
|
||||
|
||||
unload_status = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
def update_model_choices():
|
||||
loaded = get_loaded_models()
|
||||
choices = [f"{m['model_id']}/{m['variant']}" for m in loaded]
|
||||
return gr.update(choices=choices, value=None)
|
||||
|
||||
refresh_btn.click(
|
||||
fn=refresh_status,
|
||||
outputs=[vram_display, vram_slider, models_display, status_text],
|
||||
).then(
|
||||
fn=update_model_choices,
|
||||
outputs=[model_selector],
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
fn=handle_unload,
|
||||
inputs=[model_selector],
|
||||
outputs=[unload_status, vram_display, vram_slider, models_display, status_text],
|
||||
).then(
|
||||
fn=update_model_choices,
|
||||
outputs=[model_selector],
|
||||
)
|
||||
|
||||
return {
|
||||
"vram_display": vram_display,
|
||||
"vram_slider": vram_slider,
|
||||
"models_display": models_display,
|
||||
"status_text": status_text,
|
||||
"model_selector": model_selector,
|
||||
"refresh_btn": refresh_btn,
|
||||
"unload_btn": unload_btn,
|
||||
"refresh_fn": refresh_status,
|
||||
}
|
||||
9
src/ui/pages/__init__.py
Normal file
9
src/ui/pages/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""UI pages for AudioCraft Studio."""
|
||||
|
||||
from src.ui.pages.projects_page import create_projects_page
|
||||
from src.ui.pages.settings_page import create_settings_page
|
||||
|
||||
__all__ = [
|
||||
"create_projects_page",
|
||||
"create_settings_page",
|
||||
]
|
||||
374
src/ui/pages/projects_page.py
Normal file
374
src/ui/pages/projects_page.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Projects page for managing generations and history."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def create_projects_page(
|
||||
get_projects: Callable[[], list[dict]],
|
||||
get_generations: Callable[[str, int, int], list[dict]],
|
||||
delete_generation: Callable[[str], bool],
|
||||
export_project: Callable[[str], str],
|
||||
create_project: Callable[[str, str], dict],
|
||||
) -> dict[str, Any]:
|
||||
"""Create projects management page.
|
||||
|
||||
Args:
|
||||
get_projects: Function to get all projects
|
||||
get_generations: Function to get generations (project_id, limit, offset)
|
||||
delete_generation: Function to delete a generation
|
||||
export_project: Function to export project as ZIP
|
||||
create_project: Function to create new project
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("# Projects")
|
||||
gr.Markdown("Browse and manage your generations")
|
||||
|
||||
with gr.Row():
|
||||
# Left sidebar - project list
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("### Projects")
|
||||
|
||||
with gr.Row():
|
||||
new_project_name = gr.Textbox(
|
||||
placeholder="New project name...",
|
||||
show_label=False,
|
||||
scale=3,
|
||||
)
|
||||
new_project_btn = gr.Button("+", size="sm", scale=1)
|
||||
|
||||
project_list = gr.Dataframe(
|
||||
headers=["ID", "Name", "Count"],
|
||||
datatype=["str", "str", "number"],
|
||||
col_count=(3, "fixed"),
|
||||
interactive=False,
|
||||
height=400,
|
||||
)
|
||||
|
||||
refresh_projects_btn = gr.Button("Refresh Projects", size="sm")
|
||||
|
||||
# Main content - generations
|
||||
with gr.Column(scale=3):
|
||||
# Selected project info
|
||||
selected_project_id = gr.State(value=None)
|
||||
selected_project_name = gr.Markdown("### Select a project")
|
||||
|
||||
# Filters
|
||||
with gr.Row():
|
||||
model_filter = gr.Dropdown(
|
||||
label="Model",
|
||||
choices=[
|
||||
("All", "all"),
|
||||
("MusicGen", "musicgen"),
|
||||
("AudioGen", "audiogen"),
|
||||
("MAGNeT", "magnet"),
|
||||
("Style", "musicgen-style"),
|
||||
("JASCO", "jasco"),
|
||||
],
|
||||
value="all",
|
||||
scale=1,
|
||||
)
|
||||
sort_by = gr.Dropdown(
|
||||
label="Sort By",
|
||||
choices=[
|
||||
("Newest First", "newest"),
|
||||
("Oldest First", "oldest"),
|
||||
("Duration (Long)", "duration_desc"),
|
||||
("Duration (Short)", "duration_asc"),
|
||||
],
|
||||
value="newest",
|
||||
scale=1,
|
||||
)
|
||||
search_input = gr.Textbox(
|
||||
label="Search Prompts",
|
||||
placeholder="Search...",
|
||||
scale=2,
|
||||
)
|
||||
|
||||
# Generations grid
|
||||
generations_gallery = gr.Gallery(
|
||||
label="Generations",
|
||||
columns=3,
|
||||
rows=3,
|
||||
height=400,
|
||||
object_fit="contain",
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
# Pagination
|
||||
with gr.Row():
|
||||
prev_page_btn = gr.Button("← Previous", size="sm")
|
||||
page_info = gr.Markdown("Page 1 of 1")
|
||||
next_page_btn = gr.Button("Next →", size="sm")
|
||||
|
||||
current_page = gr.State(value=1)
|
||||
total_pages = gr.State(value=1)
|
||||
|
||||
# Selected generation details
|
||||
gr.Markdown("---")
|
||||
gr.Markdown("### Generation Details")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
selected_audio = gr.Audio(
|
||||
label="Audio",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column(scale=2):
|
||||
selected_prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
interactive=False,
|
||||
lines=2,
|
||||
)
|
||||
with gr.Row():
|
||||
selected_model = gr.Textbox(
|
||||
label="Model",
|
||||
interactive=False,
|
||||
)
|
||||
selected_duration = gr.Textbox(
|
||||
label="Duration",
|
||||
interactive=False,
|
||||
)
|
||||
with gr.Row():
|
||||
selected_seed = gr.Textbox(
|
||||
label="Seed",
|
||||
interactive=False,
|
||||
)
|
||||
selected_date = gr.Textbox(
|
||||
label="Created",
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
# Action buttons
|
||||
with gr.Row():
|
||||
regenerate_btn = gr.Button("Regenerate", variant="secondary")
|
||||
download_btn = gr.Button("Download", variant="secondary")
|
||||
delete_btn = gr.Button("Delete", variant="stop")
|
||||
export_project_btn = gr.Button("Export Project", variant="secondary")
|
||||
|
||||
# Event handlers
|
||||
|
||||
def load_projects():
|
||||
"""Load all projects into the list."""
|
||||
projects = get_projects()
|
||||
data = []
|
||||
for p in projects:
|
||||
data.append([
|
||||
p.get("id", ""),
|
||||
p.get("name", "Untitled"),
|
||||
p.get("generation_count", 0),
|
||||
])
|
||||
return data
|
||||
|
||||
def on_project_select(evt: gr.SelectData, df):
|
||||
"""Handle project selection from dataframe."""
|
||||
if evt.index is None:
|
||||
return None, "### Select a project"
|
||||
|
||||
row = evt.index[0]
|
||||
if row < len(df):
|
||||
project_id = df[row][0]
|
||||
project_name = df[row][1]
|
||||
return project_id, f"### {project_name}"
|
||||
|
||||
return None, "### Select a project"
|
||||
|
||||
def load_generations(project_id, page, model, sort, search):
|
||||
"""Load generations for selected project."""
|
||||
if not project_id:
|
||||
return [], "Page 0 of 0", 1, 1
|
||||
|
||||
limit = 9 # 3x3 grid
|
||||
offset = (page - 1) * limit
|
||||
|
||||
gens = get_generations(project_id, limit + 1, offset)
|
||||
|
||||
# Check if there are more pages
|
||||
has_more = len(gens) > limit
|
||||
gens = gens[:limit]
|
||||
|
||||
# Filter by model if needed
|
||||
if model != "all":
|
||||
gens = [g for g in gens if g.get("model") == model]
|
||||
|
||||
# Filter by search
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
gens = [g for g in gens if search_lower in g.get("prompt", "").lower()]
|
||||
|
||||
# Sort
|
||||
if sort == "oldest":
|
||||
gens = sorted(gens, key=lambda x: x.get("created_at", ""))
|
||||
elif sort == "duration_desc":
|
||||
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0), reverse=True)
|
||||
elif sort == "duration_asc":
|
||||
gens = sorted(gens, key=lambda x: x.get("duration_seconds", 0))
|
||||
# Default is newest first (already sorted from DB)
|
||||
|
||||
# Build gallery items (using waveform images if available)
|
||||
gallery_items = []
|
||||
for g in gens:
|
||||
waveform = g.get("waveform_path")
|
||||
if waveform:
|
||||
gallery_items.append((waveform, g.get("prompt", "")[:50]))
|
||||
else:
|
||||
# Placeholder
|
||||
gallery_items.append((None, g.get("prompt", "")[:50]))
|
||||
|
||||
# Calculate total pages (estimate)
|
||||
total = offset + len(gens) + (1 if has_more else 0)
|
||||
total_p = max(1, (total + limit - 1) // limit)
|
||||
|
||||
return gallery_items, f"Page {page} of {total_p}", page, total_p
|
||||
|
||||
def on_generation_select(evt: gr.SelectData, project_id):
|
||||
"""Handle generation selection from gallery."""
|
||||
if evt.index is None or not project_id:
|
||||
return None, "", "", "", "", ""
|
||||
|
||||
# Get generations again to find the selected one
|
||||
gens = get_generations(project_id, 100, 0)
|
||||
if evt.index < len(gens):
|
||||
gen = gens[evt.index]
|
||||
return (
|
||||
gen.get("audio_path"),
|
||||
gen.get("prompt", ""),
|
||||
gen.get("model", ""),
|
||||
f"{gen.get('duration_seconds', 0):.1f}s",
|
||||
str(gen.get("seed", "")),
|
||||
gen.get("created_at", "")[:19] if gen.get("created_at") else "",
|
||||
)
|
||||
|
||||
return None, "", "", "", "", ""
|
||||
|
||||
def do_create_project(name):
|
||||
"""Create a new project."""
|
||||
if not name.strip():
|
||||
return gr.update(), "Please enter a project name"
|
||||
|
||||
project = create_project(name.strip(), "")
|
||||
projects_data = load_projects()
|
||||
return projects_data, f"Created project: {name}"
|
||||
|
||||
def do_delete_generation(project_id, audio_path):
|
||||
"""Delete selected generation."""
|
||||
if not audio_path:
|
||||
return "No generation selected"
|
||||
|
||||
# Find generation by audio path
|
||||
gens = get_generations(project_id, 100, 0)
|
||||
for g in gens:
|
||||
if g.get("audio_path") == audio_path:
|
||||
if delete_generation(g.get("id")):
|
||||
return "Generation deleted"
|
||||
else:
|
||||
return "Failed to delete"
|
||||
|
||||
return "Generation not found"
|
||||
|
||||
def do_export_project(project_id):
|
||||
"""Export project as ZIP."""
|
||||
if not project_id:
|
||||
return "No project selected"
|
||||
|
||||
try:
|
||||
zip_path = export_project(project_id)
|
||||
return f"Exported to: {zip_path}"
|
||||
except Exception as e:
|
||||
return f"Export failed: {str(e)}"
|
||||
|
||||
# Wire up events
|
||||
|
||||
refresh_projects_btn.click(
|
||||
fn=load_projects,
|
||||
outputs=[project_list],
|
||||
)
|
||||
|
||||
project_list.select(
|
||||
fn=on_project_select,
|
||||
inputs=[project_list],
|
||||
outputs=[selected_project_id, selected_project_name],
|
||||
).then(
|
||||
fn=load_generations,
|
||||
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||
)
|
||||
|
||||
# Filter changes reload generations
|
||||
for component in [model_filter, sort_by, search_input]:
|
||||
component.change(
|
||||
fn=load_generations,
|
||||
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||
)
|
||||
|
||||
# Pagination
|
||||
def go_prev(page, total):
|
||||
return max(1, page - 1)
|
||||
|
||||
def go_next(page, total):
|
||||
return min(total, page + 1)
|
||||
|
||||
prev_page_btn.click(
|
||||
fn=go_prev,
|
||||
inputs=[current_page, total_pages],
|
||||
outputs=[current_page],
|
||||
).then(
|
||||
fn=load_generations,
|
||||
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||
)
|
||||
|
||||
next_page_btn.click(
|
||||
fn=go_next,
|
||||
inputs=[current_page, total_pages],
|
||||
outputs=[current_page],
|
||||
).then(
|
||||
fn=load_generations,
|
||||
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||
)
|
||||
|
||||
# Generation selection
|
||||
generations_gallery.select(
|
||||
fn=on_generation_select,
|
||||
inputs=[selected_project_id],
|
||||
outputs=[selected_audio, selected_prompt, selected_model, selected_duration, selected_seed, selected_date],
|
||||
)
|
||||
|
||||
# Actions
|
||||
new_project_btn.click(
|
||||
fn=do_create_project,
|
||||
inputs=[new_project_name],
|
||||
outputs=[project_list, selected_project_name],
|
||||
)
|
||||
|
||||
delete_btn.click(
|
||||
fn=do_delete_generation,
|
||||
inputs=[selected_project_id, selected_audio],
|
||||
outputs=[selected_project_name],
|
||||
).then(
|
||||
fn=load_generations,
|
||||
inputs=[selected_project_id, current_page, model_filter, sort_by, search_input],
|
||||
outputs=[generations_gallery, page_info, current_page, total_pages],
|
||||
)
|
||||
|
||||
export_project_btn.click(
|
||||
fn=do_export_project,
|
||||
inputs=[selected_project_id],
|
||||
outputs=[selected_project_name],
|
||||
)
|
||||
|
||||
return {
|
||||
"project_list": project_list,
|
||||
"generations_gallery": generations_gallery,
|
||||
"selected_audio": selected_audio,
|
||||
"selected_project_id": selected_project_id,
|
||||
"refresh_fn": load_projects,
|
||||
}
|
||||
397
src/ui/pages/settings_page.py
Normal file
397
src/ui/pages/settings_page.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Settings page for application configuration."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def create_settings_page(
|
||||
get_settings: Callable[[], dict],
|
||||
update_settings: Callable[[dict], bool],
|
||||
get_gpu_info: Callable[[], dict],
|
||||
clear_cache: Callable[[], bool],
|
||||
unload_all_models: Callable[[], bool],
|
||||
) -> dict[str, Any]:
|
||||
"""Create settings management page.
|
||||
|
||||
Args:
|
||||
get_settings: Function to get current settings
|
||||
update_settings: Function to update settings
|
||||
get_gpu_info: Function to get GPU information
|
||||
clear_cache: Function to clear model cache
|
||||
unload_all_models: Function to unload all models
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("# Settings")
|
||||
gr.Markdown("Configure AudioCraft Studio")
|
||||
|
||||
with gr.Tabs():
|
||||
# General Settings Tab
|
||||
with gr.TabItem("General"):
|
||||
with gr.Group():
|
||||
gr.Markdown("### Output Settings")
|
||||
|
||||
output_dir = gr.Textbox(
|
||||
label="Output Directory",
|
||||
placeholder="/path/to/output",
|
||||
info="Where generated audio files are saved",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
default_format = gr.Dropdown(
|
||||
label="Default Audio Format",
|
||||
choices=[("WAV", "wav"), ("MP3", "mp3"), ("FLAC", "flac"), ("OGG", "ogg")],
|
||||
value="wav",
|
||||
)
|
||||
sample_rate = gr.Dropdown(
|
||||
label="Sample Rate",
|
||||
choices=[
|
||||
("32000 Hz (AudioCraft default)", 32000),
|
||||
("44100 Hz (CD quality)", 44100),
|
||||
("48000 Hz (Video standard)", 48000),
|
||||
],
|
||||
value=32000,
|
||||
)
|
||||
|
||||
normalize_audio = gr.Checkbox(
|
||||
label="Normalize audio output",
|
||||
value=True,
|
||||
info="Normalize audio levels to prevent clipping",
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### Interface Settings")
|
||||
|
||||
theme_mode = gr.Radio(
|
||||
label="Theme",
|
||||
choices=["Dark", "Light", "System"],
|
||||
value="Dark",
|
||||
)
|
||||
|
||||
show_advanced = gr.Checkbox(
|
||||
label="Show advanced parameters by default",
|
||||
value=False,
|
||||
)
|
||||
|
||||
auto_play = gr.Checkbox(
|
||||
label="Auto-play generated audio",
|
||||
value=True,
|
||||
)
|
||||
|
||||
# GPU & Memory Tab
|
||||
with gr.TabItem("GPU & Memory"):
|
||||
with gr.Group():
|
||||
gr.Markdown("### GPU Information")
|
||||
|
||||
gpu_info_display = gr.JSON(
|
||||
label="GPU Status",
|
||||
value={},
|
||||
)
|
||||
|
||||
refresh_gpu_btn = gr.Button("Refresh GPU Info", size="sm")
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### Memory Management")
|
||||
|
||||
comfyui_reserve = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=16,
|
||||
value=10,
|
||||
step=0.5,
|
||||
label="ComfyUI VRAM Reserve (GB)",
|
||||
info="VRAM to reserve for ComfyUI when running alongside",
|
||||
)
|
||||
|
||||
idle_timeout = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=60,
|
||||
value=15,
|
||||
step=1,
|
||||
label="Idle Model Timeout (minutes)",
|
||||
info="Unload models after this period of inactivity",
|
||||
)
|
||||
|
||||
max_loaded = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=5,
|
||||
value=2,
|
||||
step=1,
|
||||
label="Maximum Loaded Models",
|
||||
info="Maximum number of models to keep in memory",
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### Cache Management")
|
||||
|
||||
with gr.Row():
|
||||
clear_cache_btn = gr.Button("Clear Model Cache", variant="secondary")
|
||||
unload_models_btn = gr.Button("Unload All Models", variant="stop")
|
||||
|
||||
cache_status = gr.Markdown("Cache status: Ready")
|
||||
|
||||
# Model Defaults Tab
|
||||
with gr.TabItem("Model Defaults"):
|
||||
with gr.Group():
|
||||
gr.Markdown("### MusicGen Defaults")
|
||||
|
||||
with gr.Row():
|
||||
musicgen_variant = gr.Dropdown(
|
||||
label="Default Variant",
|
||||
choices=[
|
||||
("Small", "small"),
|
||||
("Medium", "medium"),
|
||||
("Large", "large"),
|
||||
("Melody", "melody"),
|
||||
],
|
||||
value="medium",
|
||||
)
|
||||
musicgen_duration = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Default Duration (s)",
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### AudioGen Defaults")
|
||||
|
||||
audiogen_duration = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=5,
|
||||
step=1,
|
||||
label="Default Duration (s)",
|
||||
)
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### MAGNeT Defaults")
|
||||
|
||||
with gr.Row():
|
||||
magnet_variant = gr.Dropdown(
|
||||
label="Default Variant",
|
||||
choices=[
|
||||
("Small Music", "small"),
|
||||
("Medium Music", "medium"),
|
||||
("Small Audio", "audio-small"),
|
||||
("Medium Audio", "audio-medium"),
|
||||
],
|
||||
value="medium",
|
||||
)
|
||||
magnet_decoding_steps = gr.Slider(
|
||||
minimum=10,
|
||||
maximum=100,
|
||||
value=20,
|
||||
step=5,
|
||||
label="Decoding Steps",
|
||||
)
|
||||
|
||||
# API Settings Tab
|
||||
with gr.TabItem("API"):
|
||||
with gr.Group():
|
||||
gr.Markdown("### REST API Configuration")
|
||||
|
||||
api_enabled = gr.Checkbox(
|
||||
label="Enable REST API",
|
||||
value=True,
|
||||
info="Enable FastAPI endpoints for programmatic access",
|
||||
)
|
||||
|
||||
api_port = gr.Number(
|
||||
value=8000,
|
||||
label="API Port",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
api_key_display = gr.Textbox(
|
||||
label="API Key",
|
||||
value="••••••••",
|
||||
interactive=False,
|
||||
)
|
||||
regenerate_key_btn = gr.Button("Regenerate", size="sm")
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown("### Rate Limiting")
|
||||
|
||||
rate_limit = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=100,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Requests per minute",
|
||||
)
|
||||
|
||||
max_batch_size = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=4,
|
||||
step=1,
|
||||
label="Maximum batch size",
|
||||
)
|
||||
|
||||
# Queue Settings Tab
|
||||
with gr.TabItem("Queue"):
|
||||
with gr.Group():
|
||||
gr.Markdown("### Batch Processing")
|
||||
|
||||
max_queue_size = gr.Slider(
|
||||
minimum=10,
|
||||
maximum=500,
|
||||
value=100,
|
||||
step=10,
|
||||
label="Maximum Queue Size",
|
||||
)
|
||||
|
||||
max_workers = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=4,
|
||||
value=1,
|
||||
step=1,
|
||||
label="Concurrent Workers",
|
||||
info="Number of parallel generation workers",
|
||||
)
|
||||
|
||||
priority_queue = gr.Checkbox(
|
||||
label="Enable priority queue",
|
||||
value=False,
|
||||
info="Allow high-priority jobs to skip the queue",
|
||||
)
|
||||
|
||||
# Save button
|
||||
gr.Markdown("---")
|
||||
with gr.Row():
|
||||
save_btn = gr.Button("Save Settings", variant="primary", scale=2)
|
||||
reset_btn = gr.Button("Reset to Defaults", variant="secondary", scale=1)
|
||||
|
||||
settings_status = gr.Markdown("")
|
||||
|
||||
# Event handlers
|
||||
|
||||
def load_settings():
|
||||
"""Load current settings into form."""
|
||||
settings = get_settings()
|
||||
return (
|
||||
settings.get("output_dir", ""),
|
||||
settings.get("default_format", "wav"),
|
||||
settings.get("sample_rate", 32000),
|
||||
settings.get("normalize_audio", True),
|
||||
settings.get("theme_mode", "Dark"),
|
||||
settings.get("show_advanced", False),
|
||||
settings.get("auto_play", True),
|
||||
settings.get("comfyui_reserve_gb", 10),
|
||||
settings.get("idle_timeout_minutes", 15),
|
||||
settings.get("max_loaded_models", 2),
|
||||
settings.get("musicgen_variant", "medium"),
|
||||
settings.get("musicgen_duration", 10),
|
||||
settings.get("audiogen_duration", 5),
|
||||
settings.get("magnet_variant", "medium"),
|
||||
settings.get("magnet_decoding_steps", 20),
|
||||
settings.get("api_enabled", True),
|
||||
settings.get("api_port", 8000),
|
||||
settings.get("rate_limit", 10),
|
||||
settings.get("max_batch_size", 4),
|
||||
settings.get("max_queue_size", 100),
|
||||
settings.get("max_workers", 1),
|
||||
settings.get("priority_queue", False),
|
||||
)
|
||||
|
||||
def save_settings(
|
||||
out_dir, fmt, sr, norm, theme, adv, play,
|
||||
comfyui_res, idle_to, max_load,
|
||||
mg_var, mg_dur, ag_dur, mn_var, mn_steps,
|
||||
api_en, api_p, rate, batch, queue_sz, workers, priority
|
||||
):
|
||||
"""Save settings from form."""
|
||||
settings = {
|
||||
"output_dir": out_dir,
|
||||
"default_format": fmt,
|
||||
"sample_rate": sr,
|
||||
"normalize_audio": norm,
|
||||
"theme_mode": theme,
|
||||
"show_advanced": adv,
|
||||
"auto_play": play,
|
||||
"comfyui_reserve_gb": comfyui_res,
|
||||
"idle_timeout_minutes": idle_to,
|
||||
"max_loaded_models": max_load,
|
||||
"musicgen_variant": mg_var,
|
||||
"musicgen_duration": mg_dur,
|
||||
"audiogen_duration": ag_dur,
|
||||
"magnet_variant": mn_var,
|
||||
"magnet_decoding_steps": mn_steps,
|
||||
"api_enabled": api_en,
|
||||
"api_port": int(api_p),
|
||||
"rate_limit": rate,
|
||||
"max_batch_size": batch,
|
||||
"max_queue_size": queue_sz,
|
||||
"max_workers": workers,
|
||||
"priority_queue": priority,
|
||||
}
|
||||
|
||||
if update_settings(settings):
|
||||
return "✅ Settings saved successfully"
|
||||
else:
|
||||
return "❌ Failed to save settings"
|
||||
|
||||
def do_refresh_gpu():
|
||||
"""Refresh GPU info display."""
|
||||
return get_gpu_info()
|
||||
|
||||
def do_clear_cache():
|
||||
"""Clear model cache."""
|
||||
if clear_cache():
|
||||
return "✅ Cache cleared"
|
||||
return "❌ Failed to clear cache"
|
||||
|
||||
def do_unload_models():
|
||||
"""Unload all models."""
|
||||
if unload_all_models():
|
||||
return "✅ All models unloaded"
|
||||
return "❌ Failed to unload models"
|
||||
|
||||
# Wire up events
|
||||
|
||||
refresh_gpu_btn.click(
|
||||
fn=do_refresh_gpu,
|
||||
outputs=[gpu_info_display],
|
||||
)
|
||||
|
||||
clear_cache_btn.click(
|
||||
fn=do_clear_cache,
|
||||
outputs=[cache_status],
|
||||
)
|
||||
|
||||
unload_models_btn.click(
|
||||
fn=do_unload_models,
|
||||
outputs=[cache_status],
|
||||
)
|
||||
|
||||
save_btn.click(
|
||||
fn=save_settings,
|
||||
inputs=[
|
||||
output_dir, default_format, sample_rate, normalize_audio,
|
||||
theme_mode, show_advanced, auto_play,
|
||||
comfyui_reserve, idle_timeout, max_loaded,
|
||||
musicgen_variant, musicgen_duration, audiogen_duration,
|
||||
magnet_variant, magnet_decoding_steps,
|
||||
api_enabled, api_port, rate_limit, max_batch_size,
|
||||
max_queue_size, max_workers, priority_queue,
|
||||
],
|
||||
outputs=[settings_status],
|
||||
)
|
||||
|
||||
return {
|
||||
"output_dir": output_dir,
|
||||
"default_format": default_format,
|
||||
"sample_rate": sample_rate,
|
||||
"comfyui_reserve": comfyui_reserve,
|
||||
"idle_timeout": idle_timeout,
|
||||
"api_enabled": api_enabled,
|
||||
"save_btn": save_btn,
|
||||
"settings_status": settings_status,
|
||||
"load_fn": load_settings,
|
||||
}
|
||||
294
src/ui/state.py
Normal file
294
src/ui/state.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""State management for Gradio UI."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class UIState:
|
||||
"""Global UI state container."""
|
||||
|
||||
# Current view
|
||||
current_tab: str = "dashboard"
|
||||
|
||||
# Generation state
|
||||
is_generating: bool = False
|
||||
current_job_id: Optional[str] = None
|
||||
|
||||
# Selected items
|
||||
selected_project_id: Optional[str] = None
|
||||
selected_generation_id: Optional[str] = None
|
||||
selected_preset_id: Optional[str] = None
|
||||
|
||||
# Model state
|
||||
selected_model: str = "musicgen"
|
||||
selected_variant: str = "medium"
|
||||
|
||||
# Generation parameters (current values)
|
||||
prompt: str = ""
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
top_k: int = 250
|
||||
top_p: float = 0.0
|
||||
cfg_coef: float = 3.0
|
||||
seed: Optional[int] = None
|
||||
|
||||
# Conditioning
|
||||
melody_audio: Optional[str] = None
|
||||
style_audio: Optional[str] = None
|
||||
chords: list[dict[str, Any]] = field(default_factory=list)
|
||||
drums_pattern: str = ""
|
||||
bpm: float = 120.0
|
||||
|
||||
# UI preferences
|
||||
show_advanced: bool = False
|
||||
auto_play: bool = True
|
||||
|
||||
def reset_generation_params(self) -> None:
|
||||
"""Reset generation parameters to defaults."""
|
||||
self.prompt = ""
|
||||
self.duration = 10.0
|
||||
self.temperature = 1.0
|
||||
self.top_k = 250
|
||||
self.top_p = 0.0
|
||||
self.cfg_coef = 3.0
|
||||
self.seed = None
|
||||
self.melody_audio = None
|
||||
self.style_audio = None
|
||||
self.chords = []
|
||||
self.drums_pattern = ""
|
||||
|
||||
def apply_preset(self, preset: dict[str, Any]) -> None:
|
||||
"""Apply preset parameters."""
|
||||
params = preset.get("parameters", {})
|
||||
self.duration = params.get("duration", self.duration)
|
||||
self.temperature = params.get("temperature", self.temperature)
|
||||
self.top_k = params.get("top_k", self.top_k)
|
||||
self.top_p = params.get("top_p", self.top_p)
|
||||
self.cfg_coef = params.get("cfg_coef", self.cfg_coef)
|
||||
|
||||
def to_generation_params(self) -> dict[str, Any]:
|
||||
"""Convert current state to generation parameters."""
|
||||
return {
|
||||
"duration": self.duration,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"cfg_coef": self.cfg_coef,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
|
||||
# Default presets for each model
|
||||
DEFAULT_PRESETS = {
|
||||
"musicgen": [
|
||||
{
|
||||
"id": "cinematic",
|
||||
"name": "Cinematic",
|
||||
"description": "Epic orchestral soundscapes",
|
||||
"parameters": {
|
||||
"duration": 30,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "electronic",
|
||||
"name": "Electronic",
|
||||
"description": "Synthesizers and beats",
|
||||
"parameters": {
|
||||
"duration": 15,
|
||||
"temperature": 1.1,
|
||||
"top_k": 200,
|
||||
"cfg_coef": 3.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "ambient",
|
||||
"name": "Ambient",
|
||||
"description": "Atmospheric and calm",
|
||||
"parameters": {
|
||||
"duration": 30,
|
||||
"temperature": 0.9,
|
||||
"top_k": 300,
|
||||
"cfg_coef": 2.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "rock",
|
||||
"name": "Rock",
|
||||
"description": "Guitar-driven energy",
|
||||
"parameters": {
|
||||
"duration": 20,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "jazz",
|
||||
"name": "Jazz",
|
||||
"description": "Smooth and improvisational",
|
||||
"parameters": {
|
||||
"duration": 20,
|
||||
"temperature": 1.2,
|
||||
"top_k": 200,
|
||||
"cfg_coef": 2.5,
|
||||
},
|
||||
},
|
||||
],
|
||||
"audiogen": [
|
||||
{
|
||||
"id": "nature",
|
||||
"name": "Nature",
|
||||
"description": "Natural environments",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "urban",
|
||||
"name": "Urban",
|
||||
"description": "City sounds",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "mechanical",
|
||||
"name": "Mechanical",
|
||||
"description": "Machines and tools",
|
||||
"parameters": {
|
||||
"duration": 5,
|
||||
"temperature": 0.9,
|
||||
"top_k": 200,
|
||||
"cfg_coef": 3.5,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "weather",
|
||||
"name": "Weather",
|
||||
"description": "Rain, thunder, wind",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
],
|
||||
"magnet": [
|
||||
{
|
||||
"id": "fast",
|
||||
"name": "Fast",
|
||||
"description": "Quick generation",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 3.0,
|
||||
"top_p": 0.9,
|
||||
"cfg_coef": 3.0,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "quality",
|
||||
"name": "Quality",
|
||||
"description": "Higher quality output",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 2.5,
|
||||
"top_p": 0.85,
|
||||
"cfg_coef": 4.0,
|
||||
},
|
||||
},
|
||||
],
|
||||
"musicgen-style": [
|
||||
{
|
||||
"id": "style_transfer",
|
||||
"name": "Style Transfer",
|
||||
"description": "Copy style from reference",
|
||||
"parameters": {
|
||||
"duration": 15,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
"eval_q": 3,
|
||||
"excerpt_length": 3.0,
|
||||
},
|
||||
},
|
||||
],
|
||||
"jasco": [
|
||||
{
|
||||
"id": "pop",
|
||||
"name": "Pop",
|
||||
"description": "Pop chord progressions",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
"bpm": 120,
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "blues",
|
||||
"name": "Blues",
|
||||
"description": "12-bar blues",
|
||||
"parameters": {
|
||||
"duration": 10,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"cfg_coef": 3.0,
|
||||
"bpm": 100,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# Prompt suggestions for each model
|
||||
PROMPT_SUGGESTIONS = {
|
||||
"musicgen": [
|
||||
"Epic orchestral music with dramatic strings and powerful brass",
|
||||
"Upbeat electronic dance music with synthesizers and heavy bass",
|
||||
"Calm acoustic guitar melody with soft piano accompaniment",
|
||||
"Energetic rock song with electric guitars and driving drums",
|
||||
"Smooth jazz with saxophone solo and walking bass",
|
||||
"Ambient soundscape with ethereal pads and gentle textures",
|
||||
"Cinematic trailer music building to an epic climax",
|
||||
"Lo-fi hip hop beats with vinyl crackle and mellow keys",
|
||||
],
|
||||
"audiogen": [
|
||||
"Thunder and heavy rain with occasional lightning strikes",
|
||||
"Busy city street with traffic, horns, and distant sirens",
|
||||
"Forest ambience with birds singing and wind in trees",
|
||||
"Ocean waves crashing on a rocky shore",
|
||||
"Crackling fireplace with wood popping",
|
||||
"Coffee shop atmosphere with murmuring voices and clinking cups",
|
||||
"Construction site with hammering and machinery",
|
||||
"Spaceship engine humming with occasional beeps",
|
||||
],
|
||||
"magnet": [
|
||||
"Energetic pop music with catchy melody",
|
||||
"Dark electronic music with deep bass",
|
||||
"Cheerful ukulele tune with whistling",
|
||||
"Dramatic piano piece with building intensity",
|
||||
],
|
||||
"musicgen-style": [
|
||||
"Generate music in the style of the uploaded reference",
|
||||
"Create a variation with similar instrumentation",
|
||||
"Compose a piece matching the mood of the reference",
|
||||
],
|
||||
"jasco": [
|
||||
"Upbeat pop song with the specified chord progression",
|
||||
"Mellow jazz piece following the chord changes",
|
||||
"Rock anthem with powerful drum pattern",
|
||||
"Electronic track with syncopated rhythms",
|
||||
],
|
||||
}
|
||||
17
src/ui/tabs/__init__.py
Normal file
17
src/ui/tabs/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Model tabs for AudioCraft Studio."""
|
||||
|
||||
from src.ui.tabs.dashboard_tab import create_dashboard_tab
|
||||
from src.ui.tabs.musicgen_tab import create_musicgen_tab
|
||||
from src.ui.tabs.audiogen_tab import create_audiogen_tab
|
||||
from src.ui.tabs.magnet_tab import create_magnet_tab
|
||||
from src.ui.tabs.style_tab import create_style_tab
|
||||
from src.ui.tabs.jasco_tab import create_jasco_tab
|
||||
|
||||
__all__ = [
|
||||
"create_dashboard_tab",
|
||||
"create_musicgen_tab",
|
||||
"create_audiogen_tab",
|
||||
"create_magnet_tab",
|
||||
"create_style_tab",
|
||||
"create_jasco_tab",
|
||||
]
|
||||
283
src/ui/tabs/audiogen_tab.py
Normal file
283
src/ui/tabs/audiogen_tab.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""AudioGen tab for text-to-sound generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
AUDIOGEN_VARIANTS = [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, balanced quality/speed"},
|
||||
]
|
||||
|
||||
|
||||
def create_audiogen_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create AudioGen generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("audiogen", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("audiogen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🔊 AudioGen")
|
||||
gr.Markdown("Generate sound effects and environmental audio from text")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant (AudioGen only has medium)
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in AUDIOGEN_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the sound you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:6]):
|
||||
btn = gr.Button(suggestion[:50] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=5,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
info="AudioGen works best with shorter clips",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🔊 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 5),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
result, generation = await generate_fn(
|
||||
model_id="audiogen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="audiogen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
166
src/ui/tabs/dashboard_tab.py
Normal file
166
src/ui/tabs/dashboard_tab.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Dashboard tab - home page with model overview and quick actions."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
MODEL_INFO = {
|
||||
"musicgen": {
|
||||
"name": "MusicGen",
|
||||
"icon": "🎵",
|
||||
"description": "Text-to-music generation with optional melody conditioning",
|
||||
"capabilities": ["Text prompts", "Melody conditioning", "Stereo output"],
|
||||
},
|
||||
"audiogen": {
|
||||
"name": "AudioGen",
|
||||
"icon": "🔊",
|
||||
"description": "Text-to-sound effects and environmental audio",
|
||||
"capabilities": ["Sound effects", "Ambiences", "Foley"],
|
||||
},
|
||||
"magnet": {
|
||||
"name": "MAGNeT",
|
||||
"icon": "⚡",
|
||||
"description": "Fast non-autoregressive music generation",
|
||||
"capabilities": ["Fast generation", "Music", "Sound effects"],
|
||||
},
|
||||
"musicgen-style": {
|
||||
"name": "MusicGen Style",
|
||||
"icon": "🎨",
|
||||
"description": "Style-conditioned music from reference audio",
|
||||
"capabilities": ["Style transfer", "Reference audio", "Text prompts"],
|
||||
},
|
||||
"jasco": {
|
||||
"name": "JASCO",
|
||||
"icon": "🎹",
|
||||
"description": "Chord and drum-conditioned music generation",
|
||||
"capabilities": ["Chord control", "Drum patterns", "Symbolic conditioning"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_dashboard_tab(
|
||||
get_queue_status: Callable[[], dict[str, Any]],
|
||||
get_recent_generations: Callable[[int], list[dict[str, Any]]],
|
||||
get_gpu_status: Callable[[], dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Create dashboard tab with model overview and status.
|
||||
|
||||
Args:
|
||||
get_queue_status: Function to get generation queue status
|
||||
get_recent_generations: Function to get recent generations
|
||||
get_gpu_status: Function to get GPU status
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
def refresh_dashboard():
|
||||
"""Refresh all dashboard data."""
|
||||
queue = get_queue_status()
|
||||
recent = get_recent_generations(5)
|
||||
gpu = get_gpu_status()
|
||||
|
||||
# Format queue status
|
||||
queue_size = queue.get("queue_size", 0)
|
||||
queue_text = f"**Queue:** {queue_size} job(s) pending"
|
||||
|
||||
# Format recent generations
|
||||
if recent:
|
||||
recent_items = []
|
||||
for gen in recent[:5]:
|
||||
model = gen.get("model", "unknown")
|
||||
prompt = gen.get("prompt", "")[:50]
|
||||
duration = gen.get("duration_seconds", 0)
|
||||
recent_items.append(f"• **{model}** ({duration:.0f}s): {prompt}...")
|
||||
recent_text = "\n".join(recent_items)
|
||||
else:
|
||||
recent_text = "No recent generations"
|
||||
|
||||
# Format GPU status
|
||||
used_gb = gpu.get("used_gb", 0)
|
||||
total_gb = gpu.get("total_gb", 24)
|
||||
util = gpu.get("utilization_percent", 0)
|
||||
gpu_text = f"**GPU:** {used_gb:.1f}/{total_gb:.1f} GB ({util:.0f}%)"
|
||||
|
||||
return queue_text, recent_text, gpu_text
|
||||
|
||||
with gr.Column():
|
||||
# Header
|
||||
gr.Markdown("# AudioCraft Studio")
|
||||
gr.Markdown("AI-powered music and sound generation")
|
||||
|
||||
# Status bar
|
||||
with gr.Row():
|
||||
queue_status = gr.Markdown("**Queue:** Loading...")
|
||||
gpu_status = gr.Markdown("**GPU:** Loading...")
|
||||
refresh_btn = gr.Button("🔄 Refresh", size="sm")
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
# Model cards
|
||||
gr.Markdown("## Models")
|
||||
|
||||
with gr.Row():
|
||||
# First row of cards
|
||||
for model_id in ["musicgen", "audiogen", "magnet"]:
|
||||
info = MODEL_INFO[model_id]
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||
gr.Markdown(info["description"])
|
||||
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||
gr.Button(
|
||||
f"Open {info['name']}",
|
||||
variant="primary",
|
||||
size="sm",
|
||||
elem_id=f"btn_{model_id}",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Second row of cards
|
||||
for model_id in ["musicgen-style", "jasco"]:
|
||||
info = MODEL_INFO[model_id]
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||
gr.Markdown(info["description"])
|
||||
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||
gr.Button(
|
||||
f"Open {info['name']}",
|
||||
variant="primary",
|
||||
size="sm",
|
||||
elem_id=f"btn_{model_id}",
|
||||
)
|
||||
|
||||
# Empty column for balance
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
# Recent generations and queue
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("## Recent Generations")
|
||||
recent_list = gr.Markdown("Loading...")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("## Quick Actions")
|
||||
with gr.Group():
|
||||
gr.Button("📁 Browse Projects", variant="secondary")
|
||||
gr.Button("⚙️ Settings", variant="secondary")
|
||||
gr.Button("📖 API Documentation", variant="secondary")
|
||||
|
||||
# Refresh handler
|
||||
refresh_btn.click(
|
||||
fn=refresh_dashboard,
|
||||
outputs=[queue_status, recent_list, gpu_status],
|
||||
)
|
||||
|
||||
return {
|
||||
"queue_status": queue_status,
|
||||
"gpu_status": gpu_status,
|
||||
"recent_list": recent_list,
|
||||
"refresh_btn": refresh_btn,
|
||||
"refresh_fn": refresh_dashboard,
|
||||
}
|
||||
364
src/ui/tabs/jasco_tab.py
Normal file
364
src/ui/tabs/jasco_tab.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""JASCO tab for chord and drum-conditioned generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
JASCO_VARIANTS = [
|
||||
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation"},
|
||||
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning"},
|
||||
]
|
||||
|
||||
# Common chord progressions
|
||||
CHORD_PRESETS = [
|
||||
{"name": "Pop I-V-vi-IV", "chords": "C G Am F"},
|
||||
{"name": "Jazz ii-V-I", "chords": "Dm7 G7 Cmaj7"},
|
||||
{"name": "Blues I-IV-V", "chords": "A7 D7 E7"},
|
||||
{"name": "Rock I-bVII-IV", "chords": "E D A"},
|
||||
{"name": "Minor i-VI-III-VII", "chords": "Am F C G"},
|
||||
]
|
||||
|
||||
|
||||
def create_jasco_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create JASCO generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("jasco", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎹 JASCO")
|
||||
gr.Markdown("Generate music conditioned on chords and drum patterns")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in JASCO_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="chords-drums",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Text Prompt",
|
||||
placeholder="Describe the music style, mood, instruments...",
|
||||
lines=2,
|
||||
max_lines=4,
|
||||
)
|
||||
|
||||
# Chord conditioning
|
||||
gr.Markdown("### Chord Progression")
|
||||
|
||||
chord_input = gr.Textbox(
|
||||
label="Chords",
|
||||
placeholder="C G Am F or Cmaj7 Dm7 G7 Cmaj7",
|
||||
lines=1,
|
||||
info="Space-separated chord symbols",
|
||||
)
|
||||
|
||||
# Chord presets
|
||||
with gr.Accordion("Chord Presets", open=False):
|
||||
chord_preset_btns = []
|
||||
with gr.Row():
|
||||
for cp in CHORD_PRESETS[:3]:
|
||||
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||
chord_preset_btns.append((btn, cp["chords"]))
|
||||
with gr.Row():
|
||||
for cp in CHORD_PRESETS[3:]:
|
||||
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||
chord_preset_btns.append((btn, cp["chords"]))
|
||||
|
||||
# Drum conditioning (for chords-drums variant)
|
||||
with gr.Group(visible=True) as drum_section:
|
||||
gr.Markdown("### Drum Pattern")
|
||||
|
||||
drum_input = gr.Audio(
|
||||
label="Drum Reference",
|
||||
type="filepath",
|
||||
sources=["upload"],
|
||||
)
|
||||
gr.Markdown("*Upload a drum loop to condition the rhythm*")
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
bpm_slider = gr.Slider(
|
||||
minimum=60,
|
||||
maximum=180,
|
||||
value=120,
|
||||
step=1,
|
||||
label="BPM",
|
||||
info="Tempo for chord timing",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎹 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("bpm", 120),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, bpm_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Variant change - show/hide drum section
|
||||
def on_variant_change(variant: str):
|
||||
show_drums = "drums" in variant.lower()
|
||||
return gr.update(visible=show_drums)
|
||||
|
||||
variant_dropdown.change(
|
||||
fn=on_variant_change,
|
||||
inputs=[variant_dropdown],
|
||||
outputs=[drum_section],
|
||||
)
|
||||
|
||||
# Chord presets
|
||||
for btn, chords in chord_preset_btns:
|
||||
btn.click(
|
||||
fn=lambda c=chords: c,
|
||||
outputs=[chord_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not chords:
|
||||
return (
|
||||
gr.update(value="Please enter a chord progression"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {
|
||||
"chords": chords,
|
||||
"bpm": bpm,
|
||||
}
|
||||
if drums and "drums" in variant.lower():
|
||||
conditioning["drums"] = drums
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="jasco",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
chord_input,
|
||||
drum_input,
|
||||
duration_slider,
|
||||
bpm_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not chords:
|
||||
return "Please enter a chord progression"
|
||||
|
||||
conditioning = {
|
||||
"chords": chords,
|
||||
"bpm": bpm,
|
||||
}
|
||||
if drums and "drums" in variant.lower():
|
||||
conditioning["drums"] = drums
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="jasco",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
chord_input,
|
||||
drum_input,
|
||||
duration_slider,
|
||||
bpm_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"chords": chord_input,
|
||||
"drums": drum_input,
|
||||
"duration": duration_slider,
|
||||
"bpm": bpm_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
316
src/ui/tabs/magnet_tab.py
Normal file
316
src/ui/tabs/magnet_tab.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""MAGNeT tab for fast non-autoregressive generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
MAGNET_VARIANTS = [
|
||||
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params"},
|
||||
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params"},
|
||||
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects"},
|
||||
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects"},
|
||||
]
|
||||
|
||||
|
||||
def create_magnet_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MAGNeT generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("magnet", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) # Reuse music suggestions
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## ⚡ MAGNeT")
|
||||
gr.Markdown("Fast non-autoregressive music and sound generation")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MAGNET_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the music or sound you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:4]):
|
||||
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
gr.Markdown("*MAGNeT uses different sampling compared to MusicGen*")
|
||||
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=5.0,
|
||||
value=3.0,
|
||||
step=0.1,
|
||||
label="Temperature",
|
||||
info="Higher values recommended (3.0 default)",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=0,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
info="0 recommended for MAGNeT",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.9,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
info="0.9 recommended for MAGNeT",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
decoding_steps_slider = gr.Slider(
|
||||
minimum=10,
|
||||
maximum=100,
|
||||
value=20,
|
||||
step=5,
|
||||
label="Decoding Steps",
|
||||
info="More steps = better quality, slower",
|
||||
)
|
||||
span_arrangement = gr.Dropdown(
|
||||
label="Span Arrangement",
|
||||
choices=[("No Overlap", "nonoverlap"), ("Overlap", "stride1")],
|
||||
value="nonoverlap",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("⚡ Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 3.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 0),
|
||||
params.get("top_p", 0.9),
|
||||
params.get("decoding_steps", 20),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, decoding_steps_slider],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
result, generation = await generate_fn(
|
||||
model_id="magnet",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
decoding_steps=int(decoding_steps),
|
||||
span_arrangement=span_arr,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
decoding_steps_slider,
|
||||
span_arrangement,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="magnet",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
decoding_steps=int(decoding_steps),
|
||||
span_arrangement=span_arr,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
decoding_steps_slider,
|
||||
span_arrangement,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"decoding_steps": decoding_steps_slider,
|
||||
"span_arrangement": span_arrangement,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
325
src/ui/tabs/musicgen_tab.py
Normal file
325
src/ui/tabs/musicgen_tab.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""MusicGen tab for text-to-music generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
MUSICGEN_VARIANTS = [
|
||||
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"},
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"},
|
||||
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"},
|
||||
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"},
|
||||
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"},
|
||||
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"},
|
||||
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"},
|
||||
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"},
|
||||
]
|
||||
|
||||
|
||||
def create_musicgen_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MusicGen generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("musicgen", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎵 MusicGen")
|
||||
gr.Markdown("Generate music from text descriptions")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the music you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:4]):
|
||||
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Melody conditioning (for melody variants)
|
||||
with gr.Group(visible=False) as melody_section:
|
||||
gr.Markdown("### Melody Conditioning")
|
||||
melody_input = gr.Audio(
|
||||
label="Reference Melody",
|
||||
type="filepath",
|
||||
sources=["upload", "microphone"],
|
||||
)
|
||||
gr.Markdown("*Upload audio to condition generation on its melody*")
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
# Custom preset - don't change values
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Variant change - show/hide melody section
|
||||
def on_variant_change(variant: str):
|
||||
show_melody = "melody" in variant.lower()
|
||||
return gr.update(visible=show_melody)
|
||||
|
||||
variant_dropdown.change(
|
||||
fn=on_variant_change,
|
||||
inputs=[variant_dropdown],
|
||||
outputs=[melody_section],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Update status
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {}
|
||||
if melody:
|
||||
conditioning["melody"] = melody
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="musicgen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
conditioning = {}
|
||||
if melody:
|
||||
conditioning["melody"] = melody
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="musicgen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"melody": melody_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
292
src/ui/tabs/style_tab.py
Normal file
292
src/ui/tabs/style_tab.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""MusicGen Style tab for style-conditioned generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
STYLE_VARIANTS = [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
|
||||
]
|
||||
|
||||
|
||||
def create_style_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MusicGen Style generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("musicgen-style", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎨 MusicGen Style")
|
||||
gr.Markdown("Generate music conditioned on the style of reference audio")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Text Prompt",
|
||||
placeholder="Describe additional characteristics for the music...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
info="Optional: combine with style conditioning",
|
||||
)
|
||||
|
||||
# Style conditioning (required)
|
||||
gr.Markdown("### Style Conditioning")
|
||||
gr.Markdown("*Upload reference audio to extract musical style*")
|
||||
|
||||
style_input = gr.Audio(
|
||||
label="Style Reference",
|
||||
type="filepath",
|
||||
sources=["upload", "microphone"],
|
||||
)
|
||||
|
||||
style_info = gr.Markdown(
|
||||
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
|
||||
)
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not style_audio:
|
||||
return (
|
||||
gr.update(value="Please upload a style reference audio"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {"style": style_audio}
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="musicgen-style",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
style_input,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not style_audio:
|
||||
return "Please upload a style reference audio"
|
||||
|
||||
conditioning = {"style": style_audio}
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="musicgen-style",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
style_input,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"style": style_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
303
src/ui/theme.py
Normal file
303
src/ui/theme.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""Custom Gradio theme for AudioCraft Studio."""
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def create_theme() -> gr.themes.Base:
|
||||
"""Create custom theme for AudioCraft Studio.
|
||||
|
||||
Returns:
|
||||
Gradio theme instance
|
||||
"""
|
||||
return gr.themes.Soft(
|
||||
primary_hue=gr.themes.colors.blue,
|
||||
secondary_hue=gr.themes.colors.slate,
|
||||
neutral_hue=gr.themes.colors.gray,
|
||||
font=[
|
||||
gr.themes.GoogleFont("Inter"),
|
||||
"ui-sans-serif",
|
||||
"system-ui",
|
||||
"sans-serif",
|
||||
],
|
||||
font_mono=[
|
||||
gr.themes.GoogleFont("JetBrains Mono"),
|
||||
"ui-monospace",
|
||||
"monospace",
|
||||
],
|
||||
).set(
|
||||
# Colors
|
||||
body_background_fill="#0f172a",
|
||||
body_background_fill_dark="#0f172a",
|
||||
background_fill_primary="#1e293b",
|
||||
background_fill_primary_dark="#1e293b",
|
||||
background_fill_secondary="#334155",
|
||||
background_fill_secondary_dark="#334155",
|
||||
border_color_primary="#475569",
|
||||
border_color_primary_dark="#475569",
|
||||
|
||||
# Text
|
||||
body_text_color="#e2e8f0",
|
||||
body_text_color_dark="#e2e8f0",
|
||||
body_text_color_subdued="#94a3b8",
|
||||
body_text_color_subdued_dark="#94a3b8",
|
||||
|
||||
# Buttons
|
||||
button_primary_background_fill="#3b82f6",
|
||||
button_primary_background_fill_dark="#3b82f6",
|
||||
button_primary_background_fill_hover="#2563eb",
|
||||
button_primary_background_fill_hover_dark="#2563eb",
|
||||
button_primary_text_color="#ffffff",
|
||||
button_primary_text_color_dark="#ffffff",
|
||||
|
||||
button_secondary_background_fill="#475569",
|
||||
button_secondary_background_fill_dark="#475569",
|
||||
button_secondary_background_fill_hover="#64748b",
|
||||
button_secondary_background_fill_hover_dark="#64748b",
|
||||
|
||||
# Inputs
|
||||
input_background_fill="#1e293b",
|
||||
input_background_fill_dark="#1e293b",
|
||||
input_border_color="#475569",
|
||||
input_border_color_dark="#475569",
|
||||
input_border_color_focus="#3b82f6",
|
||||
input_border_color_focus_dark="#3b82f6",
|
||||
|
||||
# Blocks
|
||||
block_background_fill="#1e293b",
|
||||
block_background_fill_dark="#1e293b",
|
||||
block_border_color="#334155",
|
||||
block_border_color_dark="#334155",
|
||||
block_label_background_fill="#334155",
|
||||
block_label_background_fill_dark="#334155",
|
||||
block_label_text_color="#e2e8f0",
|
||||
block_label_text_color_dark="#e2e8f0",
|
||||
block_title_text_color="#f1f5f9",
|
||||
block_title_text_color_dark="#f1f5f9",
|
||||
|
||||
# Tabs
|
||||
tab_nav_background_fill="#1e293b",
|
||||
|
||||
# Sliders
|
||||
slider_color="#3b82f6",
|
||||
slider_color_dark="#3b82f6",
|
||||
|
||||
# Shadows
|
||||
shadow_spread="4px",
|
||||
block_shadow="0 4px 6px -1px rgba(0, 0, 0, 0.3)",
|
||||
|
||||
# Spacing
|
||||
layout_gap="16px",
|
||||
block_padding="16px",
|
||||
panel_border_width="1px",
|
||||
|
||||
# Radius
|
||||
radius_sm="6px",
|
||||
radius_md="8px",
|
||||
radius_lg="12px",
|
||||
)
|
||||
|
||||
|
||||
# CSS overrides for additional customization
|
||||
CUSTOM_CSS = """
|
||||
/* Global styles */
|
||||
.gradio-container {
|
||||
max-width: 100% !important;
|
||||
}
|
||||
|
||||
/* Header styling */
|
||||
.header-title {
|
||||
font-size: 1.5rem;
|
||||
font-weight: 700;
|
||||
color: #f1f5f9;
|
||||
}
|
||||
|
||||
/* Sidebar styling */
|
||||
.sidebar {
|
||||
background: #1e293b;
|
||||
border-right: 1px solid #334155;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.sidebar-nav-btn {
|
||||
width: 100%;
|
||||
justify-content: flex-start;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
/* Model cards */
|
||||
.model-card {
|
||||
background: #334155;
|
||||
border-radius: 12px;
|
||||
padding: 1rem;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
}
|
||||
|
||||
.model-card:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
/* Audio player */
|
||||
.audio-player {
|
||||
background: #1e293b;
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
/* Progress bar */
|
||||
.progress-bar {
|
||||
background: #334155;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
background: linear-gradient(90deg, #3b82f6, #8b5cf6);
|
||||
height: 100%;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
|
||||
/* VRAM monitor */
|
||||
.vram-bar {
|
||||
background: #334155;
|
||||
border-radius: 4px;
|
||||
height: 24px;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.vram-fill {
|
||||
position: absolute;
|
||||
left: 0;
|
||||
top: 0;
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #22c55e, #eab308, #ef4444);
|
||||
transition: width 0.5s ease;
|
||||
}
|
||||
|
||||
.vram-text {
|
||||
position: absolute;
|
||||
width: 100%;
|
||||
text-align: center;
|
||||
line-height: 24px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
color: white;
|
||||
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
/* Queue badge */
|
||||
.queue-badge {
|
||||
background: #3b82f6;
|
||||
color: white;
|
||||
padding: 0.25rem 0.75rem;
|
||||
border-radius: 9999px;
|
||||
font-size: 0.875rem;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* Generation card */
|
||||
.generation-card {
|
||||
background: #334155;
|
||||
border-radius: 8px;
|
||||
padding: 1rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
/* Preset chips */
|
||||
.preset-chip {
|
||||
display: inline-block;
|
||||
background: #475569;
|
||||
color: #e2e8f0;
|
||||
padding: 0.25rem 0.75rem;
|
||||
border-radius: 9999px;
|
||||
font-size: 0.875rem;
|
||||
margin: 0.25rem;
|
||||
cursor: pointer;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
|
||||
.preset-chip:hover {
|
||||
background: #3b82f6;
|
||||
}
|
||||
|
||||
.preset-chip.active {
|
||||
background: #3b82f6;
|
||||
}
|
||||
|
||||
/* Tag input */
|
||||
.tag {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
background: #475569;
|
||||
color: #e2e8f0;
|
||||
padding: 0.25rem 0.5rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
margin: 0.125rem;
|
||||
}
|
||||
|
||||
/* Accordion tweaks */
|
||||
.accordion-header {
|
||||
font-weight: 600;
|
||||
color: #f1f5f9;
|
||||
}
|
||||
|
||||
/* Status indicators */
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
display: inline-block;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
|
||||
.status-dot.loaded {
|
||||
background: #22c55e;
|
||||
}
|
||||
|
||||
.status-dot.unloaded {
|
||||
background: #64748b;
|
||||
}
|
||||
|
||||
.status-dot.loading {
|
||||
background: #eab308;
|
||||
animation: pulse 1s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 1; }
|
||||
50% { opacity: 0.5; }
|
||||
}
|
||||
|
||||
/* Tooltip */
|
||||
.tooltip {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tooltip:hover::after {
|
||||
content: attr(data-tooltip);
|
||||
position: absolute;
|
||||
bottom: 100%;
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
background: #1e293b;
|
||||
color: #e2e8f0;
|
||||
padding: 0.5rem;
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
white-space: nowrap;
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
/* Responsive adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.sidebar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.mobile-nav {
|
||||
display: flex !important;
|
||||
}
|
||||
}
|
||||
"""
|
||||
Reference in New Issue
Block a user