Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
277
src/models/musicgen_style/adapter.py
Normal file
277
src/models/musicgen_style/adapter.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""MusicGen Style model adapter for style-conditioned music generation."""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from src.core.base_model import (
|
||||
BaseAudioModel,
|
||||
ConditioningType,
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MusicGenStyleAdapter(BaseAudioModel):
|
||||
"""Adapter for Facebook's MusicGen Style model.
|
||||
|
||||
Generates music conditioned on both text and a style reference audio.
|
||||
Extracts style features from the reference and applies them to new generations.
|
||||
"""
|
||||
|
||||
VARIANTS = {
|
||||
"medium": {
|
||||
"hf_id": "facebook/musicgen-style",
|
||||
"vram_mb": 5000,
|
||||
"max_duration": 30,
|
||||
"channels": 1,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, variant: str = "medium"):
|
||||
"""Initialize MusicGen Style adapter.
|
||||
|
||||
Args:
|
||||
variant: Model variant (currently only 'medium' available)
|
||||
"""
|
||||
if variant not in self.VARIANTS:
|
||||
raise ValueError(
|
||||
f"Unknown MusicGen Style variant: {variant}. "
|
||||
f"Available: {list(self.VARIANTS.keys())}"
|
||||
)
|
||||
|
||||
self._variant = variant
|
||||
self._config = self.VARIANTS[variant]
|
||||
self._model = None
|
||||
self._device: Optional[torch.device] = None
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return "musicgen-style"
|
||||
|
||||
@property
|
||||
def variant(self) -> str:
|
||||
return self._variant
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return f"MusicGen Style ({self._variant})"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Style-conditioned music generation from reference audio"
|
||||
|
||||
@property
|
||||
def vram_estimate_mb(self) -> int:
|
||||
return self._config["vram_mb"]
|
||||
|
||||
@property
|
||||
def max_duration(self) -> float:
|
||||
return self._config["max_duration"]
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
if self._model is not None:
|
||||
return self._model.sample_rate
|
||||
return 32000
|
||||
|
||||
@property
|
||||
def supports_conditioning(self) -> list[ConditioningType]:
|
||||
return [ConditioningType.TEXT, ConditioningType.STYLE]
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> Optional[torch.device]:
|
||||
return self._device
|
||||
|
||||
def load(self, device: str = "cuda") -> None:
|
||||
"""Load the MusicGen Style model."""
|
||||
if self._model is not None:
|
||||
logger.warning(f"MusicGen Style {self._variant} already loaded")
|
||||
return
|
||||
|
||||
logger.info(f"Loading MusicGen Style {self._variant}...")
|
||||
|
||||
try:
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
|
||||
logger.info(
|
||||
f"MusicGen Style {self._variant} loaded successfully "
|
||||
f"(sample_rate={self._model.sample_rate})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._model = None
|
||||
self._device = None
|
||||
logger.error(f"Failed to load MusicGen Style {self._variant}: {e}")
|
||||
raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload the model and free memory."""
|
||||
if self._model is None:
|
||||
return
|
||||
|
||||
logger.info(f"Unloading MusicGen Style {self._variant}...")
|
||||
|
||||
del self._model
|
||||
self._model = None
|
||||
self._device = None
|
||||
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _load_style_audio(
|
||||
self, style_input: Any, target_sr: int
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Load and prepare style reference audio.
|
||||
|
||||
Args:
|
||||
style_input: File path, tensor, or numpy array
|
||||
target_sr: Target sample rate
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_tensor, sample_rate)
|
||||
"""
|
||||
if isinstance(style_input, str):
|
||||
# Load from file
|
||||
audio, sr = torchaudio.load(style_input)
|
||||
if sr != target_sr:
|
||||
audio = torchaudio.functional.resample(audio, sr, target_sr)
|
||||
return audio.to(self._device), target_sr
|
||||
elif isinstance(style_input, torch.Tensor):
|
||||
return style_input.to(self._device), target_sr
|
||||
else:
|
||||
# Assume numpy array
|
||||
return torch.tensor(style_input).to(self._device), target_sr
|
||||
|
||||
def generate(self, request: GenerationRequest) -> GenerationResult:
|
||||
"""Generate music conditioned on text and style reference.
|
||||
|
||||
Args:
|
||||
request: Generation parameters including prompts and style conditioning
|
||||
|
||||
Returns:
|
||||
GenerationResult with audio tensor and metadata
|
||||
|
||||
Note:
|
||||
Style conditioning requires 'style' in request.conditioning with either:
|
||||
- File path to audio
|
||||
- Audio tensor
|
||||
- Numpy array
|
||||
"""
|
||||
self.validate_request(request)
|
||||
|
||||
# Set random seed
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Get style conditioning parameters
|
||||
style_audio = request.conditioning.get("style")
|
||||
eval_q = request.conditioning.get("eval_q", 3)
|
||||
excerpt_length = request.conditioning.get("excerpt_length", 3.0)
|
||||
|
||||
# Configure generation parameters
|
||||
self._model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating {len(request.prompts)} sample(s) with MusicGen Style "
|
||||
f"(duration={request.duration}s, style_conditioned={style_audio is not None})"
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
if style_audio is not None:
|
||||
# Load style reference
|
||||
style_tensor, style_sr = self._load_style_audio(
|
||||
style_audio, self.sample_rate
|
||||
)
|
||||
|
||||
# Ensure proper shape [batch, channels, samples]
|
||||
if style_tensor.dim() == 1:
|
||||
style_tensor = style_tensor.unsqueeze(0).unsqueeze(0)
|
||||
elif style_tensor.dim() == 2:
|
||||
style_tensor = style_tensor.unsqueeze(0)
|
||||
|
||||
# Set style conditioner parameters
|
||||
if hasattr(self._model, 'set_style_conditioner_params'):
|
||||
self._model.set_style_conditioner_params(
|
||||
eval_q=eval_q,
|
||||
excerpt_length=excerpt_length,
|
||||
)
|
||||
|
||||
# Generate with style conditioning
|
||||
# Expand style to match number of prompts if needed
|
||||
if style_tensor.shape[0] == 1 and len(request.prompts) > 1:
|
||||
style_tensor = style_tensor.expand(len(request.prompts), -1, -1)
|
||||
|
||||
audio = self._model.generate_with_chroma(
|
||||
descriptions=request.prompts,
|
||||
melody_wavs=style_tensor,
|
||||
melody_sample_rate=style_sr,
|
||||
)
|
||||
else:
|
||||
# Generate without style (falls back to standard MusicGen behavior)
|
||||
logger.warning(
|
||||
"No style reference provided, generating without style conditioning"
|
||||
)
|
||||
audio = self._model.generate(request.prompts)
|
||||
|
||||
actual_duration = audio.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Generated {audio.shape[0]} sample(s), "
|
||||
f"duration={actual_duration:.2f}s"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=audio.cpu(),
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
"style_conditioned": style_audio is not None,
|
||||
"eval_q": eval_q,
|
||||
"excerpt_length": excerpt_length,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters for MusicGen Style."""
|
||||
return {
|
||||
"duration": 10.0,
|
||||
"temperature": 1.0,
|
||||
"top_k": 250,
|
||||
"top_p": 0.0,
|
||||
"cfg_coef": 3.0,
|
||||
"eval_q": 3,
|
||||
"excerpt_length": 3.0,
|
||||
}
|
||||
Reference in New Issue
Block a user