#!/usr/bin/env python3 """ MusicGen Music Generation Service OpenAI-compatible music generation using Meta's MusicGen Medium model. Provides /v1/audio/generations endpoint. """ import base64 import io import os import tempfile from typing import Optional import torch import torchaudio from audiocraft.models import MusicGen from fastapi import HTTPException from pydantic import BaseModel, Field # Import base service class import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..')) from core.base_service import GPUService class AudioGenerationRequest(BaseModel): """Music generation request""" model: str = Field(default="musicgen-medium", description="Model name") prompt: str = Field(..., description="Text description of the music to generate") duration: float = Field(default=30.0, ge=1.0, le=30.0, description="Duration in seconds") temperature: float = Field(default=1.0, ge=0.1, le=2.0, description="Sampling temperature") top_k: int = Field(default=250, ge=0, le=500, description="Top-k sampling") top_p: float = Field(default=0.0, ge=0.0, le=1.0, description="Top-p (nucleus) sampling") cfg_coef: float = Field(default=3.0, ge=1.0, le=15.0, description="Classifier-free guidance coefficient") response_format: str = Field(default="wav", description="Audio format (wav or mp3)") class AudioGenerationResponse(BaseModel): """Music generation response""" audio: str = Field(..., description="Base64-encoded audio data") format: str = Field(..., description="Audio format (wav or mp3)") duration: float = Field(..., description="Duration in seconds") sample_rate: int = Field(..., description="Sample rate in Hz") class MusicGenService(GPUService): """MusicGen music generation service""" def __init__(self): # Get port from environment or use default port = int(os.getenv("PORT", "8003")) super().__init__(name="musicgen-medium", port=port) # Service-specific attributes self.model: Optional[MusicGen] = None self.model_name = os.getenv("MODEL_NAME", "facebook/musicgen-medium") async def initialize(self): """Initialize MusicGen model""" await super().initialize() self.logger.info(f"Loading MusicGen model: {self.model_name}") # Load model device = "cuda" if torch.cuda.is_available() else "cpu" self.model = MusicGen.get_pretrained(self.model_name, device=device) self.logger.info(f"MusicGen model loaded successfully") self.logger.info(f"Max duration: 30 seconds at {self.model.sample_rate}Hz") async def cleanup(self): """Cleanup resources""" await super().cleanup() if self.model: self.logger.info("MusicGen model cleanup") self.model = None def create_app(self): """Create FastAPI routes""" @self.app.get("/") async def root(): """Root endpoint""" return { "service": "MusicGen API Server", "model": self.model_name, "max_duration": 30.0, "sample_rate": self.model.sample_rate if self.model else 32000 } @self.app.get("/v1/models") async def list_models(): """List available models (OpenAI-compatible)""" return { "object": "list", "data": [ { "id": "musicgen-medium", "object": "model", "created": 1234567890, "owned_by": "meta", "permission": [], "root": self.model_name, "parent": None, } ] } @self.app.post("/v1/audio/generations") async def generate_audio(request: AudioGenerationRequest) -> AudioGenerationResponse: """Generate music from text prompt""" if not self.model: raise HTTPException(status_code=503, detail="Model not initialized") self.logger.info(f"Generating music: {request.prompt[:100]}...") self.logger.info(f"Duration: {request.duration}s, Temperature: {request.temperature}") try: # Set generation parameters self.model.set_generation_params( duration=request.duration, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, cfg_coef=request.cfg_coef, ) # Generate audio descriptions = [request.prompt] with torch.no_grad(): wav = self.model.generate(descriptions) # wav shape: [batch_size, channels, samples] # Extract first batch item audio_data = wav[0].cpu() # [channels, samples] # Get sample rate sample_rate = self.model.sample_rate # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_path = temp_file.name torchaudio.save(temp_path, audio_data, sample_rate) # Read audio file and encode to base64 with open(temp_path, 'rb') as f: audio_bytes = f.read() # Clean up temporary file os.unlink(temp_path) # Encode to base64 audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') self.logger.info(f"Generated {request.duration}s of audio") return AudioGenerationResponse( audio=audio_base64, format="wav", duration=request.duration, sample_rate=sample_rate ) except Exception as e: self.logger.error(f"Error generating audio: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": service = MusicGenService() service.run()