#!/usr/bin/env python3 """Pre-download AudioCraft models for faster startup.""" import argparse import os from pathlib import Path def download_musicgen_models(variants: list[str] = None): """Download MusicGen models.""" from audiocraft.models import MusicGen variants = variants or ["small", "medium", "large", "melody"] for variant in variants: print(f"Downloading MusicGen {variant}...") try: model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}") del model print(f" ✓ MusicGen {variant} downloaded") except Exception as e: print(f" ✗ Failed to download MusicGen {variant}: {e}") def download_audiogen_models(): """Download AudioGen models.""" from audiocraft.models import AudioGen print("Downloading AudioGen medium...") try: model = AudioGen.get_pretrained("facebook/audiogen-medium") del model print(" ✓ AudioGen medium downloaded") except Exception as e: print(f" ✗ Failed to download AudioGen: {e}") def download_magnet_models(variants: list[str] = None): """Download MAGNeT models.""" from audiocraft.models import MAGNeT variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"] for variant in variants: print(f"Downloading MAGNeT {variant}...") try: model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}") del model print(f" ✓ MAGNeT {variant} downloaded") except Exception as e: print(f" ✗ Failed to download MAGNeT {variant}: {e}") def main(): parser = argparse.ArgumentParser(description="Pre-download AudioCraft models") parser.add_argument( "--models", nargs="+", choices=["musicgen", "audiogen", "magnet", "all"], default=["all"], help="Models to download", ) parser.add_argument( "--musicgen-variants", nargs="+", default=["small", "medium"], help="MusicGen variants to download", ) parser.add_argument( "--magnet-variants", nargs="+", default=["small", "medium"], help="MAGNeT variants to download", ) parser.add_argument( "--cache-dir", type=str, default=None, help="Model cache directory", ) args = parser.parse_args() # Set cache directory if args.cache_dir: os.environ["HF_HOME"] = args.cache_dir os.environ["TORCH_HOME"] = args.cache_dir Path(args.cache_dir).mkdir(parents=True, exist_ok=True) models = args.models if "all" in models: models = ["musicgen", "audiogen", "magnet"] print("=" * 50) print("AudioCraft Model Downloader") print("=" * 50) print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}") print(f"Models to download: {models}") print("=" * 50) if "musicgen" in models: download_musicgen_models(args.musicgen_variants) if "audiogen" in models: download_audiogen_models() if "magnet" in models: download_magnet_models(args.magnet_variants) print("=" * 50) print("Download complete!") print("=" * 50) if __name__ == "__main__": main()