Files
audiocraft-ui/src/models/musicgen_style/adapter.py
Sebastian Krüger ffbf02b12c Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation:

- Gradio UI with tabs for all 5 model families (MusicGen, AudioGen,
  MAGNeT, MusicGen Style, JASCO)
- REST API with FastAPI, OpenAPI docs, and API key auth
- VRAM management with ComfyUI coexistence support
- SQLite database for project/generation history
- Batch processing queue for async generation
- Docker deployment optimized for RunPod with RTX 4090

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-25 19:34:27 +01:00

278 lines
8.8 KiB
Python

"""MusicGen Style model adapter for style-conditioned music generation."""
import gc
import logging
import random
from typing import Any, Optional
import torch
import torchaudio
from src.core.base_model import (
BaseAudioModel,
ConditioningType,
GenerationRequest,
GenerationResult,
)
logger = logging.getLogger(__name__)
class MusicGenStyleAdapter(BaseAudioModel):
"""Adapter for Facebook's MusicGen Style model.
Generates music conditioned on both text and a style reference audio.
Extracts style features from the reference and applies them to new generations.
"""
VARIANTS = {
"medium": {
"hf_id": "facebook/musicgen-style",
"vram_mb": 5000,
"max_duration": 30,
"channels": 1,
},
}
def __init__(self, variant: str = "medium"):
"""Initialize MusicGen Style adapter.
Args:
variant: Model variant (currently only 'medium' available)
"""
if variant not in self.VARIANTS:
raise ValueError(
f"Unknown MusicGen Style variant: {variant}. "
f"Available: {list(self.VARIANTS.keys())}"
)
self._variant = variant
self._config = self.VARIANTS[variant]
self._model = None
self._device: Optional[torch.device] = None
@property
def model_id(self) -> str:
return "musicgen-style"
@property
def variant(self) -> str:
return self._variant
@property
def display_name(self) -> str:
return f"MusicGen Style ({self._variant})"
@property
def description(self) -> str:
return "Style-conditioned music generation from reference audio"
@property
def vram_estimate_mb(self) -> int:
return self._config["vram_mb"]
@property
def max_duration(self) -> float:
return self._config["max_duration"]
@property
def sample_rate(self) -> int:
if self._model is not None:
return self._model.sample_rate
return 32000
@property
def supports_conditioning(self) -> list[ConditioningType]:
return [ConditioningType.TEXT, ConditioningType.STYLE]
@property
def is_loaded(self) -> bool:
return self._model is not None
@property
def device(self) -> Optional[torch.device]:
return self._device
def load(self, device: str = "cuda") -> None:
"""Load the MusicGen Style model."""
if self._model is not None:
logger.warning(f"MusicGen Style {self._variant} already loaded")
return
logger.info(f"Loading MusicGen Style {self._variant}...")
try:
from audiocraft.models import MusicGen
self._device = torch.device(device)
self._model = MusicGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
logger.info(
f"MusicGen Style {self._variant} loaded successfully "
f"(sample_rate={self._model.sample_rate})"
)
except Exception as e:
self._model = None
self._device = None
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
def unload(self) -> None:
"""Unload the model and free memory."""
if self._model is None:
return
logger.info(f"Unloading MusicGen Style {self._variant}...")
del self._model
self._model = None
self._device = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _load_style_audio(
self, style_input: Any, target_sr: int
) -> tuple[torch.Tensor, int]:
"""Load and prepare style reference audio.
Args:
style_input: File path, tensor, or numpy array
target_sr: Target sample rate
Returns:
Tuple of (audio_tensor, sample_rate)
"""
if isinstance(style_input, str):
# Load from file
audio, sr = torchaudio.load(style_input)
if sr != target_sr:
audio = torchaudio.functional.resample(audio, sr, target_sr)
return audio.to(self._device), target_sr
elif isinstance(style_input, torch.Tensor):
return style_input.to(self._device), target_sr
else:
# Assume numpy array
return torch.tensor(style_input).to(self._device), target_sr
def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate music conditioned on text and style reference.
Args:
request: Generation parameters including prompts and style conditioning
Returns:
GenerationResult with audio tensor and metadata
Note:
Style conditioning requires 'style' in request.conditioning with either:
- File path to audio
- Audio tensor
- Numpy array
"""
self.validate_request(request)
# Set random seed
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Get style conditioning parameters
style_audio = request.conditioning.get("style")
eval_q = request.conditioning.get("eval_q", 3)
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
# Configure generation parameters
self._model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
)
with torch.inference_mode():
if style_audio is not None:
# Load style reference
style_tensor, style_sr = self._load_style_audio(
style_audio, self.sample_rate
)
# Ensure proper shape [batch, channels, samples]
if style_tensor.dim() == 1:
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
elif style_tensor.dim() == 2:
style_tensor = style_tensor.unsqueeze(0)
# Set style conditioner parameters
if hasattr(self._model, 'set_style_conditioner_params'):
self._model.set_style_conditioner_params(
eval_q=eval_q,
excerpt_length=excerpt_length,
)
# Generate with style conditioning
# Expand style to match number of prompts if needed
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
audio = self._model.generate_with_chroma(
descriptions=request.prompts,
melody_wavs=style_tensor,
melody_sample_rate=style_sr,
)
else:
# Generate without style (falls back to standard MusicGen behavior)
logger.warning(
"No style reference provided, generating without style conditioning"
)
audio = self._model.generate(request.prompts)
actual_duration = audio.shape[-1] / self.sample_rate
logger.info(
f"Generated {audio.shape[0]} sample(s), "
f"duration={actual_duration:.2f}s"
)
return GenerationResult(
audio=audio.cpu(),
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
"style_conditioned": style_audio is not None,
"eval_q": eval_q,
"excerpt_length": excerpt_length,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters for MusicGen Style."""
return {
"duration": 10.0,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.0,
"cfg_coef": 3.0,
"eval_q": 3,
"excerpt_length": 3.0,
}