Files
audiocraft-ui/scripts/download_models.py
Sebastian Krüger ffbf02b12c Initial implementation of AudioCraft Studio
Complete web interface for Meta's AudioCraft AI audio generation:

- Gradio UI with tabs for all 5 model families (MusicGen, AudioGen,
  MAGNeT, MusicGen Style, JASCO)
- REST API with FastAPI, OpenAPI docs, and API key auth
- VRAM management with ComfyUI coexistence support
- SQLite database for project/generation history
- Batch processing queue for async generation
- Docker deployment optimized for RunPod with RTX 4090

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-25 19:34:27 +01:00

117 lines
3.2 KiB
Python
Executable File

#!/usr/bin/env python3
"""Pre-download AudioCraft models for faster startup."""
import argparse
import os
from pathlib import Path
def download_musicgen_models(variants: list[str] = None):
"""Download MusicGen models."""
from audiocraft.models import MusicGen
variants = variants or ["small", "medium", "large", "melody"]
for variant in variants:
print(f"Downloading MusicGen {variant}...")
try:
model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}")
del model
print(f" ✓ MusicGen {variant} downloaded")
except Exception as e:
print(f" ✗ Failed to download MusicGen {variant}: {e}")
def download_audiogen_models():
"""Download AudioGen models."""
from audiocraft.models import AudioGen
print("Downloading AudioGen medium...")
try:
model = AudioGen.get_pretrained("facebook/audiogen-medium")
del model
print(" ✓ AudioGen medium downloaded")
except Exception as e:
print(f" ✗ Failed to download AudioGen: {e}")
def download_magnet_models(variants: list[str] = None):
"""Download MAGNeT models."""
from audiocraft.models import MAGNeT
variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"]
for variant in variants:
print(f"Downloading MAGNeT {variant}...")
try:
model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}")
del model
print(f" ✓ MAGNeT {variant} downloaded")
except Exception as e:
print(f" ✗ Failed to download MAGNeT {variant}: {e}")
def main():
parser = argparse.ArgumentParser(description="Pre-download AudioCraft models")
parser.add_argument(
"--models",
nargs="+",
choices=["musicgen", "audiogen", "magnet", "all"],
default=["all"],
help="Models to download",
)
parser.add_argument(
"--musicgen-variants",
nargs="+",
default=["small", "medium"],
help="MusicGen variants to download",
)
parser.add_argument(
"--magnet-variants",
nargs="+",
default=["small", "medium"],
help="MAGNeT variants to download",
)
parser.add_argument(
"--cache-dir",
type=str,
default=None,
help="Model cache directory",
)
args = parser.parse_args()
# Set cache directory
if args.cache_dir:
os.environ["HF_HOME"] = args.cache_dir
os.environ["TORCH_HOME"] = args.cache_dir
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
models = args.models
if "all" in models:
models = ["musicgen", "audiogen", "magnet"]
print("=" * 50)
print("AudioCraft Model Downloader")
print("=" * 50)
print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
print(f"Models to download: {models}")
print("=" * 50)
if "musicgen" in models:
download_musicgen_models(args.musicgen_variants)
if "audiogen" in models:
download_audiogen_models()
if "magnet" in models:
download_magnet_models(args.magnet_variants)
print("=" * 50)
print("Download complete!")
print("=" * 50)
if __name__ == "__main__":
main()