diff --git a/__init__.py b/__init__.py index 1234567..abcdefg 100644 --- a/__init__.py +++ b/__init__.py @@ -1,3 +1,51 @@ +""" +DiffRhythm ComfyUI Node with LlamaConfig Patch + +PATCH: Fixes "The size of tensor a (32) must match the size of tensor b (64)" error +in DiffRhythm's rotary position embeddings by patching LlamaConfig initialization. + +Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and +num_key_value_heads when creating LlamaConfig, causing transformers 4.49.0+ +to incorrectly infer head_dim = 32 instead of 64. + +Solution: Patch LlamaConfig globally before importing DiffRhythmNode. + +Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44 +Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48 + +Patch author: valknar@pivoine.art +""" + +# CRITICAL: Patch LlamaConfig BEFORE importing DiffRhythmNode +from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig + +class PatchedLlamaConfig(_OriginalLlamaConfig): + """ + Patched LlamaConfig that automatically adds missing attention head parameters. + + Standard Llama architecture assumptions: + - head_dim = 64 (fixed) + - num_attention_heads = hidden_size // head_dim + - num_key_value_heads = num_attention_heads // 4 (for GQA) + """ + def __init__(self, *args, **kwargs): + # If hidden_size is provided but num_attention_heads is not, calculate it + if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs: + hidden_size = kwargs['hidden_size'] + kwargs['num_attention_heads'] = hidden_size // 64 + + # If num_key_value_heads is not provided, use GQA configuration + if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs: + kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4) + + super().__init__(*args, **kwargs) + +# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it +import transformers.models.llama +transformers.models.llama.LlamaConfig = PatchedLlamaConfig +import transformers.models.llama.modeling_llama +transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig + from .DiffRhythmNode import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]