From b8a87e7109b61c8c96e9858d44bfb12ce5d702c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20Kr=C3=BCger?= Date: Thu, 27 Nov 2025 14:50:09 +0100 Subject: [PATCH] Fix model loading - use device param instead of .to() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AudioCraft's get_pretrained() accepts device parameter directly. The model objects don't have a .to() method. Fixed in all adapters: - MusicGen - AudioGen - MAGNeT - MusicGen Style - JASCO 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/models/audiogen/adapter.py | 3 +-- src/models/jasco/adapter.py | 3 +-- src/models/magnet/adapter.py | 3 +-- src/models/musicgen/adapter.py | 4 ++-- src/models/musicgen_style/adapter.py | 3 +-- 5 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/models/audiogen/adapter.py b/src/models/audiogen/adapter.py index 0a88e50..0003e42 100644 --- a/src/models/audiogen/adapter.py +++ b/src/models/audiogen/adapter.py @@ -104,8 +104,7 @@ class AudioGenAdapter(BaseAudioModel): from audiocraft.models import AudioGen self._device = torch.device(device) - self._model = AudioGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = AudioGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"AudioGen {self._variant} loaded successfully " diff --git a/src/models/jasco/adapter.py b/src/models/jasco/adapter.py index 1e739ca..e15082b 100644 --- a/src/models/jasco/adapter.py +++ b/src/models/jasco/adapter.py @@ -116,8 +116,7 @@ class JASCOAdapter(BaseAudioModel): from audiocraft.models import JASCO self._device = torch.device(device) - self._model = JASCO.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = JASCO.get_pretrained(self._config["hf_id"], device=device) logger.info( f"JASCO {self._variant} loaded successfully " diff --git a/src/models/magnet/adapter.py b/src/models/magnet/adapter.py index 98f2f2e..17e39b1 100644 --- a/src/models/magnet/adapter.py +++ b/src/models/magnet/adapter.py @@ -142,8 +142,7 @@ class MAGNeTAdapter(BaseAudioModel): from audiocraft.models import MAGNeT self._device = torch.device(device) - self._model = MAGNeT.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = MAGNeT.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MAGNeT {self._variant} loaded successfully " diff --git a/src/models/musicgen/adapter.py b/src/models/musicgen/adapter.py index e316a3b..2dae3ff 100644 --- a/src/models/musicgen/adapter.py +++ b/src/models/musicgen/adapter.py @@ -168,8 +168,8 @@ class MusicGenAdapter(BaseAudioModel): from audiocraft.models import MusicGen self._device = torch.device(device) - self._model = MusicGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + # MusicGen.get_pretrained() accepts device parameter directly + self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MusicGen {self._variant} loaded successfully " diff --git a/src/models/musicgen_style/adapter.py b/src/models/musicgen_style/adapter.py index 7926ec2..abb7c03 100644 --- a/src/models/musicgen_style/adapter.py +++ b/src/models/musicgen_style/adapter.py @@ -105,8 +105,7 @@ class MusicGenStyleAdapter(BaseAudioModel): from audiocraft.models import MusicGen self._device = torch.device(device) - self._model = MusicGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MusicGen Style {self._variant} loaded successfully "