"""MusicGen tab for text-to-music generation.""" import gradio as gr from typing import Any, Callable, Optional from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS from src.ui.components.audio_player import create_generation_output MUSICGEN_VARIANTS = [ {"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"}, {"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"}, {"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"}, {"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"}, {"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"}, {"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"}, {"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"}, {"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"}, ] def create_musicgen_tab( generate_fn: Callable[..., Any], add_to_queue_fn: Callable[..., Any], ) -> dict[str, Any]: """Create MusicGen generation tab. Args: generate_fn: Function to call for generation add_to_queue_fn: Function to add to queue Returns: Dictionary with component references """ presets = DEFAULT_PRESETS.get("musicgen", []) suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) with gr.Column(): gr.Markdown("## 🎵 MusicGen") gr.Markdown("Generate music from text descriptions") with gr.Row(): # Left column - inputs with gr.Column(scale=2): # Preset selector preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")] preset_dropdown = gr.Dropdown( label="Preset", choices=preset_choices, value=presets[0]["id"] if presets else "custom", ) # Model variant variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS] variant_dropdown = gr.Dropdown( label="Model Variant", choices=variant_choices, value="medium", ) # Prompt input prompt_input = gr.Textbox( label="Prompt", placeholder="Describe the music you want to generate...", lines=3, max_lines=5, ) # Prompt suggestions with gr.Accordion("Prompt Suggestions", open=False): suggestion_btns = [] for i, suggestion in enumerate(suggestions[:4]): btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary") suggestion_btns.append((btn, suggestion)) # Melody conditioning (for melody variants) with gr.Group(visible=False) as melody_section: gr.Markdown("### Melody Conditioning") melody_input = gr.Audio( label="Reference Melody", type="filepath", sources=["upload", "microphone"], ) gr.Markdown("*Upload audio to condition generation on its melody*") # Parameters gr.Markdown("### Parameters") # Long-form generation checkbox long_form_checkbox = gr.Checkbox( label="Long-form generation", value=False, info="Enable for tracks > 30s (uses continuation, takes longer)", ) duration_slider = gr.Slider( minimum=1, maximum=30, value=10, step=1, label="Duration (seconds)", ) # Info text for long-form mode long_form_info = gr.Markdown( value="*Long-form mode uses continuation to generate extended tracks. " "Generation time increases significantly for longer durations.*", visible=False, ) with gr.Accordion("Advanced Parameters", open=False): with gr.Row(): temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, value=1.0, step=0.05, label="Temperature", ) cfg_slider = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, step=0.5, label="CFG Coefficient", ) with gr.Row(): top_k_slider = gr.Slider( minimum=0, maximum=500, value=250, step=10, label="Top-K", ) top_p_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="Top-P", ) with gr.Row(): seed_input = gr.Number( value=None, label="Seed (empty = random)", precision=0, ) # Generate buttons with gr.Row(): generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2) queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1) # Right column - output with gr.Column(scale=3): output = create_generation_output() # Event handlers # Preset change def apply_preset(preset_id: str): for p in presets: if p["id"] == preset_id: params = p["parameters"] return ( params.get("duration", 10), params.get("temperature", 1.0), params.get("cfg_coef", 3.0), params.get("top_k", 250), params.get("top_p", 0.0), ) # Custom preset - don't change values return gr.update(), gr.update(), gr.update(), gr.update(), gr.update() preset_dropdown.change( fn=apply_preset, inputs=[preset_dropdown], outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider], ) # Variant change - show/hide melody section def on_variant_change(variant: str): show_melody = "melody" in variant.lower() return gr.update(visible=show_melody) variant_dropdown.change( fn=on_variant_change, inputs=[variant_dropdown], outputs=[melody_section], ) # Long-form checkbox - update duration slider max def on_long_form_change(long_form: bool): if long_form: return ( gr.update(maximum=300, label="Duration (seconds) - up to 5 min"), gr.update(visible=True), ) else: return ( gr.update(maximum=30, label="Duration (seconds)"), gr.update(visible=False), ) long_form_checkbox.change( fn=on_long_form_change, inputs=[long_form_checkbox], outputs=[duration_slider, long_form_info], ) # Prompt suggestions for btn, suggestion in suggestion_btns: btn.click( fn=lambda s=suggestion: s, outputs=[prompt_input], ) # Generate async def do_generate( prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form ): if not prompt: yield ( gr.update(value="Please enter a prompt"), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), ) return # Generate status message status_msg = "🔄 Generating..." if long_form and duration > 30: status_msg = f"🔄 Long-form generation ({duration}s, may take several minutes)..." # Update status yield ( gr.update(value=status_msg), gr.update(visible=True, value=0), gr.update(), gr.update(), gr.update(), gr.update(), ) try: conditioning = {} if melody: conditioning["melody"] = melody result, generation = await generate_fn( model_id="musicgen", variant=variant, prompts=[prompt], duration=duration, temperature=temperature, top_k=int(top_k), top_p=top_p, cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, long_form=long_form, ) yield ( gr.update(value="✅ Generation complete!"), gr.update(visible=False), gr.update(value=generation.audio_path), gr.update(), gr.update(value=f"{result.duration:.2f}s"), gr.update(value=str(result.seed)), ) except Exception as e: yield ( gr.update(value=f"❌ Error: {str(e)}"), gr.update(visible=False), gr.update(), gr.update(), gr.update(), gr.update(), ) generate_btn.click( fn=do_generate, inputs=[ prompt_input, variant_dropdown, duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, seed_input, melody_input, long_form_checkbox, ], outputs=[ output["status"], output["progress"], output["player"]["audio"], output["player"]["waveform"], output["player"]["duration"], output["player"]["seed"], ], ) # Add to queue def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody, long_form): if not prompt: return "Please enter a prompt" conditioning = {} if melody: conditioning["melody"] = melody job = add_to_queue_fn( model_id="musicgen", variant=variant, prompts=[prompt], duration=duration, temperature=temperature, top_k=int(top_k), top_p=top_p, cfg_coef=cfg_coef, seed=int(seed) if seed else None, conditioning=conditioning, long_form=long_form, ) return f"✅ Added to queue: {job.id}" queue_btn.click( fn=do_add_queue, inputs=[ prompt_input, variant_dropdown, duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, seed_input, melody_input, long_form_checkbox, ], outputs=[output["status"]], ) return { "preset": preset_dropdown, "variant": variant_dropdown, "prompt": prompt_input, "melody": melody_input, "long_form": long_form_checkbox, "duration": duration_slider, "temperature": temperature_slider, "cfg_coef": cfg_slider, "top_k": top_k_slider, "top_p": top_p_slider, "seed": seed_input, "generate_btn": generate_btn, "queue_btn": queue_btn, "output": output, }