fix: patch DiffRhythm DIT to add missing LlamaConfig attention head parameters
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 15s
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 15s
Adds monkey-patch for DiT.__init__() to properly configure LlamaConfig with num_attention_heads and num_key_value_heads parameters, which are missing in the upstream DiffRhythm code. Root cause: transformers 4.49.0+ requires these parameters but DiffRhythm's dit.py only specifies hidden_size, causing the library to incorrectly infer head_dim as 32 instead of 64, leading to tensor dimension mismatches. Solution: - Sets num_attention_heads = hidden_size // 64 (standard Llama architecture) - Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration) - Ensures head_dim = 64, fixing the "tensor a (32) vs tensor b (64)" error This is a proper fix rather than just downgrading transformers version. References: - https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44 - https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
"""
|
||||
Pivoine DiffRhythm Node
|
||||
Custom wrapper for DiffRhythm that ensures correct transformer library version
|
||||
compatibility and provides fallback fixes for tensor dimension issues.
|
||||
Custom wrapper for DiffRhythm that fixes LlamaConfig initialization issues
|
||||
with transformers 4.49.0+ to prevent tensor dimension mismatches.
|
||||
|
||||
Known Issue: DiffRhythm requires transformers==4.49.0. Newer versions (4.50+)
|
||||
cause "The size of tensor a (32) must match the size of tensor b (64)" error
|
||||
in rotary position embeddings due to transformer block initialization changes.
|
||||
Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
|
||||
num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must
|
||||
match the size of tensor b (64)" error in rotary position embeddings.
|
||||
|
||||
This patch adds the missing parameters to LlamaConfig initialization.
|
||||
|
||||
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
|
||||
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
|
||||
@@ -24,9 +26,54 @@ def patched_decode_audio(latent, vae_model, chunked=True):
|
||||
"""Patched version that always uses chunked=False"""
|
||||
return _original_decode_audio(latent, vae_model, chunked=False)
|
||||
|
||||
# Apply the monkey patch
|
||||
# Apply the decode_audio monkey patch
|
||||
infer_utils.decode_audio = patched_decode_audio
|
||||
|
||||
# Monkey-patch DiT __init__ to fix LlamaConfig initialization
|
||||
from diffrhythm.model import dit
|
||||
from transformers.models.llama import LlamaConfig
|
||||
import torch.nn as nn
|
||||
|
||||
_original_dit_init = dit.DiT.__init__
|
||||
|
||||
def patched_dit_init(self, *args, **kwargs):
|
||||
"""
|
||||
Patched DiT.__init__ that adds missing num_attention_heads and
|
||||
num_key_value_heads to LlamaConfig initialization.
|
||||
|
||||
This fixes the tensor dimension mismatch (32 vs 64) error in
|
||||
rotary position embeddings with transformers 4.49.0+.
|
||||
"""
|
||||
# Call original __init__ but intercept the LlamaConfig creation
|
||||
_original_llama_config = LlamaConfig
|
||||
|
||||
def patched_llama_config(*config_args, **config_kwargs):
|
||||
"""Add missing attention head parameters to LlamaConfig"""
|
||||
hidden_size = config_kwargs.get('hidden_size', config_args[0] if config_args else 1024)
|
||||
|
||||
# Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64
|
||||
# For GQA (Grouped Query Attention), num_key_value_heads is usually num_heads // 4
|
||||
num_attention_heads = hidden_size // 64
|
||||
num_key_value_heads = max(1, num_attention_heads // 4)
|
||||
|
||||
config_kwargs['num_attention_heads'] = config_kwargs.get('num_attention_heads', num_attention_heads)
|
||||
config_kwargs['num_key_value_heads'] = config_kwargs.get('num_key_value_heads', num_key_value_heads)
|
||||
|
||||
return _original_llama_config(*config_args, **config_kwargs)
|
||||
|
||||
# Temporarily replace LlamaConfig in the dit module
|
||||
dit.LlamaConfig = patched_llama_config
|
||||
|
||||
try:
|
||||
# Call the original __init__
|
||||
_original_dit_init(self, *args, **kwargs)
|
||||
finally:
|
||||
# Restore original LlamaConfig
|
||||
dit.LlamaConfig = _original_llama_config
|
||||
|
||||
# Apply the DiT init monkey patch
|
||||
dit.DiT.__init__ = patched_dit_init
|
||||
|
||||
from DiffRhythmNode import DiffRhythmRun
|
||||
|
||||
class PivoineDiffRhythmRun(DiffRhythmRun):
|
||||
@@ -34,13 +81,16 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
|
||||
Pivoine version of DiffRhythmRun with enhanced compatibility and error handling.
|
||||
|
||||
Changes from original:
|
||||
- Patches DIT.__init__ to add missing num_attention_heads and num_key_value_heads to LlamaConfig
|
||||
- Monkey-patches decode_audio to always use chunked=False for stability
|
||||
- Ensures transformers library version compatibility (requires 4.49.0)
|
||||
- Prevents tensor dimension mismatch in VAE decoding
|
||||
- Requires more VRAM (~12-16GB) but works reliably on RTX 4090
|
||||
- Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64)
|
||||
- Compatible with transformers 4.49.0+
|
||||
- Requires ~12-16GB VRAM, works reliably on RTX 4090
|
||||
|
||||
Note: If you encounter "tensor a (32) must match tensor b (64)" errors,
|
||||
ensure transformers==4.49.0 is installed in your ComfyUI venv.
|
||||
Technical details:
|
||||
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
|
||||
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
|
||||
- This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32)
|
||||
"""
|
||||
|
||||
CATEGORY = "🌸Pivoine/Audio"
|
||||
|
||||
Reference in New Issue
Block a user