Fix model registration + add long-form generation support

- Fix critical bug: register_all_adapters() now called in main.py
- Add generate_long() method to MusicGen adapter for continuation-based
  extended tracks (up to 5 minutes)
- Add long-form checkbox in UI that unlocks duration slider to 300s
- Update GenerationService to route to generate_long when duration > 30s
- Update BatchProcessor to support long_form parameter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-27 13:22:06 +01:00
parent e53314023c
commit 503b3ce473
5 changed files with 228 additions and 7 deletions

View File

@@ -17,6 +17,7 @@ from src.services.batch_processor import BatchProcessor
from src.services.project_service import ProjectService from src.services.project_service import ProjectService
from src.storage.database import Database from src.storage.database import Database
from src.ui.app import create_app from src.ui.app import create_app
from src.models import register_all_adapters
# Configure logging # Configure logging
@@ -56,6 +57,10 @@ async def initialize_services():
idle_timeout_minutes=settings.idle_unload_minutes, idle_timeout_minutes=settings.idle_unload_minutes,
) )
# Register all model adapters
logger.info("Registering model adapters...")
register_all_adapters(model_registry)
# Initialize services # Initialize services
logger.info("Initializing services...") logger.info("Initializing services...")
generation_service = GenerationService( generation_service = GenerationService(

View File

@@ -3,8 +3,9 @@
import gc import gc
import logging import logging
import random import random
from typing import Any, Optional from typing import Any, Callable, Optional
import numpy as np
import torch import torch
from src.core.base_model import ( from src.core.base_model import (
@@ -13,6 +14,7 @@ from src.core.base_model import (
GenerationRequest, GenerationRequest,
GenerationResult, GenerationResult,
) )
from src.core.audio_utils import concatenate_audio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -279,6 +281,152 @@ class MusicGenAdapter(BaseAudioModel):
seed=seed, seed=seed,
) )
def generate_long(
self,
request: GenerationRequest,
target_duration: float,
overlap_seconds: float = 15.0,
crossfade_ms: float = 500.0,
progress_callback: Optional[Callable[[float], None]] = None,
) -> GenerationResult:
"""Generate long-form audio via continuation.
Uses AudioCraft's continuation API to generate audio longer than 30s
by generating segments and using the tail of each as conditioning
for the next segment.
Args:
request: Base generation request (prompts, temperature, etc.)
target_duration: Total target duration in seconds
overlap_seconds: Seconds of overlap between segments for continuation
crossfade_ms: Crossfade duration between segments in milliseconds
progress_callback: Optional callback receiving progress (0.0-1.0)
Returns:
GenerationResult with concatenated long-form audio
"""
if not self.is_loaded:
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
# Segment duration (max 30s for MusicGen)
segment_duration = min(30.0, self.max_duration)
# Set random seed for reproducibility
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation parameters
self._model.set_generation_params(
duration=segment_duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Starting long-form generation: target={target_duration}s, "
f"segment={segment_duration}s, overlap={overlap_seconds}s"
)
segments = []
generated_duration = 0.0
# First segment - standard generation
logger.info("Generating first segment...")
with torch.inference_mode():
first_audio = self._model.generate(request.prompts)
segments.append(first_audio[0].cpu().numpy()) # [channels, samples]
generated_duration = segment_duration
if progress_callback:
progress_callback(generated_duration / target_duration)
# Continuation segments
segment_num = 1
while generated_duration < target_duration:
segment_num += 1
logger.info(
f"Generating segment {segment_num} "
f"(progress: {generated_duration:.1f}/{target_duration:.1f}s)"
)
# Get tail of previous segment as conditioning
prev_audio = segments[-1]
overlap_samples = int(overlap_seconds * self.sample_rate)
# Ensure we don't try to use more overlap than exists
overlap_samples = min(overlap_samples, prev_audio.shape[-1])
overlap_audio = prev_audio[:, -overlap_samples:]
# Convert back to tensor for continuation
overlap_tensor = torch.tensor(overlap_audio).unsqueeze(0).to(self._device)
# Generate continuation
with torch.inference_mode():
continuation = self._model.generate_continuation(
prompt=overlap_tensor,
prompt_sample_rate=self.sample_rate,
descriptions=request.prompts,
progress=False,
)
# continuation shape: [batch, channels, samples]
# Trim the overlap portion from the start
new_audio = continuation[0, :, overlap_samples:].cpu().numpy()
if new_audio.shape[-1] > 0:
segments.append(new_audio)
generated_duration += (segment_duration - overlap_seconds)
else:
logger.warning("Empty continuation segment, stopping generation")
break
if progress_callback:
progress_callback(min(1.0, generated_duration / target_duration))
# Concatenate all segments with crossfade
logger.info(f"Concatenating {len(segments)} segments with {crossfade_ms}ms crossfade...")
final_audio = concatenate_audio(
segments,
sample_rate=self.sample_rate,
crossfade_ms=crossfade_ms,
)
# Convert back to tensor with batch dimension
final_tensor = torch.tensor(final_audio).unsqueeze(0) # [1, channels, samples]
actual_duration = final_tensor.shape[-1] / self.sample_rate
logger.info(
f"Long-form generation complete: {actual_duration:.2f}s from {len(segments)} segments"
)
return GenerationResult(
audio=final_tensor,
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"target_duration": target_duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
"long_form": True,
"segments": len(segments),
"overlap_seconds": overlap_seconds,
"crossfade_ms": crossfade_ms,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]: def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters.""" """Get default generation parameters."""
return { return {

View File

@@ -65,6 +65,7 @@ class GenerationJob:
project_id: Optional[str] = None, project_id: Optional[str] = None,
preset_used: Optional[str] = None, preset_used: Optional[str] = None,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
long_form: bool = False,
) -> "GenerationJob": ) -> "GenerationJob":
"""Create a new generation job.""" """Create a new generation job."""
return cls( return cls(
@@ -79,6 +80,7 @@ class GenerationJob:
"top_p": top_p, "top_p": top_p,
"cfg_coef": cfg_coef, "cfg_coef": cfg_coef,
"seed": seed, "seed": seed,
"long_form": long_form,
}, },
conditioning=conditioning or {}, conditioning=conditioning or {},
project_id=project_id, project_id=project_id,
@@ -352,6 +354,7 @@ class BatchProcessor:
preset_used=job.preset_used, preset_used=job.preset_used,
tags=job.tags, tags=job.tags,
progress_callback=progress_callback, progress_callback=progress_callback,
long_form=job.parameters.get("long_form", False),
) )
job.status = JobStatus.COMPLETED job.status = JobStatus.COMPLETED

View File

@@ -71,6 +71,7 @@ class GenerationService:
preset_used: Optional[str] = None, preset_used: Optional[str] = None,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
progress_callback: Optional[Callable[[float, str], None]] = None, progress_callback: Optional[Callable[[float, str], None]] = None,
long_form: bool = False,
) -> tuple[GenerationResult, Generation]: ) -> tuple[GenerationResult, Generation]:
"""Generate audio and save to database. """Generate audio and save to database.
@@ -89,6 +90,7 @@ class GenerationService:
preset_used: Name of preset used (for metadata) preset_used: Name of preset used (for metadata)
tags: Optional tags for organization tags: Optional tags for organization
progress_callback: Optional callback for progress updates progress_callback: Optional callback for progress updates
long_form: Enable long-form generation for durations > 30s (uses continuation)
Returns: Returns:
Tuple of (GenerationResult, Generation database record) Tuple of (GenerationResult, Generation database record)
@@ -133,12 +135,31 @@ class GenerationService:
if progress_callback: if progress_callback:
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...") progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
# Determine if we should use long-form generation
use_long_form = long_form and duration > 30 and model_id == "musicgen"
@self.oom_handler.with_oom_recovery @self.oom_handler.with_oom_recovery
def do_generation() -> GenerationResult: def do_generation() -> GenerationResult:
with self.registry.get_model(model_id, actual_variant) as model: with self.registry.get_model(model_id, actual_variant) as model:
if progress_callback: if use_long_form and hasattr(model, 'generate_long'):
progress_callback(0.4, "Generating audio...") if progress_callback:
return model.generate(request) progress_callback(0.4, f"Generating long-form audio ({duration}s)...")
# Create progress wrapper that maps to our progress range
def long_progress(p: float) -> None:
if progress_callback:
# Map 0-1 progress to 0.4-0.8 range
progress_callback(0.4 + p * 0.4, f"Generating segment... ({int(p*100)}%)")
return model.generate_long(
request,
target_duration=duration,
progress_callback=long_progress,
)
else:
if progress_callback:
progress_callback(0.4, "Generating audio...")
return model.generate(request)
result = do_generation() result = do_generation()
@@ -159,6 +180,7 @@ class GenerationService:
"top_k": top_k, "top_k": top_k,
"top_p": top_p, "top_p": top_p,
"cfg_coef": cfg_coef, "cfg_coef": cfg_coef,
"long_form": use_long_form,
}, },
project_id=project_id, project_id=project_id,
preset_used=preset_used, preset_used=preset_used,

View File

@@ -86,6 +86,13 @@ def create_musicgen_tab(
# Parameters # Parameters
gr.Markdown("### Parameters") gr.Markdown("### Parameters")
# Long-form generation checkbox
long_form_checkbox = gr.Checkbox(
label="Long-form generation",
value=False,
info="Enable for tracks > 30s (uses continuation, takes longer)",
)
duration_slider = gr.Slider( duration_slider = gr.Slider(
minimum=1, minimum=1,
maximum=30, maximum=30,
@@ -94,6 +101,13 @@ def create_musicgen_tab(
label="Duration (seconds)", label="Duration (seconds)",
) )
# Info text for long-form mode
long_form_info = gr.Markdown(
value="*Long-form mode uses continuation to generate extended tracks. "
"Generation time increases significantly for longer durations.*",
visible=False,
)
with gr.Accordion("Advanced Parameters", open=False): with gr.Accordion("Advanced Parameters", open=False):
with gr.Row(): with gr.Row():
temperature_slider = gr.Slider( temperature_slider = gr.Slider(
@@ -177,6 +191,25 @@ def create_musicgen_tab(
outputs=[melody_section], outputs=[melody_section],
) )
# Long-form checkbox - update duration slider max
def on_long_form_change(long_form: bool):
if long_form:
return (
gr.update(maximum=300, label="Duration (seconds) - up to 5 min"),
gr.update(visible=True),
)
else:
return (
gr.update(maximum=30, label="Duration (seconds)"),
gr.update(visible=False),
)
long_form_checkbox.change(
fn=on_long_form_change,
inputs=[long_form_checkbox],
outputs=[duration_slider, long_form_info],
)
# Prompt suggestions # Prompt suggestions
for btn, suggestion in suggestion_btns: for btn, suggestion in suggestion_btns:
btn.click( btn.click(
@@ -186,7 +219,7 @@ def create_musicgen_tab(
# Generate # Generate
async def do_generate( async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form
): ):
if not prompt: if not prompt:
yield ( yield (
@@ -199,9 +232,14 @@ def create_musicgen_tab(
) )
return return
# Generate status message
status_msg = "🔄 Generating..."
if long_form and duration > 30:
status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..."
# Update status # Update status
yield ( yield (
gr.update(value="🔄 Generating..."), gr.update(value=status_msg),
gr.update(visible=True, value=0), gr.update(visible=True, value=0),
gr.update(), gr.update(),
gr.update(), gr.update(),
@@ -225,6 +263,7 @@ def create_musicgen_tab(
cfg_coef=cfg_coef, cfg_coef=cfg_coef,
seed=int(seed) if seed else None, seed=int(seed) if seed else None,
conditioning=conditioning, conditioning=conditioning,
long_form=long_form,
) )
yield ( yield (
@@ -258,6 +297,7 @@ def create_musicgen_tab(
top_p_slider, top_p_slider,
seed_input, seed_input,
melody_input, melody_input,
long_form_checkbox,
], ],
outputs=[ outputs=[
output["status"], output["status"],
@@ -270,7 +310,7 @@ def create_musicgen_tab(
) )
# Add to queue # Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody): def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form):
if not prompt: if not prompt:
return "Please enter a prompt" return "Please enter a prompt"
@@ -289,6 +329,7 @@ def create_musicgen_tab(
cfg_coef=cfg_coef, cfg_coef=cfg_coef,
seed=int(seed) if seed else None, seed=int(seed) if seed else None,
conditioning=conditioning, conditioning=conditioning,
long_form=long_form,
) )
return f"✅ Added to queue: {job.id}" return f"✅ Added to queue: {job.id}"
@@ -305,6 +346,7 @@ def create_musicgen_tab(
top_p_slider, top_p_slider,
seed_input, seed_input,
melody_input, melody_input,
long_form_checkbox,
], ],
outputs=[output["status"]], outputs=[output["status"]],
) )
@@ -314,6 +356,7 @@ def create_musicgen_tab(
"variant": variant_dropdown, "variant": variant_dropdown,
"prompt": prompt_input, "prompt": prompt_input,
"melody": melody_input, "melody": melody_input,
"long_form": long_form_checkbox,
"duration": duration_slider, "duration": duration_slider,
"temperature": temperature_slider, "temperature": temperature_slider,
"cfg_coef": cfg_slider, "cfg_coef": cfg_slider,