Fix model registration + add long-form generation support
- Fix critical bug: register_all_adapters() now called in main.py - Add generate_long() method to MusicGen adapter for continuation-based extended tracks (up to 5 minutes) - Add long-form checkbox in UI that unlocks duration slider to 300s - Update GenerationService to route to generate_long when duration > 30s - Update BatchProcessor to support long_form parameter 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
5
main.py
5
main.py
@@ -17,6 +17,7 @@ from src.services.batch_processor import BatchProcessor
|
||||
from src.services.project_service import ProjectService
|
||||
from src.storage.database import Database
|
||||
from src.ui.app import create_app
|
||||
from src.models import register_all_adapters
|
||||
|
||||
|
||||
# Configure logging
|
||||
@@ -56,6 +57,10 @@ async def initialize_services():
|
||||
idle_timeout_minutes=settings.idle_unload_minutes,
|
||||
)
|
||||
|
||||
# Register all model adapters
|
||||
logger.info("Registering model adapters...")
|
||||
register_all_adapters(model_registry)
|
||||
|
||||
# Initialize services
|
||||
logger.info("Initializing services...")
|
||||
generation_service = GenerationService(
|
||||
|
||||
@@ -3,8 +3,9 @@
|
||||
import gc
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from src.core.base_model import (
|
||||
@@ -13,6 +14,7 @@ from src.core.base_model import (
|
||||
GenerationRequest,
|
||||
GenerationResult,
|
||||
)
|
||||
from src.core.audio_utils import concatenate_audio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -279,6 +281,152 @@ class MusicGenAdapter(BaseAudioModel):
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def generate_long(
|
||||
self,
|
||||
request: GenerationRequest,
|
||||
target_duration: float,
|
||||
overlap_seconds: float = 15.0,
|
||||
crossfade_ms: float = 500.0,
|
||||
progress_callback: Optional[Callable[[float], None]] = None,
|
||||
) -> GenerationResult:
|
||||
"""Generate long-form audio via continuation.
|
||||
|
||||
Uses AudioCraft's continuation API to generate audio longer than 30s
|
||||
by generating segments and using the tail of each as conditioning
|
||||
for the next segment.
|
||||
|
||||
Args:
|
||||
request: Base generation request (prompts, temperature, etc.)
|
||||
target_duration: Total target duration in seconds
|
||||
overlap_seconds: Seconds of overlap between segments for continuation
|
||||
crossfade_ms: Crossfade duration between segments in milliseconds
|
||||
progress_callback: Optional callback receiving progress (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
GenerationResult with concatenated long-form audio
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError(f"Model {self.model_id}/{self.variant} is not loaded")
|
||||
|
||||
# Segment duration (max 30s for MusicGen)
|
||||
segment_duration = min(30.0, self.max_duration)
|
||||
|
||||
# Set random seed for reproducibility
|
||||
seed = request.seed if request.seed is not None else random.randint(0, 2**32 - 1)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
# Configure generation parameters
|
||||
self._model.set_generation_params(
|
||||
duration=segment_duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Starting long-form generation: target={target_duration}s, "
|
||||
f"segment={segment_duration}s, overlap={overlap_seconds}s"
|
||||
)
|
||||
|
||||
segments = []
|
||||
generated_duration = 0.0
|
||||
|
||||
# First segment - standard generation
|
||||
logger.info("Generating first segment...")
|
||||
with torch.inference_mode():
|
||||
first_audio = self._model.generate(request.prompts)
|
||||
|
||||
segments.append(first_audio[0].cpu().numpy()) # [channels, samples]
|
||||
generated_duration = segment_duration
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(generated_duration / target_duration)
|
||||
|
||||
# Continuation segments
|
||||
segment_num = 1
|
||||
while generated_duration < target_duration:
|
||||
segment_num += 1
|
||||
logger.info(
|
||||
f"Generating segment {segment_num} "
|
||||
f"(progress: {generated_duration:.1f}/{target_duration:.1f}s)"
|
||||
)
|
||||
|
||||
# Get tail of previous segment as conditioning
|
||||
prev_audio = segments[-1]
|
||||
overlap_samples = int(overlap_seconds * self.sample_rate)
|
||||
|
||||
# Ensure we don't try to use more overlap than exists
|
||||
overlap_samples = min(overlap_samples, prev_audio.shape[-1])
|
||||
overlap_audio = prev_audio[:, -overlap_samples:]
|
||||
|
||||
# Convert back to tensor for continuation
|
||||
overlap_tensor = torch.tensor(overlap_audio).unsqueeze(0).to(self._device)
|
||||
|
||||
# Generate continuation
|
||||
with torch.inference_mode():
|
||||
continuation = self._model.generate_continuation(
|
||||
prompt=overlap_tensor,
|
||||
prompt_sample_rate=self.sample_rate,
|
||||
descriptions=request.prompts,
|
||||
progress=False,
|
||||
)
|
||||
|
||||
# continuation shape: [batch, channels, samples]
|
||||
# Trim the overlap portion from the start
|
||||
new_audio = continuation[0, :, overlap_samples:].cpu().numpy()
|
||||
|
||||
if new_audio.shape[-1] > 0:
|
||||
segments.append(new_audio)
|
||||
generated_duration += (segment_duration - overlap_seconds)
|
||||
else:
|
||||
logger.warning("Empty continuation segment, stopping generation")
|
||||
break
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(min(1.0, generated_duration / target_duration))
|
||||
|
||||
# Concatenate all segments with crossfade
|
||||
logger.info(f"Concatenating {len(segments)} segments with {crossfade_ms}ms crossfade...")
|
||||
|
||||
final_audio = concatenate_audio(
|
||||
segments,
|
||||
sample_rate=self.sample_rate,
|
||||
crossfade_ms=crossfade_ms,
|
||||
)
|
||||
|
||||
# Convert back to tensor with batch dimension
|
||||
final_tensor = torch.tensor(final_audio).unsqueeze(0) # [1, channels, samples]
|
||||
actual_duration = final_tensor.shape[-1] / self.sample_rate
|
||||
|
||||
logger.info(
|
||||
f"Long-form generation complete: {actual_duration:.2f}s from {len(segments)} segments"
|
||||
)
|
||||
|
||||
return GenerationResult(
|
||||
audio=final_tensor,
|
||||
sample_rate=self.sample_rate,
|
||||
duration=actual_duration,
|
||||
model_id=self.model_id,
|
||||
variant=self._variant,
|
||||
parameters={
|
||||
"duration": request.duration,
|
||||
"target_duration": target_duration,
|
||||
"temperature": request.temperature,
|
||||
"top_k": request.top_k,
|
||||
"top_p": request.top_p,
|
||||
"cfg_coef": request.cfg_coef,
|
||||
"prompts": request.prompts,
|
||||
"long_form": True,
|
||||
"segments": len(segments),
|
||||
"overlap_seconds": overlap_seconds,
|
||||
"crossfade_ms": crossfade_ms,
|
||||
},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def get_default_params(self) -> dict[str, Any]:
|
||||
"""Get default generation parameters."""
|
||||
return {
|
||||
|
||||
@@ -65,6 +65,7 @@ class GenerationJob:
|
||||
project_id: Optional[str] = None,
|
||||
preset_used: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
long_form: bool = False,
|
||||
) -> "GenerationJob":
|
||||
"""Create a new generation job."""
|
||||
return cls(
|
||||
@@ -79,6 +80,7 @@ class GenerationJob:
|
||||
"top_p": top_p,
|
||||
"cfg_coef": cfg_coef,
|
||||
"seed": seed,
|
||||
"long_form": long_form,
|
||||
},
|
||||
conditioning=conditioning or {},
|
||||
project_id=project_id,
|
||||
@@ -352,6 +354,7 @@ class BatchProcessor:
|
||||
preset_used=job.preset_used,
|
||||
tags=job.tags,
|
||||
progress_callback=progress_callback,
|
||||
long_form=job.parameters.get("long_form", False),
|
||||
)
|
||||
|
||||
job.status = JobStatus.COMPLETED
|
||||
|
||||
@@ -71,6 +71,7 @@ class GenerationService:
|
||||
preset_used: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
long_form: bool = False,
|
||||
) -> tuple[GenerationResult, Generation]:
|
||||
"""Generate audio and save to database.
|
||||
|
||||
@@ -89,6 +90,7 @@ class GenerationService:
|
||||
preset_used: Name of preset used (for metadata)
|
||||
tags: Optional tags for organization
|
||||
progress_callback: Optional callback for progress updates
|
||||
long_form: Enable long-form generation for durations > 30s (uses continuation)
|
||||
|
||||
Returns:
|
||||
Tuple of (GenerationResult, Generation database record)
|
||||
@@ -133,12 +135,31 @@ class GenerationService:
|
||||
if progress_callback:
|
||||
progress_callback(0.2, f"Loading {model_id}/{actual_variant}...")
|
||||
|
||||
# Determine if we should use long-form generation
|
||||
use_long_form = long_form and duration > 30 and model_id == "musicgen"
|
||||
|
||||
@self.oom_handler.with_oom_recovery
|
||||
def do_generation() -> GenerationResult:
|
||||
with self.registry.get_model(model_id, actual_variant) as model:
|
||||
if progress_callback:
|
||||
progress_callback(0.4, "Generating audio...")
|
||||
return model.generate(request)
|
||||
if use_long_form and hasattr(model, 'generate_long'):
|
||||
if progress_callback:
|
||||
progress_callback(0.4, f"Generating long-form audio ({duration}s)...")
|
||||
|
||||
# Create progress wrapper that maps to our progress range
|
||||
def long_progress(p: float) -> None:
|
||||
if progress_callback:
|
||||
# Map 0-1 progress to 0.4-0.8 range
|
||||
progress_callback(0.4 + p * 0.4, f"Generating segment... ({int(p*100)}%)")
|
||||
|
||||
return model.generate_long(
|
||||
request,
|
||||
target_duration=duration,
|
||||
progress_callback=long_progress,
|
||||
)
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(0.4, "Generating audio...")
|
||||
return model.generate(request)
|
||||
|
||||
result = do_generation()
|
||||
|
||||
@@ -159,6 +180,7 @@ class GenerationService:
|
||||
"top_k": top_k,
|
||||
"top_p": top_p,
|
||||
"cfg_coef": cfg_coef,
|
||||
"long_form": use_long_form,
|
||||
},
|
||||
project_id=project_id,
|
||||
preset_used=preset_used,
|
||||
|
||||
@@ -86,6 +86,13 @@ def create_musicgen_tab(
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
# Long-form generation checkbox
|
||||
long_form_checkbox = gr.Checkbox(
|
||||
label="Long-form generation",
|
||||
value=False,
|
||||
info="Enable for tracks > 30s (uses continuation, takes longer)",
|
||||
)
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
@@ -94,6 +101,13 @@ def create_musicgen_tab(
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
# Info text for long-form mode
|
||||
long_form_info = gr.Markdown(
|
||||
value="*Long-form mode uses continuation to generate extended tracks. "
|
||||
"Generation time increases significantly for longer durations.*",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
@@ -177,6 +191,25 @@ def create_musicgen_tab(
|
||||
outputs=[melody_section],
|
||||
)
|
||||
|
||||
# Long-form checkbox - update duration slider max
|
||||
def on_long_form_change(long_form: bool):
|
||||
if long_form:
|
||||
return (
|
||||
gr.update(maximum=300, label="Duration (seconds) - up to 5 min"),
|
||||
gr.update(visible=True),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.update(maximum=30, label="Duration (seconds)"),
|
||||
gr.update(visible=False),
|
||||
)
|
||||
|
||||
long_form_checkbox.change(
|
||||
fn=on_long_form_change,
|
||||
inputs=[long_form_checkbox],
|
||||
outputs=[duration_slider, long_form_info],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
@@ -186,7 +219,7 @@ def create_musicgen_tab(
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form
|
||||
):
|
||||
if not prompt:
|
||||
yield (
|
||||
@@ -199,9 +232,14 @@ def create_musicgen_tab(
|
||||
)
|
||||
return
|
||||
|
||||
# Generate status message
|
||||
status_msg = "🔄 Generating..."
|
||||
if long_form and duration > 30:
|
||||
status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..."
|
||||
|
||||
# Update status
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(value=status_msg),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
@@ -225,6 +263,7 @@ def create_musicgen_tab(
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
long_form=long_form,
|
||||
)
|
||||
|
||||
yield (
|
||||
@@ -258,6 +297,7 @@ def create_musicgen_tab(
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
long_form_checkbox,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
@@ -270,7 +310,7 @@ def create_musicgen_tab(
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
@@ -289,6 +329,7 @@ def create_musicgen_tab(
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
long_form=long_form,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
@@ -305,6 +346,7 @@ def create_musicgen_tab(
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
long_form_checkbox,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
@@ -314,6 +356,7 @@ def create_musicgen_tab(
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"melody": melody_input,
|
||||
"long_form": long_form_checkbox,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
|
||||
Reference in New Issue
Block a user