2025-11-25 19:34:27 +01:00
|
|
|
"""MusicGen Style tab for style-conditioned generation."""
|
|
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
|
|
|
|
|
|
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
|
|
|
|
from src.ui.components.audio_player import create_generation_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
STYLE_VARIANTS = [
|
|
|
|
|
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_style_tab(
|
|
|
|
|
generate_fn: Callable[..., Any],
|
|
|
|
|
add_to_queue_fn: Callable[..., Any],
|
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
|
"""Create MusicGen Style generation tab.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
generate_fn: Function to call for generation
|
|
|
|
|
add_to_queue_fn: Function to add to queue
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dictionary with component references
|
|
|
|
|
"""
|
|
|
|
|
presets = DEFAULT_PRESETS.get("musicgen-style", [])
|
|
|
|
|
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
|
|
|
|
|
|
|
|
|
with gr.Column():
|
|
|
|
|
gr.Markdown("## 🎨 MusicGen Style")
|
|
|
|
|
gr.Markdown("Generate music conditioned on the style of reference audio")
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
# Left column - inputs
|
|
|
|
|
with gr.Column(scale=2):
|
|
|
|
|
# Preset selector
|
|
|
|
|
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
|
|
|
|
preset_dropdown = gr.Dropdown(
|
|
|
|
|
label="Preset",
|
|
|
|
|
choices=preset_choices,
|
|
|
|
|
value=presets[0]["id"] if presets else "custom",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Model variant
|
|
|
|
|
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
|
|
|
|
|
variant_dropdown = gr.Dropdown(
|
|
|
|
|
label="Model Variant",
|
|
|
|
|
choices=variant_choices,
|
|
|
|
|
value="medium",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Prompt input
|
|
|
|
|
prompt_input = gr.Textbox(
|
|
|
|
|
label="Text Prompt",
|
|
|
|
|
placeholder="Describe additional characteristics for the music...",
|
|
|
|
|
lines=3,
|
|
|
|
|
max_lines=5,
|
|
|
|
|
info="Optional: combine with style conditioning",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Style conditioning (required)
|
|
|
|
|
gr.Markdown("### Style Conditioning")
|
|
|
|
|
gr.Markdown("*Upload reference audio to extract musical style*")
|
|
|
|
|
|
|
|
|
|
style_input = gr.Audio(
|
|
|
|
|
label="Style Reference",
|
|
|
|
|
type="filepath",
|
|
|
|
|
sources=["upload", "microphone"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
style_info = gr.Markdown(
|
|
|
|
|
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Parameters
|
|
|
|
|
gr.Markdown("### Parameters")
|
|
|
|
|
|
|
|
|
|
duration_slider = gr.Slider(
|
|
|
|
|
minimum=1,
|
|
|
|
|
maximum=30,
|
|
|
|
|
value=10,
|
|
|
|
|
step=1,
|
|
|
|
|
label="Duration (seconds)",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Accordion("Advanced Parameters", open=False):
|
|
|
|
|
with gr.Row():
|
|
|
|
|
temperature_slider = gr.Slider(
|
|
|
|
|
minimum=0.0,
|
|
|
|
|
maximum=2.0,
|
|
|
|
|
value=1.0,
|
|
|
|
|
step=0.05,
|
|
|
|
|
label="Temperature",
|
|
|
|
|
)
|
|
|
|
|
cfg_slider = gr.Slider(
|
|
|
|
|
minimum=1.0,
|
|
|
|
|
maximum=10.0,
|
|
|
|
|
value=3.0,
|
|
|
|
|
step=0.5,
|
|
|
|
|
label="CFG Coefficient",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
top_k_slider = gr.Slider(
|
|
|
|
|
minimum=0,
|
|
|
|
|
maximum=500,
|
|
|
|
|
value=250,
|
|
|
|
|
step=10,
|
|
|
|
|
label="Top-K",
|
|
|
|
|
)
|
|
|
|
|
top_p_slider = gr.Slider(
|
|
|
|
|
minimum=0.0,
|
|
|
|
|
maximum=1.0,
|
|
|
|
|
value=0.0,
|
|
|
|
|
step=0.05,
|
|
|
|
|
label="Top-P",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
with gr.Row():
|
|
|
|
|
seed_input = gr.Number(
|
|
|
|
|
value=None,
|
|
|
|
|
label="Seed (empty = random)",
|
|
|
|
|
precision=0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Generate buttons
|
|
|
|
|
with gr.Row():
|
|
|
|
|
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
|
|
|
|
|
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
|
|
|
|
|
|
|
|
|
# Right column - output
|
|
|
|
|
with gr.Column(scale=3):
|
|
|
|
|
output = create_generation_output()
|
|
|
|
|
|
|
|
|
|
# Event handlers
|
|
|
|
|
|
|
|
|
|
# Preset change
|
|
|
|
|
def apply_preset(preset_id: str):
|
|
|
|
|
for p in presets:
|
|
|
|
|
if p["id"] == preset_id:
|
|
|
|
|
params = p["parameters"]
|
|
|
|
|
return (
|
|
|
|
|
params.get("duration", 10),
|
|
|
|
|
params.get("temperature", 1.0),
|
|
|
|
|
params.get("cfg_coef", 3.0),
|
|
|
|
|
params.get("top_k", 250),
|
|
|
|
|
params.get("top_p", 0.0),
|
|
|
|
|
)
|
|
|
|
|
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
preset_dropdown.change(
|
|
|
|
|
fn=apply_preset,
|
|
|
|
|
inputs=[preset_dropdown],
|
|
|
|
|
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Generate
|
|
|
|
|
async def do_generate(
|
|
|
|
|
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
|
|
|
|
|
):
|
|
|
|
|
if not style_audio:
|
2025-11-26 23:16:39 +01:00
|
|
|
yield (
|
2025-11-25 19:34:27 +01:00
|
|
|
gr.update(value="Please upload a style reference audio"),
|
2025-11-27 17:04:55 +01:00
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
2025-11-25 19:34:27 +01:00
|
|
|
)
|
2025-11-26 23:16:39 +01:00
|
|
|
return
|
2025-11-25 19:34:27 +01:00
|
|
|
|
|
|
|
|
yield (
|
|
|
|
|
gr.update(value="🔄 Generating..."),
|
|
|
|
|
gr.update(visible=True, value=0),
|
2025-11-27 17:04:55 +01:00
|
|
|
gr.skip(), # Don't update audio yet
|
|
|
|
|
gr.skip(), # Don't update waveform
|
|
|
|
|
gr.skip(), # Don't update duration
|
|
|
|
|
gr.skip(), # Don't update seed
|
2025-11-25 19:34:27 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
conditioning = {"style": style_audio}
|
|
|
|
|
|
|
|
|
|
result, generation = await generate_fn(
|
|
|
|
|
model_id="musicgen-style",
|
|
|
|
|
variant=variant,
|
|
|
|
|
prompts=[prompt] if prompt else [""],
|
|
|
|
|
duration=duration,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
top_k=int(top_k),
|
|
|
|
|
top_p=top_p,
|
|
|
|
|
cfg_coef=cfg_coef,
|
|
|
|
|
seed=int(seed) if seed else None,
|
|
|
|
|
conditioning=conditioning,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield (
|
|
|
|
|
gr.update(value="✅ Generation complete!"),
|
|
|
|
|
gr.update(visible=False),
|
|
|
|
|
gr.update(value=generation.audio_path),
|
|
|
|
|
gr.update(),
|
|
|
|
|
gr.update(value=f"{result.duration:.2f}s"),
|
|
|
|
|
gr.update(value=str(result.seed)),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
yield (
|
|
|
|
|
gr.update(value=f"❌ Error: {str(e)}"),
|
|
|
|
|
gr.update(visible=False),
|
2025-11-27 17:04:55 +01:00
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
|
|
|
|
gr.skip(),
|
2025-11-25 19:34:27 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
generate_btn.click(
|
|
|
|
|
fn=do_generate,
|
|
|
|
|
inputs=[
|
|
|
|
|
prompt_input,
|
|
|
|
|
variant_dropdown,
|
|
|
|
|
style_input,
|
|
|
|
|
duration_slider,
|
|
|
|
|
temperature_slider,
|
|
|
|
|
cfg_slider,
|
|
|
|
|
top_k_slider,
|
|
|
|
|
top_p_slider,
|
|
|
|
|
seed_input,
|
|
|
|
|
],
|
|
|
|
|
outputs=[
|
|
|
|
|
output["status"],
|
|
|
|
|
output["progress"],
|
|
|
|
|
output["player"]["audio"],
|
|
|
|
|
output["player"]["waveform"],
|
|
|
|
|
output["player"]["duration"],
|
|
|
|
|
output["player"]["seed"],
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Add to queue
|
|
|
|
|
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
|
|
|
|
|
if not style_audio:
|
|
|
|
|
return "Please upload a style reference audio"
|
|
|
|
|
|
|
|
|
|
conditioning = {"style": style_audio}
|
|
|
|
|
|
|
|
|
|
job = add_to_queue_fn(
|
|
|
|
|
model_id="musicgen-style",
|
|
|
|
|
variant=variant,
|
|
|
|
|
prompts=[prompt] if prompt else [""],
|
|
|
|
|
duration=duration,
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
top_k=int(top_k),
|
|
|
|
|
top_p=top_p,
|
|
|
|
|
cfg_coef=cfg_coef,
|
|
|
|
|
seed=int(seed) if seed else None,
|
|
|
|
|
conditioning=conditioning,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return f"✅ Added to queue: {job.id}"
|
|
|
|
|
|
|
|
|
|
queue_btn.click(
|
|
|
|
|
fn=do_add_queue,
|
|
|
|
|
inputs=[
|
|
|
|
|
prompt_input,
|
|
|
|
|
variant_dropdown,
|
|
|
|
|
style_input,
|
|
|
|
|
duration_slider,
|
|
|
|
|
temperature_slider,
|
|
|
|
|
cfg_slider,
|
|
|
|
|
top_k_slider,
|
|
|
|
|
top_p_slider,
|
|
|
|
|
seed_input,
|
|
|
|
|
],
|
|
|
|
|
outputs=[output["status"]],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"preset": preset_dropdown,
|
|
|
|
|
"variant": variant_dropdown,
|
|
|
|
|
"prompt": prompt_input,
|
|
|
|
|
"style": style_input,
|
|
|
|
|
"duration": duration_slider,
|
|
|
|
|
"temperature": temperature_slider,
|
|
|
|
|
"cfg_coef": cfg_slider,
|
|
|
|
|
"top_k": top_k_slider,
|
|
|
|
|
"top_p": top_p_slider,
|
|
|
|
|
"seed": seed_input,
|
|
|
|
|
"generate_btn": generate_btn,
|
|
|
|
|
"queue_btn": queue_btn,
|
|
|
|
|
"output": output,
|
|
|
|
|
}
|