fix: apply LlamaConfig patch globally at import time
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s
Previous approach patched DiT.__init__ at runtime, but models were already instantiated and cached. This version patches LlamaConfig globally BEFORE any DiffRhythm imports, ensuring all model instances use the correct config. Key changes: - Created PatchedLlamaConfig subclass that auto-calculates attention heads - Replaced LlamaConfig in transformers.models.llama module at import time - Patch applies to all LlamaConfig instances, including pre-loaded models This should finally fix the tensor dimension mismatch error. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,7 @@ Known Issue: DiffRhythm's DIT model doesn't specify num_attention_heads and
|
|||||||
num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must
|
num_key_value_heads in LlamaConfig, causing "The size of tensor a (32) must
|
||||||
match the size of tensor b (64)" error in rotary position embeddings.
|
match the size of tensor b (64)" error in rotary position embeddings.
|
||||||
|
|
||||||
This patch adds the missing parameters to LlamaConfig initialization.
|
This patch globally intercepts LlamaConfig at import time.
|
||||||
|
|
||||||
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
|
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/44
|
||||||
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
|
Reference: https://github.com/billwuhao/ComfyUI_DiffRhythm/issues/48
|
||||||
@@ -18,62 +18,54 @@ Author: valknar@pivoine.art
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm')
|
sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm')
|
||||||
|
|
||||||
# Monkey-patch decode_audio from infer_utils to force chunked=False
|
# CRITICAL: Patch LlamaConfig BEFORE any DiffRhythm imports
|
||||||
|
# This must happen at module import time, not at runtime
|
||||||
|
from transformers.models.llama import LlamaConfig as _OriginalLlamaConfig
|
||||||
|
|
||||||
|
class PatchedLlamaConfig(_OriginalLlamaConfig):
|
||||||
|
"""
|
||||||
|
Patched LlamaConfig that automatically adds missing attention head parameters.
|
||||||
|
|
||||||
|
Fixes the tensor dimension mismatch (32 vs 64) in DiffRhythm's rotary
|
||||||
|
position embeddings by ensuring num_attention_heads and num_key_value_heads
|
||||||
|
are properly set based on hidden_size.
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
# If hidden_size is provided but num_attention_heads is not, calculate it
|
||||||
|
if 'hidden_size' in kwargs and 'num_attention_heads' not in kwargs:
|
||||||
|
hidden_size = kwargs['hidden_size']
|
||||||
|
# Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64
|
||||||
|
kwargs['num_attention_heads'] = hidden_size // 64
|
||||||
|
|
||||||
|
# If num_key_value_heads is not provided, use GQA configuration
|
||||||
|
if 'num_attention_heads' in kwargs and 'num_key_value_heads' not in kwargs:
|
||||||
|
# For GQA (Grouped Query Attention), typically num_kv_heads = num_heads // 4
|
||||||
|
kwargs['num_key_value_heads'] = max(1, kwargs['num_attention_heads'] // 4)
|
||||||
|
|
||||||
|
# Call original __init__ with patched parameters
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Replace LlamaConfig in transformers module BEFORE DiffRhythm imports it
|
||||||
|
import transformers.models.llama
|
||||||
|
transformers.models.llama.LlamaConfig = PatchedLlamaConfig
|
||||||
|
|
||||||
|
# Also replace in modeling_llama module if it's already imported
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
transformers.models.llama.modeling_llama.LlamaConfig = PatchedLlamaConfig
|
||||||
|
|
||||||
|
# Now import DiffRhythm modules - they will use our patched LlamaConfig
|
||||||
import infer_utils
|
import infer_utils
|
||||||
|
|
||||||
|
# Monkey-patch decode_audio to force chunked=False
|
||||||
_original_decode_audio = infer_utils.decode_audio
|
_original_decode_audio = infer_utils.decode_audio
|
||||||
|
|
||||||
def patched_decode_audio(latent, vae_model, chunked=True):
|
def patched_decode_audio(latent, vae_model, chunked=True):
|
||||||
"""Patched version that always uses chunked=False"""
|
"""Patched version that always uses chunked=False"""
|
||||||
return _original_decode_audio(latent, vae_model, chunked=False)
|
return _original_decode_audio(latent, vae_model, chunked=False)
|
||||||
|
|
||||||
# Apply the decode_audio monkey patch
|
|
||||||
infer_utils.decode_audio = patched_decode_audio
|
infer_utils.decode_audio = patched_decode_audio
|
||||||
|
|
||||||
# Monkey-patch DiT __init__ to fix LlamaConfig initialization
|
# Import DiffRhythm node
|
||||||
from diffrhythm.model import dit
|
|
||||||
from transformers.models.llama import LlamaConfig
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
_original_dit_init = dit.DiT.__init__
|
|
||||||
|
|
||||||
def patched_dit_init(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Patched DiT.__init__ that adds missing num_attention_heads and
|
|
||||||
num_key_value_heads to LlamaConfig initialization.
|
|
||||||
|
|
||||||
This fixes the tensor dimension mismatch (32 vs 64) error in
|
|
||||||
rotary position embeddings with transformers 4.49.0+.
|
|
||||||
"""
|
|
||||||
# Call original __init__ but intercept the LlamaConfig creation
|
|
||||||
_original_llama_config = LlamaConfig
|
|
||||||
|
|
||||||
def patched_llama_config(*config_args, **config_kwargs):
|
|
||||||
"""Add missing attention head parameters to LlamaConfig"""
|
|
||||||
hidden_size = config_kwargs.get('hidden_size', config_args[0] if config_args else 1024)
|
|
||||||
|
|
||||||
# Standard Llama architecture: head_dim = 64, so num_heads = hidden_size // 64
|
|
||||||
# For GQA (Grouped Query Attention), num_key_value_heads is usually num_heads // 4
|
|
||||||
num_attention_heads = hidden_size // 64
|
|
||||||
num_key_value_heads = max(1, num_attention_heads // 4)
|
|
||||||
|
|
||||||
config_kwargs['num_attention_heads'] = config_kwargs.get('num_attention_heads', num_attention_heads)
|
|
||||||
config_kwargs['num_key_value_heads'] = config_kwargs.get('num_key_value_heads', num_key_value_heads)
|
|
||||||
|
|
||||||
return _original_llama_config(*config_args, **config_kwargs)
|
|
||||||
|
|
||||||
# Temporarily replace LlamaConfig in the dit module
|
|
||||||
dit.LlamaConfig = patched_llama_config
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Call the original __init__
|
|
||||||
_original_dit_init(self, *args, **kwargs)
|
|
||||||
finally:
|
|
||||||
# Restore original LlamaConfig
|
|
||||||
dit.LlamaConfig = _original_llama_config
|
|
||||||
|
|
||||||
# Apply the DiT init monkey patch
|
|
||||||
dit.DiT.__init__ = patched_dit_init
|
|
||||||
|
|
||||||
from DiffRhythmNode import DiffRhythmRun
|
from DiffRhythmNode import DiffRhythmRun
|
||||||
|
|
||||||
class PivoineDiffRhythmRun(DiffRhythmRun):
|
class PivoineDiffRhythmRun(DiffRhythmRun):
|
||||||
@@ -81,7 +73,7 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
|
|||||||
Pivoine version of DiffRhythmRun with enhanced compatibility and error handling.
|
Pivoine version of DiffRhythmRun with enhanced compatibility and error handling.
|
||||||
|
|
||||||
Changes from original:
|
Changes from original:
|
||||||
- Patches DIT.__init__ to add missing num_attention_heads and num_key_value_heads to LlamaConfig
|
- Globally patches LlamaConfig to add missing num_attention_heads and num_key_value_heads
|
||||||
- Monkey-patches decode_audio to always use chunked=False for stability
|
- Monkey-patches decode_audio to always use chunked=False for stability
|
||||||
- Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64)
|
- Fixes tensor dimension mismatch in rotary position embeddings (32 vs 64)
|
||||||
- Compatible with transformers 4.49.0+
|
- Compatible with transformers 4.49.0+
|
||||||
@@ -91,6 +83,7 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
|
|||||||
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
|
- Sets num_attention_heads = hidden_size // 64 (standard Llama architecture)
|
||||||
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
|
- Sets num_key_value_heads = num_attention_heads // 4 (GQA configuration)
|
||||||
- This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32)
|
- This ensures head_dim = hidden_size // num_attention_heads = 64 (not 32)
|
||||||
|
- Patch is applied globally at import time, affecting all LlamaConfig instances
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CATEGORY = "🌸Pivoine/Audio"
|
CATEGORY = "🌸Pivoine/Audio"
|
||||||
|
|||||||
Reference in New Issue
Block a user