#!/usr/bin/env python3 """ MusicGen API Server OpenAI-compatible API for music generation using Meta's MusicGen Endpoints: - POST /v1/audio/generations - Generate music from text prompt - GET /health - Health check - GET / - Service info """ import base64 import io import logging import os import tempfile from typing import Optional import torch import torchaudio from audiocraft.models import MusicGen from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI(title="MusicGen API Server", version="1.0.0") # Global model instance model: Optional[MusicGen] = None model_name: str = os.getenv("MODEL_NAME", "facebook/musicgen-medium") device: str = "cuda" if torch.cuda.is_available() else "cpu" class AudioGenerationRequest(BaseModel): """Music generation request""" model: str = Field(default="musicgen-medium", description="Model name") prompt: str = Field(..., description="Text description of the music to generate") duration: float = Field(default=30.0, ge=1.0, le=30.0, description="Duration in seconds") temperature: float = Field(default=1.0, ge=0.1, le=2.0, description="Sampling temperature") top_k: int = Field(default=250, ge=0, le=500, description="Top-k sampling") top_p: float = Field(default=0.0, ge=0.0, le=1.0, description="Top-p (nucleus) sampling") cfg_coef: float = Field(default=3.0, ge=1.0, le=15.0, description="Classifier-free guidance coefficient") response_format: str = Field(default="wav", description="Audio format (wav or mp3)") class AudioGenerationResponse(BaseModel): """Music generation response""" audio: str = Field(..., description="Base64-encoded audio data") format: str = Field(..., description="Audio format (wav or mp3)") duration: float = Field(..., description="Duration in seconds") sample_rate: int = Field(..., description="Sample rate in Hz") @app.on_event("startup") async def startup_event(): """Load MusicGen model on startup""" global model logger.info(f"Loading MusicGen model: {model_name}") logger.info(f"Device: {device}") # Load model model = MusicGen.get_pretrained(model_name, device=device) logger.info(f"MusicGen model loaded successfully") logger.info(f"Max duration: 30 seconds at 32kHz") @app.get("/") async def root(): """Root endpoint""" return { "service": "MusicGen API Server", "model": model_name, "device": device, "max_duration": 30.0, "sample_rate": 32000 } @app.get("/health") async def health(): """Health check endpoint""" return { "status": "healthy" if model else "initializing", "model": model_name, "device": device, "ready": model is not None, "gpu_available": torch.cuda.is_available() } @app.post("/v1/audio/generations") async def generate_audio(request: AudioGenerationRequest) -> AudioGenerationResponse: """Generate music from text prompt""" if not model: raise HTTPException(status_code=503, detail="Model not initialized") logger.info(f"Generating music: {request.prompt[:100]}...") logger.info(f"Duration: {request.duration}s, Temperature: {request.temperature}") try: # Set generation parameters model.set_generation_params( duration=request.duration, temperature=request.temperature, top_k=request.top_k, top_p=request.top_p, cfg_coef=request.cfg_coef, ) # Generate audio descriptions = [request.prompt] with torch.no_grad(): wav = model.generate(descriptions) # wav shape: [batch_size, channels, samples] # Extract first batch item audio_data = wav[0].cpu() # [channels, samples] # Get sample rate sample_rate = model.sample_rate # Save to temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_path = temp_file.name torchaudio.save(temp_path, audio_data, sample_rate) # Read audio file and encode to base64 with open(temp_path, 'rb') as f: audio_bytes = f.read() # Clean up temporary file os.unlink(temp_path) # Encode to base64 audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') logger.info(f"Generated {request.duration}s of audio") return AudioGenerationResponse( audio=audio_base64, format="wav", duration=request.duration, sample_rate=sample_rate ) except Exception as e: logger.error(f"Error generating audio: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/v1/models") async def list_models(): """List available models (OpenAI-compatible)""" return { "object": "list", "data": [ { "id": "musicgen-medium", "object": "model", "created": 1234567890, "owned_by": "meta", "permission": [], "root": model_name, "parent": None, } ] } if __name__ == "__main__": import uvicorn host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "8000")) logger.info(f"Starting MusicGen API server on {host}:{port}") uvicorn.run( app, host=host, port=port, log_level="info", access_log=True, )