""" Pivoine DiffRhythm Node Custom wrapper for DiffRhythm that fixes LlamaConfig initialization issues with transformers 4.49.0+ to prevent tensor dimension mismatches. Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must match the size of tensor b (64)" error in rotary position embeddings. This patch globally intercepts LlamaConfig at import time. Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44 Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48 Author: valknar@pivoine.art """ import sys sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm') # CRITICAL: Patch LlamaConfig BEFORE any DiffRhythm imports # This must happen at module import time, not at runtime from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig class PatchedLlamaConfig(_OriginalLlamaConfig): """ Patched LlamaConfig that automatically adds missing attention head parameters. Fixes the tensor dimension mismatch (32 vs 64) in DiffRhythm's rotary position embeddings by ensuring num_attention_heads and num_key_value_heads are properly set based on hidden_size. """ def __init__(self, *args, **kwargs): # If hidden_size is provided but num_attention_heads is not, calculate it if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs: hidden_size = kwargs['hidden_size'] # Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64 kwargs['num_attention_heads'] = hidden_size // 64 # If num_key_value_heads is not provided, use GQA configuration if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs: # For GQA (Grouped Query Attention), typically num_kv_heads = num_heads // 4 kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4) # Call original __init__ with patched parameters super().__init__(*args, **kwargs) # Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it import transformers.models.llama transformers.models.llama.LlamaConfig = PatchedLlamaConfig # Also replace in modeling_llama module if it's already imported import transformers.models.llama.modeling_llama transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig # Now import DiffRhythm modules - they will use our patched LlamaConfig import infer_utils # Monkey-patch decode_audio to force chunked=False _original_decode_audio = infer_utils.decode_audio def patched_decode_audio(latent, vae_model, chunked=True): """Patched version that always uses chunked=False""" return _original_decode_audio(latent, vae_model, chunked=False) infer_utils.decode_audio = patched_decode_audio # Import DiffRhythm node from DiffRhythmNode import DiffRhythmRun class PivoineDiffRhythmRun(DiffRhythmRun): """ Pivoine version of DiffRhythmRun with enhanced compatibility and error handling. Changes from original: - Globally patches LlamaConfig to add missing num_attention_heads and num_key_value_heads - Monkey-patches decode_audio to always use chunked=False for stability - Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64) - Compatible with transformers 4.49.0+ - Requires ~12-16GB VRAM, works reliably on RTX 4090 Technical details: - Sets num_attention_heads = hidden_size // 64 (standard Llama architecture) - Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration) - This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32) - Patch is applied globally at import time, affecting all LlamaConfig instances """ CATEGORY = "🌸Pivoine/Audio" @classmethod def INPUT_TYPES(cls): return super().INPUT_TYPES() NODE_CLASS_MAPPINGS = { "PivoineDiffRhythmRun": PivoineDiffRhythmRun, } NODE_DISPLAY_NAME_MAPPINGS = { "PivoineDiffRhythmRun": "Pivoine DiffRhythm Run", }