Files
runpod/comfyui/patches/diffrhythm-llamaconfig-fix.patch

57 lines
2.3 KiB
Diff
Raw Permalink Normal View History

diff --git a/__init__.py b/__init__.py
index 1234567..abcdefg 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,3 +1,51 @@
+"""
+DiffRhythm ComfyUI Node with LlamaConfig Patch
+
+PATCH: Fixes "The size of tensor a (32) must match the size of tensor b (64)" error
+in DiffRhythm's rotary position embeddings by patching LlamaConfig initialization.
+
+Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
+num_key_value_heads when creating LlamaConfig, causing transformers 4.49.0+
+to incorrectly infer head_dim = 32 instead of 64.
+
+Solution: Patch LlamaConfig globally before importing DiffRhythmNode.
+
+Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
+Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
+
+Patch author: valknar@pivoine.art
+"""
+
+# CRITICAL: Patch LlamaConfig BEFORE importing DiffRhythmNode
+from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig
+
+class PatchedLlamaConfig(_OriginalLlamaConfig):
+ """
+ Patched LlamaConfig that automatically adds missing attention head parameters.
+
+ Standard Llama architecture assumptions:
+ - head_dim = 64 (fixed)
+ - num_attention_heads = hidden_size // head_dim
+ - num_key_value_heads = num_attention_heads // 4 (for GQA)
+ """
+ def __init__(self, *args, **kwargs):
+ # If hidden_size is provided but num_attention_heads is not, calculate it
+ if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs:
+ hidden_size = kwargs['hidden_size']
+ kwargs['num_attention_heads'] = hidden_size // 64
+
+ # If num_key_value_heads is not provided, use GQA configuration
+ if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs:
+ kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4)
+
+ super().__init__(*args, **kwargs)
+
+# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it
+import transformers.models.llama
+transformers.models.llama.LlamaConfig = PatchedLlamaConfig
+import transformers.models.llama.modeling_llama
+transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig
+
from .DiffRhythmNode import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]