chore: remove orchestrator - replaced with dedicated vLLM servers
This commit is contained in:
@@ -1,77 +0,0 @@
|
||||
# Model Registry for AI Orchestrator
|
||||
# Add new models by appending to this file
|
||||
|
||||
models:
|
||||
# Text Generation Models
|
||||
qwen-2.5-7b:
|
||||
type: text
|
||||
framework: vllm
|
||||
service_script: vllm/server.py
|
||||
port: 8000
|
||||
vram_gb: 14
|
||||
startup_time_seconds: 120
|
||||
endpoint: /v1/chat/completions
|
||||
description: "Qwen 2.5 7B Instruct - Fast text generation, no authentication required"
|
||||
|
||||
llama-3.1-8b:
|
||||
type: text
|
||||
framework: vllm
|
||||
service_script: vllm/server.py
|
||||
port: 8001
|
||||
vram_gb: 17
|
||||
startup_time_seconds: 120
|
||||
endpoint: /v1/chat/completions
|
||||
description: "Llama 3.1 8B Instruct - Meta's latest model"
|
||||
|
||||
# Example: Add more models easily by uncommenting and customizing below
|
||||
|
||||
# Future Text Models:
|
||||
# llama-3.1-8b:
|
||||
# type: text
|
||||
# framework: vllm
|
||||
# docker_service: vllm-llama
|
||||
# port: 8004
|
||||
# vram_gb: 17
|
||||
# startup_time_seconds: 120
|
||||
# endpoint: /v1/chat/completions
|
||||
# description: "Llama 3.1 8B Instruct - Meta's latest model"
|
||||
|
||||
# Future Image Models:
|
||||
# sdxl:
|
||||
# type: image
|
||||
# framework: openedai-images
|
||||
# docker_service: sdxl
|
||||
# port: 8005
|
||||
# vram_gb: 10
|
||||
# startup_time_seconds: 45
|
||||
# endpoint: /v1/images/generations
|
||||
# description: "Stable Diffusion XL - High quality image generation"
|
||||
|
||||
# Future Audio Models:
|
||||
# whisper-large:
|
||||
# type: audio
|
||||
# framework: faster-whisper
|
||||
# docker_service: whisper
|
||||
# port: 8006
|
||||
# vram_gb: 3
|
||||
# startup_time_seconds: 30
|
||||
# endpoint: /v1/audio/transcriptions
|
||||
# description: "Whisper Large v3 - Speech-to-text transcription"
|
||||
#
|
||||
# xtts-v2:
|
||||
# type: audio
|
||||
# framework: openedai-speech
|
||||
# docker_service: tts
|
||||
# port: 8007
|
||||
# vram_gb: 3
|
||||
# startup_time_seconds: 30
|
||||
# endpoint: /v1/audio/speech
|
||||
# description: "XTTS v2 - High-quality text-to-speech with voice cloning"
|
||||
|
||||
# Configuration
|
||||
config:
|
||||
gpu_memory_total_gb: 24
|
||||
allow_concurrent_loading: false # Sequential loading only
|
||||
model_switch_timeout_seconds: 300 # 5 minutes max for model switching
|
||||
health_check_interval_seconds: 10
|
||||
default_model: qwen-2.5-7b
|
||||
@@ -1,404 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AI Model Orchestrator for RunPod RTX 4090
|
||||
Manages sequential loading of text, image, and music models on a single GPU
|
||||
|
||||
Features:
|
||||
- Automatic model switching based on request type
|
||||
- OpenAI-compatible API endpoints
|
||||
- Docker Compose service management
|
||||
- GPU memory monitoring
|
||||
- Simple YAML configuration for adding new models
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
import docker
|
||||
import httpx
|
||||
import yaml
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="AI Model Orchestrator", version="1.0.0")
|
||||
|
||||
# Docker client
|
||||
docker_client = docker.from_env()
|
||||
|
||||
# Global state
|
||||
current_model: Optional[str] = None
|
||||
model_registry: Dict[str, Dict[str, Any]] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def load_model_registry():
|
||||
"""Load model registry from models.yaml"""
|
||||
global model_registry, config
|
||||
|
||||
config_path = os.getenv("MODELS_CONFIG", "/app/models.yaml")
|
||||
logger.info(f"Loading model registry from {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
model_registry = data.get('models', {})
|
||||
config = data.get('config', {})
|
||||
|
||||
logger.info(f"Loaded {len(model_registry)} models from registry")
|
||||
for model_name, model_info in model_registry.items():
|
||||
logger.info(f" - {model_name}: {model_info['description']}")
|
||||
|
||||
|
||||
def get_docker_service_name(service_name: str) -> str:
|
||||
"""Get full Docker service name with project prefix"""
|
||||
project_name = os.getenv("COMPOSE_PROJECT_NAME", "ai")
|
||||
return f"{project_name}_{service_name}_1"
|
||||
|
||||
|
||||
async def stop_current_model():
|
||||
"""Stop the currently running model service"""
|
||||
global current_model
|
||||
|
||||
if not current_model:
|
||||
logger.info("No model currently running")
|
||||
return
|
||||
|
||||
model_info = model_registry.get(current_model)
|
||||
if not model_info:
|
||||
logger.warning(f"Model {current_model} not found in registry")
|
||||
current_model = None
|
||||
return
|
||||
|
||||
service_name = get_docker_service_name(model_info['docker_service'])
|
||||
logger.info(f"Stopping model: {current_model} (service: {service_name})")
|
||||
|
||||
try:
|
||||
container = docker_client.containers.get(service_name)
|
||||
container.stop(timeout=30)
|
||||
logger.info(f"Stopped {current_model}")
|
||||
current_model = None
|
||||
except docker.errors.NotFound:
|
||||
logger.warning(f"Container {service_name} not found (already stopped?)")
|
||||
current_model = None
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping {service_name}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def start_model(model_name: str):
|
||||
"""Start a model service"""
|
||||
global current_model
|
||||
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in registry")
|
||||
|
||||
model_info = model_registry[model_name]
|
||||
service_name = get_docker_service_name(model_info['docker_service'])
|
||||
|
||||
logger.info(f"Starting model: {model_name} (service: {service_name})")
|
||||
logger.info(f" VRAM requirement: {model_info['vram_gb']} GB")
|
||||
logger.info(f" Estimated startup time: {model_info['startup_time_seconds']}s")
|
||||
|
||||
try:
|
||||
# Start the container
|
||||
container = docker_client.containers.get(service_name)
|
||||
container.start()
|
||||
|
||||
# Wait for service to be healthy
|
||||
port = model_info['port']
|
||||
endpoint = model_info.get('endpoint', '/')
|
||||
base_url = f"http://localhost:{port}"
|
||||
|
||||
logger.info(f"Waiting for {model_name} to be ready at {base_url}...")
|
||||
|
||||
max_wait = model_info['startup_time_seconds'] + 60 # Add buffer
|
||||
start_time = time.time()
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
while time.time() - start_time < max_wait:
|
||||
try:
|
||||
# Try health check or root endpoint
|
||||
health_url = f"{base_url}/health"
|
||||
try:
|
||||
response = await client.get(health_url, timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"{model_name} is ready!")
|
||||
current_model = model_name
|
||||
return
|
||||
except:
|
||||
# Try root endpoint if /health doesn't exist
|
||||
response = await client.get(base_url, timeout=5.0)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"{model_name} is ready!")
|
||||
current_model = model_name
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug(f"Waiting for {model_name}... ({e})")
|
||||
|
||||
await asyncio.sleep(5)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Model {model_name} failed to start within {max_wait}s"
|
||||
)
|
||||
|
||||
except docker.errors.NotFound:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Docker service {service_name} not found. Is it defined in docker-compose?"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting {model_name}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def ensure_model_running(model_name: str):
|
||||
"""Ensure the specified model is running, switching if necessary"""
|
||||
global current_model
|
||||
|
||||
if current_model == model_name:
|
||||
logger.info(f"Model {model_name} already running")
|
||||
return
|
||||
|
||||
logger.info(f"Switching model: {current_model} -> {model_name}")
|
||||
|
||||
# Stop current model
|
||||
await stop_current_model()
|
||||
|
||||
# Start requested model
|
||||
await start_model(model_name)
|
||||
|
||||
logger.info(f"Model switch complete: {model_name} is now active")
|
||||
|
||||
|
||||
async def proxy_request(model_name: str, request: Request):
|
||||
"""Proxy request to the active model service"""
|
||||
model_info = model_registry[model_name]
|
||||
port = model_info['port']
|
||||
|
||||
# Get request details
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None) # Remove host header
|
||||
|
||||
# Build target URL
|
||||
target_url = f"http://localhost:{port}{path}"
|
||||
|
||||
# Check if this is a streaming request
|
||||
body = await request.body()
|
||||
is_streaming = False
|
||||
if method == "POST" and body:
|
||||
try:
|
||||
import json
|
||||
body_json = json.loads(body)
|
||||
is_streaming = body_json.get('stream', False)
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
||||
|
||||
if is_streaming:
|
||||
# For streaming requests, use httpx streaming and yield chunks
|
||||
async def stream_response():
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
else:
|
||||
# For non-streaming requests, use the original behavior
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(target_url, headers=headers)
|
||||
elif method == "POST":
|
||||
response = await client.post(target_url, content=body, headers=headers)
|
||||
else:
|
||||
raise HTTPException(status_code=405, detail=f"Method {method} not supported")
|
||||
|
||||
# Return response
|
||||
return JSONResponse(
|
||||
content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Load model registry on startup"""
|
||||
load_model_registry()
|
||||
logger.info("AI Model Orchestrator started successfully")
|
||||
logger.info(f"GPU Memory: {config.get('gpu_memory_total_gb', 24)} GB")
|
||||
logger.info(f"Default model: {config.get('default_model', 'qwen-2.5-7b')}")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "AI Model Orchestrator",
|
||||
"version": "1.0.0",
|
||||
"current_model": current_model,
|
||||
"available_models": list(model_registry.keys())
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"current_model": current_model,
|
||||
"model_info": model_registry.get(current_model) if current_model else None,
|
||||
"gpu_memory_total_gb": config.get('gpu_memory_total_gb', 24),
|
||||
"models_available": len(model_registry)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List all available models"""
|
||||
return {
|
||||
"models": model_registry,
|
||||
"current_model": current_model
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models_openai():
|
||||
"""OpenAI-compatible models listing endpoint"""
|
||||
models_list = []
|
||||
for model_name, model_info in model_registry.items():
|
||||
models_list.append({
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "pivoine-gpu",
|
||||
"permission": [],
|
||||
"root": model_name,
|
||||
"parent": None,
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models_list
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: Request):
|
||||
"""OpenAI-compatible chat completions endpoint (text models)"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', config.get('default_model', 'qwen-2.5-7b'))
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'text':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not a text model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/v1/images/generations")
|
||||
async def image_generations(request: Request):
|
||||
"""OpenAI-compatible image generation endpoint"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', 'flux-schnell')
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'image':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an image model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/v1/audio/generations")
|
||||
async def audio_generations(request: Request):
|
||||
"""Custom audio generation endpoint (music/sound effects)"""
|
||||
# Parse request to get model name
|
||||
body = await request.json()
|
||||
model_name = body.get('model', 'musicgen-medium')
|
||||
|
||||
# Validate model type
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
if model_registry[model_name]['type'] != 'audio':
|
||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an audio model")
|
||||
|
||||
# Ensure model is running
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
# Proxy request to model
|
||||
return await proxy_request(model_name, request)
|
||||
|
||||
|
||||
@app.post("/switch")
|
||||
async def switch_model(request: Request):
|
||||
"""Manually switch to a specific model"""
|
||||
body = await request.json()
|
||||
model_name = body.get('model')
|
||||
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="Model name required")
|
||||
|
||||
if model_name not in model_registry:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
||||
|
||||
await ensure_model_running(model_name)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"model": model_name,
|
||||
"message": f"Switched to {model_name}"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
host = os.getenv("HOST", "0.0.0.0")
|
||||
port = int(os.getenv("PORT", "9000"))
|
||||
|
||||
logger.info(f"Starting AI Model Orchestrator on {host}:{port}")
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
access_log=True,
|
||||
)
|
||||
@@ -1,323 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
AI Model Orchestrator for RunPod (Process-Based)
|
||||
Manages sequential loading of AI models using subprocess instead of Docker
|
||||
|
||||
Simplified architecture for RunPod's containerized environment:
|
||||
- No Docker-in-Docker complexity
|
||||
- Direct process management via subprocess
|
||||
- Models run as Python background processes
|
||||
- GPU memory efficient (sequential model loading)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import signal
|
||||
from typing import Dict, Optional, Any
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
import psutil
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="AI Model Orchestrator (Process-Based)", version="2.0.0")
|
||||
|
||||
# Global state
|
||||
current_model: Optional[str] = None
|
||||
model_processes: Dict[str, subprocess.Popen] = {}
|
||||
model_registry: Dict[str, Dict[str, Any]] = {}
|
||||
config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def load_model_registry():
|
||||
"""Load model registry from models.yaml"""
|
||||
global model_registry, config
|
||||
|
||||
config_path = os.getenv("MODELS_CONFIG", "/workspace/ai/model-orchestrator/models.yaml")
|
||||
logger.info(f"Loading model registry from {config_path}")
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
model_registry = data.get('models', {})
|
||||
config = data.get('config', {})
|
||||
|
||||
logger.info(f"Loaded {len(model_registry)} models")
|
||||
for model_name, model_config in model_registry.items():
|
||||
logger.info(f" - {model_name}: {model_config.get('type')} ({model_config.get('framework')})")
|
||||
|
||||
|
||||
async def start_model_process(model_name: str) -> bool:
|
||||
"""Start a model as a subprocess"""
|
||||
global current_model, model_processes
|
||||
|
||||
if model_name not in model_registry:
|
||||
logger.error(f"Model {model_name} not found in registry")
|
||||
return False
|
||||
|
||||
model_config = model_registry[model_name]
|
||||
|
||||
# Stop current model if running
|
||||
if current_model and current_model != model_name:
|
||||
await stop_model_process(current_model)
|
||||
|
||||
# Check if already running
|
||||
if model_name in model_processes:
|
||||
proc = model_processes[model_name]
|
||||
if proc.poll() is None: # Still running
|
||||
logger.info(f"Model {model_name} already running")
|
||||
return True
|
||||
|
||||
logger.info(f"Starting model {model_name}...")
|
||||
|
||||
try:
|
||||
# Get service command from config
|
||||
service_script = model_config.get('service_script')
|
||||
if not service_script:
|
||||
logger.error(f"No service_script defined for {model_name}")
|
||||
return False
|
||||
|
||||
script_path = Path(f"/workspace/ai/{service_script}")
|
||||
if not script_path.exists():
|
||||
logger.error(f"Service script not found: {script_path}")
|
||||
return False
|
||||
|
||||
# Start process
|
||||
port = model_config.get('port', 8000)
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
'HF_TOKEN': os.getenv('HF_TOKEN', ''),
|
||||
'PORT': str(port),
|
||||
'HOST': '0.0.0.0',
|
||||
'MODEL_NAME': model_config.get('model_name', model_name)
|
||||
})
|
||||
|
||||
# Use venv python if it exists
|
||||
script_dir = script_path.parent
|
||||
venv_python = script_dir / 'venv' / 'bin' / 'python3'
|
||||
python_cmd = str(venv_python) if venv_python.exists() else 'python3'
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[python_cmd, str(script_path)],
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
preexec_fn=os.setsid # Create new process group
|
||||
)
|
||||
|
||||
model_processes[model_name] = proc
|
||||
|
||||
# Wait for service to be ready
|
||||
max_wait = model_config.get('startup_time_seconds', 120)
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
if proc.poll() is not None:
|
||||
logger.error(f"Process for {model_name} exited prematurely")
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"http://localhost:{port}/health",
|
||||
timeout=5.0
|
||||
)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"Model {model_name} is ready on port {port}")
|
||||
current_model = model_name
|
||||
return True
|
||||
except:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.error(f"Model {model_name} failed to start within {max_wait}s")
|
||||
await stop_model_process(model_name)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting {model_name}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def stop_model_process(model_name: str):
|
||||
"""Stop a running model process"""
|
||||
global model_processes, current_model
|
||||
|
||||
if model_name not in model_processes:
|
||||
logger.warning(f"Model {model_name} not in process registry")
|
||||
return
|
||||
|
||||
proc = model_processes[model_name]
|
||||
|
||||
if proc.poll() is None: # Still running
|
||||
logger.info(f"Stopping model {model_name}...")
|
||||
try:
|
||||
# Send SIGTERM to process group
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
|
||||
# Wait for graceful shutdown
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if not terminated
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
||||
proc.wait()
|
||||
|
||||
logger.info(f"Model {model_name} stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping {model_name}: {e}")
|
||||
|
||||
del model_processes[model_name]
|
||||
if current_model == model_name:
|
||||
current_model = None
|
||||
|
||||
|
||||
def get_model_for_endpoint(endpoint: str) -> Optional[str]:
|
||||
"""Determine which model handles this endpoint"""
|
||||
for model_name, model_config in model_registry.items():
|
||||
if endpoint.startswith(model_config.get('endpoint', '')):
|
||||
return model_name
|
||||
return None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize on startup"""
|
||||
logger.info("Starting AI Model Orchestrator (Process-Based)")
|
||||
load_model_registry()
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on shutdown"""
|
||||
logger.info("Shutting down orchestrator...")
|
||||
for model_name in list(model_processes.keys()):
|
||||
await stop_model_process(model_name)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"current_model": current_model,
|
||||
"active_processes": len(model_processes),
|
||||
"available_models": list(model_registry.keys())
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models_openai():
|
||||
"""OpenAI-compatible models listing endpoint"""
|
||||
models_list = []
|
||||
for model_name, model_info in model_registry.items():
|
||||
models_list.append({
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "pivoine-gpu",
|
||||
"permission": [],
|
||||
"root": model_name,
|
||||
"parent": None,
|
||||
})
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
"data": models_list
|
||||
}
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
||||
async def proxy_request(request: Request, path: str):
|
||||
"""Proxy requests to appropriate model service"""
|
||||
endpoint = f"/{path}"
|
||||
|
||||
# Determine which model should handle this
|
||||
target_model = get_model_for_endpoint(endpoint)
|
||||
|
||||
if not target_model:
|
||||
raise HTTPException(status_code=404, detail=f"No model configured for endpoint: {endpoint}")
|
||||
|
||||
# Ensure model is running
|
||||
if current_model != target_model:
|
||||
logger.info(f"Switching to model {target_model}")
|
||||
success = await start_model_process(target_model)
|
||||
if not success:
|
||||
raise HTTPException(status_code=503, detail=f"Failed to start model {target_model}")
|
||||
|
||||
# Proxy the request
|
||||
model_config = model_registry[target_model]
|
||||
target_url = f"http://localhost:{model_config['port']}/{path}"
|
||||
|
||||
# Get request details
|
||||
method = request.method
|
||||
headers = dict(request.headers)
|
||||
headers.pop('host', None) # Remove host header
|
||||
body = await request.body()
|
||||
|
||||
# Check if this is a streaming request
|
||||
is_streaming = False
|
||||
if method == "POST" and body:
|
||||
try:
|
||||
import json
|
||||
body_json = json.loads(body)
|
||||
is_streaming = body_json.get('stream', False)
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
# For streaming requests, use httpx streaming and yield chunks
|
||||
async def stream_response():
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
||||
async for chunk in response.aiter_bytes():
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
else:
|
||||
# For non-streaming requests, use the original behavior
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=target_url,
|
||||
headers=headers,
|
||||
content=body
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
content=response.json() if response.headers.get('content-type') == 'application/json' else response.text,
|
||||
status_code=response.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error proxying request: {e}")
|
||||
raise HTTPException(status_code=502, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
port = int(os.getenv("PORT", "9000"))
|
||||
host = os.getenv("HOST", "0.0.0.0")
|
||||
|
||||
logger.info(f"Starting orchestrator on {host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
@@ -1,7 +0,0 @@
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
httpx==0.25.1
|
||||
docker==6.1.3
|
||||
pyyaml==6.0.1
|
||||
pydantic==2.5.0
|
||||
psutil==5.9.6
|
||||
Reference in New Issue
Block a user