Fix model loading - use device param instead of .to()

AudioCraft's get_pretrained() accepts device parameter directly.
The model objects don't have a .to() method.

Fixed in all adapters:
- MusicGen
- AudioGen
- MAGNeT
- MusicGen Style
- JASCO

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-11-27 14:50:09 +01:00
parent 74f69707b8
commit b8a87e7109
5 changed files with 6 additions and 10 deletions

View File

@@ -104,8 +104,7 @@ class AudioGenAdapter(BaseAudioModel):
from audiocraft.models import AudioGen
self._device = torch.device(device)
self._model = AudioGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
self._model = AudioGen.get_pretrained(self._config["hf_id"], device=device)
logger.info(
f"AudioGen {self._variant} loaded successfully "

View File

@@ -116,8 +116,7 @@ class JASCOAdapter(BaseAudioModel):
from audiocraft.models import JASCO
self._device = torch.device(device)
self._model = JASCO.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
self._model = JASCO.get_pretrained(self._config["hf_id"], device=device)
logger.info(
f"JASCO {self._variant} loaded successfully "

View File

@@ -142,8 +142,7 @@ class MAGNeTAdapter(BaseAudioModel):
from audiocraft.models import MAGNeT
self._device = torch.device(device)
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
self._model = MAGNeT.get_pretrained(self._config["hf_id"], device=device)
logger.info(
f"MAGNeT {self._variant} loaded successfully "

View File

@@ -168,8 +168,8 @@ class MusicGenAdapter(BaseAudioModel):
from audiocraft.models import MusicGen
self._device = torch.device(device)
self._model = MusicGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
# MusicGen.get_pretrained() accepts device parameter directly
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
logger.info(
f"MusicGen {self._variant} loaded successfully "

View File

@@ -105,8 +105,7 @@ class MusicGenStyleAdapter(BaseAudioModel):
from audiocraft.models import MusicGen
self._device = torch.device(device)
self._model = MusicGen.get_pretrained(self._config["hf_id"])
self._model.to(self._device)
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
logger.info(
f"MusicGen Style {self._variant} loaded successfully "