Fix model registration + add long-form generation support

- Fix critical bug: register_all_adapters() now called in main.py
- Add generate_long() method to MusicGen adapter for continuation-based
  extended tracks (up to 5 minutes)
- Add long-form checkbox in UI that unlocks duration slider to 300s
- Update GenerationService to route to generate_long when duration > 30s
- Update BatchProcessor to support long_form parameter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-27 13:22:06 +01:00
parent e53314023c
commit 503b3ce473
5 changed files with 228 additions and 7 deletions

View File

@@ -17,6 +17,7 @@ from src.services.batch_processor import BatchProcessor
from src.services.project_service import ProjectService
from src.storage.database import Database
from src.ui.app import create_app
from src.models import register_all_adapters
# Configure logging
@@ -56,6 +57,10 @@ async def initialize_services():
idle_timeout_minutes=settings.idle_unload_minutes,
)
# Register all model adapters
logger.info("Registering model adapters...")
register_all_adapters(model_registry)
# Initialize services
logger.info("Initializing services...")
generation_service = GenerationService(

View File

@@ -3,8 +3,9 @@
import gc
import logging
import random
from typing import Any, Optional
from typing import Any, Callable, Optional
import numpy as np
import torch
from src.core.base_model import (
@@ -13,6 +14,7 @@ from src.core.base_model import (
GenerationRequest,
GenerationResult,
)
from src.core.audio_utils import concatenate_audio
logger = logging.getLogger(__name__)
@@ -279,6 +281,152 @@ class MusicGenAdapter(BaseAudioModel):
seed=seed,
)
def generate_long(
self,
request: GenerationRequest,
target_duration: float,
overlap_seconds: float = 15.0,
crossfade_ms: float = 500.0,
progress_callback: Optional[Callable[[float], None]] = None,
) -> GenerationResult:
"""Generate long-form audio via continuation.
Uses AudioCraft's continuation API to generate audio longer than 30s
by generating segments and using the tail of each as conditioning
for the next segment.
Args:
request: Base generation request (prompts, temperature, etc.)
target_duration: Total target duration in seconds
overlap_seconds: Seconds of overlap between segments for continuation
crossfade_ms: Crossfade duration between segments in milliseconds
progress_callback: Optional callback receiving progress (0.0-1.0)
Returns:
GenerationResult with concatenated long-form audio
"""
if not self.is_loaded:
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
# Segment duration (max 30s for MusicGen)
segment_duration = min(30.0, self.max_duration)
# Set random seed for reproducibility
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
# Configure generation parameters
self._model.set_generation_params(
duration=segment_duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
logger.info(
f"Starting long-form generation: target={target_duration}s, "
f"segment={segment_duration}s, overlap={overlap_seconds}s"
)
segments = []
generated_duration = 0.0
# First segment - standard generation
logger.info("Generating first segment...")
with torch.inference_mode():
first_audio = self._model.generate(request.prompts)
segments.append(first_audio[0].cpu().numpy()) # [channels, samples]
generated_duration = segment_duration
if progress_callback:
progress_callback(generated_duration / target_duration)
# Continuation segments
segment_num = 1
while generated_duration < target_duration:
segment_num += 1
logger.info(
f"Generating segment {segment_num} "
f"(progress: {generated_duration:.1f}/{target_duration:.1f}s)"
)
# Get tail of previous segment as conditioning
prev_audio = segments[-1]
overlap_samples = int(overlap_seconds * self.sample_rate)
# Ensure we don't try to use more overlap than exists
overlap_samples = min(overlap_samples, prev_audio.shape[-1])
overlap_audio = prev_audio[:, -overlap_samples:]
# Convert back to tensor for continuation
overlap_tensor = torch.tensor(overlap_audio).unsqueeze(0).to(self._device)
# Generate continuation
with torch.inference_mode():
continuation = self._model.generate_continuation(
prompt=overlap_tensor,
prompt_sample_rate=self.sample_rate,
descriptions=request.prompts,
progress=False,
)
# continuation shape: [batch, channels, samples]
# Trim the overlap portion from the start
new_audio = continuation[0, :, overlap_samples:].cpu().numpy()
if new_audio.shape[-1] > 0:
segments.append(new_audio)
generated_duration += (segment_duration - overlap_seconds)
else:
logger.warning("Empty continuation segment, stopping generation")
break
if progress_callback:
progress_callback(min(1.0, generated_duration / target_duration))
# Concatenate all segments with crossfade
logger.info(f"Concatenating {len(segments)} segments with {crossfade_ms}ms crossfade...")
final_audio = concatenate_audio(
segments,
sample_rate=self.sample_rate,
crossfade_ms=crossfade_ms,
)
# Convert back to tensor with batch dimension
final_tensor = torch.tensor(final_audio).unsqueeze(0) # [1, channels, samples]
actual_duration = final_tensor.shape[-1] / self.sample_rate
logger.info(
f"Long-form generation complete: {actual_duration:.2f}s from {len(segments)} segments"
)
return GenerationResult(
audio=final_tensor,
sample_rate=self.sample_rate,
duration=actual_duration,
model_id=self.model_id,
variant=self._variant,
parameters={
"duration": request.duration,
"target_duration": target_duration,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"cfg_coef": request.cfg_coef,
"prompts": request.prompts,
"long_form": True,
"segments": len(segments),
"overlap_seconds": overlap_seconds,
"crossfade_ms": crossfade_ms,
},
seed=seed,
)
def get_default_params(self) -> dict[str, Any]:
"""Get default generation parameters."""
return {

View File

@@ -65,6 +65,7 @@ class GenerationJob:
project_id: Optional[str] = None,
preset_used: Optional[str] = None,
tags: Optional[list[str]] = None,
long_form: bool = False,
) -> "GenerationJob":
"""Create a new generation job."""
return cls(
@@ -79,6 +80,7 @@ class GenerationJob:
"top_p": top_p,
"cfg_coef": cfg_coef,
"seed": seed,
"long_form": long_form,
},
conditioning=conditioning or {},
project_id=project_id,
@@ -352,6 +354,7 @@ class BatchProcessor:
preset_used=job.preset_used,
tags=job.tags,
progress_callback=progress_callback,
long_form=job.parameters.get("long_form", False),
)
job.status = JobStatus.COMPLETED

View File

@@ -71,6 +71,7 @@ class GenerationService:
preset_used: Optional[str] = None,
tags: Optional[list[str]] = None,
progress_callback: Optional[Callable[[float, str], None]] = None,
long_form: bool = False,
) -> tuple[GenerationResult, Generation]:
"""Generate audio and save to database.
@@ -89,6 +90,7 @@ class GenerationService:
preset_used: Name of preset used (for metadata)
tags: Optional tags for organization
progress_callback: Optional callback for progress updates
long_form: Enable long-form generation for durations > 30s (uses continuation)
Returns:
Tuple of (GenerationResult, Generation database record)
@@ -133,12 +135,31 @@ class GenerationService:
if progress_callback:
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
# Determine if we should use long-form generation
use_long_form = long_form and duration > 30 and model_id == "musicgen"
@self.oom_handler.with_oom_recovery
def do_generation() -> GenerationResult:
with self.registry.get_model(model_id, actual_variant) as model:
if progress_callback:
progress_callback(0.4, "Generating audio...")
return model.generate(request)
if use_long_form and hasattr(model, 'generate_long'):
if progress_callback:
progress_callback(0.4, f"Generating long-form audio ({duration}s)...")
# Create progress wrapper that maps to our progress range
def long_progress(p: float) -> None:
if progress_callback:
# Map 0-1 progress to 0.4-0.8 range
progress_callback(0.4 + p * 0.4, f"Generating segment... ({int(p*100)}%)")
return model.generate_long(
request,
target_duration=duration,
progress_callback=long_progress,
)
else:
if progress_callback:
progress_callback(0.4, "Generating audio...")
return model.generate(request)
result = do_generation()
@@ -159,6 +180,7 @@ class GenerationService:
"top_k": top_k,
"top_p": top_p,
"cfg_coef": cfg_coef,
"long_form": use_long_form,
},
project_id=project_id,
preset_used=preset_used,

View File

@@ -86,6 +86,13 @@ def create_musicgen_tab(
# Parameters
gr.Markdown("### Parameters")
# Long-form generation checkbox
long_form_checkbox = gr.Checkbox(
label="Long-form generation",
value=False,
info="Enable for tracks > 30s (uses continuation, takes longer)",
)
duration_slider = gr.Slider(
minimum=1,
maximum=30,
@@ -94,6 +101,13 @@ def create_musicgen_tab(
label="Duration (seconds)",
)
# Info text for long-form mode
long_form_info = gr.Markdown(
value="*Long-form mode uses continuation to generate extended tracks. "
"Generation time increases significantly for longer durations.*",
visible=False,
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
@@ -177,6 +191,25 @@ def create_musicgen_tab(
outputs=[melody_section],
)
# Long-form checkbox - update duration slider max
def on_long_form_change(long_form: bool):
if long_form:
return (
gr.update(maximum=300, label="Duration (seconds) - up to 5 min"),
gr.update(visible=True),
)
else:
return (
gr.update(maximum=30, label="Duration (seconds)"),
gr.update(visible=False),
)
long_form_checkbox.change(
fn=on_long_form_change,
inputs=[long_form_checkbox],
outputs=[duration_slider, long_form_info],
)
# Prompt suggestions
for btn, suggestion in suggestion_btns:
btn.click(
@@ -186,7 +219,7 @@ def create_musicgen_tab(
# Generate
async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form
):
if not prompt:
yield (
@@ -199,9 +232,14 @@ def create_musicgen_tab(
)
return
# Generate status message
status_msg = "🔄 Generating..."
if long_form and duration > 30:
status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..."
# Update status
yield (
gr.update(value="🔄 Generating..."),
gr.update(value=status_msg),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
@@ -225,6 +263,7 @@ def create_musicgen_tab(
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
long_form=long_form,
)
yield (
@@ -258,6 +297,7 @@ def create_musicgen_tab(
top_p_slider,
seed_input,
melody_input,
long_form_checkbox,
],
outputs=[
output["status"],
@@ -270,7 +310,7 @@ def create_musicgen_tab(
)
# Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form):
if not prompt:
return "Please enter a prompt"
@@ -289,6 +329,7 @@ def create_musicgen_tab(
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
long_form=long_form,
)
return f"✅ Added to queue: {job.id}"
@@ -305,6 +346,7 @@ def create_musicgen_tab(
top_p_slider,
seed_input,
melody_input,
long_form_checkbox,
],
outputs=[output["status"]],
)
@@ -314,6 +356,7 @@ def create_musicgen_tab(
"variant": variant_dropdown,
"prompt": prompt_input,
"melody": melody_input,
"long_form": long_form_checkbox,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,