167 lines
5.0 KiB
Python
167 lines
5.0 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
Base Service Class for AI Model Services
|
||
|
|
|
||
|
|
Provides common functionality for all model services:
|
||
|
|
- Health check endpoint
|
||
|
|
- Graceful shutdown handling
|
||
|
|
- Logging configuration
|
||
|
|
- Standard FastAPI setup
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import signal
|
||
|
|
import sys
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from fastapi import FastAPI
|
||
|
|
import uvicorn
|
||
|
|
|
||
|
|
|
||
|
|
class BaseService(ABC):
|
||
|
|
"""Abstract base class for all AI model services"""
|
||
|
|
|
||
|
|
def __init__(self, name: str, port: int, host: str = "0.0.0.0"):
|
||
|
|
"""
|
||
|
|
Initialize base service
|
||
|
|
|
||
|
|
Args:
|
||
|
|
name: Service name (for logging)
|
||
|
|
port: Port to run service on
|
||
|
|
host: Host to bind to (default: 0.0.0.0)
|
||
|
|
"""
|
||
|
|
self.name = name
|
||
|
|
self.port = port
|
||
|
|
self.host = host
|
||
|
|
self.app = FastAPI(title=f"{name} Service", version="1.0.0")
|
||
|
|
self.logger = self._setup_logging()
|
||
|
|
self.shutdown_event = asyncio.Event()
|
||
|
|
|
||
|
|
# Register standard endpoints
|
||
|
|
self._register_health_endpoint()
|
||
|
|
|
||
|
|
# Register signal handlers for graceful shutdown
|
||
|
|
self._register_signal_handlers()
|
||
|
|
|
||
|
|
# Allow subclasses to add custom routes
|
||
|
|
self.create_app()
|
||
|
|
|
||
|
|
def _setup_logging(self) -> logging.Logger:
|
||
|
|
"""Configure logging for the service"""
|
||
|
|
logging.basicConfig(
|
||
|
|
level=logging.INFO,
|
||
|
|
format=f'%(asctime)s - {self.name} - %(levelname)s - %(message)s',
|
||
|
|
handlers=[
|
||
|
|
logging.StreamHandler(sys.stdout)
|
||
|
|
]
|
||
|
|
)
|
||
|
|
return logging.getLogger(self.name)
|
||
|
|
|
||
|
|
def _register_health_endpoint(self):
|
||
|
|
"""Register standard health check endpoint"""
|
||
|
|
@self.app.get("/health")
|
||
|
|
async def health_check():
|
||
|
|
"""Health check endpoint"""
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"service": self.name,
|
||
|
|
"port": self.port
|
||
|
|
}
|
||
|
|
|
||
|
|
def _register_signal_handlers(self):
|
||
|
|
"""Register signal handlers for graceful shutdown"""
|
||
|
|
def signal_handler(sig, frame):
|
||
|
|
self.logger.info(f"Received signal {sig}, initiating graceful shutdown...")
|
||
|
|
self.shutdown_event.set()
|
||
|
|
|
||
|
|
signal.signal(signal.SIGINT, signal_handler)
|
||
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def create_app(self):
|
||
|
|
"""
|
||
|
|
Create FastAPI routes for this service.
|
||
|
|
Subclasses must implement this to add their specific endpoints.
|
||
|
|
|
||
|
|
Example:
|
||
|
|
@self.app.post("/v1/generate")
|
||
|
|
async def generate(request: MyRequest):
|
||
|
|
return await self.model.generate(request)
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
async def initialize(self):
|
||
|
|
"""
|
||
|
|
Initialize the service (load models, etc.).
|
||
|
|
Subclasses can override this for custom initialization.
|
||
|
|
"""
|
||
|
|
self.logger.info(f"Initializing {self.name} service...")
|
||
|
|
|
||
|
|
async def cleanup(self):
|
||
|
|
"""
|
||
|
|
Cleanup resources on shutdown.
|
||
|
|
Subclasses can override this for custom cleanup.
|
||
|
|
"""
|
||
|
|
self.logger.info(f"Cleaning up {self.name} service...")
|
||
|
|
|
||
|
|
def run(self):
|
||
|
|
"""
|
||
|
|
Run the service.
|
||
|
|
This is the main entry point that starts the FastAPI server.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
self.logger.info(f"Starting {self.name} service on {self.host}:{self.port}")
|
||
|
|
|
||
|
|
# Run initialization
|
||
|
|
asyncio.run(self.initialize())
|
||
|
|
|
||
|
|
# Start uvicorn server
|
||
|
|
config = uvicorn.Config(
|
||
|
|
app=self.app,
|
||
|
|
host=self.host,
|
||
|
|
port=self.port,
|
||
|
|
log_level="info",
|
||
|
|
access_log=True
|
||
|
|
)
|
||
|
|
server = uvicorn.Server(config)
|
||
|
|
|
||
|
|
# Run server
|
||
|
|
asyncio.run(server.serve())
|
||
|
|
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
self.logger.info("Keyboard interrupt received")
|
||
|
|
except Exception as e:
|
||
|
|
self.logger.error(f"Error running service: {e}", exc_info=True)
|
||
|
|
sys.exit(1)
|
||
|
|
finally:
|
||
|
|
# Cleanup
|
||
|
|
asyncio.run(self.cleanup())
|
||
|
|
self.logger.info(f"{self.name} service stopped")
|
||
|
|
|
||
|
|
|
||
|
|
class GPUService(BaseService):
|
||
|
|
"""
|
||
|
|
Base class for GPU-accelerated services.
|
||
|
|
Provides additional GPU-specific functionality.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, name: str, port: int, host: str = "0.0.0.0"):
|
||
|
|
super().__init__(name, port, host)
|
||
|
|
self._check_gpu_availability()
|
||
|
|
|
||
|
|
def _check_gpu_availability(self):
|
||
|
|
"""Check if GPU is available"""
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
gpu_count = torch.cuda.device_count()
|
||
|
|
gpu_name = torch.cuda.get_device_name(0)
|
||
|
|
self.logger.info(f"GPU available: {gpu_name} (count: {gpu_count})")
|
||
|
|
else:
|
||
|
|
self.logger.warning("No GPU available - service may run slowly")
|
||
|
|
except ImportError:
|
||
|
|
self.logger.warning("PyTorch not installed - cannot check GPU availability")
|