Fix model loading - use device param instead of .to()
AudioCraft's get_pretrained() accepts device parameter directly. The model objects don't have a .to() method. Fixed in all adapters: - MusicGen - AudioGen - MAGNeT - MusicGen Style - JASCO 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -104,8 +104,7 @@ class AudioGenAdapter(BaseAudioModel):
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = AudioGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
self._model = AudioGen.get_pretrained(self._config["hf_id"], device=device)
|
||||
|
||||
logger.info(
|
||||
f"AudioGen {self._variant} loaded successfully "
|
||||
|
||||
@@ -116,8 +116,7 @@ class JASCOAdapter(BaseAudioModel):
|
||||
from audiocraft.models import JASCO
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = JASCO.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
self._model = JASCO.get_pretrained(self._config["hf_id"], device=device)
|
||||
|
||||
logger.info(
|
||||
f"JASCO {self._variant} loaded successfully "
|
||||
|
||||
@@ -142,8 +142,7 @@ class MAGNeTAdapter(BaseAudioModel):
|
||||
from audiocraft.models import MAGNeT
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
self._model = MAGNeT.get_pretrained(self._config["hf_id"], device=device)
|
||||
|
||||
logger.info(
|
||||
f"MAGNeT {self._variant} loaded successfully "
|
||||
|
||||
@@ -168,8 +168,8 @@ class MusicGenAdapter(BaseAudioModel):
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
# MusicGen.get_pretrained() accepts device parameter directly
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
|
||||
|
||||
logger.info(
|
||||
f"MusicGen {self._variant} loaded successfully "
|
||||
|
||||
@@ -105,8 +105,7 @@ class MusicGenStyleAdapter(BaseAudioModel):
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
self._device = torch.device(device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
||||
self._model.to(self._device)
|
||||
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
|
||||
|
||||
logger.info(
|
||||
f"MusicGen Style {self._variant} loaded successfully "
|
||||
|
||||
Reference in New Issue
Block a user