diff --git a/comfyui/nodes/pivoine_diffrhythm.py b/comfyui/nodes/pivoine_diffrhythm.py index 2c842a7..7f159fe 100644 --- a/comfyui/nodes/pivoine_diffrhythm.py +++ b/comfyui/nodes/pivoine_diffrhythm.py @@ -7,7 +7,7 @@ Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must match the size of tensor b (64)" error in rotary position embeddings. -This patch adds the missing parameters to LlamaConfig initialization. +This patch globally intercepts LlamaConfig at import time. Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44 Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48 @@ -18,62 +18,54 @@ Author: valknar@pivoine.art import sys sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm') -# Monkey-patch decode_audio from infer_utils to force chunked=False +# CRITICAL: Patch LlamaConfig BEFORE any DiffRhythm imports +# This must happen at module import time, not at runtime +from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig + +class PatchedLlamaConfig(_OriginalLlamaConfig): + """ + Patched LlamaConfig that automatically adds missing attention head parameters. + + Fixes the tensor dimension mismatch (32 vs 64) in DiffRhythm's rotary + position embeddings by ensuring num_attention_heads and num_key_value_heads + are properly set based on hidden_size. + """ + def __init__(self, *args, **kwargs): + # If hidden_size is provided but num_attention_heads is not, calculate it + if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs: + hidden_size = kwargs['hidden_size'] + # Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64 + kwargs['num_attention_heads'] = hidden_size // 64 + + # If num_key_value_heads is not provided, use GQA configuration + if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs: + # For GQA (Grouped Query Attention), typically num_kv_heads = num_heads // 4 + kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4) + + # Call original __init__ with patched parameters + super().__init__(*args, **kwargs) + +# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it +import transformers.models.llama +transformers.models.llama.LlamaConfig = PatchedLlamaConfig + +# Also replace in modeling_llama module if it's already imported +import transformers.models.llama.modeling_llama +transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig + +# Now import DiffRhythm modules - they will use our patched LlamaConfig import infer_utils + +# Monkey-patch decode_audio to force chunked=False _original_decode_audio = infer_utils.decode_audio def patched_decode_audio(latent, vae_model, chunked=True): """Patched version that always uses chunked=False""" return _original_decode_audio(latent, vae_model, chunked=False) -# Apply the decode_audio monkey patch infer_utils.decode_audio = patched_decode_audio -# Monkey-patch DiT __init__ to fix LlamaConfig initialization -from diffrhythm.model import dit -from transformers.models.llama import LlamaConfig -import torch.nn as nn - -_original_dit_init = dit.DiT.__init__ - -def patched_dit_init(self, *args, **kwargs): - """ - Patched DiT.__init__ that adds missing num_attention_heads and - num_key_value_heads to LlamaConfig initialization. - - This fixes the tensor dimension mismatch (32 vs 64) error in - rotary position embeddings with transformers 4.49.0+. - """ - # Call original __init__ but intercept the LlamaConfig creation - _original_llama_config = LlamaConfig - - def patched_llama_config(*config_args, **config_kwargs): - """Add missing attention head parameters to LlamaConfig""" - hidden_size = config_kwargs.get('hidden_size', config_args[0] if config_args else 1024) - - # Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64 - # For GQA (Grouped Query Attention), num_key_value_heads is usually num_heads // 4 - num_attention_heads = hidden_size // 64 - num_key_value_heads = max(1, num_attention_heads // 4) - - config_kwargs['num_attention_heads'] = config_kwargs.get('num_attention_heads', num_attention_heads) - config_kwargs['num_key_value_heads'] = config_kwargs.get('num_key_value_heads', num_key_value_heads) - - return _original_llama_config(*config_args, **config_kwargs) - - # Temporarily replace LlamaConfig in the dit module - dit.LlamaConfig = patched_llama_config - - try: - # Call the original __init__ - _original_dit_init(self, *args, **kwargs) - finally: - # Restore original LlamaConfig - dit.LlamaConfig = _original_llama_config - -# Apply the DiT init monkey patch -dit.DiT.__init__ = patched_dit_init - +# Import DiffRhythm node from DiffRhythmNode import DiffRhythmRun class PivoineDiffRhythmRun(DiffRhythmRun): @@ -81,7 +73,7 @@ class PivoineDiffRhythmRun(DiffRhythmRun): Pivoine version of DiffRhythmRun with enhanced compatibility and error handling. Changes from original: - - Patches DIT.__init__ to add missing num_attention_heads and num_key_value_heads to LlamaConfig + - Globally patches LlamaConfig to add missing num_attention_heads and num_key_value_heads - Monkey-patches decode_audio to always use chunked=False for stability - Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64) - Compatible with transformers 4.49.0+ @@ -91,6 +83,7 @@ class PivoineDiffRhythmRun(DiffRhythmRun): - Sets num_attention_heads = hidden_size // 64 (standard Llama architecture) - Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration) - This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32) + - Patch is applied globally at import time, affecting all LlamaConfig instances """ CATEGORY = "🌸Pivoine/Audio"