The orchestrator was calling response.json() which buffered the entire streaming response before returning it. This caused LiteLLM to receive only one chunk with empty content instead of token-by-token streaming. Changes: - Detect streaming requests by parsing request body for 'stream': true - Use client.stream() with aiter_bytes() for streaming requests - Return StreamingResponse with proper SSE headers - Keep original JSONResponse behavior for non-streaming requests This fixes streaming from vLLM → orchestrator → LiteLLM chain.
384 lines
12 KiB
Python
384 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
AI Model Orchestrator for RunPod RTX 4090
|
|
Manages sequential loading of text, image, and music models on a single GPU
|
|
|
|
Features:
|
|
- Automatic model switching based on request type
|
|
- OpenAI-compatible API endpoints
|
|
- Docker Compose service management
|
|
- GPU memory monitoring
|
|
- Simple YAML configuration for adding new models
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Dict, Optional, Any
|
|
|
|
import docker
|
|
import httpx
|
|
import yaml
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# FastAPI app
|
|
app = FastAPI(title="AI Model Orchestrator", version="1.0.0")
|
|
|
|
# Docker client
|
|
docker_client = docker.from_env()
|
|
|
|
# Global state
|
|
current_model: Optional[str] = None
|
|
model_registry: Dict[str, Dict[str, Any]] = {}
|
|
config: Dict[str, Any] = {}
|
|
|
|
|
|
def load_model_registry():
|
|
"""Load model registry from models.yaml"""
|
|
global model_registry, config
|
|
|
|
config_path = os.getenv("MODELS_CONFIG", "/app/models.yaml")
|
|
logger.info(f"Loading model registry from {config_path}")
|
|
|
|
with open(config_path, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
model_registry = data.get('models', {})
|
|
config = data.get('config', {})
|
|
|
|
logger.info(f"Loaded {len(model_registry)} models from registry")
|
|
for model_name, model_info in model_registry.items():
|
|
logger.info(f" - {model_name}: {model_info['description']}")
|
|
|
|
|
|
def get_docker_service_name(service_name: str) -> str:
|
|
"""Get full Docker service name with project prefix"""
|
|
project_name = os.getenv("COMPOSE_PROJECT_NAME", "ai")
|
|
return f"{project_name}_{service_name}_1"
|
|
|
|
|
|
async def stop_current_model():
|
|
"""Stop the currently running model service"""
|
|
global current_model
|
|
|
|
if not current_model:
|
|
logger.info("No model currently running")
|
|
return
|
|
|
|
model_info = model_registry.get(current_model)
|
|
if not model_info:
|
|
logger.warning(f"Model {current_model} not found in registry")
|
|
current_model = None
|
|
return
|
|
|
|
service_name = get_docker_service_name(model_info['docker_service'])
|
|
logger.info(f"Stopping model: {current_model} (service: {service_name})")
|
|
|
|
try:
|
|
container = docker_client.containers.get(service_name)
|
|
container.stop(timeout=30)
|
|
logger.info(f"Stopped {current_model}")
|
|
current_model = None
|
|
except docker.errors.NotFound:
|
|
logger.warning(f"Container {service_name} not found (already stopped?)")
|
|
current_model = None
|
|
except Exception as e:
|
|
logger.error(f"Error stopping {service_name}: {e}")
|
|
raise
|
|
|
|
|
|
async def start_model(model_name: str):
|
|
"""Start a model service"""
|
|
global current_model
|
|
|
|
if model_name not in model_registry:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in registry")
|
|
|
|
model_info = model_registry[model_name]
|
|
service_name = get_docker_service_name(model_info['docker_service'])
|
|
|
|
logger.info(f"Starting model: {model_name} (service: {service_name})")
|
|
logger.info(f" VRAM requirement: {model_info['vram_gb']} GB")
|
|
logger.info(f" Estimated startup time: {model_info['startup_time_seconds']}s")
|
|
|
|
try:
|
|
# Start the container
|
|
container = docker_client.containers.get(service_name)
|
|
container.start()
|
|
|
|
# Wait for service to be healthy
|
|
port = model_info['port']
|
|
endpoint = model_info.get('endpoint', '/')
|
|
base_url = f"http://localhost:{port}"
|
|
|
|
logger.info(f"Waiting for {model_name} to be ready at {base_url}...")
|
|
|
|
max_wait = model_info['startup_time_seconds'] + 60 # Add buffer
|
|
start_time = time.time()
|
|
|
|
async with httpx.AsyncClient() as client:
|
|
while time.time() - start_time < max_wait:
|
|
try:
|
|
# Try health check or root endpoint
|
|
health_url = f"{base_url}/health"
|
|
try:
|
|
response = await client.get(health_url, timeout=5.0)
|
|
if response.status_code == 200:
|
|
logger.info(f"{model_name} is ready!")
|
|
current_model = model_name
|
|
return
|
|
except:
|
|
# Try root endpoint if /health doesn't exist
|
|
response = await client.get(base_url, timeout=5.0)
|
|
if response.status_code == 200:
|
|
logger.info(f"{model_name} is ready!")
|
|
current_model = model_name
|
|
return
|
|
except Exception as e:
|
|
logger.debug(f"Waiting for {model_name}... ({e})")
|
|
|
|
await asyncio.sleep(5)
|
|
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=f"Model {model_name} failed to start within {max_wait}s"
|
|
)
|
|
|
|
except docker.errors.NotFound:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Docker service {service_name} not found. Is it defined in docker-compose?"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error starting {model_name}: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
async def ensure_model_running(model_name: str):
|
|
"""Ensure the specified model is running, switching if necessary"""
|
|
global current_model
|
|
|
|
if current_model == model_name:
|
|
logger.info(f"Model {model_name} already running")
|
|
return
|
|
|
|
logger.info(f"Switching model: {current_model} -> {model_name}")
|
|
|
|
# Stop current model
|
|
await stop_current_model()
|
|
|
|
# Start requested model
|
|
await start_model(model_name)
|
|
|
|
logger.info(f"Model switch complete: {model_name} is now active")
|
|
|
|
|
|
async def proxy_request(model_name: str, request: Request):
|
|
"""Proxy request to the active model service"""
|
|
model_info = model_registry[model_name]
|
|
port = model_info['port']
|
|
|
|
# Get request details
|
|
path = request.url.path
|
|
method = request.method
|
|
headers = dict(request.headers)
|
|
headers.pop('host', None) # Remove host header
|
|
|
|
# Build target URL
|
|
target_url = f"http://localhost:{port}{path}"
|
|
|
|
# Check if this is a streaming request
|
|
body = await request.body()
|
|
is_streaming = False
|
|
if method == "POST" and body:
|
|
try:
|
|
import json
|
|
body_json = json.loads(body)
|
|
is_streaming = body_json.get('stream', False)
|
|
except:
|
|
pass
|
|
|
|
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
|
|
|
if is_streaming:
|
|
# For streaming requests, use httpx streaming and yield chunks
|
|
async def stream_response():
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
return StreamingResponse(
|
|
stream_response(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
)
|
|
else:
|
|
# For non-streaming requests, use the original behavior
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
if method == "GET":
|
|
response = await client.get(target_url, headers=headers)
|
|
elif method == "POST":
|
|
response = await client.post(target_url, content=body, headers=headers)
|
|
else:
|
|
raise HTTPException(status_code=405, detail=f"Method {method} not supported")
|
|
|
|
# Return response
|
|
return JSONResponse(
|
|
content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers)
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Load model registry on startup"""
|
|
load_model_registry()
|
|
logger.info("AI Model Orchestrator started successfully")
|
|
logger.info(f"GPU Memory: {config.get('gpu_memory_total_gb', 24)} GB")
|
|
logger.info(f"Default model: {config.get('default_model', 'qwen-2.5-7b')}")
|
|
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
"""Root endpoint"""
|
|
return {
|
|
"service": "AI Model Orchestrator",
|
|
"version": "1.0.0",
|
|
"current_model": current_model,
|
|
"available_models": list(model_registry.keys())
|
|
}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"current_model": current_model,
|
|
"model_info": model_registry.get(current_model) if current_model else None,
|
|
"gpu_memory_total_gb": config.get('gpu_memory_total_gb', 24),
|
|
"models_available": len(model_registry)
|
|
}
|
|
|
|
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List all available models"""
|
|
return {
|
|
"models": model_registry,
|
|
"current_model": current_model
|
|
}
|
|
|
|
|
|
@app.post("/v1/chat/completions")
|
|
async def chat_completions(request: Request):
|
|
"""OpenAI-compatible chat completions endpoint (text models)"""
|
|
# Parse request to get model name
|
|
body = await request.json()
|
|
model_name = body.get('model', config.get('default_model', 'qwen-2.5-7b'))
|
|
|
|
# Validate model type
|
|
if model_name not in model_registry:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
|
|
if model_registry[model_name]['type'] != 'text':
|
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not a text model")
|
|
|
|
# Ensure model is running
|
|
await ensure_model_running(model_name)
|
|
|
|
# Proxy request to model
|
|
return await proxy_request(model_name, request)
|
|
|
|
|
|
@app.post("/v1/images/generations")
|
|
async def image_generations(request: Request):
|
|
"""OpenAI-compatible image generation endpoint"""
|
|
# Parse request to get model name
|
|
body = await request.json()
|
|
model_name = body.get('model', 'flux-schnell')
|
|
|
|
# Validate model type
|
|
if model_name not in model_registry:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
|
|
if model_registry[model_name]['type'] != 'image':
|
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an image model")
|
|
|
|
# Ensure model is running
|
|
await ensure_model_running(model_name)
|
|
|
|
# Proxy request to model
|
|
return await proxy_request(model_name, request)
|
|
|
|
|
|
@app.post("/v1/audio/generations")
|
|
async def audio_generations(request: Request):
|
|
"""Custom audio generation endpoint (music/sound effects)"""
|
|
# Parse request to get model name
|
|
body = await request.json()
|
|
model_name = body.get('model', 'musicgen-medium')
|
|
|
|
# Validate model type
|
|
if model_name not in model_registry:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
|
|
if model_registry[model_name]['type'] != 'audio':
|
|
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an audio model")
|
|
|
|
# Ensure model is running
|
|
await ensure_model_running(model_name)
|
|
|
|
# Proxy request to model
|
|
return await proxy_request(model_name, request)
|
|
|
|
|
|
@app.post("/switch")
|
|
async def switch_model(request: Request):
|
|
"""Manually switch to a specific model"""
|
|
body = await request.json()
|
|
model_name = body.get('model')
|
|
|
|
if not model_name:
|
|
raise HTTPException(status_code=400, detail="Model name required")
|
|
|
|
if model_name not in model_registry:
|
|
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
|
|
await ensure_model_running(model_name)
|
|
|
|
return {
|
|
"status": "success",
|
|
"model": model_name,
|
|
"message": f"Switched to {model_name}"
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
host = os.getenv("HOST", "0.0.0.0")
|
|
port = int(os.getenv("PORT", "9000"))
|
|
|
|
logger.info(f"Starting AI Model Orchestrator on {host}:{port}")
|
|
|
|
uvicorn.run(
|
|
app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info",
|
|
access_log=True,
|
|
)
|