Files
runpod/models/musicgen/server.py
Sebastian Krüger 9ee626a78e feat: implement Ansible-based process architecture for RunPod
Major architecture overhaul to address RunPod Docker limitations:

Core Infrastructure:
- Add base_service.py: Abstract base class for all AI services
- Add service_manager.py: Process lifecycle management
- Add core/requirements.txt: Core dependencies

Model Services (Standalone Python):
- Add models/vllm/server.py: Qwen 2.5 7B text generation
- Add models/flux/server.py: Flux.1 Schnell image generation
- Add models/musicgen/server.py: MusicGen Medium music generation
- Each service inherits from GPUService base class
- OpenAI-compatible APIs
- Standalone execution support

Ansible Deployment:
- Add playbook.yml: Comprehensive deployment automation
- Add ansible.cfg: Ansible configuration
- Add inventory.yml: Localhost inventory
- Tags: base, python, dependencies, models, tailscale, validate, cleanup

Scripts:
- Add scripts/install.sh: Full installation wrapper
- Add scripts/download-models.sh: Model download wrapper
- Add scripts/start-all.sh: Start orchestrator
- Add scripts/stop-all.sh: Stop all services

Documentation:
- Update ARCHITECTURE.md: Document distributed VPS+GPU architecture

Benefits:
- No Docker: Avoids RunPod CAP_SYS_ADMIN limitations
- Fully reproducible via Ansible
- Extensible: Add models in 3 steps
- Direct Python execution (no container overhead)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-21 15:37:18 +01:00

173 lines
6.1 KiB
Python

#!/usr/bin/env python3
"""
MusicGen Music Generation Service
OpenAI-compatible music generation using Meta's MusicGen Medium model.
Provides /v1/audio/generations endpoint.
"""
import base64
import io
import os
import tempfile
from typing import Optional
import torch
import torchaudio
from audiocraft.models import MusicGen
from fastapi import HTTPException
from pydantic import BaseModel, Field
# Import base service class
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
from core.base_service import GPUService
class AudioGenerationRequest(BaseModel):
"""Music generation request"""
model: str = Field(default="musicgen-medium", description="Model name")
prompt: str = Field(..., description="Text description of the music to generate")
duration: float = Field(default=30.0, ge=1.0, le=30.0, description="Duration in seconds")
temperature: float = Field(default=1.0, ge=0.1, le=2.0, description="Sampling temperature")
top_k: int = Field(default=250, ge=0, le=500, description="Top-k sampling")
top_p: float = Field(default=0.0, ge=0.0, le=1.0, description="Top-p (nucleus) sampling")
cfg_coef: float = Field(default=3.0, ge=1.0, le=15.0, description="Classifier-free guidance coefficient")
response_format: str = Field(default="wav", description="Audio format (wav or mp3)")
class AudioGenerationResponse(BaseModel):
"""Music generation response"""
audio: str = Field(..., description="Base64-encoded audio data")
format: str = Field(..., description="Audio format (wav or mp3)")
duration: float = Field(..., description="Duration in seconds")
sample_rate: int = Field(..., description="Sample rate in Hz")
class MusicGenService(GPUService):
"""MusicGen music generation service"""
def __init__(self):
# Get port from environment or use default
port = int(os.getenv("PORT", "8003"))
super().__init__(name="musicgen-medium", port=port)
# Service-specific attributes
self.model: Optional[MusicGen] = None
self.model_name = os.getenv("MODEL_NAME", "facebook/musicgen-medium")
async def initialize(self):
"""Initialize MusicGen model"""
await super().initialize()
self.logger.info(f"Loading MusicGen model: {self.model_name}")
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = MusicGen.get_pretrained(self.model_name, device=device)
self.logger.info(f"MusicGen model loaded successfully")
self.logger.info(f"Max duration: 30 seconds at {self.model.sample_rate}Hz")
async def cleanup(self):
"""Cleanup resources"""
await super().cleanup()
if self.model:
self.logger.info("MusicGen model cleanup")
self.model = None
def create_app(self):
"""Create FastAPI routes"""
@self.app.get("/")
async def root():
"""Root endpoint"""
return {
"service": "MusicGen API Server",
"model": self.model_name,
"max_duration": 30.0,
"sample_rate": self.model.sample_rate if self.model else 32000
}
@self.app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "musicgen-medium",
"object": "model",
"created": 1234567890,
"owned_by": "meta",
"permission": [],
"root": self.model_name,
"parent": None,
}
]
}
@self.app.post("/v1/audio/generations")
async def generate_audio(request: AudioGenerationRequest) -> AudioGenerationResponse:
"""Generate music from text prompt"""
if not self.model:
raise HTTPException(status_code=503, detail="Model not initialized")
self.logger.info(f"Generating music: {request.prompt[:100]}...")
self.logger.info(f"Duration: {request.duration}s, Temperature: {request.temperature}")
try:
# Set generation parameters
self.model.set_generation_params(
duration=request.duration,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
cfg_coef=request.cfg_coef,
)
# Generate audio
descriptions = [request.prompt]
with torch.no_grad():
wav = self.model.generate(descriptions)
# wav shape: [batch_size, channels, samples]
# Extract first batch item
audio_data = wav[0].cpu() # [channels, samples]
# Get sample rate
sample_rate = self.model.sample_rate
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
temp_path = temp_file.name
torchaudio.save(temp_path, audio_data, sample_rate)
# Read audio file and encode to base64
with open(temp_path, 'rb') as f:
audio_bytes = f.read()
# Clean up temporary file
os.unlink(temp_path)
# Encode to base64
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
self.logger.info(f"Generated {request.duration}s of audio")
return AudioGenerationResponse(
audio=audio_base64,
format="wav",
duration=request.duration,
sample_rate=sample_rate
)
except Exception as e:
self.logger.error(f"Error generating audio: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
service = MusicGenService()
service.run()