117 lines
3.2 KiB
Python
117 lines
3.2 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Pre-download AudioCraft models for faster startup."""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import os
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
def download_musicgen_models(variants: list[str] = None):
|
||
|
|
"""Download MusicGen models."""
|
||
|
|
from audiocraft.models import MusicGen
|
||
|
|
|
||
|
|
variants = variants or ["small", "medium", "large", "melody"]
|
||
|
|
|
||
|
|
for variant in variants:
|
||
|
|
print(f"Downloading MusicGen {variant}...")
|
||
|
|
try:
|
||
|
|
model = MusicGen.get_pretrained(f"facebook/musicgen-{variant}")
|
||
|
|
del model
|
||
|
|
print(f" ✓ MusicGen {variant} downloaded")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ Failed to download MusicGen {variant}: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
def download_audiogen_models():
|
||
|
|
"""Download AudioGen models."""
|
||
|
|
from audiocraft.models import AudioGen
|
||
|
|
|
||
|
|
print("Downloading AudioGen medium...")
|
||
|
|
try:
|
||
|
|
model = AudioGen.get_pretrained("facebook/audiogen-medium")
|
||
|
|
del model
|
||
|
|
print(" ✓ AudioGen medium downloaded")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ Failed to download AudioGen: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
def download_magnet_models(variants: list[str] = None):
|
||
|
|
"""Download MAGNeT models."""
|
||
|
|
from audiocraft.models import MAGNeT
|
||
|
|
|
||
|
|
variants = variants or ["small", "medium", "audio-small-10secs", "audio-medium-10secs"]
|
||
|
|
|
||
|
|
for variant in variants:
|
||
|
|
print(f"Downloading MAGNeT {variant}...")
|
||
|
|
try:
|
||
|
|
model = MAGNeT.get_pretrained(f"facebook/magnet-{variant}")
|
||
|
|
del model
|
||
|
|
print(f" ✓ MAGNeT {variant} downloaded")
|
||
|
|
except Exception as e:
|
||
|
|
print(f" ✗ Failed to download MAGNeT {variant}: {e}")
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Pre-download AudioCraft models")
|
||
|
|
parser.add_argument(
|
||
|
|
"--models",
|
||
|
|
nargs="+",
|
||
|
|
choices=["musicgen", "audiogen", "magnet", "all"],
|
||
|
|
default=["all"],
|
||
|
|
help="Models to download",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--musicgen-variants",
|
||
|
|
nargs="+",
|
||
|
|
default=["small", "medium"],
|
||
|
|
help="MusicGen variants to download",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--magnet-variants",
|
||
|
|
nargs="+",
|
||
|
|
default=["small", "medium"],
|
||
|
|
help="MAGNeT variants to download",
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"--cache-dir",
|
||
|
|
type=str,
|
||
|
|
default=None,
|
||
|
|
help="Model cache directory",
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# Set cache directory
|
||
|
|
if args.cache_dir:
|
||
|
|
os.environ["HF_HOME"] = args.cache_dir
|
||
|
|
os.environ["TORCH_HOME"] = args.cache_dir
|
||
|
|
Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
models = args.models
|
||
|
|
if "all" in models:
|
||
|
|
models = ["musicgen", "audiogen", "magnet"]
|
||
|
|
|
||
|
|
print("=" * 50)
|
||
|
|
print("AudioCraft Model Downloader")
|
||
|
|
print("=" * 50)
|
||
|
|
print(f"Cache directory: {os.environ.get('HF_HOME', 'default')}")
|
||
|
|
print(f"Models to download: {models}")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
if "musicgen" in models:
|
||
|
|
download_musicgen_models(args.musicgen_variants)
|
||
|
|
|
||
|
|
if "audiogen" in models:
|
||
|
|
download_audiogen_models()
|
||
|
|
|
||
|
|
if "magnet" in models:
|
||
|
|
download_magnet_models(args.magnet_variants)
|
||
|
|
|
||
|
|
print("=" * 50)
|
||
|
|
print("Download complete!")
|
||
|
|
print("=" * 50)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|