chore: remove orchestrator - replaced with dedicated vLLM servers
This commit is contained in:
@@ -1,77 +0,0 @@
|
|||||||
# Model Registry for AI Orchestrator
|
|
||||||
# Add new models by appending to this file
|
|
||||||
|
|
||||||
models:
|
|
||||||
# Text Generation Models
|
|
||||||
qwen-2.5-7b:
|
|
||||||
type: text
|
|
||||||
framework: vllm
|
|
||||||
service_script: vllm/server.py
|
|
||||||
port: 8000
|
|
||||||
vram_gb: 14
|
|
||||||
startup_time_seconds: 120
|
|
||||||
endpoint: /v1/chat/completions
|
|
||||||
description: "Qwen 2.5 7B Instruct - Fast text generation, no authentication required"
|
|
||||||
|
|
||||||
llama-3.1-8b:
|
|
||||||
type: text
|
|
||||||
framework: vllm
|
|
||||||
service_script: vllm/server.py
|
|
||||||
port: 8001
|
|
||||||
vram_gb: 17
|
|
||||||
startup_time_seconds: 120
|
|
||||||
endpoint: /v1/chat/completions
|
|
||||||
description: "Llama 3.1 8B Instruct - Meta's latest model"
|
|
||||||
|
|
||||||
# Example: Add more models easily by uncommenting and customizing below
|
|
||||||
|
|
||||||
# Future Text Models:
|
|
||||||
# llama-3.1-8b:
|
|
||||||
# type: text
|
|
||||||
# framework: vllm
|
|
||||||
# docker_service: vllm-llama
|
|
||||||
# port: 8004
|
|
||||||
# vram_gb: 17
|
|
||||||
# startup_time_seconds: 120
|
|
||||||
# endpoint: /v1/chat/completions
|
|
||||||
# description: "Llama 3.1 8B Instruct - Meta's latest model"
|
|
||||||
|
|
||||||
# Future Image Models:
|
|
||||||
# sdxl:
|
|
||||||
# type: image
|
|
||||||
# framework: openedai-images
|
|
||||||
# docker_service: sdxl
|
|
||||||
# port: 8005
|
|
||||||
# vram_gb: 10
|
|
||||||
# startup_time_seconds: 45
|
|
||||||
# endpoint: /v1/images/generations
|
|
||||||
# description: "Stable Diffusion XL - High quality image generation"
|
|
||||||
|
|
||||||
# Future Audio Models:
|
|
||||||
# whisper-large:
|
|
||||||
# type: audio
|
|
||||||
# framework: faster-whisper
|
|
||||||
# docker_service: whisper
|
|
||||||
# port: 8006
|
|
||||||
# vram_gb: 3
|
|
||||||
# startup_time_seconds: 30
|
|
||||||
# endpoint: /v1/audio/transcriptions
|
|
||||||
# description: "Whisper Large v3 - Speech-to-text transcription"
|
|
||||||
#
|
|
||||||
# xtts-v2:
|
|
||||||
# type: audio
|
|
||||||
# framework: openedai-speech
|
|
||||||
# docker_service: tts
|
|
||||||
# port: 8007
|
|
||||||
# vram_gb: 3
|
|
||||||
# startup_time_seconds: 30
|
|
||||||
# endpoint: /v1/audio/speech
|
|
||||||
# description: "XTTS v2 - High-quality text-to-speech with voice cloning"
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
config:
|
|
||||||
gpu_memory_total_gb: 24
|
|
||||||
allow_concurrent_loading: false # Sequential loading only
|
|
||||||
model_switch_timeout_seconds: 300 # 5 minutes max for model switching
|
|
||||||
health_check_interval_seconds: 10
|
|
||||||
default_model: qwen-2.5-7b
|
|
||||||
@@ -1,404 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
AI Model Orchestrator for RunPod RTX 4090
|
|
||||||
Manages sequential loading of text, image, and music models on a single GPU
|
|
||||||
|
|
||||||
Features:
|
|
||||||
- Automatic model switching based on request type
|
|
||||||
- OpenAI-compatible API endpoints
|
|
||||||
- Docker Compose service management
|
|
||||||
- GPU memory monitoring
|
|
||||||
- Simple YAML configuration for adding new models
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Dict, Optional, Any
|
|
||||||
|
|
||||||
import docker
|
|
||||||
import httpx
|
|
||||||
import yaml
|
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# FastAPI app
|
|
||||||
app = FastAPI(title="AI Model Orchestrator", version="1.0.0")
|
|
||||||
|
|
||||||
# Docker client
|
|
||||||
docker_client = docker.from_env()
|
|
||||||
|
|
||||||
# Global state
|
|
||||||
current_model: Optional[str] = None
|
|
||||||
model_registry: Dict[str, Dict[str, Any]] = {}
|
|
||||||
config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_registry():
|
|
||||||
"""Load model registry from models.yaml"""
|
|
||||||
global model_registry, config
|
|
||||||
|
|
||||||
config_path = os.getenv("MODELS_CONFIG", "/app/models.yaml")
|
|
||||||
logger.info(f"Loading model registry from {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
model_registry = data.get('models', {})
|
|
||||||
config = data.get('config', {})
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(model_registry)} models from registry")
|
|
||||||
for model_name, model_info in model_registry.items():
|
|
||||||
logger.info(f" - {model_name}: {model_info['description']}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_docker_service_name(service_name: str) -> str:
|
|
||||||
"""Get full Docker service name with project prefix"""
|
|
||||||
project_name = os.getenv("COMPOSE_PROJECT_NAME", "ai")
|
|
||||||
return f"{project_name}_{service_name}_1"
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_current_model():
|
|
||||||
"""Stop the currently running model service"""
|
|
||||||
global current_model
|
|
||||||
|
|
||||||
if not current_model:
|
|
||||||
logger.info("No model currently running")
|
|
||||||
return
|
|
||||||
|
|
||||||
model_info = model_registry.get(current_model)
|
|
||||||
if not model_info:
|
|
||||||
logger.warning(f"Model {current_model} not found in registry")
|
|
||||||
current_model = None
|
|
||||||
return
|
|
||||||
|
|
||||||
service_name = get_docker_service_name(model_info['docker_service'])
|
|
||||||
logger.info(f"Stopping model: {current_model} (service: {service_name})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
container = docker_client.containers.get(service_name)
|
|
||||||
container.stop(timeout=30)
|
|
||||||
logger.info(f"Stopped {current_model}")
|
|
||||||
current_model = None
|
|
||||||
except docker.errors.NotFound:
|
|
||||||
logger.warning(f"Container {service_name} not found (already stopped?)")
|
|
||||||
current_model = None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error stopping {service_name}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
async def start_model(model_name: str):
|
|
||||||
"""Start a model service"""
|
|
||||||
global current_model
|
|
||||||
|
|
||||||
if model_name not in model_registry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found in registry")
|
|
||||||
|
|
||||||
model_info = model_registry[model_name]
|
|
||||||
service_name = get_docker_service_name(model_info['docker_service'])
|
|
||||||
|
|
||||||
logger.info(f"Starting model: {model_name} (service: {service_name})")
|
|
||||||
logger.info(f" VRAM requirement: {model_info['vram_gb']} GB")
|
|
||||||
logger.info(f" Estimated startup time: {model_info['startup_time_seconds']}s")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Start the container
|
|
||||||
container = docker_client.containers.get(service_name)
|
|
||||||
container.start()
|
|
||||||
|
|
||||||
# Wait for service to be healthy
|
|
||||||
port = model_info['port']
|
|
||||||
endpoint = model_info.get('endpoint', '/')
|
|
||||||
base_url = f"http://localhost:{port}"
|
|
||||||
|
|
||||||
logger.info(f"Waiting for {model_name} to be ready at {base_url}...")
|
|
||||||
|
|
||||||
max_wait = model_info['startup_time_seconds'] + 60 # Add buffer
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
while time.time() - start_time < max_wait:
|
|
||||||
try:
|
|
||||||
# Try health check or root endpoint
|
|
||||||
health_url = f"{base_url}/health"
|
|
||||||
try:
|
|
||||||
response = await client.get(health_url, timeout=5.0)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(f"{model_name} is ready!")
|
|
||||||
current_model = model_name
|
|
||||||
return
|
|
||||||
except:
|
|
||||||
# Try root endpoint if /health doesn't exist
|
|
||||||
response = await client.get(base_url, timeout=5.0)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(f"{model_name} is ready!")
|
|
||||||
current_model = model_name
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Waiting for {model_name}... ({e})")
|
|
||||||
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=503,
|
|
||||||
detail=f"Model {model_name} failed to start within {max_wait}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
except docker.errors.NotFound:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail=f"Docker service {service_name} not found. Is it defined in docker-compose?"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error starting {model_name}: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_model_running(model_name: str):
|
|
||||||
"""Ensure the specified model is running, switching if necessary"""
|
|
||||||
global current_model
|
|
||||||
|
|
||||||
if current_model == model_name:
|
|
||||||
logger.info(f"Model {model_name} already running")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"Switching model: {current_model} -> {model_name}")
|
|
||||||
|
|
||||||
# Stop current model
|
|
||||||
await stop_current_model()
|
|
||||||
|
|
||||||
# Start requested model
|
|
||||||
await start_model(model_name)
|
|
||||||
|
|
||||||
logger.info(f"Model switch complete: {model_name} is now active")
|
|
||||||
|
|
||||||
|
|
||||||
async def proxy_request(model_name: str, request: Request):
|
|
||||||
"""Proxy request to the active model service"""
|
|
||||||
model_info = model_registry[model_name]
|
|
||||||
port = model_info['port']
|
|
||||||
|
|
||||||
# Get request details
|
|
||||||
path = request.url.path
|
|
||||||
method = request.method
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop('host', None) # Remove host header
|
|
||||||
|
|
||||||
# Build target URL
|
|
||||||
target_url = f"http://localhost:{port}{path}"
|
|
||||||
|
|
||||||
# Check if this is a streaming request
|
|
||||||
body = await request.body()
|
|
||||||
is_streaming = False
|
|
||||||
if method == "POST" and body:
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
body_json = json.loads(body)
|
|
||||||
is_streaming = body_json.get('stream', False)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
# For streaming requests, use httpx streaming and yield chunks
|
|
||||||
async def stream_response():
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
||||||
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
|
||||||
async for chunk in response.aiter_bytes():
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_response(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For non-streaming requests, use the original behavior
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
||||||
if method == "GET":
|
|
||||||
response = await client.get(target_url, headers=headers)
|
|
||||||
elif method == "POST":
|
|
||||||
response = await client.post(target_url, content=body, headers=headers)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=405, detail=f"Method {method} not supported")
|
|
||||||
|
|
||||||
# Return response
|
|
||||||
return JSONResponse(
|
|
||||||
content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
"""Load model registry on startup"""
|
|
||||||
load_model_registry()
|
|
||||||
logger.info("AI Model Orchestrator started successfully")
|
|
||||||
logger.info(f"GPU Memory: {config.get('gpu_memory_total_gb', 24)} GB")
|
|
||||||
logger.info(f"Default model: {config.get('default_model', 'qwen-2.5-7b')}")
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
"""Root endpoint"""
|
|
||||||
return {
|
|
||||||
"service": "AI Model Orchestrator",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"current_model": current_model,
|
|
||||||
"available_models": list(model_registry.keys())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health():
|
|
||||||
"""Health check endpoint"""
|
|
||||||
return {
|
|
||||||
"status": "healthy",
|
|
||||||
"current_model": current_model,
|
|
||||||
"model_info": model_registry.get(current_model) if current_model else None,
|
|
||||||
"gpu_memory_total_gb": config.get('gpu_memory_total_gb', 24),
|
|
||||||
"models_available": len(model_registry)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/models")
|
|
||||||
async def list_models():
|
|
||||||
"""List all available models"""
|
|
||||||
return {
|
|
||||||
"models": model_registry,
|
|
||||||
"current_model": current_model
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
|
||||||
async def list_models_openai():
|
|
||||||
"""OpenAI-compatible models listing endpoint"""
|
|
||||||
models_list = []
|
|
||||||
for model_name, model_info in model_registry.items():
|
|
||||||
models_list.append({
|
|
||||||
"id": model_name,
|
|
||||||
"object": "model",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"owned_by": "pivoine-gpu",
|
|
||||||
"permission": [],
|
|
||||||
"root": model_name,
|
|
||||||
"parent": None,
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": models_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completions(request: Request):
|
|
||||||
"""OpenAI-compatible chat completions endpoint (text models)"""
|
|
||||||
# Parse request to get model name
|
|
||||||
body = await request.json()
|
|
||||||
model_name = body.get('model', config.get('default_model', 'qwen-2.5-7b'))
|
|
||||||
|
|
||||||
# Validate model type
|
|
||||||
if model_name not in model_registry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
||||||
|
|
||||||
if model_registry[model_name]['type'] != 'text':
|
|
||||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not a text model")
|
|
||||||
|
|
||||||
# Ensure model is running
|
|
||||||
await ensure_model_running(model_name)
|
|
||||||
|
|
||||||
# Proxy request to model
|
|
||||||
return await proxy_request(model_name, request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/images/generations")
|
|
||||||
async def image_generations(request: Request):
|
|
||||||
"""OpenAI-compatible image generation endpoint"""
|
|
||||||
# Parse request to get model name
|
|
||||||
body = await request.json()
|
|
||||||
model_name = body.get('model', 'flux-schnell')
|
|
||||||
|
|
||||||
# Validate model type
|
|
||||||
if model_name not in model_registry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
||||||
|
|
||||||
if model_registry[model_name]['type'] != 'image':
|
|
||||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an image model")
|
|
||||||
|
|
||||||
# Ensure model is running
|
|
||||||
await ensure_model_running(model_name)
|
|
||||||
|
|
||||||
# Proxy request to model
|
|
||||||
return await proxy_request(model_name, request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/audio/generations")
|
|
||||||
async def audio_generations(request: Request):
|
|
||||||
"""Custom audio generation endpoint (music/sound effects)"""
|
|
||||||
# Parse request to get model name
|
|
||||||
body = await request.json()
|
|
||||||
model_name = body.get('model', 'musicgen-medium')
|
|
||||||
|
|
||||||
# Validate model type
|
|
||||||
if model_name not in model_registry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
||||||
|
|
||||||
if model_registry[model_name]['type'] != 'audio':
|
|
||||||
raise HTTPException(status_code=400, detail=f"Model {model_name} is not an audio model")
|
|
||||||
|
|
||||||
# Ensure model is running
|
|
||||||
await ensure_model_running(model_name)
|
|
||||||
|
|
||||||
# Proxy request to model
|
|
||||||
return await proxy_request(model_name, request)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/switch")
|
|
||||||
async def switch_model(request: Request):
|
|
||||||
"""Manually switch to a specific model"""
|
|
||||||
body = await request.json()
|
|
||||||
model_name = body.get('model')
|
|
||||||
|
|
||||||
if not model_name:
|
|
||||||
raise HTTPException(status_code=400, detail="Model name required")
|
|
||||||
|
|
||||||
if model_name not in model_registry:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Model {model_name} not found")
|
|
||||||
|
|
||||||
await ensure_model_running(model_name)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"model": model_name,
|
|
||||||
"message": f"Switched to {model_name}"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
host = os.getenv("HOST", "0.0.0.0")
|
|
||||||
port = int(os.getenv("PORT", "9000"))
|
|
||||||
|
|
||||||
logger.info(f"Starting AI Model Orchestrator on {host}:{port}")
|
|
||||||
|
|
||||||
uvicorn.run(
|
|
||||||
app,
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
log_level="info",
|
|
||||||
access_log=True,
|
|
||||||
)
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
AI Model Orchestrator for RunPod (Process-Based)
|
|
||||||
Manages sequential loading of AI models using subprocess instead of Docker
|
|
||||||
|
|
||||||
Simplified architecture for RunPod's containerized environment:
|
|
||||||
- No Docker-in-Docker complexity
|
|
||||||
- Direct process management via subprocess
|
|
||||||
- Models run as Python background processes
|
|
||||||
- GPU memory efficient (sequential model loading)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import signal
|
|
||||||
from typing import Dict, Optional, Any
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import yaml
|
|
||||||
import psutil
|
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# FastAPI app
|
|
||||||
app = FastAPI(title="AI Model Orchestrator (Process-Based)", version="2.0.0")
|
|
||||||
|
|
||||||
# Global state
|
|
||||||
current_model: Optional[str] = None
|
|
||||||
model_processes: Dict[str, subprocess.Popen] = {}
|
|
||||||
model_registry: Dict[str, Dict[str, Any]] = {}
|
|
||||||
config: Dict[str, Any] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_registry():
|
|
||||||
"""Load model registry from models.yaml"""
|
|
||||||
global model_registry, config
|
|
||||||
|
|
||||||
config_path = os.getenv("MODELS_CONFIG", "/workspace/ai/model-orchestrator/models.yaml")
|
|
||||||
logger.info(f"Loading model registry from {config_path}")
|
|
||||||
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
data = yaml.safe_load(f)
|
|
||||||
|
|
||||||
model_registry = data.get('models', {})
|
|
||||||
config = data.get('config', {})
|
|
||||||
|
|
||||||
logger.info(f"Loaded {len(model_registry)} models")
|
|
||||||
for model_name, model_config in model_registry.items():
|
|
||||||
logger.info(f" - {model_name}: {model_config.get('type')} ({model_config.get('framework')})")
|
|
||||||
|
|
||||||
|
|
||||||
async def start_model_process(model_name: str) -> bool:
|
|
||||||
"""Start a model as a subprocess"""
|
|
||||||
global current_model, model_processes
|
|
||||||
|
|
||||||
if model_name not in model_registry:
|
|
||||||
logger.error(f"Model {model_name} not found in registry")
|
|
||||||
return False
|
|
||||||
|
|
||||||
model_config = model_registry[model_name]
|
|
||||||
|
|
||||||
# Stop current model if running
|
|
||||||
if current_model and current_model != model_name:
|
|
||||||
await stop_model_process(current_model)
|
|
||||||
|
|
||||||
# Check if already running
|
|
||||||
if model_name in model_processes:
|
|
||||||
proc = model_processes[model_name]
|
|
||||||
if proc.poll() is None: # Still running
|
|
||||||
logger.info(f"Model {model_name} already running")
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info(f"Starting model {model_name}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get service command from config
|
|
||||||
service_script = model_config.get('service_script')
|
|
||||||
if not service_script:
|
|
||||||
logger.error(f"No service_script defined for {model_name}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
script_path = Path(f"/workspace/ai/{service_script}")
|
|
||||||
if not script_path.exists():
|
|
||||||
logger.error(f"Service script not found: {script_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Start process
|
|
||||||
port = model_config.get('port', 8000)
|
|
||||||
env = os.environ.copy()
|
|
||||||
env.update({
|
|
||||||
'HF_TOKEN': os.getenv('HF_TOKEN', ''),
|
|
||||||
'PORT': str(port),
|
|
||||||
'HOST': '0.0.0.0',
|
|
||||||
'MODEL_NAME': model_config.get('model_name', model_name)
|
|
||||||
})
|
|
||||||
|
|
||||||
# Use venv python if it exists
|
|
||||||
script_dir = script_path.parent
|
|
||||||
venv_python = script_dir / 'venv' / 'bin' / 'python3'
|
|
||||||
python_cmd = str(venv_python) if venv_python.exists() else 'python3'
|
|
||||||
|
|
||||||
proc = subprocess.Popen(
|
|
||||||
[python_cmd, str(script_path)],
|
|
||||||
env=env,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.PIPE,
|
|
||||||
preexec_fn=os.setsid # Create new process group
|
|
||||||
)
|
|
||||||
|
|
||||||
model_processes[model_name] = proc
|
|
||||||
|
|
||||||
# Wait for service to be ready
|
|
||||||
max_wait = model_config.get('startup_time_seconds', 120)
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
while time.time() - start_time < max_wait:
|
|
||||||
if proc.poll() is not None:
|
|
||||||
logger.error(f"Process for {model_name} exited prematurely")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
f"http://localhost:{port}/health",
|
|
||||||
timeout=5.0
|
|
||||||
)
|
|
||||||
if response.status_code == 200:
|
|
||||||
logger.info(f"Model {model_name} is ready on port {port}")
|
|
||||||
current_model = model_name
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
|
|
||||||
logger.error(f"Model {model_name} failed to start within {max_wait}s")
|
|
||||||
await stop_model_process(model_name)
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error starting {model_name}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_model_process(model_name: str):
|
|
||||||
"""Stop a running model process"""
|
|
||||||
global model_processes, current_model
|
|
||||||
|
|
||||||
if model_name not in model_processes:
|
|
||||||
logger.warning(f"Model {model_name} not in process registry")
|
|
||||||
return
|
|
||||||
|
|
||||||
proc = model_processes[model_name]
|
|
||||||
|
|
||||||
if proc.poll() is None: # Still running
|
|
||||||
logger.info(f"Stopping model {model_name}...")
|
|
||||||
try:
|
|
||||||
# Send SIGTERM to process group
|
|
||||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
|
||||||
|
|
||||||
# Wait for graceful shutdown
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=10)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
# Force kill if not terminated
|
|
||||||
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
|
|
||||||
proc.wait()
|
|
||||||
|
|
||||||
logger.info(f"Model {model_name} stopped")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error stopping {model_name}: {e}")
|
|
||||||
|
|
||||||
del model_processes[model_name]
|
|
||||||
if current_model == model_name:
|
|
||||||
current_model = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_for_endpoint(endpoint: str) -> Optional[str]:
|
|
||||||
"""Determine which model handles this endpoint"""
|
|
||||||
for model_name, model_config in model_registry.items():
|
|
||||||
if endpoint.startswith(model_config.get('endpoint', '')):
|
|
||||||
return model_name
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def startup_event():
|
|
||||||
"""Initialize on startup"""
|
|
||||||
logger.info("Starting AI Model Orchestrator (Process-Based)")
|
|
||||||
load_model_registry()
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
"""Cleanup on shutdown"""
|
|
||||||
logger.info("Shutting down orchestrator...")
|
|
||||||
for model_name in list(model_processes.keys()):
|
|
||||||
await stop_model_process(model_name)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
"""Health check endpoint"""
|
|
||||||
return {
|
|
||||||
"status": "healthy",
|
|
||||||
"current_model": current_model,
|
|
||||||
"active_processes": len(model_processes),
|
|
||||||
"available_models": list(model_registry.keys())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
|
||||||
async def list_models_openai():
|
|
||||||
"""OpenAI-compatible models listing endpoint"""
|
|
||||||
models_list = []
|
|
||||||
for model_name, model_info in model_registry.items():
|
|
||||||
models_list.append({
|
|
||||||
"id": model_name,
|
|
||||||
"object": "model",
|
|
||||||
"created": int(time.time()),
|
|
||||||
"owned_by": "pivoine-gpu",
|
|
||||||
"permission": [],
|
|
||||||
"root": model_name,
|
|
||||||
"parent": None,
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": models_list
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
|
|
||||||
async def proxy_request(request: Request, path: str):
|
|
||||||
"""Proxy requests to appropriate model service"""
|
|
||||||
endpoint = f"/{path}"
|
|
||||||
|
|
||||||
# Determine which model should handle this
|
|
||||||
target_model = get_model_for_endpoint(endpoint)
|
|
||||||
|
|
||||||
if not target_model:
|
|
||||||
raise HTTPException(status_code=404, detail=f"No model configured for endpoint: {endpoint}")
|
|
||||||
|
|
||||||
# Ensure model is running
|
|
||||||
if current_model != target_model:
|
|
||||||
logger.info(f"Switching to model {target_model}")
|
|
||||||
success = await start_model_process(target_model)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=503, detail=f"Failed to start model {target_model}")
|
|
||||||
|
|
||||||
# Proxy the request
|
|
||||||
model_config = model_registry[target_model]
|
|
||||||
target_url = f"http://localhost:{model_config['port']}/{path}"
|
|
||||||
|
|
||||||
# Get request details
|
|
||||||
method = request.method
|
|
||||||
headers = dict(request.headers)
|
|
||||||
headers.pop('host', None) # Remove host header
|
|
||||||
body = await request.body()
|
|
||||||
|
|
||||||
# Check if this is a streaming request
|
|
||||||
is_streaming = False
|
|
||||||
if method == "POST" and body:
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
body_json = json.loads(body)
|
|
||||||
is_streaming = body_json.get('stream', False)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if is_streaming:
|
|
||||||
# For streaming requests, use httpx streaming and yield chunks
|
|
||||||
async def stream_response():
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
||||||
async with client.stream(method, target_url, content=body, headers=headers) as response:
|
|
||||||
async for chunk in response.aiter_bytes():
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
stream_response(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For non-streaming requests, use the original behavior
|
|
||||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
|
||||||
response = await client.request(
|
|
||||||
method=method,
|
|
||||||
url=target_url,
|
|
||||||
headers=headers,
|
|
||||||
content=body
|
|
||||||
)
|
|
||||||
|
|
||||||
return JSONResponse(
|
|
||||||
content=response.json() if response.headers.get('content-type') == 'application/json' else response.text,
|
|
||||||
status_code=response.status_code
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error proxying request: {e}")
|
|
||||||
raise HTTPException(status_code=502, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
port = int(os.getenv("PORT", "9000"))
|
|
||||||
host = os.getenv("HOST", "0.0.0.0")
|
|
||||||
|
|
||||||
logger.info(f"Starting orchestrator on {host}:{port}")
|
|
||||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
fastapi==0.104.1
|
|
||||||
uvicorn[standard]==0.24.0
|
|
||||||
httpx==0.25.1
|
|
||||||
docker==6.1.3
|
|
||||||
pyyaml==6.0.1
|
|
||||||
pydantic==2.5.0
|
|
||||||
psutil==5.9.6
|
|
||||||
Reference in New Issue
Block a user