324 lines
10 KiB
Python
324 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
AI Model Orchestrator for RunPod (Process-Based)
|
|
Manages sequential loading of AI models using subprocess instead of Docker
|
|
|
|
Simplified architecture for RunPod's containerized environment:
|
|
- No Docker-in-Docker complexity
|
|
- Direct process management via subprocess
|
|
- Models run as Python background processes
|
|
- GPU memory efficient (sequential model loading)
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import time
|
|
import signal
|
|
from typing import Dict, Optional, Any
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
import yaml
|
|
import psutil
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# FastAPI app
|
|
app = FastAPI(title="AI Model Orchestrator (Process-Based)", version="2.0.0")
|
|
|
|
# Global state
|
|
current_model: Optional[str] = None
|
|
model_processes: Dict[str, subprocess.Popen] = {}
|
|
model_registry: Dict[str, Dict[str, Any]] = {}
|
|
config: Dict[str, Any] = {}
|
|
|
|
|
|
def load_model_registry():
|
|
"""Load model registry from models.yaml"""
|
|
global model_registry, config
|
|
|
|
config_path = os.getenv("MODELS_CONFIG", "/workspace/ai/model-orchestrator/models.yaml")
|
|
logger.info(f"Loading model registry from {config_path}")
|
|
|
|
with open(config_path, 'r') as f:
|
|
data = yaml.safe_load(f)
|
|
|
|
model_registry = data.get('models', {})
|
|
config = data.get('config', {})
|
|
|
|
logger.info(f"Loaded {len(model_registry)} models")
|
|
for model_name, model_config in model_registry.items():
|
|
logger.info(f" - {model_name}: {model_config.get('type')} ({model_config.get('framework')})")
|
|
|
|
|
|
async def start_model_process(model_name: str) -> bool:
|
|
"""Start a model as a subprocess"""
|
|
global current_model, model_processes
|
|
|
|
if model_name not in model_registry:
|
|
logger.error(f"Model {model_name} not found in registry")
|
|
return False
|
|
|
|
model_config = model_registry[model_name]
|
|
|
|
# Stop current model if running
|
|
if current_model and current_model != model_name:
|
|
await stop_model_process(current_model)
|
|
|
|
# Check if already running
|
|
if model_name in model_processes:
|
|
proc = model_processes[model_name]
|
|
if proc.poll() is None: # Still running
|
|
logger.info(f"Model {model_name} already running")
|
|
return True
|
|
|
|
logger.info(f"Starting model {model_name}...")
|
|
|
|
try:
|
|
# Get service command from config
|
|
service_script = model_config.get('service_script')
|
|
if not service_script:
|
|
logger.error(f"No service_script defined for {model_name}")
|
|
return False
|
|
|
|
script_path = Path(f"/workspace/ai/{service_script}")
|
|
if not script_path.exists():
|
|
logger.error(f"Service script not found: {script_path}")
|
|
return False
|
|
|
|
# Start process
|
|
port = model_config.get('port', 8000)
|
|
env = os.environ.copy()
|
|
env.update({
|
|
'HF_TOKEN': os.getenv('HF_TOKEN', ''),
|
|
'PORT': str(port),
|
|
'HOST': '0.0.0.0',
|
|
'MODEL_NAME': model_config.get('model_name', model_name)
|
|
})
|
|
|
|
# Use venv python if it exists
|
|
script_dir = script_path.parent
|
|
venv_python = script_dir / 'venv' / 'bin' / 'python3'
|
|
python_cmd = str(venv_python) if venv_python.exists() else 'python3'
|
|
|
|
proc = subprocess.Popen(
|
|
[python_cmd, str(script_path)],
|
|
env=env,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
preexec_fn=os.setsid # Create new process group
|
|
)
|
|
|
|
model_processes[model_name] = proc
|
|
|
|
# Wait for service to be ready
|
|
max_wait = model_config.get('startup_time_seconds', 120)
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < max_wait:
|
|
if proc.poll() is not None:
|
|
logger.error(f"Process for {model_name} exited prematurely")
|
|
return False
|
|
|
|
try:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(
|
|
f"http://localhost:{port}/health",
|
|
timeout=5.0
|
|
)
|
|
if response.status_code == 200:
|
|
logger.info(f"Model {model_name} is ready on port {port}")
|
|
current_model = model_name
|
|
return True
|
|
except:
|
|
await asyncio.sleep(2)
|
|
|
|
logger.error(f"Model {model_name} failed to start within {max_wait}s")
|
|
await stop_model_process(model_name)
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error starting {model_name}: {e}")
|
|
return False
|
|
|
|
|
|
async def stop_model_process(model_name: str):
|
|
"""Stop a running model process"""
|
|
global model_processes, current_model
|
|
|
|
if model_name not in model_processes:
|
|
logger.warning(f"Model {model_name} not in process registry")
|
|
return
|
|
|
|
proc = model_processes[model_name]
|
|
|
|
if proc.poll() is None: # Still running
|
|
logger.info(f"Stopping model {model_name}...")
|
|
try:
|
|
# Send SIGTERM to process group
|
|
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
|
|
# Wait for graceful shutdown
|
|
try:
|
|
proc.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
# Force kill if not terminated
|
|
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
|
proc.wait()
|
|
|
|
logger.info(f"Model {model_name} stopped")
|
|
except Exception as e:
|
|
logger.error(f"Error stopping {model_name}: {e}")
|
|
|
|
del model_processes[model_name]
|
|
if current_model == model_name:
|
|
current_model = None
|
|
|
|
|
|
def get_model_for_endpoint(endpoint: str) -> Optional[str]:
|
|
"""Determine which model handles this endpoint"""
|
|
for model_name, model_config in model_registry.items():
|
|
if endpoint.startswith(model_config.get('endpoint', '')):
|
|
return model_name
|
|
return None
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Initialize on startup"""
|
|
logger.info("Starting AI Model Orchestrator (Process-Based)")
|
|
load_model_registry()
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_event():
|
|
"""Cleanup on shutdown"""
|
|
logger.info("Shutting down orchestrator...")
|
|
for model_name in list(model_processes.keys()):
|
|
await stop_model_process(model_name)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {
|
|
"status": "healthy",
|
|
"current_model": current_model,
|
|
"active_processes": len(model_processes),
|
|
"available_models": list(model_registry.keys())
|
|
}
|
|
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models_openai():
|
|
"""OpenAI-compatible models listing endpoint"""
|
|
models_list = []
|
|
for model_name, model_info in model_registry.items():
|
|
models_list.append({
|
|
"id": model_name,
|
|
"object": "model",
|
|
"created": int(time.time()),
|
|
"owned_by": "pivoine-gpu",
|
|
"permission": [],
|
|
"root": model_name,
|
|
"parent": None,
|
|
})
|
|
|
|
return {
|
|
"object": "list",
|
|
"data": models_list
|
|
}
|
|
|
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
|
async def proxy_request(request: Request, path: str):
|
|
"""Proxy requests to appropriate model service"""
|
|
endpoint = f"/{path}"
|
|
|
|
# Determine which model should handle this
|
|
target_model = get_model_for_endpoint(endpoint)
|
|
|
|
if not target_model:
|
|
raise HTTPException(status_code=404, detail=f"No model configured for endpoint: {endpoint}")
|
|
|
|
# Ensure model is running
|
|
if current_model != target_model:
|
|
logger.info(f"Switching to model {target_model}")
|
|
success = await start_model_process(target_model)
|
|
if not success:
|
|
raise HTTPException(status_code=503, detail=f"Failed to start model {target_model}")
|
|
|
|
# Proxy the request
|
|
model_config = model_registry[target_model]
|
|
target_url = f"http://localhost:{model_config['port']}/{path}"
|
|
|
|
# Get request details
|
|
method = request.method
|
|
headers = dict(request.headers)
|
|
headers.pop('host', None) # Remove host header
|
|
body = await request.body()
|
|
|
|
# Check if this is a streaming request
|
|
is_streaming = False
|
|
if method == "POST" and body:
|
|
try:
|
|
import json
|
|
body_json = json.loads(body)
|
|
is_streaming = body_json.get('stream', False)
|
|
except:
|
|
pass
|
|
|
|
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
|
|
|
try:
|
|
if is_streaming:
|
|
# For streaming requests, use httpx streaming and yield chunks
|
|
async def stream_response():
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
|
async for chunk in response.aiter_bytes():
|
|
yield chunk
|
|
|
|
return StreamingResponse(
|
|
stream_response(),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
)
|
|
else:
|
|
# For non-streaming requests, use the original behavior
|
|
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
response = await client.request(
|
|
method=method,
|
|
url=target_url,
|
|
headers=headers,
|
|
content=body
|
|
)
|
|
|
|
return JSONResponse(
|
|
content=response.json() if response.headers.get('content-type') == 'application/json' else response.text,
|
|
status_code=response.status_code
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error proxying request: {e}")
|
|
raise HTTPException(status_code=502, detail=str(e))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
port = int(os.getenv("PORT", "9000"))
|
|
host = os.getenv("HOST", "0.0.0.0")
|
|
|
|
logger.info(f"Starting orchestrator on {host}:{port}")
|
|
uvicorn.run(app, host=host, port=port, log_level="info")
|