Initial commit: RunPod multi-modal AI orchestration stack
- Multi-modal AI infrastructure for RunPod RTX 4090 - Automatic model orchestration (text, image, music) - Text: vLLM + Qwen 2.5 7B Instruct - Image: Flux.1 Schnell via OpenEDAI - Music: MusicGen Medium via AudioCraft - Cost-optimized sequential loading on single GPU - Template preparation scripts for rapid deployment - Comprehensive documentation (README, DEPLOYMENT, TEMPLATE)
This commit is contained in:
359
model-orchestrator/orchestrator.py
Normal file
359
model-orchestrator/orchestrator.py
Normal file
@@ -0,0 +1,359 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AI Model Orchestrator for RunPod RTX 4090
|
||||
Manages sequential loading of text, image, and music models on a single GPU
|
||||
|
||||
Features:
|
||||
- Automatic model switching based on request type
|
||||
- OpenAI-compatible API endpoints
|
||||
- Docker Compose service management
|
||||
- GPU memory monitoring
|
||||
- Simple YAML configuration for adding new models
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="AI Model Orchestrator", version="1.0.0")
|
||||
|
||||
# Docker client
|
||||
docker_client = docker.from_env()
|
||||
|
||||
# Global state
|
||||
current_model: Optional[str] = None
|
||||
model_registry: Dict[str, Dict[str, Any]] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def load_model_registry():
|
||||
"""Load model registry from models.yaml"""
|
||||
global model_registry, config
|
||||
|
||||
config_path = os.getenv("MODELS_CONFIG", "/app/models.yaml")
|
||||
logger.info(f"Loading model registry from {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
model_registry = data.get('models', {})
|
||||
config = data.get('config', {})
|
||||
|
||||
logger.info(f"Loaded {len(model_registry)} models from registry")
|
||||
for model_name, model_info in model_registry.items():
|
||||
logger.info(f" - {model_name}: {model_info['description']}")
|
||||
|
||||
|
||||
def get_docker_service_name(service_name: str) -> str:
|
||||
"""Get full Docker service name with project prefix"""
|
||||
project_name = os.getenv("COMPOSE_PROJECT_NAME", "ai")
|
||||
return f"{project_name}_{service_name}_1"
|
||||
|
||||
|
||||
async def stop_current_model():
|
||||
"""Stop the currently running model service"""
|
||||
global current_model
|
||||
|
||||
if not current_model:
|
||||
logger.info("No model currently running")
|
||||
return
|
||||
|
||||
model_info = model_registry.get(current_model)
|
||||
if not model_info:
|
||||
logger.warning(f"Model {current_model} not found in registry")
|
||||
current_model = None
|
||||
return
|
||||
|
||||
service_name = get_docker_service_name(model_info['docker_service'])
|
||||
logger.info(f"Stopping model: {current_model} (service: {service_name})")
|
||||
|
||||
try:
|
||||
container = docker_client.containers.get(service_name)
|
||||
container.stop(timeout=30)
|
||||
logger.info(f"Stopped {current_model}")
|
||||
current_model = None
|
||||
except docker.errors.NotFound:
|
||||
logger.warning(f"Container {service_name} not found (already stopped?)")
|
||||
current_model = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping {service_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def start_model(model_name: str):
|
||||
"""Start a model service"""
|
||||
global current_model
|
||||
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in registry")
|
||||
|
||||
model_info = model_registry[model_name]
|
||||
service_name = get_docker_service_name(model_info['docker_service'])
|
||||
|
||||
logger.info(f"Starting model: {model_name} (service: {service_name})")
|
||||
logger.info(f" VRAM requirement: {model_info['vram_gb']} GB")
|
||||
logger.info(f" Estimated startup time: {model_info['startup_time_seconds']}s")
|
||||
|
||||
try:
|
||||
# Start the container
|
||||
container = docker_client.containers.get(service_name)
|
||||
container.start()
|
||||
|
||||
# Wait for service to be healthy
|
||||
port = model_info['port']
|
||||
endpoint = model_info.get('endpoint', '/')
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
logger.info(f"Waiting for {model_name} to be ready at {base_url}...")
|
||||
|
||||
max_wait = model_info['startup_time_seconds'] + 60 # Add buffer
|
||||
start_time = time.time()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
while time.time() - start_time < max_wait:
|
||||
try:
|
||||
# Try health check or root endpoint
|
||||
health_url = f"{base_url}/health"
|
||||
try:
|
||||
response = await client.get(health_url, timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"{model_name} is ready!")
|
||||
current_model = model_name
|
||||
return
|
||||
except:
|
||||
# Try root endpoint if /health doesn't exist
|
||||
response = await client.get(base_url, timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"{model_name} is ready!")
|
||||
current_model = model_name
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"Waiting for {model_name}... ({e})")
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Model {model_name} failed to start within {max_wait}s"
|
||||
)
|
||||
|
||||
except docker.errors.NotFound:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Docker service {service_name} not found. Is it defined in docker-compose?"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting {model_name}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def ensure_model_running(model_name: str):
|
||||
"""Ensure the specified model is running, switching if necessary"""
|
||||
global current_model
|
||||
|
||||
if current_model == model_name:
|
||||
logger.info(f"Model {model_name} already running")
|
||||
return
|
||||
|
||||
logger.info(f"Switching model: {current_model} -> {model_name}")
|
||||
|
||||
# Stop current model
|
||||
await stop_current_model()
|
||||
|
||||
# Start requested model
|
||||
await start_model(model_name)
|
||||
|
||||
logger.info(f"Model switch complete: {model_name} is now active")
|
||||
|
||||
|
||||
async def proxy_request(model_name: str, request: Request):
|
||||
"""Proxy request to the active model service"""
|
||||
model_info = model_registry[model_name]
|
||||
port = model_info['port']
|
||||
|
||||
# Get request details
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None) # Remove host header
|
||||
|
||||
# Build target URL
|
||||
target_url = f"http://localhost:{port}{path}"
|
||||
|
||||
logger.info(f"Proxying {method} request to {target_url}")
|
||||
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
# Handle different request types
|
||||
if method == "GET":
|
||||
response = await client.get(target_url, headers=headers)
|
||||
elif method == "POST":
|
||||
body = await request.body()
|
||||
response = await client.post(target_url, content=body, headers=headers)
|
||||
else:
|
||||
raise HTTPException(status_code=405, detail=f"Method {method} not supported")
|
||||
|
||||
# Return response
|
||||
return JSONResponse(
|
||||
content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load model registry on startup"""
|
||||
load_model_registry()
|
||||
logger.info("AI Model Orchestrator started successfully")
|
||||
logger.info(f"GPU Memory: {config.get('gpu_memory_total_gb', 24)} GB")
|
||||
logger.info(f"Default model: {config.get('default_model', 'qwen-2.5-7b')}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "AI Model Orchestrator",
|
||||
"version": "1.0.0",
|
||||
"current_model": current_model,
|
||||
"available_models": list(model_registry.keys())
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"current_model": current_model,
|
||||
"model_info": model_registry.get(current_model) if current_model else None,
|
||||
"gpu_memory_total_gb": config.get('gpu_memory_total_gb', 24),
|
||||
"models_available": len(model_registry)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List all available models"""
|
||||
return {
|
||||
"models": model_registry,
|
||||
"current_model": current_model
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
"""OpenAI-compatible chat completions endpoint (text models)"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', config.get('default_model', 'qwen-2.5-7b'))
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'text':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not a text model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/v1/images/generations")
|
||||
async def image_generations(request: Request):
|
||||
"""OpenAI-compatible image generation endpoint"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', 'flux-schnell')
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'image':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an image model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/v1/audio/generations")
|
||||
async def audio_generations(request: Request):
|
||||
"""Custom audio generation endpoint (music/sound effects)"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', 'musicgen-medium')
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'audio':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an audio model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/switch")
|
||||
async def switch_model(request: Request):
|
||||
"""Manually switch to a specific model"""
|
||||
body = await request.json()
|
||||
model_name = body.get('model')
|
||||
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name required")
|
||||
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model": model_name,
|
||||
"message": f"Switched to {model_name}"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
host = os.getenv("HOST", "0.0.0.0")
|
||||
port = int(os.getenv("PORT", "9000"))
|
||||
|
||||
logger.info(f"Starting AI Model Orchestrator on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
)
|
||||
Reference in New Issue
Block a user