#!/usr/bin/env python3 """ AI Model Orchestrator for RunPod (Process-Based) Manages sequential loading of AI models using subprocess instead of Docker Simplified architecture for RunPod's containerized environment: - No Docker-in-Docker complexity - Direct process management via subprocess - Models run as Python background processes - GPU memory efficient (sequential model loading) """ import asyncio import logging import os import subprocess import time import signal from typing import Dict, Optional, Any from pathlib import Path import httpx import yaml import psutil from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI(title="AI Model Orchestrator (Process-Based)", version="2.0.0") # Global state current_model: Optional[str] = None model_processes: Dict[str, subprocess.Popen] = {} model_registry: Dict[str, Dict[str, Any]] = {} config: Dict[str, Any] = {} def load_model_registry(): """Load model registry from models.yaml""" global model_registry, config config_path = os.getenv("MODELS_CONFIG", "/workspace/ai/model-orchestrator/models.yaml") logger.info(f"Loading model registry from {config_path}") with open(config_path, 'r') as f: data = yaml.safe_load(f) model_registry = data.get('models', {}) config = data.get('config', {}) logger.info(f"Loaded {len(model_registry)} models") for model_name, model_config in model_registry.items(): logger.info(f" - {model_name}: {model_config.get('type')} ({model_config.get('framework')})") async def start_model_process(model_name: str) -> bool: """Start a model as a subprocess""" global current_model, model_processes if model_name not in model_registry: logger.error(f"Model {model_name} not found in registry") return False model_config = model_registry[model_name] # Stop current model if running if current_model and current_model != model_name: await stop_model_process(current_model) # Check if already running if model_name in model_processes: proc = model_processes[model_name] if proc.poll() is None: # Still running logger.info(f"Model {model_name} already running") return True logger.info(f"Starting model {model_name}...") try: # Get service command from config service_script = model_config.get('service_script') if not service_script: logger.error(f"No service_script defined for {model_name}") return False script_path = Path(f"/workspace/ai/{service_script}") if not script_path.exists(): logger.error(f"Service script not found: {script_path}") return False # Start process port = model_config.get('port', 8000) env = os.environ.copy() env.update({ 'HF_TOKEN': os.getenv('HF_TOKEN', ''), 'PORT': str(port), 'HOST': '0.0.0.0' }) proc = subprocess.Popen( ['python3', str(script_path)], env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setsid # Create new process group ) model_processes[model_name] = proc # Wait for service to be ready max_wait = model_config.get('startup_time_seconds', 120) start_time = time.time() while time.time() - start_time < max_wait: if proc.poll() is not None: logger.error(f"Process for {model_name} exited prematurely") return False try: async with httpx.AsyncClient() as client: response = await client.get( f"http://localhost:{port}/health", timeout=5.0 ) if response.status_code == 200: logger.info(f"Model {model_name} is ready on port {port}") current_model = model_name return True except: await asyncio.sleep(2) logger.error(f"Model {model_name} failed to start within {max_wait}s") await stop_model_process(model_name) return False except Exception as e: logger.error(f"Error starting {model_name}: {e}") return False async def stop_model_process(model_name: str): """Stop a running model process""" global model_processes, current_model if model_name not in model_processes: logger.warning(f"Model {model_name} not in process registry") return proc = model_processes[model_name] if proc.poll() is None: # Still running logger.info(f"Stopping model {model_name}...") try: # Send SIGTERM to process group os.killpg(os.getpgid(proc.pid), signal.SIGTERM) # Wait for graceful shutdown try: proc.wait(timeout=10) except subprocess.TimeoutExpired: # Force kill if not terminated os.killpg(os.getpgid(proc.pid), signal.SIGKILL) proc.wait() logger.info(f"Model {model_name} stopped") except Exception as e: logger.error(f"Error stopping {model_name}: {e}") del model_processes[model_name] if current_model == model_name: current_model = None def get_model_for_endpoint(endpoint: str) -> Optional[str]: """Determine which model handles this endpoint""" for model_name, model_config in model_registry.items(): if endpoint.startswith(model_config.get('endpoint', '')): return model_name return None @app.on_event("startup") async def startup_event(): """Initialize on startup""" logger.info("Starting AI Model Orchestrator (Process-Based)") load_model_registry() @app.on_event("shutdown") async def shutdown_event(): """Cleanup on shutdown""" logger.info("Shutting down orchestrator...") for model_name in list(model_processes.keys()): await stop_model_process(model_name) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "current_model": current_model, "active_processes": len(model_processes), "available_models": list(model_registry.keys()) } @app.get("/v1/models") async def list_models_openai(): """OpenAI-compatible models listing endpoint""" models_list = [] for model_name, model_info in model_registry.items(): models_list.append({ "id": model_name, "object": "model", "created": int(time.time()), "owned_by": "pivoine-gpu", "permission": [], "root": model_name, "parent": None, }) return { "object": "list", "data": models_list } @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_request(request: Request, path: str): """Proxy requests to appropriate model service""" endpoint = f"/{path}" # Determine which model should handle this target_model = get_model_for_endpoint(endpoint) if not target_model: raise HTTPException(status_code=404, detail=f"No model configured for endpoint: {endpoint}") # Ensure model is running if current_model != target_model: logger.info(f"Switching to model {target_model}") success = await start_model_process(target_model) if not success: raise HTTPException(status_code=503, detail=f"Failed to start model {target_model}") # Proxy the request model_config = model_registry[target_model] target_url = f"http://localhost:{model_config['port']}/{path}" # Get request details method = request.method headers = dict(request.headers) headers.pop('host', None) # Remove host header body = await request.body() # Check if this is a streaming request is_streaming = False if method == "POST" and body: try: import json body_json = json.loads(body) is_streaming = body_json.get('stream', False) except: pass logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})") try: if is_streaming: # For streaming requests, use httpx streaming and yield chunks async def stream_response(): async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream(method, target_url, content=body, headers=headers) as response: async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( stream_response(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) else: # For non-streaming requests, use the original behavior async with httpx.AsyncClient(timeout=300.0) as client: response = await client.request( method=method, url=target_url, headers=headers, content=body ) return JSONResponse( content=response.json() if response.headers.get('content-type') == 'application/json' else response.text, status_code=response.status_code ) except Exception as e: logger.error(f"Error proxying request: {e}") raise HTTPException(status_code=502, detail=str(e)) if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", "9000")) host = os.getenv("HOST", "0.0.0.0") logger.info(f"Starting orchestrator on {host}:{port}") uvicorn.run(app, host=host, port=port, log_level="info")