Initial commit: RunPod multi-modal AI orchestration stack
- Multi-modal AI infrastructure for RunPod RTX 4090 - Automatic model orchestration (text, image, music) - Text: vLLM + Qwen 2.5 7B Instruct - Image: Flux.1 Schnell via OpenEDAI - Music: MusicGen Medium via AudioCraft - Cost-optimized sequential loading on single GPU - Template preparation scripts for rapid deployment - Comprehensive documentation (README, DEPLOYMENT, TEMPLATE)
This commit is contained in:
194
musicgen/server.py
Normal file
194
musicgen/server.py
Normal file
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MusicGen API Server
|
||||
OpenAI-compatible API for music generation using Meta's MusicGen
|
||||
|
||||
Endpoints:
|
||||
- POST /v1/audio/generations - Generate music from text prompt
|
||||
- GET /health - Health check
|
||||
- GET / - Service info
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="MusicGen API Server", version="1.0.0")
|
||||
|
||||
# Global model instance
|
||||
model: Optional[MusicGen] = None
|
||||
model_name: str = os.getenv("MODEL_NAME", "facebook/musicgen-medium")
|
||||
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
class AudioGenerationRequest(BaseModel):
|
||||
"""Music generation request"""
|
||||
model: str = Field(default="musicgen-medium", description="Model name")
|
||||
prompt: str = Field(..., description="Text description of the music to generate")
|
||||
duration: float = Field(default=30.0, ge=1.0, le=30.0, description="Duration in seconds")
|
||||
temperature: float = Field(default=1.0, ge=0.1, le=2.0, description="Sampling temperature")
|
||||
top_k: int = Field(default=250, ge=0, le=500, description="Top-k sampling")
|
||||
top_p: float = Field(default=0.0, ge=0.0, le=1.0, description="Top-p (nucleus) sampling")
|
||||
cfg_coef: float = Field(default=3.0, ge=1.0, le=15.0, description="Classifier-free guidance coefficient")
|
||||
response_format: str = Field(default="wav", description="Audio format (wav or mp3)")
|
||||
|
||||
|
||||
class AudioGenerationResponse(BaseModel):
|
||||
"""Music generation response"""
|
||||
audio: str = Field(..., description="Base64-encoded audio data")
|
||||
format: str = Field(..., description="Audio format (wav or mp3)")
|
||||
duration: float = Field(..., description="Duration in seconds")
|
||||
sample_rate: int = Field(..., description="Sample rate in Hz")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load MusicGen model on startup"""
|
||||
global model
|
||||
|
||||
logger.info(f"Loading MusicGen model: {model_name}")
|
||||
logger.info(f"Device: {device}")
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained(model_name, device=device)
|
||||
|
||||
logger.info(f"MusicGen model loaded successfully")
|
||||
logger.info(f"Max duration: 30 seconds at 32kHz")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "MusicGen API Server",
|
||||
"model": model_name,
|
||||
"device": device,
|
||||
"max_duration": 30.0,
|
||||
"sample_rate": 32000
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy" if model else "initializing",
|
||||
"model": model_name,
|
||||
"device": device,
|
||||
"ready": model is not None,
|
||||
"gpu_available": torch.cuda.is_available()
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/audio/generations")
|
||||
async def generate_audio(request: AudioGenerationRequest) -> AudioGenerationResponse:
|
||||
"""Generate music from text prompt"""
|
||||
if not model:
|
||||
raise HTTPException(status_code=503, detail="Model not initialized")
|
||||
|
||||
logger.info(f"Generating music: {request.prompt[:100]}...")
|
||||
logger.info(f"Duration: {request.duration}s, Temperature: {request.temperature}")
|
||||
|
||||
try:
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
descriptions = [request.prompt]
|
||||
with torch.no_grad():
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch_size, channels, samples]
|
||||
# Extract first batch item
|
||||
audio_data = wav[0].cpu() # [channels, samples]
|
||||
|
||||
# Get sample rate
|
||||
sample_rate = model.sample_rate
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
torchaudio.save(temp_path, audio_data, sample_rate)
|
||||
|
||||
# Read audio file and encode to base64
|
||||
with open(temp_path, 'rb') as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(temp_path)
|
||||
|
||||
# Encode to base64
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
||||
|
||||
logger.info(f"Generated {request.duration}s of audio")
|
||||
|
||||
return AudioGenerationResponse(
|
||||
audio=audio_base64,
|
||||
format="wav",
|
||||
duration=request.duration,
|
||||
sample_rate=sample_rate
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating audio: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""List available models (OpenAI-compatible)"""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "musicgen-medium",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "meta",
|
||||
"permission": [],
|
||||
"root": model_name,
|
||||
"parent": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
host = os.getenv("HOST", "0.0.0.0")
|
||||
port = int(os.getenv("PORT", "8000"))
|
||||
|
||||
logger.info(f"Starting MusicGen API server on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
)
|
||||
Reference in New Issue
Block a user