fix: monkey-patch DiffRhythm infer function to force chunked=False
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s

The previous approach of overriding diffrhythmgen wasn't working because
ComfyUI doesn't pass the chunked parameter when it's not in INPUT_TYPES.
This fix monkey-patches the infer() function at module level to always
force chunked=False, preventing the tensor dimension mismatch error.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-24 17:24:22 +01:00
parent 5096e3ffb5
commit 1981b7b256

View File

@@ -9,15 +9,28 @@ Author: valknar@pivoine.art
import sys import sys
sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm') sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm')
# Monkey-patch the infer function to force chunked=False
import DiffRhythmNode
_original_infer = DiffRhythmNode.infer
def patched_infer(*args, **kwargs):
# Force chunked to False if present
if 'chunked' in kwargs:
kwargs['chunked'] = False
return _original_infer(*args, chunked=False, **kwargs)
# Apply the monkey patch
DiffRhythmNode.infer = patched_infer
from DiffRhythmNode import DiffRhythmRun from DiffRhythmNode import DiffRhythmRun
class PivoineDiffRhythmRun(DiffRhythmRun): class PivoineDiffRhythmRun(DiffRhythmRun):
""" """
Pivoine version of DiffRhythmRun with chunked decoding disabled. Pivoine version of DiffRhythmRun with chunked decoding forcibly disabled.
Changes from original: Changes from original:
- chunked parameter defaults to False (was True) - Monkey-patches the infer() function to always use chunked=False
- Prevents tensor dimension mismatch in VAE - Prevents tensor dimension mismatch in VAE (32 vs 64 channel error)
- Requires more VRAM (~12-16GB) but works reliably on RTX 4090 - Requires more VRAM (~12-16GB) but works reliably on RTX 4090
""" """
@@ -27,39 +40,6 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return super().INPUT_TYPES() return super().INPUT_TYPES()
def diffrhythmgen(
self,
edit,
model: str,
style_prompt: str = None,
lyrics_or_edit_lyrics: str = "",
style_audio_or_edit_song = None,
edit_segments: str = "",
chunked: bool = False, # Changed from True to False
odeint_method: str = "euler",
steps: int = 30,
cfg: int = 4,
quality_or_speed: str = "speed",
unload_model: bool = False,
seed: int = 0
):
# Force chunked=False to avoid dimension mismatch
return super().diffrhythmgen(
edit=edit,
model=model,
style_prompt=style_prompt,
lyrics_or_edit_lyrics=lyrics_or_edit_lyrics,
style_audio_or_edit_song=style_audio_or_edit_song,
edit_segments=edit_segments,
chunked=False,
odeint_method=odeint_method,
steps=steps,
cfg=cfg,
quality_or_speed=quality_or_speed,
unload_model=unload_model,
seed=seed
)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"PivoineDiffRhythmRun": PivoineDiffRhythmRun, "PivoineDiffRhythmRun": PivoineDiffRhythmRun,
} }