Files
runpod/comfyui/nodes/pivoine_diffrhythm.py
Sebastian Krüger f74457b049
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s
fix: apply LlamaConfig patch globally at import time
Previous approach patched DiT.__init__ at runtime, but models were already
instantiated and cached. This version patches LlamaConfig globally BEFORE
any DiffRhythm imports, ensuring all model instances use the correct config.

Key changes:
- Created PatchedLlamaConfig subclass that auto-calculates attention heads
- Replaced LlamaConfig in transformers.models.llama module at import time
- Patch applies to all LlamaConfig instances, including pre-loaded models

This should finally fix the tensor dimension mismatch error.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-24 19:00:29 +01:00

102 lines
4.0 KiB
Python

"""
Pivoine DiffRhythm Node
Custom wrapper for DiffRhythm that fixes LlamaConfig initialization issues
with transformers 4.49.0+ to prevent tensor dimension mismatches.
Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must
match the size of tensor b (64)" error in rotary position embeddings.
This patch globally intercepts LlamaConfig at import time.
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
Author: valknar@pivoine.art
"""
import sys
sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm')
# CRITICAL: Patch LlamaConfig BEFORE any DiffRhythm imports
# This must happen at module import time, not at runtime
from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig
class PatchedLlamaConfig(_OriginalLlamaConfig):
"""
Patched LlamaConfig that automatically adds missing attention head parameters.
Fixes the tensor dimension mismatch (32 vs 64) in DiffRhythm's rotary
position embeddings by ensuring num_attention_heads and num_key_value_heads
are properly set based on hidden_size.
"""
def __init__(self, *args, **kwargs):
# If hidden_size is provided but num_attention_heads is not, calculate it
if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs:
hidden_size = kwargs['hidden_size']
# Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64
kwargs['num_attention_heads'] = hidden_size // 64
# If num_key_value_heads is not provided, use GQA configuration
if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs:
# For GQA (Grouped Query Attention), typically num_kv_heads = num_heads // 4
kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4)
# Call original __init__ with patched parameters
super().__init__(*args, **kwargs)
# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it
import transformers.models.llama
transformers.models.llama.LlamaConfig = PatchedLlamaConfig
# Also replace in modeling_llama module if it's already imported
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig
# Now import DiffRhythm modules - they will use our patched LlamaConfig
import infer_utils
# Monkey-patch decode_audio to force chunked=False
_original_decode_audio = infer_utils.decode_audio
def patched_decode_audio(latent, vae_model, chunked=True):
"""Patched version that always uses chunked=False"""
return _original_decode_audio(latent, vae_model, chunked=False)
infer_utils.decode_audio = patched_decode_audio
# Import DiffRhythm node
from DiffRhythmNode import DiffRhythmRun
class PivoineDiffRhythmRun(DiffRhythmRun):
"""
Pivoine version of DiffRhythmRun with enhanced compatibility and error handling.
Changes from original:
- Globally patches LlamaConfig to add missing num_attention_heads and num_key_value_heads
- Monkey-patches decode_audio to always use chunked=False for stability
- Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64)
- Compatible with transformers 4.49.0+
- Requires ~12-16GB VRAM, works reliably on RTX 4090
Technical details:
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
- This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32)
- Patch is applied globally at import time, affecting all LlamaConfig instances
"""
CATEGORY = "🌸Pivoine/Audio"
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
NODE_CLASS_MAPPINGS = {
"PivoineDiffRhythmRun": PivoineDiffRhythmRun,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PivoineDiffRhythmRun": "Pivoine DiffRhythm Run",
}