diff --git a/main.py b/main.py index d412677..e7582ba 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,20 @@ os.chdir(PROJECT_ROOT) # Add project root to path sys.path.insert(0, str(PROJECT_ROOT)) +# PyTorch pytree compatibility shim for AudioCraft +# The API changed between PyTorch versions: +# - Older: _register_pytree_node (private, with underscore) +# - Newer: register_pytree_node (public, without underscore) +# Create aliases in both directions to support any version +try: + import torch.utils._pytree as _pytree + if hasattr(_pytree, '_register_pytree_node') and not hasattr(_pytree, 'register_pytree_node'): + _pytree.register_pytree_node = _pytree._register_pytree_node + if hasattr(_pytree, 'register_pytree_node') and not hasattr(_pytree, '_register_pytree_node'): + _pytree._register_pytree_node = _pytree.register_pytree_node +except Exception: + pass # Ignore if torch not installed or patch fails + from config.settings import get_settings from src.core.gpu_manager import GPUMemoryManager from src.core.model_registry import ModelRegistry