Files
audiocraft-ui/src/ui/tabs/style_tab.py
Sebastian Krüger f5030245e5 Fix audio AbortError race condition in all generation tabs
Use gr.skip() instead of gr.update() for audio/waveform/duration/seed
outputs during the initial "Generating..." status yield. This prevents
Gradio's StaticAudio component from starting a fetch that gets aborted
when the actual audio path arrives in the second yield.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-27 17:04:55 +01:00

294 lines
9.9 KiB
Python

"""MusicGen Style tab for style-conditioned generation."""
import gradio as gr
from typing import Any, Callable, Optional
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
from src.ui.components.audio_player import create_generation_output
STYLE_VARIANTS = [
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
]
def create_style_tab(
generate_fn: Callable[..., Any],
add_to_queue_fn: Callable[..., Any],
) -> dict[str, Any]:
"""Create MusicGen Style generation tab.
Args:
generate_fn: Function to call for generation
add_to_queue_fn: Function to add to queue
Returns:
Dictionary with component references
"""
presets = DEFAULT_PRESETS.get("musicgen-style", [])
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
with gr.Column():
gr.Markdown("## 🎨 MusicGen Style")
gr.Markdown("Generate music conditioned on the style of reference audio")
with gr.Row():
# Left column - inputs
with gr.Column(scale=2):
# Preset selector
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
preset_dropdown = gr.Dropdown(
label="Preset",
choices=preset_choices,
value=presets[0]["id"] if presets else "custom",
)
# Model variant
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
variant_dropdown = gr.Dropdown(
label="Model Variant",
choices=variant_choices,
value="medium",
)
# Prompt input
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Describe additional characteristics for the music...",
lines=3,
max_lines=5,
info="Optional: combine with style conditioning",
)
# Style conditioning (required)
gr.Markdown("### Style Conditioning")
gr.Markdown("*Upload reference audio to extract musical style*")
style_input = gr.Audio(
label="Style Reference",
type="filepath",
sources=["upload", "microphone"],
)
style_info = gr.Markdown(
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
)
# Parameters
gr.Markdown("### Parameters")
duration_slider = gr.Slider(
minimum=1,
maximum=30,
value=10,
step=1,
label="Duration (seconds)",
)
with gr.Accordion("Advanced Parameters", open=False):
with gr.Row():
temperature_slider = gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.0,
step=0.05,
label="Temperature",
)
cfg_slider = gr.Slider(
minimum=1.0,
maximum=10.0,
value=3.0,
step=0.5,
label="CFG Coefficient",
)
with gr.Row():
top_k_slider = gr.Slider(
minimum=0,
maximum=500,
value=250,
step=10,
label="Top-K",
)
top_p_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.05,
label="Top-P",
)
with gr.Row():
seed_input = gr.Number(
value=None,
label="Seed (empty = random)",
precision=0,
)
# Generate buttons
with gr.Row():
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
# Right column - output
with gr.Column(scale=3):
output = create_generation_output()
# Event handlers
# Preset change
def apply_preset(preset_id: str):
for p in presets:
if p["id"] == preset_id:
params = p["parameters"]
return (
params.get("duration", 10),
params.get("temperature", 1.0),
params.get("cfg_coef", 3.0),
params.get("top_k", 250),
params.get("top_p", 0.0),
)
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
preset_dropdown.change(
fn=apply_preset,
inputs=[preset_dropdown],
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
)
# Generate
async def do_generate(
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
):
if not style_audio:
yield (
gr.update(value="Please upload a style reference audio"),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
)
return
yield (
gr.update(value="🔄 Generating..."),
gr.update(visible=True, value=0),
gr.skip(), # Don't update audio yet
gr.skip(), # Don't update waveform
gr.skip(), # Don't update duration
gr.skip(), # Don't update seed
)
try:
conditioning = {"style": style_audio}
result, generation = await generate_fn(
model_id="musicgen-style",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
yield (
gr.update(value="✅ Generation complete!"),
gr.update(visible=False),
gr.update(value=generation.audio_path),
gr.update(),
gr.update(value=f"{result.duration:.2f}s"),
gr.update(value=str(result.seed)),
)
except Exception as e:
yield (
gr.update(value=f"❌ Error: {str(e)}"),
gr.update(visible=False),
gr.skip(),
gr.skip(),
gr.skip(),
gr.skip(),
)
generate_btn.click(
fn=do_generate,
inputs=[
prompt_input,
variant_dropdown,
style_input,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[
output["status"],
output["progress"],
output["player"]["audio"],
output["player"]["waveform"],
output["player"]["duration"],
output["player"]["seed"],
],
)
# Add to queue
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
if not style_audio:
return "Please upload a style reference audio"
conditioning = {"style": style_audio}
job = add_to_queue_fn(
model_id="musicgen-style",
variant=variant,
prompts=[prompt] if prompt else [""],
duration=duration,
temperature=temperature,
top_k=int(top_k),
top_p=top_p,
cfg_coef=cfg_coef,
seed=int(seed) if seed else None,
conditioning=conditioning,
)
return f"✅ Added to queue: {job.id}"
queue_btn.click(
fn=do_add_queue,
inputs=[
prompt_input,
variant_dropdown,
style_input,
duration_slider,
temperature_slider,
cfg_slider,
top_k_slider,
top_p_slider,
seed_input,
],
outputs=[output["status"]],
)
return {
"preset": preset_dropdown,
"variant": variant_dropdown,
"prompt": prompt_input,
"style": style_input,
"duration": duration_slider,
"temperature": temperature_slider,
"cfg_coef": cfg_slider,
"top_k": top_k_slider,
"top_p": top_p_slider,
"seed": seed_input,
"generate_btn": generate_btn,
"queue_btn": queue_btn,
"output": output,
}