Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation: - Gradio UI with tabs for all 5 model families (MusicGen, AudioGen, MAGNeT, MusicGen Style, JASCO) - REST API with FastAPI, OpenAPI docs, and API key auth - VRAM management with ComfyUI coexistence support - SQLite database for project/generation history - Batch processing queue for async generation - Docker deployment optimized for RunPod with RTX 4090 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
17
src/ui/tabs/__init__.py
Normal file
17
src/ui/tabs/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Model tabs for AudioCraft Studio."""
|
||||
|
||||
from src.ui.tabs.dashboard_tab import create_dashboard_tab
|
||||
from src.ui.tabs.musicgen_tab import create_musicgen_tab
|
||||
from src.ui.tabs.audiogen_tab import create_audiogen_tab
|
||||
from src.ui.tabs.magnet_tab import create_magnet_tab
|
||||
from src.ui.tabs.style_tab import create_style_tab
|
||||
from src.ui.tabs.jasco_tab import create_jasco_tab
|
||||
|
||||
__all__ = [
|
||||
"create_dashboard_tab",
|
||||
"create_musicgen_tab",
|
||||
"create_audiogen_tab",
|
||||
"create_magnet_tab",
|
||||
"create_style_tab",
|
||||
"create_jasco_tab",
|
||||
]
|
||||
283
src/ui/tabs/audiogen_tab.py
Normal file
283
src/ui/tabs/audiogen_tab.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""AudioGen tab for text-to-sound generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
AUDIOGEN_VARIANTS = [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, balanced quality/speed"},
|
||||
]
|
||||
|
||||
|
||||
def create_audiogen_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create AudioGen generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("audiogen", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("audiogen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🔊 AudioGen")
|
||||
gr.Markdown("Generate sound effects and environmental audio from text")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant (AudioGen only has medium)
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in AUDIOGEN_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the sound you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:6]):
|
||||
btn = gr.Button(suggestion[:50] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=5,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
info="AudioGen works best with shorter clips",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🔊 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 5),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
result, generation = await generate_fn(
|
||||
model_id="audiogen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="audiogen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
166
src/ui/tabs/dashboard_tab.py
Normal file
166
src/ui/tabs/dashboard_tab.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Dashboard tab - home page with model overview and quick actions."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
|
||||
MODEL_INFO = {
|
||||
"musicgen": {
|
||||
"name": "MusicGen",
|
||||
"icon": "🎵",
|
||||
"description": "Text-to-music generation with optional melody conditioning",
|
||||
"capabilities": ["Text prompts", "Melody conditioning", "Stereo output"],
|
||||
},
|
||||
"audiogen": {
|
||||
"name": "AudioGen",
|
||||
"icon": "🔊",
|
||||
"description": "Text-to-sound effects and environmental audio",
|
||||
"capabilities": ["Sound effects", "Ambiences", "Foley"],
|
||||
},
|
||||
"magnet": {
|
||||
"name": "MAGNeT",
|
||||
"icon": "⚡",
|
||||
"description": "Fast non-autoregressive music generation",
|
||||
"capabilities": ["Fast generation", "Music", "Sound effects"],
|
||||
},
|
||||
"musicgen-style": {
|
||||
"name": "MusicGen Style",
|
||||
"icon": "🎨",
|
||||
"description": "Style-conditioned music from reference audio",
|
||||
"capabilities": ["Style transfer", "Reference audio", "Text prompts"],
|
||||
},
|
||||
"jasco": {
|
||||
"name": "JASCO",
|
||||
"icon": "🎹",
|
||||
"description": "Chord and drum-conditioned music generation",
|
||||
"capabilities": ["Chord control", "Drum patterns", "Symbolic conditioning"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_dashboard_tab(
|
||||
get_queue_status: Callable[[], dict[str, Any]],
|
||||
get_recent_generations: Callable[[int], list[dict[str, Any]]],
|
||||
get_gpu_status: Callable[[], dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""Create dashboard tab with model overview and status.
|
||||
|
||||
Args:
|
||||
get_queue_status: Function to get generation queue status
|
||||
get_recent_generations: Function to get recent generations
|
||||
get_gpu_status: Function to get GPU status
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
|
||||
def refresh_dashboard():
|
||||
"""Refresh all dashboard data."""
|
||||
queue = get_queue_status()
|
||||
recent = get_recent_generations(5)
|
||||
gpu = get_gpu_status()
|
||||
|
||||
# Format queue status
|
||||
queue_size = queue.get("queue_size", 0)
|
||||
queue_text = f"**Queue:** {queue_size} job(s) pending"
|
||||
|
||||
# Format recent generations
|
||||
if recent:
|
||||
recent_items = []
|
||||
for gen in recent[:5]:
|
||||
model = gen.get("model", "unknown")
|
||||
prompt = gen.get("prompt", "")[:50]
|
||||
duration = gen.get("duration_seconds", 0)
|
||||
recent_items.append(f"• **{model}** ({duration:.0f}s): {prompt}...")
|
||||
recent_text = "\n".join(recent_items)
|
||||
else:
|
||||
recent_text = "No recent generations"
|
||||
|
||||
# Format GPU status
|
||||
used_gb = gpu.get("used_gb", 0)
|
||||
total_gb = gpu.get("total_gb", 24)
|
||||
util = gpu.get("utilization_percent", 0)
|
||||
gpu_text = f"**GPU:** {used_gb:.1f}/{total_gb:.1f} GB ({util:.0f}%)"
|
||||
|
||||
return queue_text, recent_text, gpu_text
|
||||
|
||||
with gr.Column():
|
||||
# Header
|
||||
gr.Markdown("# AudioCraft Studio")
|
||||
gr.Markdown("AI-powered music and sound generation")
|
||||
|
||||
# Status bar
|
||||
with gr.Row():
|
||||
queue_status = gr.Markdown("**Queue:** Loading...")
|
||||
gpu_status = gr.Markdown("**GPU:** Loading...")
|
||||
refresh_btn = gr.Button("🔄 Refresh", size="sm")
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
# Model cards
|
||||
gr.Markdown("## Models")
|
||||
|
||||
with gr.Row():
|
||||
# First row of cards
|
||||
for model_id in ["musicgen", "audiogen", "magnet"]:
|
||||
info = MODEL_INFO[model_id]
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||
gr.Markdown(info["description"])
|
||||
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||
gr.Button(
|
||||
f"Open {info['name']}",
|
||||
variant="primary",
|
||||
size="sm",
|
||||
elem_id=f"btn_{model_id}",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
# Second row of cards
|
||||
for model_id in ["musicgen-style", "jasco"]:
|
||||
info = MODEL_INFO[model_id]
|
||||
with gr.Column(scale=1):
|
||||
with gr.Group():
|
||||
gr.Markdown(f"### {info['icon']} {info['name']}")
|
||||
gr.Markdown(info["description"])
|
||||
gr.Markdown("**Features:** " + ", ".join(info["capabilities"]))
|
||||
gr.Button(
|
||||
f"Open {info['name']}",
|
||||
variant="primary",
|
||||
size="sm",
|
||||
elem_id=f"btn_{model_id}",
|
||||
)
|
||||
|
||||
# Empty column for balance
|
||||
with gr.Column(scale=1):
|
||||
pass
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
# Recent generations and queue
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("## Recent Generations")
|
||||
recent_list = gr.Markdown("Loading...")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("## Quick Actions")
|
||||
with gr.Group():
|
||||
gr.Button("📁 Browse Projects", variant="secondary")
|
||||
gr.Button("⚙️ Settings", variant="secondary")
|
||||
gr.Button("📖 API Documentation", variant="secondary")
|
||||
|
||||
# Refresh handler
|
||||
refresh_btn.click(
|
||||
fn=refresh_dashboard,
|
||||
outputs=[queue_status, recent_list, gpu_status],
|
||||
)
|
||||
|
||||
return {
|
||||
"queue_status": queue_status,
|
||||
"gpu_status": gpu_status,
|
||||
"recent_list": recent_list,
|
||||
"refresh_btn": refresh_btn,
|
||||
"refresh_fn": refresh_dashboard,
|
||||
}
|
||||
364
src/ui/tabs/jasco_tab.py
Normal file
364
src/ui/tabs/jasco_tab.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""JASCO tab for chord and drum-conditioned generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
JASCO_VARIANTS = [
|
||||
{"id": "chords", "name": "Chords", "vram_mb": 5000, "description": "Chord-conditioned generation"},
|
||||
{"id": "chords-drums", "name": "Chords + Drums", "vram_mb": 5500, "description": "Full symbolic conditioning"},
|
||||
]
|
||||
|
||||
# Common chord progressions
|
||||
CHORD_PRESETS = [
|
||||
{"name": "Pop I-V-vi-IV", "chords": "C G Am F"},
|
||||
{"name": "Jazz ii-V-I", "chords": "Dm7 G7 Cmaj7"},
|
||||
{"name": "Blues I-IV-V", "chords": "A7 D7 E7"},
|
||||
{"name": "Rock I-bVII-IV", "chords": "E D A"},
|
||||
{"name": "Minor i-VI-III-VII", "chords": "Am F C G"},
|
||||
]
|
||||
|
||||
|
||||
def create_jasco_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create JASCO generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("jasco", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎹 JASCO")
|
||||
gr.Markdown("Generate music conditioned on chords and drum patterns")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in JASCO_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="chords-drums",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Text Prompt",
|
||||
placeholder="Describe the music style, mood, instruments...",
|
||||
lines=2,
|
||||
max_lines=4,
|
||||
)
|
||||
|
||||
# Chord conditioning
|
||||
gr.Markdown("### Chord Progression")
|
||||
|
||||
chord_input = gr.Textbox(
|
||||
label="Chords",
|
||||
placeholder="C G Am F or Cmaj7 Dm7 G7 Cmaj7",
|
||||
lines=1,
|
||||
info="Space-separated chord symbols",
|
||||
)
|
||||
|
||||
# Chord presets
|
||||
with gr.Accordion("Chord Presets", open=False):
|
||||
chord_preset_btns = []
|
||||
with gr.Row():
|
||||
for cp in CHORD_PRESETS[:3]:
|
||||
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||
chord_preset_btns.append((btn, cp["chords"]))
|
||||
with gr.Row():
|
||||
for cp in CHORD_PRESETS[3:]:
|
||||
btn = gr.Button(cp["name"], size="sm", variant="secondary")
|
||||
chord_preset_btns.append((btn, cp["chords"]))
|
||||
|
||||
# Drum conditioning (for chords-drums variant)
|
||||
with gr.Group(visible=True) as drum_section:
|
||||
gr.Markdown("### Drum Pattern")
|
||||
|
||||
drum_input = gr.Audio(
|
||||
label="Drum Reference",
|
||||
type="filepath",
|
||||
sources=["upload"],
|
||||
)
|
||||
gr.Markdown("*Upload a drum loop to condition the rhythm*")
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
bpm_slider = gr.Slider(
|
||||
minimum=60,
|
||||
maximum=180,
|
||||
value=120,
|
||||
step=1,
|
||||
label="BPM",
|
||||
info="Tempo for chord timing",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎹 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("bpm", 120),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, bpm_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Variant change - show/hide drum section
|
||||
def on_variant_change(variant: str):
|
||||
show_drums = "drums" in variant.lower()
|
||||
return gr.update(visible=show_drums)
|
||||
|
||||
variant_dropdown.change(
|
||||
fn=on_variant_change,
|
||||
inputs=[variant_dropdown],
|
||||
outputs=[drum_section],
|
||||
)
|
||||
|
||||
# Chord presets
|
||||
for btn, chords in chord_preset_btns:
|
||||
btn.click(
|
||||
fn=lambda c=chords: c,
|
||||
outputs=[chord_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not chords:
|
||||
return (
|
||||
gr.update(value="Please enter a chord progression"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {
|
||||
"chords": chords,
|
||||
"bpm": bpm,
|
||||
}
|
||||
if drums and "drums" in variant.lower():
|
||||
conditioning["drums"] = drums
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="jasco",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
chord_input,
|
||||
drum_input,
|
||||
duration_slider,
|
||||
bpm_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, chords, drums, duration, bpm, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not chords:
|
||||
return "Please enter a chord progression"
|
||||
|
||||
conditioning = {
|
||||
"chords": chords,
|
||||
"bpm": bpm,
|
||||
}
|
||||
if drums and "drums" in variant.lower():
|
||||
conditioning["drums"] = drums
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="jasco",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
chord_input,
|
||||
drum_input,
|
||||
duration_slider,
|
||||
bpm_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"chords": chord_input,
|
||||
"drums": drum_input,
|
||||
"duration": duration_slider,
|
||||
"bpm": bpm_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
316
src/ui/tabs/magnet_tab.py
Normal file
316
src/ui/tabs/magnet_tab.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""MAGNeT tab for fast non-autoregressive generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
MAGNET_VARIANTS = [
|
||||
{"id": "small", "name": "Small Music", "vram_mb": 2000, "description": "Fast music, 300M params"},
|
||||
{"id": "medium", "name": "Medium Music", "vram_mb": 5000, "description": "Balanced music, 1.5B params"},
|
||||
{"id": "audio-small", "name": "Small Audio", "vram_mb": 2000, "description": "Fast sound effects"},
|
||||
{"id": "audio-medium", "name": "Medium Audio", "vram_mb": 5000, "description": "Balanced sound effects"},
|
||||
]
|
||||
|
||||
|
||||
def create_magnet_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MAGNeT generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("magnet", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", []) # Reuse music suggestions
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## ⚡ MAGNeT")
|
||||
gr.Markdown("Fast non-autoregressive music and sound generation")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MAGNET_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the music or sound you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:4]):
|
||||
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
gr.Markdown("*MAGNeT uses different sampling compared to MusicGen*")
|
||||
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=5.0,
|
||||
value=3.0,
|
||||
step=0.1,
|
||||
label="Temperature",
|
||||
info="Higher values recommended (3.0 default)",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=0,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
info="0 recommended for MAGNeT",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.9,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
info="0.9 recommended for MAGNeT",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
decoding_steps_slider = gr.Slider(
|
||||
minimum=10,
|
||||
maximum=100,
|
||||
value=20,
|
||||
step=5,
|
||||
label="Decoding Steps",
|
||||
info="More steps = better quality, slower",
|
||||
)
|
||||
span_arrangement = gr.Dropdown(
|
||||
label="Span Arrangement",
|
||||
choices=[("No Overlap", "nonoverlap"), ("Overlap", "stride1")],
|
||||
value="nonoverlap",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("⚡ Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 3.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 0),
|
||||
params.get("top_p", 0.9),
|
||||
params.get("decoding_steps", 20),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider, decoding_steps_slider],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
result, generation = await generate_fn(
|
||||
model_id="magnet",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
decoding_steps=int(decoding_steps),
|
||||
span_arrangement=span_arr,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
decoding_steps_slider,
|
||||
span_arrangement,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, decoding_steps, span_arr, seed):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="magnet",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
decoding_steps=int(decoding_steps),
|
||||
span_arrangement=span_arr,
|
||||
seed=int(seed) if seed else None,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
decoding_steps_slider,
|
||||
span_arrangement,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"decoding_steps": decoding_steps_slider,
|
||||
"span_arrangement": span_arrangement,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
325
src/ui/tabs/musicgen_tab.py
Normal file
325
src/ui/tabs/musicgen_tab.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""MusicGen tab for text-to-music generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
MUSICGEN_VARIANTS = [
|
||||
{"id": "small", "name": "Small", "vram_mb": 1500, "description": "Fast, 300M params"},
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "Balanced, 1.5B params"},
|
||||
{"id": "large", "name": "Large", "vram_mb": 10000, "description": "Best quality, 3.3B params"},
|
||||
{"id": "melody", "name": "Melody", "vram_mb": 5000, "description": "With melody conditioning"},
|
||||
{"id": "stereo-small", "name": "Stereo Small", "vram_mb": 1800, "description": "Stereo, 300M params"},
|
||||
{"id": "stereo-medium", "name": "Stereo Medium", "vram_mb": 6000, "description": "Stereo, 1.5B params"},
|
||||
{"id": "stereo-large", "name": "Stereo Large", "vram_mb": 12000, "description": "Stereo, 3.3B params"},
|
||||
{"id": "stereo-melody", "name": "Stereo Melody", "vram_mb": 6000, "description": "Stereo with melody"},
|
||||
]
|
||||
|
||||
|
||||
def create_musicgen_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MusicGen generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("musicgen", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎵 MusicGen")
|
||||
gr.Markdown("Generate music from text descriptions")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in MUSICGEN_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Prompt",
|
||||
placeholder="Describe the music you want to generate...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
with gr.Accordion("Prompt Suggestions", open=False):
|
||||
suggestion_btns = []
|
||||
for i, suggestion in enumerate(suggestions[:4]):
|
||||
btn = gr.Button(suggestion[:60] + "...", size="sm", variant="secondary")
|
||||
suggestion_btns.append((btn, suggestion))
|
||||
|
||||
# Melody conditioning (for melody variants)
|
||||
with gr.Group(visible=False) as melody_section:
|
||||
gr.Markdown("### Melody Conditioning")
|
||||
melody_input = gr.Audio(
|
||||
label="Reference Melody",
|
||||
type="filepath",
|
||||
sources=["upload", "microphone"],
|
||||
)
|
||||
gr.Markdown("*Upload audio to condition generation on its melody*")
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎵 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
# Custom preset - don't change values
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Variant change - show/hide melody section
|
||||
def on_variant_change(variant: str):
|
||||
show_melody = "melody" in variant.lower()
|
||||
return gr.update(visible=show_melody)
|
||||
|
||||
variant_dropdown.change(
|
||||
fn=on_variant_change,
|
||||
inputs=[variant_dropdown],
|
||||
outputs=[melody_section],
|
||||
)
|
||||
|
||||
# Prompt suggestions
|
||||
for btn, suggestion in suggestion_btns:
|
||||
btn.click(
|
||||
fn=lambda s=suggestion: s,
|
||||
outputs=[prompt_input],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody
|
||||
):
|
||||
if not prompt:
|
||||
return (
|
||||
gr.update(value="Please enter a prompt"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Update status
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {}
|
||||
if melody:
|
||||
conditioning["melody"] = melody
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="musicgen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, duration, temperature, cfg_coef, top_k, top_p, seed, melody):
|
||||
if not prompt:
|
||||
return "Please enter a prompt"
|
||||
|
||||
conditioning = {}
|
||||
if melody:
|
||||
conditioning["melody"] = melody
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="musicgen",
|
||||
variant=variant,
|
||||
prompts=[prompt],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
melody_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"melody": melody_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
292
src/ui/tabs/style_tab.py
Normal file
292
src/ui/tabs/style_tab.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""MusicGen Style tab for style-conditioned generation."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.ui.state import DEFAULT_PRESETS, PROMPT_SUGGESTIONS
|
||||
from src.ui.components.audio_player import create_generation_output
|
||||
|
||||
|
||||
STYLE_VARIANTS = [
|
||||
{"id": "medium", "name": "Medium", "vram_mb": 5000, "description": "1.5B params, style conditioning"},
|
||||
]
|
||||
|
||||
|
||||
def create_style_tab(
|
||||
generate_fn: Callable[..., Any],
|
||||
add_to_queue_fn: Callable[..., Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Create MusicGen Style generation tab.
|
||||
|
||||
Args:
|
||||
generate_fn: Function to call for generation
|
||||
add_to_queue_fn: Function to add to queue
|
||||
|
||||
Returns:
|
||||
Dictionary with component references
|
||||
"""
|
||||
presets = DEFAULT_PRESETS.get("musicgen-style", [])
|
||||
suggestions = PROMPT_SUGGESTIONS.get("musicgen", [])
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("## 🎨 MusicGen Style")
|
||||
gr.Markdown("Generate music conditioned on the style of reference audio")
|
||||
|
||||
with gr.Row():
|
||||
# Left column - inputs
|
||||
with gr.Column(scale=2):
|
||||
# Preset selector
|
||||
preset_choices = [(p["name"], p["id"]) for p in presets] + [("Custom", "custom")]
|
||||
preset_dropdown = gr.Dropdown(
|
||||
label="Preset",
|
||||
choices=preset_choices,
|
||||
value=presets[0]["id"] if presets else "custom",
|
||||
)
|
||||
|
||||
# Model variant
|
||||
variant_choices = [(f"{v['name']} ({v['vram_mb']/1024:.1f}GB)", v["id"]) for v in STYLE_VARIANTS]
|
||||
variant_dropdown = gr.Dropdown(
|
||||
label="Model Variant",
|
||||
choices=variant_choices,
|
||||
value="medium",
|
||||
)
|
||||
|
||||
# Prompt input
|
||||
prompt_input = gr.Textbox(
|
||||
label="Text Prompt",
|
||||
placeholder="Describe additional characteristics for the music...",
|
||||
lines=3,
|
||||
max_lines=5,
|
||||
info="Optional: combine with style conditioning",
|
||||
)
|
||||
|
||||
# Style conditioning (required)
|
||||
gr.Markdown("### Style Conditioning")
|
||||
gr.Markdown("*Upload reference audio to extract musical style*")
|
||||
|
||||
style_input = gr.Audio(
|
||||
label="Style Reference",
|
||||
type="filepath",
|
||||
sources=["upload", "microphone"],
|
||||
)
|
||||
|
||||
style_info = gr.Markdown(
|
||||
"*The model will learn the style (instrumentation, tempo, mood) from this audio*"
|
||||
)
|
||||
|
||||
# Parameters
|
||||
gr.Markdown("### Parameters")
|
||||
|
||||
duration_slider = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=30,
|
||||
value=10,
|
||||
step=1,
|
||||
label="Duration (seconds)",
|
||||
)
|
||||
|
||||
with gr.Accordion("Advanced Parameters", open=False):
|
||||
with gr.Row():
|
||||
temperature_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.05,
|
||||
label="Temperature",
|
||||
)
|
||||
cfg_slider = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=10.0,
|
||||
value=3.0,
|
||||
step=0.5,
|
||||
label="CFG Coefficient",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
top_k_slider = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=500,
|
||||
value=250,
|
||||
step=10,
|
||||
label="Top-K",
|
||||
)
|
||||
top_p_slider = gr.Slider(
|
||||
minimum=0.0,
|
||||
maximum=1.0,
|
||||
value=0.0,
|
||||
step=0.05,
|
||||
label="Top-P",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed_input = gr.Number(
|
||||
value=None,
|
||||
label="Seed (empty = random)",
|
||||
precision=0,
|
||||
)
|
||||
|
||||
# Generate buttons
|
||||
with gr.Row():
|
||||
generate_btn = gr.Button("🎨 Generate", variant="primary", scale=2)
|
||||
queue_btn = gr.Button("Add to Queue", variant="secondary", scale=1)
|
||||
|
||||
# Right column - output
|
||||
with gr.Column(scale=3):
|
||||
output = create_generation_output()
|
||||
|
||||
# Event handlers
|
||||
|
||||
# Preset change
|
||||
def apply_preset(preset_id: str):
|
||||
for p in presets:
|
||||
if p["id"] == preset_id:
|
||||
params = p["parameters"]
|
||||
return (
|
||||
params.get("duration", 10),
|
||||
params.get("temperature", 1.0),
|
||||
params.get("cfg_coef", 3.0),
|
||||
params.get("top_k", 250),
|
||||
params.get("top_p", 0.0),
|
||||
)
|
||||
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
||||
|
||||
preset_dropdown.change(
|
||||
fn=apply_preset,
|
||||
inputs=[preset_dropdown],
|
||||
outputs=[duration_slider, temperature_slider, cfg_slider, top_k_slider, top_p_slider],
|
||||
)
|
||||
|
||||
# Generate
|
||||
async def do_generate(
|
||||
prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed
|
||||
):
|
||||
if not style_audio:
|
||||
return (
|
||||
gr.update(value="Please upload a style reference audio"),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="🔄 Generating..."),
|
||||
gr.update(visible=True, value=0),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
conditioning = {"style": style_audio}
|
||||
|
||||
result, generation = await generate_fn(
|
||||
model_id="musicgen-style",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
yield (
|
||||
gr.update(value="✅ Generation complete!"),
|
||||
gr.update(visible=False),
|
||||
gr.update(value=generation.audio_path),
|
||||
gr.update(),
|
||||
gr.update(value=f"{result.duration:.2f}s"),
|
||||
gr.update(value=str(result.seed)),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield (
|
||||
gr.update(value=f"❌ Error: {str(e)}"),
|
||||
gr.update(visible=False),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
generate_btn.click(
|
||||
fn=do_generate,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
style_input,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[
|
||||
output["status"],
|
||||
output["progress"],
|
||||
output["player"]["audio"],
|
||||
output["player"]["waveform"],
|
||||
output["player"]["duration"],
|
||||
output["player"]["seed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
def do_add_queue(prompt, variant, style_audio, duration, temperature, cfg_coef, top_k, top_p, seed):
|
||||
if not style_audio:
|
||||
return "Please upload a style reference audio"
|
||||
|
||||
conditioning = {"style": style_audio}
|
||||
|
||||
job = add_to_queue_fn(
|
||||
model_id="musicgen-style",
|
||||
variant=variant,
|
||||
prompts=[prompt] if prompt else [""],
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
top_k=int(top_k),
|
||||
top_p=top_p,
|
||||
cfg_coef=cfg_coef,
|
||||
seed=int(seed) if seed else None,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
return f"✅ Added to queue: {job.id}"
|
||||
|
||||
queue_btn.click(
|
||||
fn=do_add_queue,
|
||||
inputs=[
|
||||
prompt_input,
|
||||
variant_dropdown,
|
||||
style_input,
|
||||
duration_slider,
|
||||
temperature_slider,
|
||||
cfg_slider,
|
||||
top_k_slider,
|
||||
top_p_slider,
|
||||
seed_input,
|
||||
],
|
||||
outputs=[output["status"]],
|
||||
)
|
||||
|
||||
return {
|
||||
"preset": preset_dropdown,
|
||||
"variant": variant_dropdown,
|
||||
"prompt": prompt_input,
|
||||
"style": style_input,
|
||||
"duration": duration_slider,
|
||||
"temperature": temperature_slider,
|
||||
"cfg_coef": cfg_slider,
|
||||
"top_k": top_k_slider,
|
||||
"top_p": top_p_slider,
|
||||
"seed": seed_input,
|
||||
"generate_btn": generate_btn,
|
||||
"queue_btn": queue_btn,
|
||||
"output": output,
|
||||
}
|
||||
Reference in New Issue
Block a user