diff --git a/main.py b/main.py index 7e86200..68fcd38 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ from src.services.batch_processor import BatchProcessor from src.services.project_service import ProjectService from src.storage.database import Database from src.ui.app import create_app +from src.models import register_all_adapters # Configure logging @@ -56,6 +57,10 @@ async def initialize_services(): idle_timeout_minutes=settings.idle_unload_minutes, ) + # Register all model adapters + logger.info("Registering model adapters...") + register_all_adapters(model_registry) + # Initialize services logger.info("Initializing services...") generation_service = GenerationService( diff --git a/src/models/musicgen/adapter.py b/src/models/musicgen/adapter.py index 4998613..e316a3b 100644 --- a/src/models/musicgen/adapter.py +++ b/src/models/musicgen/adapter.py @@ -3,8 +3,9 @@ import gc import logging import random -from typing import Any, Optional +from typing import Any, Callable, Optional +import numpy as np import torch from src.core.base_model import ( @@ -13,6 +14,7 @@ from src.core.base_model import ( GenerationRequest, GenerationResult, ) +from src.core.audio_utils import concatenate_audio logger = logging.getLogger(__name__) @@ -279,6 +281,152 @@ class MusicGenAdapter(BaseAudioModel): seed=seed, ) + def generate_long( + self, + request: GenerationRequest, + target_duration: float, + overlap_seconds: float = 15.0, + crossfade_ms: float = 500.0, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> GenerationResult: + """Generate long-form audio via continuation. + + Uses AudioCraft's continuation API to generate audio longer than 30s + by generating segments and using the tail of each as conditioning + for the next segment. + + Args: + request: Base generation request (prompts, temperature, etc.) + target_duration: Total target duration in seconds + overlap_seconds: Seconds of overlap between segments for continuation + crossfade_ms: Crossfade duration between segments in milliseconds + progress_callback: Optional callback receiving progress (0.0-1.0) + + Returns: + GenerationResult with concatenated long-form audio + """ + if not self.is_loaded: + raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded") + + # Segment duration (max 30s for MusicGen) + segment_duration = min(30.0, self.max_duration) + + # Set random seed for reproducibility + seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + # Configure generation parameters + self._model.set_generation_params( + duration=segment_duration, + temperature=request.temperature, + top_k=request.top_k, + top_p=request.top_p, + cfg_coef=request.cfg_coef, + ) + + logger.info( + f"Starting long-form generation: target={target_duration}s, " + f"segment={segment_duration}s, overlap={overlap_seconds}s" + ) + + segments = [] + generated_duration = 0.0 + + # First segment - standard generation + logger.info("Generating first segment...") + with torch.inference_mode(): + first_audio = self._model.generate(request.prompts) + + segments.append(first_audio[0].cpu().numpy()) # [channels, samples] + generated_duration = segment_duration + + if progress_callback: + progress_callback(generated_duration / target_duration) + + # Continuation segments + segment_num = 1 + while generated_duration < target_duration: + segment_num += 1 + logger.info( + f"Generating segment {segment_num} " + f"(progress: {generated_duration:.1f}/{target_duration:.1f}s)" + ) + + # Get tail of previous segment as conditioning + prev_audio = segments[-1] + overlap_samples = int(overlap_seconds * self.sample_rate) + + # Ensure we don't try to use more overlap than exists + overlap_samples = min(overlap_samples, prev_audio.shape[-1]) + overlap_audio = prev_audio[:, -overlap_samples:] + + # Convert back to tensor for continuation + overlap_tensor = torch.tensor(overlap_audio).unsqueeze(0).to(self._device) + + # Generate continuation + with torch.inference_mode(): + continuation = self._model.generate_continuation( + prompt=overlap_tensor, + prompt_sample_rate=self.sample_rate, + descriptions=request.prompts, + progress=False, + ) + + # continuation shape: [batch, channels, samples] + # Trim the overlap portion from the start + new_audio = continuation[0, :, overlap_samples:].cpu().numpy() + + if new_audio.shape[-1] > 0: + segments.append(new_audio) + generated_duration += (segment_duration - overlap_seconds) + else: + logger.warning("Empty continuation segment, stopping generation") + break + + if progress_callback: + progress_callback(min(1.0, generated_duration / target_duration)) + + # Concatenate all segments with crossfade + logger.info(f"Concatenating {len(segments)} segments with {crossfade_ms}ms crossfade...") + + final_audio = concatenate_audio( + segments, + sample_rate=self.sample_rate, + crossfade_ms=crossfade_ms, + ) + + # Convert back to tensor with batch dimension + final_tensor = torch.tensor(final_audio).unsqueeze(0) # [1, channels, samples] + actual_duration = final_tensor.shape[-1] / self.sample_rate + + logger.info( + f"Long-form generation complete: {actual_duration:.2f}s from {len(segments)} segments" + ) + + return GenerationResult( + audio=final_tensor, + sample_rate=self.sample_rate, + duration=actual_duration, + model_id=self.model_id, + variant=self._variant, + parameters={ + "duration": request.duration, + "target_duration": target_duration, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "cfg_coef": request.cfg_coef, + "prompts": request.prompts, + "long_form": True, + "segments": len(segments), + "overlap_seconds": overlap_seconds, + "crossfade_ms": crossfade_ms, + }, + seed=seed, + ) + def get_default_params(self) -> dict[str, Any]: """Get default generation parameters.""" return { diff --git a/src/services/batch_processor.py b/src/services/batch_processor.py index db37a10..9cec467 100644 --- a/src/services/batch_processor.py +++ b/src/services/batch_processor.py @@ -65,6 +65,7 @@ class GenerationJob: project_id: Optional[str] = None, preset_used: Optional[str] = None, tags: Optional[list[str]] = None, + long_form: bool = False, ) -> "GenerationJob": """Create a new generation job.""" return cls( @@ -79,6 +80,7 @@ class GenerationJob: "top_p": top_p, "cfg_coef": cfg_coef, "seed": seed, + "long_form": long_form, }, conditioning=conditioning or {}, project_id=project_id, @@ -352,6 +354,7 @@ class BatchProcessor: preset_used=job.preset_used, tags=job.tags, progress_callback=progress_callback, + long_form=job.parameters.get("long_form", False), ) job.status = JobStatus.COMPLETED diff --git a/src/services/generation_service.py b/src/services/generation_service.py index 228ab09..29e3183 100644 --- a/src/services/generation_service.py +++ b/src/services/generation_service.py @@ -71,6 +71,7 @@ class GenerationService: preset_used: Optional[str] = None, tags: Optional[list[str]] = None, progress_callback: Optional[Callable[[float, str], None]] = None, + long_form: bool = False, ) -> tuple[GenerationResult, Generation]: """Generate audio and save to database. @@ -89,6 +90,7 @@ class GenerationService: preset_used: Name of preset used (for metadata) tags: Optional tags for organization progress_callback: Optional callback for progress updates + long_form: Enable long-form generation for durations > 30s (uses continuation) Returns: Tuple of (GenerationResult, Generation database record) @@ -133,12 +135,31 @@ class GenerationService: if progress_callback: progress_callback(0.2, f"Loading {model_id}/{actual_variant}...") + # Determine if we should use long-form generation + use_long_form = long_form and duration > 30 and model_id == "musicgen" + @self.oom_handler.with_oom_recovery def do_generation() -> GenerationResult: with self.registry.get_model(model_id, actual_variant) as model: - if progress_callback: - progress_callback(0.4, "Generating audio...") - return model.generate(request) + if use_long_form and hasattr(model, 'generate_long'): + if progress_callback: + progress_callback(0.4, f"Generating long-form audio ({duration}s)...") + + # Create progress wrapper that maps to our progress range + def long_progress(p: float) -> None: + if progress_callback: + # Map 0-1 progress to 0.4-0.8 range + progress_callback(0.4 + p * 0.4, f"Generating segment... ({int(p*100)}%)") + + return model.generate_long( + request, + target_duration=duration, + progress_callback=long_progress, + ) + else: + if progress_callback: + progress_callback(0.4, "Generating audio...") + return model.generate(request) result = do_generation() @@ -159,6 +180,7 @@ class GenerationService: "top_k": top_k, "top_p": top_p, "cfg_coef": cfg_coef, + "long_form": use_long_form, }, project_id=project_id, preset_used=preset_used, diff --git a/src/ui/tabs/musicgen_tab.py b/src/ui/tabs/musicgen_tab.py index c6dd751..5205e2e 100644 --- a/src/ui/tabs/musicgen_tab.py +++ b/src/ui/tabs/musicgen_tab.py @@ -86,6 +86,13 @@ def create_musicgen_tab( # Parameters gr.Markdown("### Parameters") + # Long-form generation checkbox + long_form_checkbox = gr.Checkbox( + label="Long-form generation", + value=False, + info="Enable for tracks > 30s (uses continuation, takes longer)", + ) + duration_slider = gr.Slider( minimum=1, maximum=30, @@ -94,6 +101,13 @@ def create_musicgen_tab( label="Duration (seconds)", ) + # Info text for long-form mode + long_form_info = gr.Markdown( + value="*Long-form mode uses continuation to generate extended tracks. " + "Generation time increases significantly for longer durations.*", + visible=False, + ) + with gr.Accordion("Advanced Parameters", open=False): with gr.Row(): temperature_slider = gr.Slider( @@ -177,6 +191,25 @@ def create_musicgen_tab( outputs=[melody_section], ) + # Long-form checkbox - update duration slider max + def on_long_form_change(long_form: bool): + if long_form: + return ( + gr.update(maximum=300, label="Duration (seconds) - up to 5 min"), + gr.update(visible=True), + ) + else: + return ( + gr.update(maximum=30, label="Duration (seconds)"), + gr.update(visible=False), + ) + + long_form_checkbox.change( + fn=on_long_form_change, + inputs=[long_form_checkbox], + outputs=[duration_slider, long_form_info], + ) + # Prompt suggestions for btn, suggestion in suggestion_btns: btn.click( @@ -186,7 +219,7 @@ def create_musicgen_tab( # Generate async def do_generate( - prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody + prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form ): if not prompt: yield ( @@ -199,9 +232,14 @@ def create_musicgen_tab( ) return + # Generate status message + status_msg = "🔄 Generating..." + if long_form and duration > 30: + status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..." + # Update status yield ( - gr.update(value="🔄 Generating..."), + gr.update(value=status_msg), gr.update(visible=True, value=0), gr.update(), gr.update(), @@ -225,6 +263,7 @@ def create_musicgen_tab( cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, + long_form=long_form, ) yield ( @@ -258,6 +297,7 @@ def create_musicgen_tab( top_p_slider, seed_input, melody_input, + long_form_checkbox, ], outputs=[ output["status"], @@ -270,7 +310,7 @@ def create_musicgen_tab( ) # Add to queue - def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody): + def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form): if not prompt: return "Please enter a prompt" @@ -289,6 +329,7 @@ def create_musicgen_tab( cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, + long_form=long_form, ) return f"✅ Added to queue: {job.id}" @@ -305,6 +346,7 @@ def create_musicgen_tab( top_p_slider, seed_input, melody_input, + long_form_checkbox, ], outputs=[output["status"]], ) @@ -314,6 +356,7 @@ def create_musicgen_tab( "variant": variant_dropdown, "prompt": prompt_input, "melody": melody_input, + "long_form": long_form_checkbox, "duration": duration_slider, "temperature": temperature_slider, "cfg_coef": cfg_slider,