fix: patch DiffRhythm DIT to add missing LlamaConfig attention head parameters
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 15s

Adds monkey-patch for DiT.__init__() to properly configure LlamaConfig with
num_attention_heads and num_key_value_heads parameters, which are missing
in the upstream DiffRhythm code.

Root cause: transformers 4.49.0+ requires these parameters but DiffRhythm's
dit.py only specifies hidden_size, causing the library to incorrectly infer
head_dim as 32 instead of 64, leading to tensor dimension mismatches.

Solution:
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
- Ensures head_dim = 64, fixing the "tensor a (32) vs tensor b (64)" error

This is a proper fix rather than just downgrading transformers version.

References:
- https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
- https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-24 18:53:18 +01:00
parent 60ca8b08d0
commit 91f6e9bd59

View File

@@ -1,11 +1,13 @@
"""
Pivoine DiffRhythm Node
Custom wrapper for DiffRhythm that ensures correct transformer library version
compatibility and provides fallback fixes for tensor dimension issues.
Custom wrapper for DiffRhythm that fixes LlamaConfig initialization issues
with transformers 4.49.0+ to prevent tensor dimension mismatches.
Known Issue: DiffRhythm requires transformers==4.49.0. Newer versions (4.50+)
cause "The size of tensor a (32) must match the size of tensor b (64)" error
in rotary position embeddings due to transformer block initialization changes.
Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must
match the size of tensor b (64)" error in rotary position embeddings.
This patch adds the missing parameters to LlamaConfig initialization.
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
@@ -24,9 +26,54 @@ def patched_decode_audio(latent, vae_model, chunked=True):
"""Patched version that always uses chunked=False"""
return _original_decode_audio(latent, vae_model, chunked=False)
# Apply the monkey patch
# Apply the decode_audio monkey patch
infer_utils.decode_audio = patched_decode_audio
# Monkey-patch DiT __init__ to fix LlamaConfig initialization
from diffrhythm.model import dit
from transformers.models.llama import LlamaConfig
import torch.nn as nn
_original_dit_init = dit.DiT.__init__
def patched_dit_init(self, *args, **kwargs):
"""
Patched DiT.__init__ that adds missing num_attention_heads and
num_key_value_heads to LlamaConfig initialization.
This fixes the tensor dimension mismatch (32 vs 64) error in
rotary position embeddings with transformers 4.49.0+.
"""
# Call original __init__ but intercept the LlamaConfig creation
_original_llama_config = LlamaConfig
def patched_llama_config(*config_args, **config_kwargs):
"""Add missing attention head parameters to LlamaConfig"""
hidden_size = config_kwargs.get('hidden_size', config_args[0] if config_args else 1024)
# Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64
# For GQA (Grouped Query Attention), num_key_value_heads is usually num_heads // 4
num_attention_heads = hidden_size // 64
num_key_value_heads = max(1, num_attention_heads // 4)
config_kwargs['num_attention_heads'] = config_kwargs.get('num_attention_heads', num_attention_heads)
config_kwargs['num_key_value_heads'] = config_kwargs.get('num_key_value_heads', num_key_value_heads)
return _original_llama_config(*config_args, **config_kwargs)
# Temporarily replace LlamaConfig in the dit module
dit.LlamaConfig = patched_llama_config
try:
# Call the original __init__
_original_dit_init(self, *args, **kwargs)
finally:
# Restore original LlamaConfig
dit.LlamaConfig = _original_llama_config
# Apply the DiT init monkey patch
dit.DiT.__init__ = patched_dit_init
from DiffRhythmNode import DiffRhythmRun
class PivoineDiffRhythmRun(DiffRhythmRun):
@@ -34,13 +81,16 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
Pivoine version of DiffRhythmRun with enhanced compatibility and error handling.
Changes from original:
- Patches DIT.__init__ to add missing num_attention_heads and num_key_value_heads to LlamaConfig
- Monkey-patches decode_audio to always use chunked=False for stability
- Ensures transformers library version compatibility (requires 4.49.0)
- Prevents tensor dimension mismatch in VAE decoding
- Requires more VRAM (~12-16GB) but works reliably on RTX 4090
- Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64)
- Compatible with transformers 4.49.0+
- Requires ~12-16GB VRAM, works reliably on RTX 4090
Note: If you encounter "tensor a (32) must match tensor b (64)" errors,
ensure transformers==4.49.0 is installed in your ComfyUI venv.
Technical details:
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
- This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32)
"""
CATEGORY = "🌸Pivoine/Audio"