57 lines
2.3 KiB
Diff
57 lines
2.3 KiB
Diff
|
|
diff --git a/__init__.py b/__init__.py
|
||
|
|
index 1234567..abcdefg 100644
|
||
|
|
--- a/__init__.py
|
||
|
|
+++ b/__init__.py
|
||
|
|
@@ -1,3 +1,51 @@
|
||
|
|
+"""
|
||
|
|
+DiffRhythm ComfyUI Node with LlamaConfig Patch
|
||
|
|
+
|
||
|
|
+PATCH: Fixes "The size of tensor a (32) must match the size of tensor b (64)" error
|
||
|
|
+in DiffRhythm's rotary position embeddings by patching LlamaConfig initialization.
|
||
|
|
+
|
||
|
|
+Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
|
||
|
|
+num_key_value_heads when creating LlamaConfig, causing transformers 4.49.0+
|
||
|
|
+to incorrectly infer head_dim = 32 instead of 64.
|
||
|
|
+
|
||
|
|
+Solution: Patch LlamaConfig globally before importing DiffRhythmNode.
|
||
|
|
+
|
||
|
|
+Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
|
||
|
|
+Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
|
||
|
|
+
|
||
|
|
+Patch author: valknar@pivoine.art
|
||
|
|
+"""
|
||
|
|
+
|
||
|
|
+# CRITICAL: Patch LlamaConfig BEFORE importing DiffRhythmNode
|
||
|
|
+from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig
|
||
|
|
+
|
||
|
|
+class PatchedLlamaConfig(_OriginalLlamaConfig):
|
||
|
|
+ """
|
||
|
|
+ Patched LlamaConfig that automatically adds missing attention head parameters.
|
||
|
|
+
|
||
|
|
+ Standard Llama architecture assumptions:
|
||
|
|
+ - head_dim = 64 (fixed)
|
||
|
|
+ - num_attention_heads = hidden_size // head_dim
|
||
|
|
+ - num_key_value_heads = num_attention_heads // 4 (for GQA)
|
||
|
|
+ """
|
||
|
|
+ def __init__(self, *args, **kwargs):
|
||
|
|
+ # If hidden_size is provided but num_attention_heads is not, calculate it
|
||
|
|
+ if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs:
|
||
|
|
+ hidden_size = kwargs['hidden_size']
|
||
|
|
+ kwargs['num_attention_heads'] = hidden_size // 64
|
||
|
|
+
|
||
|
|
+ # If num_key_value_heads is not provided, use GQA configuration
|
||
|
|
+ if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs:
|
||
|
|
+ kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4)
|
||
|
|
+
|
||
|
|
+ super().__init__(*args, **kwargs)
|
||
|
|
+
|
||
|
|
+# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it
|
||
|
|
+import transformers.models.llama
|
||
|
|
+transformers.models.llama.LlamaConfig = PatchedLlamaConfig
|
||
|
|
+import transformers.models.llama.modeling_llama
|
||
|
|
+transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig
|
||
|
|
+
|
||
|
|
from .DiffRhythmNode import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
||
|
|
|
||
|
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|