Fix model loading - use device param instead of .to()
AudioCraft's get_pretrained() accepts device parameter directly. The model objects don't have a .to() method. Fixed in all adapters: - MusicGen - AudioGen - MAGNeT - MusicGen Style - JASCO 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -104,8 +104,7 @@ class AudioGenAdapter(BaseAudioModel):
|
|||||||
from audiocraft.models import AudioGen
|
from audiocraft.models import AudioGen
|
||||||
|
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
self._model = AudioGen.get_pretrained(self._config["hf_id"])
|
self._model = AudioGen.get_pretrained(self._config["hf_id"], device=device)
|
||||||
self._model.to(self._device)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"AudioGen {self._variant} loaded successfully "
|
f"AudioGen {self._variant} loaded successfully "
|
||||||
|
|||||||
@@ -116,8 +116,7 @@ class JASCOAdapter(BaseAudioModel):
|
|||||||
from audiocraft.models import JASCO
|
from audiocraft.models import JASCO
|
||||||
|
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
self._model = JASCO.get_pretrained(self._config["hf_id"])
|
self._model = JASCO.get_pretrained(self._config["hf_id"], device=device)
|
||||||
self._model.to(self._device)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"JASCO {self._variant} loaded successfully "
|
f"JASCO {self._variant} loaded successfully "
|
||||||
|
|||||||
@@ -142,8 +142,7 @@ class MAGNeTAdapter(BaseAudioModel):
|
|||||||
from audiocraft.models import MAGNeT
|
from audiocraft.models import MAGNeT
|
||||||
|
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
self._model = MAGNeT.get_pretrained(self._config["hf_id"])
|
self._model = MAGNeT.get_pretrained(self._config["hf_id"], device=device)
|
||||||
self._model.to(self._device)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"MAGNeT {self._variant} loaded successfully "
|
f"MAGNeT {self._variant} loaded successfully "
|
||||||
|
|||||||
@@ -168,8 +168,8 @@ class MusicGenAdapter(BaseAudioModel):
|
|||||||
from audiocraft.models import MusicGen
|
from audiocraft.models import MusicGen
|
||||||
|
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
# MusicGen.get_pretrained() accepts device parameter directly
|
||||||
self._model.to(self._device)
|
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"MusicGen {self._variant} loaded successfully "
|
f"MusicGen {self._variant} loaded successfully "
|
||||||
|
|||||||
@@ -105,8 +105,7 @@ class MusicGenStyleAdapter(BaseAudioModel):
|
|||||||
from audiocraft.models import MusicGen
|
from audiocraft.models import MusicGen
|
||||||
|
|
||||||
self._device = torch.device(device)
|
self._device = torch.device(device)
|
||||||
self._model = MusicGen.get_pretrained(self._config["hf_id"])
|
self._model = MusicGen.get_pretrained(self._config["hf_id"], device=device)
|
||||||
self._model.to(self._device)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"MusicGen Style {self._variant} loaded successfully "
|
f"MusicGen Style {self._variant} loaded successfully "
|
||||||
|
|||||||
Reference in New Issue
Block a user