Pin transformers<4.40.0 for PyTorch 2.1 compatibility
transformers 4.40+ uses serialized_type_name kwarg in _register_pytree_node() which doesn't exist in PyTorch 2.1. Pin to older version. Also removes the ineffective pytree shim. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
14
main.py
14
main.py
@@ -16,20 +16,6 @@ os.chdir(PROJECT_ROOT)
|
|||||||
# Add project root to path
|
# Add project root to path
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
# PyTorch pytree compatibility shim for AudioCraft
|
|
||||||
# The API changed between PyTorch versions:
|
|
||||||
# - Older: _register_pytree_node (private, with underscore)
|
|
||||||
# - Newer: register_pytree_node (public, without underscore)
|
|
||||||
# Create aliases in both directions to support any version
|
|
||||||
try:
|
|
||||||
import torch.utils._pytree as _pytree
|
|
||||||
if hasattr(_pytree, '_register_pytree_node') and not hasattr(_pytree, 'register_pytree_node'):
|
|
||||||
_pytree.register_pytree_node = _pytree._register_pytree_node
|
|
||||||
if hasattr(_pytree, 'register_pytree_node') and not hasattr(_pytree, '_register_pytree_node'):
|
|
||||||
_pytree._register_pytree_node = _pytree.register_pytree_node
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore if torch not installed or patch fails
|
|
||||||
|
|
||||||
from config.settings import get_settings
|
from config.settings import get_settings
|
||||||
from src.core.gpu_manager import GPUMemoryManager
|
from src.core.gpu_manager import GPUMemoryManager
|
||||||
from src.core.model_registry import ModelRegistry
|
from src.core.model_registry import ModelRegistry
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
# Core ML - Pin to AudioCraft-compatible versions (2.4+ breaks pytree API)
|
# Core ML - Pin to AudioCraft-compatible versions
|
||||||
torch>=2.1.0,<2.4.0
|
torch>=2.1.0,<2.4.0
|
||||||
torchaudio>=2.1.0,<2.4.0
|
torchaudio>=2.1.0,<2.4.0
|
||||||
audiocraft>=1.3.0
|
audiocraft>=1.3.0
|
||||||
xformers>=0.0.22
|
xformers>=0.0.22
|
||||||
|
|
||||||
|
# Pin transformers to version compatible with PyTorch 2.1
|
||||||
|
# Newer versions use serialized_type_name kwarg not in PyTorch 2.1
|
||||||
|
transformers>=4.30.0,<4.40.0
|
||||||
|
|
||||||
# UI
|
# UI
|
||||||
gradio>=4.0.0
|
gradio>=4.0.0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user