diff --git a/model-orchestrator/models.yaml b/model-orchestrator/models.yaml index caf6a95..f2f55bc 100644 --- a/model-orchestrator/models.yaml +++ b/model-orchestrator/models.yaml @@ -6,7 +6,7 @@ models: qwen-2.5-7b: type: text framework: vllm - docker_service: vllm-qwen + service_script: vllm/server.py port: 8001 vram_gb: 14 startup_time_seconds: 120 @@ -17,7 +17,7 @@ models: flux-schnell: type: image framework: openedai-images - docker_service: flux + service_script: flux/server.py port: 8002 vram_gb: 14 startup_time_seconds: 60 @@ -28,7 +28,7 @@ models: musicgen-medium: type: audio framework: audiocraft - docker_service: musicgen + service_script: musicgen/server.py port: 8003 vram_gb: 11 startup_time_seconds: 45 diff --git a/model-orchestrator/orchestrator_subprocess.py b/model-orchestrator/orchestrator_subprocess.py new file mode 100644 index 0000000..3245d4d --- /dev/null +++ b/model-orchestrator/orchestrator_subprocess.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +AI Model Orchestrator for RunPod (Process-Based) +Manages sequential loading of AI models using subprocess instead of Docker + +Simplified architecture for RunPod's containerized environment: +- No Docker-in-Docker complexity +- Direct process management via subprocess +- Models run as Python background processes +- GPU memory efficient (sequential model loading) +""" + +import asyncio +import logging +import os +import subprocess +import time +import signal +from typing import Dict, Optional, Any +from pathlib import Path + +import httpx +import yaml +import psutil +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# FastAPI app +app = FastAPI(title="AI Model Orchestrator (Process-Based)", version="2.0.0") + +# Global state +current_model: Optional[str] = None +model_processes: Dict[str, subprocess.Popen] = {} +model_registry: Dict[str, Dict[str, Any]] = {} +config: Dict[str, Any] = {} + + +def load_model_registry(): + """Load model registry from models.yaml""" + global model_registry, config + + config_path = os.getenv("MODELS_CONFIG", "/workspace/ai/model-orchestrator/models.yaml") + logger.info(f"Loading model registry from {config_path}") + + with open(config_path, 'r') as f: + data = yaml.safe_load(f) + + model_registry = data.get('models', {}) + config = data.get('config', {}) + + logger.info(f"Loaded {len(model_registry)} models") + for model_name, model_config in model_registry.items(): + logger.info(f" - {model_name}: {model_config.get('type')} ({model_config.get('framework')})") + + +async def start_model_process(model_name: str) -> bool: + """Start a model as a subprocess""" + global current_model, model_processes + + if model_name not in model_registry: + logger.error(f"Model {model_name} not found in registry") + return False + + model_config = model_registry[model_name] + + # Stop current model if running + if current_model and current_model != model_name: + await stop_model_process(current_model) + + # Check if already running + if model_name in model_processes: + proc = model_processes[model_name] + if proc.poll() is None: # Still running + logger.info(f"Model {model_name} already running") + return True + + logger.info(f"Starting model {model_name}...") + + try: + # Get service command from config + service_script = model_config.get('service_script') + if not service_script: + logger.error(f"No service_script defined for {model_name}") + return False + + script_path = Path(f"/workspace/ai/{service_script}") + if not script_path.exists(): + logger.error(f"Service script not found: {script_path}") + return False + + # Start process + port = model_config.get('port', 8000) + env = os.environ.copy() + env.update({ + 'HF_TOKEN': os.getenv('HF_TOKEN', ''), + 'PORT': str(port), + 'HOST': '0.0.0.0' + }) + + proc = subprocess.Popen( + ['python3', str(script_path)], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + preexec_fn=os.setsid # Create new process group + ) + + model_processes[model_name] = proc + + # Wait for service to be ready + max_wait = model_config.get('startup_time_seconds', 120) + start_time = time.time() + + while time.time() - start_time < max_wait: + if proc.poll() is not None: + logger.error(f"Process for {model_name} exited prematurely") + return False + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"http://localhost:{port}/health", + timeout=5.0 + ) + if response.status_code == 200: + logger.info(f"Model {model_name} is ready on port {port}") + current_model = model_name + return True + except: + await asyncio.sleep(2) + + logger.error(f"Model {model_name} failed to start within {max_wait}s") + await stop_model_process(model_name) + return False + + except Exception as e: + logger.error(f"Error starting {model_name}: {e}") + return False + + +async def stop_model_process(model_name: str): + """Stop a running model process""" + global model_processes, current_model + + if model_name not in model_processes: + logger.warning(f"Model {model_name} not in process registry") + return + + proc = model_processes[model_name] + + if proc.poll() is None: # Still running + logger.info(f"Stopping model {model_name}...") + try: + # Send SIGTERM to process group + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + + # Wait for graceful shutdown + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + # Force kill if not terminated + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + proc.wait() + + logger.info(f"Model {model_name} stopped") + except Exception as e: + logger.error(f"Error stopping {model_name}: {e}") + + del model_processes[model_name] + if current_model == model_name: + current_model = None + + +def get_model_for_endpoint(endpoint: str) -> Optional[str]: + """Determine which model handles this endpoint""" + for model_name, model_config in model_registry.items(): + if endpoint.startswith(model_config.get('endpoint', '')): + return model_name + return None + + +@app.on_event("startup") +async def startup_event(): + """Initialize on startup""" + logger.info("Starting AI Model Orchestrator (Process-Based)") + load_model_registry() + + +@app.on_event("shutdown") +async def shutdown_event(): + """Cleanup on shutdown""" + logger.info("Shutting down orchestrator...") + for model_name in list(model_processes.keys()): + await stop_model_process(model_name) + + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return { + "status": "healthy", + "current_model": current_model, + "active_processes": len(model_processes), + "available_models": list(model_registry.keys()) + } + + +@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) +async def proxy_request(request: Request, path: str): + """Proxy requests to appropriate model service""" + endpoint = f"/{path}" + + # Determine which model should handle this + target_model = get_model_for_endpoint(endpoint) + + if not target_model: + raise HTTPException(status_code=404, detail=f"No model configured for endpoint: {endpoint}") + + # Ensure model is running + if current_model != target_model: + logger.info(f"Switching to model {target_model}") + success = await start_model_process(target_model) + if not success: + raise HTTPException(status_code=503, detail=f"Failed to start model {target_model}") + + # Proxy the request + model_config = model_registry[target_model] + target_url = f"http://localhost:{model_config['port']}/{path}" + + try: + async with httpx.AsyncClient(timeout=300.0) as client: + response = await client.request( + method=request.method, + url=target_url, + headers=dict(request.headers), + content=await request.body() + ) + + return JSONResponse( + content=response.json() if response.headers.get('content-type') == 'application/json' else response.text, + status_code=response.status_code + ) + except Exception as e: + logger.error(f"Error proxying request: {e}") + raise HTTPException(status_code=502, detail=str(e)) + + +if __name__ == "__main__": + import uvicorn + + port = int(os.getenv("PORT", "9000")) + host = os.getenv("HOST", "0.0.0.0") + + logger.info(f"Starting orchestrator on {host}:{port}") + uvicorn.run(app, host=host, port=port, log_level="info")