"""MusicGen Style tab for style-conditioned generation.""" import gradio as gr from typing import Any, Callable, Optional from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS from src.ui.components.audio_player import create_generation_output STYLE_VARIANTS = [ {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"}, ] def create_style_tab( generate_fn: Callable[..., Any], add_to_queue_fn: Callable[..., Any], ) -> dict[str, Any]: """Create MusicGen Style generation tab. Args: generate_fn: Function to call for generation add_to_queue_fn: Function to add to queue Returns: Dictionary with component references """ presets = DEFAULT_PRESETS.get("musicgen-style", []) suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) with gr.Column(): gr.Markdown("## 🎨 MusicGen Style") gr.Markdown("Generate music conditioned on the style of reference audio") with gr.Row(): # Left column - inputs with gr.Column(scale=2): # Preset selector preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] preset_dropdown = gr.Dropdown( label="Preset", choices=preset_choices, value=presets[0]["id"] if presets else "custom", ) # Model variant variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS] variant_dropdown = gr.Dropdown( label="Model Variant", choices=variant_choices, value="medium", ) # Prompt input prompt_input = gr.Textbox( label="Text Prompt", placeholder="Describe additional characteristics for the music...", lines=3, max_lines=5, info="Optional: combine with style conditioning", ) # Style conditioning (required) gr.Markdown("### Style Conditioning") gr.Markdown("*Upload reference audio to extract musical style*") style_input = gr.Audio( label="Style Reference", type="filepath", sources=["upload", "microphone"], ) style_info = gr.Markdown( "*The model will learn the style (instrumentation, tempo, mood) from this audio*" ) # Parameters gr.Markdown("### Parameters") duration_slider = gr.Slider( minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)", ) with gr.Accordion("Advanced Parameters", open=False): with gr.Row(): temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Temperature", ) cfg_slider = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="CFG Coefficient", ) with gr.Row(): top_k_slider = gr.Slider( minimum=0, maximum=500, value=250, step=10, label="Top-K", ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Top-P", ) with gr.Row(): seed_input = gr.Number( value=None, label="Seed (empty = random)", precision=0, ) # Generate buttons with gr.Row(): generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2) queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) # Right column - output with gr.Column(scale=3): output = create_generation_output() # Event handlers # Preset change def apply_preset(preset_id: str): for p in presets: if p["id"] == preset_id: params = p["parameters"] return ( params.get("duration", 10), params.get("temperature", 1.0), params.get("cfg_coef", 3.0), params.get("top_k", 250), params.get("top_p", 0.0), ) return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() preset_dropdown.change( fn=apply_preset, inputs=[preset_dropdown], outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], ) # Generate async def do_generate( prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed ): if not style_audio: yield ( gr.update(value="Please upload a style reference audio"), gr.skip(), gr.skip(), gr.skip(), gr.skip(), gr.skip(), ) return yield ( gr.update(value="🔄 Generating..."), gr.update(visible=True, value=0), gr.skip(), # Don't update audio yet gr.skip(), # Don't update waveform gr.skip(), # Don't update duration gr.skip(), # Don't update seed ) try: conditioning = {"style": style_audio} result, generation = await generate_fn( model_id="musicgen-style", variant=variant, prompts=[prompt] if prompt else [""], duration=duration, temperature=temperature, top_k=int(top_k), top_p=top_p, cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, ) yield ( gr.update(value="✅ Generation complete!"), gr.update(visible=False), gr.update(value=generation.audio_path), gr.update(), gr.update(value=f"{result.duration:.2f}s"), gr.update(value=str(result.seed)), ) except Exception as e: yield ( gr.update(value=f"❌ Error: {str(e)}"), gr.update(visible=False), gr.skip(), gr.skip(), gr.skip(), gr.skip(), ) generate_btn.click( fn=do_generate, inputs=[ prompt_input, variant_dropdown, style_input, duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, seed_input, ], outputs=[ output["status"], output["progress"], output["player"]["audio"], output["player"]["waveform"], output["player"]["duration"], output["player"]["seed"], ], ) # Add to queue def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed): if not style_audio: return "Please upload a style reference audio" conditioning = {"style": style_audio} job = add_to_queue_fn( model_id="musicgen-style", variant=variant, prompts=[prompt] if prompt else [""], duration=duration, temperature=temperature, top_k=int(top_k), top_p=top_p, cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, ) return f"✅ Added to queue: {job.id}" queue_btn.click( fn=do_add_queue, inputs=[ prompt_input, variant_dropdown, style_input, duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, seed_input, ], outputs=[output["status"]], ) return { "preset": preset_dropdown, "variant": variant_dropdown, "prompt": prompt_input, "style": style_input, "duration": duration_slider, "temperature": temperature_slider, "cfg_coef": cfg_slider, "top_k": top_k_slider, "top_p": top_p_slider, "seed": seed_input, "generate_btn": generate_btn, "queue_btn": queue_btn, "output": output, }