"""MusicGen Style model adapter for style-conditioned music generation.""" import gc import logging import random from typing import Any, Optional import torch import torchaudio from src.core.base_model import ( BaseAudioModel, ConditioningType, GenerationRequest, GenerationResult, ) logger = logging.getLogger(__name__) class MusicGenStyleAdapter(BaseAudioModel): """Adapter for Facebook's MusicGen Style model. Generates music conditioned on both text and a style reference audio. Extracts style features from the reference and applies them to new generations. """ VARIANTS = { "medium": { "hf_id": "facebook/musicgen-style", "vram_mb": 5000, "max_duration": 30, "channels": 1, }, } def __init__(self, variant: str = "medium"): """Initialize MusicGen Style adapter. Args: variant: Model variant (currently only 'medium' available) """ if variant not in self.VARIANTS: raise ValueError( f"Unknown MusicGen Style variant: {variant}. " f"Available: {list(self.VARIANTS.keys())}" ) self._variant = variant self._config = self.VARIANTS[variant] self._model = None self._device: Optional[torch.device] = None @property def model_id(self) -> str: return "musicgen-style" @property def variant(self) -> str: return self._variant @property def display_name(self) -> str: return f"MusicGen Style ({self._variant})" @property def description(self) -> str: return "Style-conditioned music generation from reference audio" @property def vram_estimate_mb(self) -> int: return self._config["vram_mb"] @property def max_duration(self) -> float: return self._config["max_duration"] @property def sample_rate(self) -> int: if self._model is not None: return self._model.sample_rate return 32000 @property def supports_conditioning(self) -> list[ConditioningType]: return [ConditioningType.TEXT, ConditioningType.STYLE] @property def is_loaded(self) -> bool: return self._model is not None @property def device(self) -> Optional[torch.device]: return self._device def load(self, device: str = "cuda") -> None: """Load the MusicGen Style model.""" if self._model is not None: logger.warning(f"MusicGen Style {self._variant} already loaded") return logger.info(f"Loading MusicGen Style {self._variant}...") try: from audiocraft.models import MusicGen self._device = torch.device(device) self._model = MusicGen.get_pretrained(self._config["hf_id"]) self._model.to(self._device) logger.info( f"MusicGen Style {self._variant} loaded successfully " f"(sample_rate={self._model.sample_rate})" ) except Exception as e: self._model = None self._device = None logger.error(f"Failed to load MusicGen Style {self._variant}: {e}") raise RuntimeError(f"Failed to load MusicGen Style: {e}") from e def unload(self) -> None: """Unload the model and free memory.""" if self._model is None: return logger.info(f"Unloading MusicGen Style {self._variant}...") del self._model self._model = None self._device = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() def _load_style_audio( self, style_input: Any, target_sr: int ) -> tuple[torch.Tensor, int]: """Load and prepare style reference audio. Args: style_input: File path, tensor, or numpy array target_sr: Target sample rate Returns: Tuple of (audio_tensor, sample_rate) """ if isinstance(style_input, str): # Load from file audio, sr = torchaudio.load(style_input) if sr != target_sr: audio = torchaudio.functional.resample(audio, sr, target_sr) return audio.to(self._device), target_sr elif isinstance(style_input, torch.Tensor): return style_input.to(self._device), target_sr else: # Assume numpy array return torch.tensor(style_input).to(self._device), target_sr def generate(self, request: GenerationRequest) -> GenerationResult: """Generate music conditioned on text and style reference. Args: request: Generation parameters including prompts and style conditioning Returns: GenerationResult with audio tensor and metadata Note: Style conditioning requires 'style' in request.conditioning with either: - File path to audio - Audio tensor - Numpy array """ self.validate_request(request) # Set random seed seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # Get style conditioning parameters style_audio = request.conditioning.get("style") eval_q = request.conditioning.get("eval_q", 3) excerpt_length = request.conditioning.get("excerpt_length", 3.0) # Configure generation parameters self._model.set_generation_params( duration=request.duration, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, cfg_coef=request.cfg_coef, ) logger.info( f"Generating {len(request.prompts)} sample(s) with MusicGen Style " f"(duration={request.duration}s, style_conditioned={style_audio is not None})" ) with torch.inference_mode(): if style_audio is not None: # Load style reference style_tensor, style_sr = self._load_style_audio( style_audio, self.sample_rate ) # Ensure proper shape [batch, channels, samples] if style_tensor.dim() == 1: style_tensor = style_tensor.unsqueeze(0).unsqueeze(0) elif style_tensor.dim() == 2: style_tensor = style_tensor.unsqueeze(0) # Set style conditioner parameters if hasattr(self._model, 'set_style_conditioner_params'): self._model.set_style_conditioner_params( eval_q=eval_q, excerpt_length=excerpt_length, ) # Generate with style conditioning # Expand style to match number of prompts if needed if style_tensor.shape[0] == 1 and len(request.prompts) > 1: style_tensor = style_tensor.expand(len(request.prompts), -1, -1) audio = self._model.generate_with_chroma( descriptions=request.prompts, melody_wavs=style_tensor, melody_sample_rate=style_sr, ) else: # Generate without style (falls back to standard MusicGen behavior) logger.warning( "No style reference provided, generating without style conditioning" ) audio = self._model.generate(request.prompts) actual_duration = audio.shape[-1] / self.sample_rate logger.info( f"Generated {audio.shape[0]} sample(s), " f"duration={actual_duration:.2f}s" ) return GenerationResult( audio=audio.cpu(), sample_rate=self.sample_rate, duration=actual_duration, model_id=self.model_id, variant=self._variant, parameters={ "duration": request.duration, "temperature": request.temperature, "top_k": request.top_k, "top_p": request.top_p, "cfg_coef": request.cfg_coef, "prompts": request.prompts, "style_conditioned": style_audio is not None, "eval_q": eval_q, "excerpt_length": excerpt_length, }, seed=seed, ) def get_default_params(self) -> dict[str, Any]: """Get default generation parameters for MusicGen Style.""" return { "duration": 10.0, "temperature": 1.0, "top_k": 250, "top_p": 0.0, "cfg_coef": 3.0, "eval_q": 3, "excerpt_length": 3.0, }