Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
278 lines
8.8 KiB
Python
278 lines
8.8 KiB
Python
"""MusicGen Style model adapter for style-conditioned music generation."""
|
|
|
|
import gc
|
|
import logging
|
|
import random
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
import torchaudio
|
|
|
|
from src.core.base_model import (
|
|
BaseAudioModel,
|
|
ConditioningType,
|
|
GenerationRequest,
|
|
GenerationResult,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MusicGenStyleAdapter(BaseAudioModel):
|
|
"""Adapter for Facebook's MusicGen Style model.
|
|
|
|
Generates music conditioned on both text and a style reference audio.
|
|
Extracts style features from the reference and applies them to new generations.
|
|
"""
|
|
|
|
VARIANTS = {
|
|
"medium": {
|
|
"hf_id": "facebook/musicgen-style",
|
|
"vram_mb": 5000,
|
|
"max_duration": 30,
|
|
"channels": 1,
|
|
},
|
|
}
|
|
|
|
def __init__(self, variant: str = "medium"):
|
|
"""Initialize MusicGen Style adapter.
|
|
|
|
Args:
|
|
variant: Model variant (currently only 'medium' available)
|
|
"""
|
|
if variant not in self.VARIANTS:
|
|
raise ValueError(
|
|
f"Unknown MusicGen Style variant: {variant}. "
|
|
f"Available: {list(self.VARIANTS.keys())}"
|
|
)
|
|
|
|
self._variant = variant
|
|
self._config = self.VARIANTS[variant]
|
|
self._model = None
|
|
self._device: Optional[torch.device] = None
|
|
|
|
@property
|
|
def model_id(self) -> str:
|
|
return "musicgen-style"
|
|
|
|
@property
|
|
def variant(self) -> str:
|
|
return self._variant
|
|
|
|
@property
|
|
def display_name(self) -> str:
|
|
return f"MusicGen Style ({self._variant})"
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return "Style-conditioned music generation from reference audio"
|
|
|
|
@property
|
|
def vram_estimate_mb(self) -> int:
|
|
return self._config["vram_mb"]
|
|
|
|
@property
|
|
def max_duration(self) -> float:
|
|
return self._config["max_duration"]
|
|
|
|
@property
|
|
def sample_rate(self) -> int:
|
|
if self._model is not None:
|
|
return self._model.sample_rate
|
|
return 32000
|
|
|
|
@property
|
|
def supports_conditioning(self) -> list[ConditioningType]:
|
|
return [ConditioningType.TEXT, ConditioningType.STYLE]
|
|
|
|
@property
|
|
def is_loaded(self) -> bool:
|
|
return self._model is not None
|
|
|
|
@property
|
|
def device(self) -> Optional[torch.device]:
|
|
return self._device
|
|
|
|
def load(self, device: str = "cuda") -> None:
|
|
"""Load the MusicGen Style model."""
|
|
if self._model is not None:
|
|
logger.warning(f"MusicGen Style {self._variant} already loaded")
|
|
return
|
|
|
|
logger.info(f"Loading MusicGen Style {self._variant}...")
|
|
|
|
try:
|
|
from audiocraft.models import MusicGen
|
|
|
|
self._device = torch.device(device)
|
|
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
|
self._model.to(self._device)
|
|
|
|
logger.info(
|
|
f"MusicGen Style {self._variant} loaded successfully "
|
|
f"(sample_rate={self._model.sample_rate})"
|
|
)
|
|
|
|
except Exception as e:
|
|
self._model = None
|
|
self._device = None
|
|
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
|
|
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
|
|
|
|
def unload(self) -> None:
|
|
"""Unload the model and free memory."""
|
|
if self._model is None:
|
|
return
|
|
|
|
logger.info(f"Unloading MusicGen Style {self._variant}...")
|
|
|
|
del self._model
|
|
self._model = None
|
|
self._device = None
|
|
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
def _load_style_audio(
|
|
self, style_input: Any, target_sr: int
|
|
) -> tuple[torch.Tensor, int]:
|
|
"""Load and prepare style reference audio.
|
|
|
|
Args:
|
|
style_input: File path, tensor, or numpy array
|
|
target_sr: Target sample rate
|
|
|
|
Returns:
|
|
Tuple of (audio_tensor, sample_rate)
|
|
"""
|
|
if isinstance(style_input, str):
|
|
# Load from file
|
|
audio, sr = torchaudio.load(style_input)
|
|
if sr != target_sr:
|
|
audio = torchaudio.functional.resample(audio, sr, target_sr)
|
|
return audio.to(self._device), target_sr
|
|
elif isinstance(style_input, torch.Tensor):
|
|
return style_input.to(self._device), target_sr
|
|
else:
|
|
# Assume numpy array
|
|
return torch.tensor(style_input).to(self._device), target_sr
|
|
|
|
def generate(self, request: GenerationRequest) -> GenerationResult:
|
|
"""Generate music conditioned on text and style reference.
|
|
|
|
Args:
|
|
request: Generation parameters including prompts and style conditioning
|
|
|
|
Returns:
|
|
GenerationResult with audio tensor and metadata
|
|
|
|
Note:
|
|
Style conditioning requires 'style' in request.conditioning with either:
|
|
- File path to audio
|
|
- Audio tensor
|
|
- Numpy array
|
|
"""
|
|
self.validate_request(request)
|
|
|
|
# Set random seed
|
|
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed(seed)
|
|
|
|
# Get style conditioning parameters
|
|
style_audio = request.conditioning.get("style")
|
|
eval_q = request.conditioning.get("eval_q", 3)
|
|
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
|
|
|
|
# Configure generation parameters
|
|
self._model.set_generation_params(
|
|
duration=request.duration,
|
|
temperature=request.temperature,
|
|
top_k=request.top_k,
|
|
top_p=request.top_p,
|
|
cfg_coef=request.cfg_coef,
|
|
)
|
|
|
|
logger.info(
|
|
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
|
|
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
if style_audio is not None:
|
|
# Load style reference
|
|
style_tensor, style_sr = self._load_style_audio(
|
|
style_audio, self.sample_rate
|
|
)
|
|
|
|
# Ensure proper shape [batch, channels, samples]
|
|
if style_tensor.dim() == 1:
|
|
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
|
|
elif style_tensor.dim() == 2:
|
|
style_tensor = style_tensor.unsqueeze(0)
|
|
|
|
# Set style conditioner parameters
|
|
if hasattr(self._model, 'set_style_conditioner_params'):
|
|
self._model.set_style_conditioner_params(
|
|
eval_q=eval_q,
|
|
excerpt_length=excerpt_length,
|
|
)
|
|
|
|
# Generate with style conditioning
|
|
# Expand style to match number of prompts if needed
|
|
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
|
|
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
|
|
|
|
audio = self._model.generate_with_chroma(
|
|
descriptions=request.prompts,
|
|
melody_wavs=style_tensor,
|
|
melody_sample_rate=style_sr,
|
|
)
|
|
else:
|
|
# Generate without style (falls back to standard MusicGen behavior)
|
|
logger.warning(
|
|
"No style reference provided, generating without style conditioning"
|
|
)
|
|
audio = self._model.generate(request.prompts)
|
|
|
|
actual_duration = audio.shape[-1] / self.sample_rate
|
|
|
|
logger.info(
|
|
f"Generated {audio.shape[0]} sample(s), "
|
|
f"duration={actual_duration:.2f}s"
|
|
)
|
|
|
|
return GenerationResult(
|
|
audio=audio.cpu(),
|
|
sample_rate=self.sample_rate,
|
|
duration=actual_duration,
|
|
model_id=self.model_id,
|
|
variant=self._variant,
|
|
parameters={
|
|
"duration": request.duration,
|
|
"temperature": request.temperature,
|
|
"top_k": request.top_k,
|
|
"top_p": request.top_p,
|
|
"cfg_coef": request.cfg_coef,
|
|
"prompts": request.prompts,
|
|
"style_conditioned": style_audio is not None,
|
|
"eval_q": eval_q,
|
|
"excerpt_length": excerpt_length,
|
|
},
|
|
seed=seed,
|
|
)
|
|
|
|
def get_default_params(self) -> dict[str, Any]:
|
|
"""Get default generation parameters for MusicGen Style."""
|
|
return {
|
|
"duration": 10.0,
|
|
"temperature": 1.0,
|
|
"top_k": 250,
|
|
"top_p": 0.0,
|
|
"cfg_coef": 3.0,
|
|
"eval_q": 3,
|
|
"excerpt_length": 3.0,
|
|
}
|