diff --git a/src/models/audiogen/adapter.py b/src/models/audiogen/adapter.py index 0a88e50..0003e42 100644 --- a/src/models/audiogen/adapter.py +++ b/src/models/audiogen/adapter.py @@ -104,8 +104,7 @@ class AudioGenAdapter(BaseAudioModel): from audiocraft.models import AudioGen self._device = torch.device(device) - self._model = AudioGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = AudioGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"AudioGen {self._variant} loaded successfully " diff --git a/src/models/jasco/adapter.py b/src/models/jasco/adapter.py index 1e739ca..e15082b 100644 --- a/src/models/jasco/adapter.py +++ b/src/models/jasco/adapter.py @@ -116,8 +116,7 @@ class JASCOAdapter(BaseAudioModel): from audiocraft.models import JASCO self._device = torch.device(device) - self._model = JASCO.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = JASCO.get_pretrained(self._config["hf_id"], device=device) logger.info( f"JASCO {self._variant} loaded successfully " diff --git a/src/models/magnet/adapter.py b/src/models/magnet/adapter.py index 98f2f2e..17e39b1 100644 --- a/src/models/magnet/adapter.py +++ b/src/models/magnet/adapter.py @@ -142,8 +142,7 @@ class MAGNeTAdapter(BaseAudioModel): from audiocraft.models import MAGNeT self._device = torch.device(device) - self._model = MAGNeT.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = MAGNeT.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MAGNeT {self._variant} loaded successfully " diff --git a/src/models/musicgen/adapter.py b/src/models/musicgen/adapter.py index e316a3b..2dae3ff 100644 --- a/src/models/musicgen/adapter.py +++ b/src/models/musicgen/adapter.py @@ -168,8 +168,8 @@ class MusicGenAdapter(BaseAudioModel): from audiocraft.models import MusicGen self._device = torch.device(device) - self._model = MusicGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + # MusicGen.get_pretrained() accepts device parameter directly + self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MusicGen {self._variant} loaded successfully " diff --git a/src/models/musicgen_style/adapter.py b/src/models/musicgen_style/adapter.py index 7926ec2..abb7c03 100644 --- a/src/models/musicgen_style/adapter.py +++ b/src/models/musicgen_style/adapter.py @@ -105,8 +105,7 @@ class MusicGenStyleAdapter(BaseAudioModel): from audiocraft.models import MusicGen self._device = torch.device(device) - self._model = MusicGen.get_pretrained(self._config["hf_id"]) - self._model.to(self._device) + self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device) logger.info( f"MusicGen Style {self._variant} loaded successfully "