Files
audiocraft-ui/src/ui/tabs/musicgen_tab.py
Sebastian Krüger 503b3ce473 Fix model registration + add long-form generation support
- Fix critical bug: register_all_adapters() now called in main.py
- Add generate_long() method to MusicGen adapter for continuation-based
  extended tracks (up to 5 minutes)
- Add long-form checkbox in UI that unlocks duration slider to 300s
- Update GenerationService to route to generate_long when duration > 30s
- Update BatchProcessor to support long_form parameter

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-27 13:22:06 +01:00

370 lines
13 KiB
Python

"""MusicGen tab for text-to-music generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
MUSICGEN_VARIANTS = [
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"},
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"},
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"},
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"},
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"},
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"},
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"},
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"},
]
def create_musicgen_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create MusicGen generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("musicgen", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
with gr.Column():
gr.Markdown("## 🎵 MusicGen")
gr.Markdown("Generate music from text descriptions")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the music you want to generate...",
lines=3,
max_lines=5,
)
# Prompt suggestions
with gr.Accordion("Prompt Suggestions", open=False):
suggestion_btns = []
for i, suggestion in enumerate(suggestions[:4]):
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
suggestion_btns.append((btn, suggestion))
# Melody conditioning (for melody variants)
with gr.Group(visible=False) as melody_section:
gr.Markdown("### Melody Conditioning")
melody_input = gr.Audio(
label="Reference Melody",
type="filepath",
sources=["upload", "microphone"],
)
gr.Markdown("*Upload audio to condition generation on its melody*")
# Parameters
gr.Markdown("### Parameters")
# Long-form generation checkbox
long_form_checkbox = gr.Checkbox(
label="Long-form generation",
value=False,
info="Enable for tracks > 30s (uses continuation, takes longer)",
)
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
# Info text for long-form mode
long_form_info = gr.Markdown(
value="*Long-form mode uses continuation to generate extended tracks. "
"Generation time increases significantly for longer durations.*",
visible=False,
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
# Custom preset - don't change values
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Variant change - show/hide melody section
def on_variant_change(variant: str):
show_melody = "melody" in variant.lower()
return gr.update(visible=show_melody)
variant_dropdown.change(
fn=on_variant_change,
inputs=[variant_dropdown],
outputs=[melody_section],
)
# Long-form checkbox - update duration slider max
def on_long_form_change(long_form: bool):
if long_form:
return (
gr.update(maximum=300, label="Duration (seconds) - up to 5 min"),
gr.update(visible=True),
)
else:
return (
gr.update(maximum=30, label="Duration (seconds)"),
gr.update(visible=False),
)
long_form_checkbox.change(
fn=on_long_form_change,
inputs=[long_form_checkbox],
outputs=[duration_slider, long_form_info],
)
# Prompt suggestions
for btn, suggestion in suggestion_btns:
btn.click(
fn=lambda s=suggestion: s,
outputs=[prompt_input],
)
# Generate
async def do_generate(
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form
):
if not prompt:
yield (
gr.update(value="Please enter a prompt"),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
return
# Generate status message
status_msg = "🔄 Generating..."
if long_form and duration > 30:
status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..."
# Update status
yield (
gr.update(value=status_msg),
gr.update(visible=True, value=0),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
conditioning = {}
if melody:
conditioning["melody"] = melody
result, generation = await generate_fn(
model_id="musicgen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
long_form=long_form,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
melody_input,
long_form_checkbox,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form):
if not prompt:
return "Please enter a prompt"
conditioning = {}
if melody:
conditioning["melody"] = melody
job = add_to_queue_fn(
model_id="musicgen",
variant=variant,
prompts=[prompt],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
long_form=long_form,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
melody_input,
long_form_checkbox,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"melody": melody_input,
"long_form": long_form_checkbox,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}