fix: monkey-patch DiffRhythm infer function to force chunked=False
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s
All checks were successful
Build and Push RunPod Docker Image / build-and-push (push) Successful in 14s
The previous approach of overriding diffrhythmgen wasn't working because ComfyUI doesn't pass the chunked parameter when it's not in INPUT_TYPES. This fix monkey-patches the infer() function at module level to always force chunked=False, preventing the tensor dimension mismatch error. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -9,15 +9,28 @@ Author: valknar@pivoine.art
|
||||
import sys
|
||||
sys.path.append('/workspace/ComfyUI/custom_nodes/ComfyUI_DiffRhythm')
|
||||
|
||||
# Monkey-patch the infer function to force chunked=False
|
||||
import DiffRhythmNode
|
||||
_original_infer = DiffRhythmNode.infer
|
||||
|
||||
def patched_infer(*args, **kwargs):
|
||||
# Force chunked to False if present
|
||||
if 'chunked' in kwargs:
|
||||
kwargs['chunked'] = False
|
||||
return _original_infer(*args, chunked=False, **kwargs)
|
||||
|
||||
# Apply the monkey patch
|
||||
DiffRhythmNode.infer = patched_infer
|
||||
|
||||
from DiffRhythmNode import DiffRhythmRun
|
||||
|
||||
class PivoineDiffRhythmRun(DiffRhythmRun):
|
||||
"""
|
||||
Pivoine version of DiffRhythmRun with chunked decoding disabled.
|
||||
Pivoine version of DiffRhythmRun with chunked decoding forcibly disabled.
|
||||
|
||||
Changes from original:
|
||||
- chunked parameter defaults to False (was True)
|
||||
- Prevents tensor dimension mismatch in VAE
|
||||
- Monkey-patches the infer() function to always use chunked=False
|
||||
- Prevents tensor dimension mismatch in VAE (32 vs 64 channel error)
|
||||
- Requires more VRAM (~12-16GB) but works reliably on RTX 4090
|
||||
"""
|
||||
|
||||
@@ -27,39 +40,6 @@ class PivoineDiffRhythmRun(DiffRhythmRun):
|
||||
def INPUT_TYPES(cls):
|
||||
return super().INPUT_TYPES()
|
||||
|
||||
def diffrhythmgen(
|
||||
self,
|
||||
edit,
|
||||
model: str,
|
||||
style_prompt: str = None,
|
||||
lyrics_or_edit_lyrics: str = "",
|
||||
style_audio_or_edit_song = None,
|
||||
edit_segments: str = "",
|
||||
chunked: bool = False, # Changed from True to False
|
||||
odeint_method: str = "euler",
|
||||
steps: int = 30,
|
||||
cfg: int = 4,
|
||||
quality_or_speed: str = "speed",
|
||||
unload_model: bool = False,
|
||||
seed: int = 0
|
||||
):
|
||||
# Force chunked=False to avoid dimension mismatch
|
||||
return super().diffrhythmgen(
|
||||
edit=edit,
|
||||
model=model,
|
||||
style_prompt=style_prompt,
|
||||
lyrics_or_edit_lyrics=lyrics_or_edit_lyrics,
|
||||
style_audio_or_edit_song=style_audio_or_edit_song,
|
||||
edit_segments=edit_segments,
|
||||
chunked=False,
|
||||
odeint_method=odeint_method,
|
||||
steps=steps,
|
||||
cfg=cfg,
|
||||
quality_or_speed=quality_or_speed,
|
||||
unload_model=unload_model,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"PivoineDiffRhythmRun": PivoineDiffRhythmRun,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user