#!/usr/bin/env python3 """ AI Model Orchestrator for RunPod RTX 4090 Manages sequential loading of text, image, and music models on a single GPU Features: - Automatic model switching based on request type - OpenAI-compatible API endpoints - Docker Compose service management - GPU memory monitoring - Simple YAML configuration for adding new models """ import asyncio import logging import os import time from typing import Dict, Optional, Any import docker import httpx import yaml from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI(title="AI Model Orchestrator", version="1.0.0") # Docker client docker_client = docker.from_env() # Global state current_model: Optional[str] = None model_registry: Dict[str, Dict[str, Any]] = {} config: Dict[str, Any] = {} def load_model_registry(): """Load model registry from models.yaml""" global model_registry, config config_path = os.getenv("MODELS_CONFIG", "/app/models.yaml") logger.info(f"Loading model registry from {config_path}") with open(config_path, 'r') as f: data = yaml.safe_load(f) model_registry = data.get('models', {}) config = data.get('config', {}) logger.info(f"Loaded {len(model_registry)} models from registry") for model_name, model_info in model_registry.items(): logger.info(f" - {model_name}: {model_info['description']}") def get_docker_service_name(service_name: str) -> str: """Get full Docker service name with project prefix""" project_name = os.getenv("COMPOSE_PROJECT_NAME", "ai") return f"{project_name}_{service_name}_1" async def stop_current_model(): """Stop the currently running model service""" global current_model if not current_model: logger.info("No model currently running") return model_info = model_registry.get(current_model) if not model_info: logger.warning(f"Model {current_model} not found in registry") current_model = None return service_name = get_docker_service_name(model_info['docker_service']) logger.info(f"Stopping model: {current_model} (service: {service_name})") try: container = docker_client.containers.get(service_name) container.stop(timeout=30) logger.info(f"Stopped {current_model}") current_model = None except docker.errors.NotFound: logger.warning(f"Container {service_name} not found (already stopped?)") current_model = None except Exception as e: logger.error(f"Error stopping {service_name}: {e}") raise async def start_model(model_name: str): """Start a model service""" global current_model if model_name not in model_registry: raise HTTPException(status_code=404, detail=f"Model {model_name} not found in registry") model_info = model_registry[model_name] service_name = get_docker_service_name(model_info['docker_service']) logger.info(f"Starting model: {model_name} (service: {service_name})") logger.info(f" VRAM requirement: {model_info['vram_gb']} GB") logger.info(f" Estimated startup time: {model_info['startup_time_seconds']}s") try: # Start the container container = docker_client.containers.get(service_name) container.start() # Wait for service to be healthy port = model_info['port'] endpoint = model_info.get('endpoint', '/') base_url = f"http://localhost:{port}" logger.info(f"Waiting for {model_name} to be ready at {base_url}...") max_wait = model_info['startup_time_seconds'] + 60 # Add buffer start_time = time.time() async with httpx.AsyncClient() as client: while time.time() - start_time < max_wait: try: # Try health check or root endpoint health_url = f"{base_url}/health" try: response = await client.get(health_url, timeout=5.0) if response.status_code == 200: logger.info(f"{model_name} is ready!") current_model = model_name return except: # Try root endpoint if /health doesn't exist response = await client.get(base_url, timeout=5.0) if response.status_code == 200: logger.info(f"{model_name} is ready!") current_model = model_name return except Exception as e: logger.debug(f"Waiting for {model_name}... ({e})") await asyncio.sleep(5) raise HTTPException( status_code=503, detail=f"Model {model_name} failed to start within {max_wait}s" ) except docker.errors.NotFound: raise HTTPException( status_code=500, detail=f"Docker service {service_name} not found. Is it defined in docker-compose?" ) except Exception as e: logger.error(f"Error starting {model_name}: {e}") raise HTTPException(status_code=500, detail=str(e)) async def ensure_model_running(model_name: str): """Ensure the specified model is running, switching if necessary""" global current_model if current_model == model_name: logger.info(f"Model {model_name} already running") return logger.info(f"Switching model: {current_model} -> {model_name}") # Stop current model await stop_current_model() # Start requested model await start_model(model_name) logger.info(f"Model switch complete: {model_name} is now active") async def proxy_request(model_name: str, request: Request): """Proxy request to the active model service""" model_info = model_registry[model_name] port = model_info['port'] # Get request details path = request.url.path method = request.method headers = dict(request.headers) headers.pop('host', None) # Remove host header # Build target URL target_url = f"http://localhost:{port}{path}" # Check if this is a streaming request body = await request.body() is_streaming = False if method == "POST" and body: try: import json body_json = json.loads(body) is_streaming = body_json.get('stream', False) except: pass logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})") if is_streaming: # For streaming requests, use httpx streaming and yield chunks async def stream_response(): async with httpx.AsyncClient(timeout=300.0) as client: async with client.stream(method, target_url, content=body, headers=headers) as response: async for chunk in response.aiter_bytes(): yield chunk return StreamingResponse( stream_response(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) else: # For non-streaming requests, use the original behavior async with httpx.AsyncClient(timeout=300.0) as client: if method == "GET": response = await client.get(target_url, headers=headers) elif method == "POST": response = await client.post(target_url, content=body, headers=headers) else: raise HTTPException(status_code=405, detail=f"Method {method} not supported") # Return response return JSONResponse( content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text, status_code=response.status_code, headers=dict(response.headers) ) @app.on_event("startup") async def startup_event(): """Load model registry on startup""" load_model_registry() logger.info("AI Model Orchestrator started successfully") logger.info(f"GPU Memory: {config.get('gpu_memory_total_gb', 24)} GB") logger.info(f"Default model: {config.get('default_model', 'qwen-2.5-7b')}") @app.get("/") async def root(): """Root endpoint""" return { "service": "AI Model Orchestrator", "version": "1.0.0", "current_model": current_model, "available_models": list(model_registry.keys()) } @app.get("/health") async def health(): """Health check endpoint""" return { "status": "healthy", "current_model": current_model, "model_info": model_registry.get(current_model) if current_model else None, "gpu_memory_total_gb": config.get('gpu_memory_total_gb', 24), "models_available": len(model_registry) } @app.get("/models") async def list_models(): """List all available models""" return { "models": model_registry, "current_model": current_model } @app.get("/v1/models") async def list_models_openai(): """OpenAI-compatible models listing endpoint""" models_list = [] for model_name, model_info in model_registry.items(): models_list.append({ "id": model_name, "object": "model", "created": int(time.time()), "owned_by": "pivoine-gpu", "permission": [], "root": model_name, "parent": None, }) return { "object": "list", "data": models_list } @app.post("/v1/chat/completions") async def chat_completions(request: Request): """OpenAI-compatible chat completions endpoint (text models)""" # Parse request to get model name body = await request.json() model_name = body.get('model', config.get('default_model', 'qwen-2.5-7b')) # Validate model type if model_name not in model_registry: raise HTTPException(status_code=404, detail=f"Model {model_name} not found") if model_registry[model_name]['type'] != 'text': raise HTTPException(status_code=400, detail=f"Model {model_name} is not a text model") # Ensure model is running await ensure_model_running(model_name) # Proxy request to model return await proxy_request(model_name, request) @app.post("/v1/images/generations") async def image_generations(request: Request): """OpenAI-compatible image generation endpoint""" # Parse request to get model name body = await request.json() model_name = body.get('model', 'flux-schnell') # Validate model type if model_name not in model_registry: raise HTTPException(status_code=404, detail=f"Model {model_name} not found") if model_registry[model_name]['type'] != 'image': raise HTTPException(status_code=400, detail=f"Model {model_name} is not an image model") # Ensure model is running await ensure_model_running(model_name) # Proxy request to model return await proxy_request(model_name, request) @app.post("/v1/audio/generations") async def audio_generations(request: Request): """Custom audio generation endpoint (music/sound effects)""" # Parse request to get model name body = await request.json() model_name = body.get('model', 'musicgen-medium') # Validate model type if model_name not in model_registry: raise HTTPException(status_code=404, detail=f"Model {model_name} not found") if model_registry[model_name]['type'] != 'audio': raise HTTPException(status_code=400, detail=f"Model {model_name} is not an audio model") # Ensure model is running await ensure_model_running(model_name) # Proxy request to model return await proxy_request(model_name, request) @app.post("/switch") async def switch_model(request: Request): """Manually switch to a specific model""" body = await request.json() model_name = body.get('model') if not model_name: raise HTTPException(status_code=400, detail="Model name required") if model_name not in model_registry: raise HTTPException(status_code=404, detail=f"Model {model_name} not found") await ensure_model_running(model_name) return { "status": "success", "model": model_name, "message": f"Switched to {model_name}" } if __name__ == "__main__": import uvicorn host = os.getenv("HOST", "0.0.0.0") port = int(os.getenv("PORT", "9000")) logger.info(f"Starting AI Model Orchestrator on {host}:{port}") uvicorn.run( app, host=host, port=port, log_level="info", access_log=True, )