""" Pivoine DiffRhythm Node Custom wrapper for DiffRhythm that fixes LlamaConfig initialization issues with transformers 4.49.0+ to prevent tensor dimension mismatches. Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must match the size of tensor b (64)" error in rotary position embeddings. This patch adds the missing parameters to LlamaConfig initialization. Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44 Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48 Author: valknar@pivoine.art """ import sys sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm') # Monkey-patch decode_audio from infer_utils to force chunked=False import infer_utils _original_decode_audio = infer_utils.decode_audio def patched_decode_audio(latent, vae_model, chunked=True): """Patched version that always uses chunked=False""" return _original_decode_audio(latent, vae_model, chunked=False) # Apply the decode_audio monkey patch infer_utils.decode_audio = patched_decode_audio # Monkey-patch DiT __init__ to fix LlamaConfig initialization from diffrhythm.model import dit from transformers.models.llama import LlamaConfig import torch.nn as nn _original_dit_init = dit.DiT.__init__ def patched_dit_init(self, *args, **kwargs): """ Patched DiT.__init__ that adds missing num_attention_heads and num_key_value_heads to LlamaConfig initialization. This fixes the tensor dimension mismatch (32 vs 64) error in rotary position embeddings with transformers 4.49.0+. """ # Call original __init__ but intercept the LlamaConfig creation _original_llama_config = LlamaConfig def patched_llama_config(*config_args, **config_kwargs): """Add missing attention head parameters to LlamaConfig""" hidden_size = config_kwargs.get('hidden_size', config_args[0] if config_args else 1024) # Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64 # For GQA (Grouped Query Attention), num_key_value_heads is usually num_heads // 4 num_attention_heads = hidden_size // 64 num_key_value_heads = max(1, num_attention_heads // 4) config_kwargs['num_attention_heads'] = config_kwargs.get('num_attention_heads', num_attention_heads) config_kwargs['num_key_value_heads'] = config_kwargs.get('num_key_value_heads', num_key_value_heads) return _original_llama_config(*config_args, **config_kwargs) # Temporarily replace LlamaConfig in the dit module dit.LlamaConfig = patched_llama_config try: # Call the original __init__ _original_dit_init(self, *args, **kwargs) finally: # Restore original LlamaConfig dit.LlamaConfig = _original_llama_config # Apply the DiT init monkey patch dit.DiT.__init__ = patched_dit_init from DiffRhythmNode import DiffRhythmRun class PivoineDiffRhythmRun(DiffRhythmRun): """ Pivoine version of DiffRhythmRun with enhanced compatibility and error handling. Changes from original: - Patches DIT.__init__ to add missing num_attention_heads and num_key_value_heads to LlamaConfig - Monkey-patches decode_audio to always use chunked=False for stability - Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64) - Compatible with transformers 4.49.0+ - Requires ~12-16GB VRAM, works reliably on RTX 4090 Technical details: - Sets num_attention_heads = hidden_size // 64 (standard Llama architecture) - Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration) - This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32) """ CATEGORY = "🌸Pivoine/Audio" @classmethod def INPUT_TYPES(cls): return super().INPUT_TYPES() NODE_CLASS_MAPPINGS = { "PivoineDiffRhythmRun": PivoineDiffRhythmRun, } NODE_DISPLAY_NAME_MAPPINGS = { "PivoineDiffRhythmRun": "Pivoine DiffRhythm Run", }