feat: implement Ansible-based process architecture for RunPod
Major architecture overhaul to address RunPod Docker limitations: Core Infrastructure: - Add base_service.py: Abstract base class for all AI services - Add service_manager.py: Process lifecycle management - Add core/requirements.txt: Core dependencies Model Services (Standalone Python): - Add models/vllm/server.py: Qwen 2.5 7B text generation - Add models/flux/server.py: Flux.1 Schnell image generation - Add models/musicgen/server.py: MusicGen Medium music generation - Each service inherits from GPUService base class - OpenAI-compatible APIs - Standalone execution support Ansible Deployment: - Add playbook.yml: Comprehensive deployment automation - Add ansible.cfg: Ansible configuration - Add inventory.yml: Localhost inventory - Tags: base, python, dependencies, models, tailscale, validate, cleanup Scripts: - Add scripts/install.sh: Full installation wrapper - Add scripts/download-models.sh: Model download wrapper - Add scripts/start-all.sh: Start orchestrator - Add scripts/stop-all.sh: Stop all services Documentation: - Update ARCHITECTURE.md: Document distributed VPS+GPU architecture Benefits: - No Docker: Avoids RunPod CAP_SYS_ADMIN limitations - Fully reproducible via Ansible - Extensible: Add models in 3 steps - Direct Python execution (no container overhead) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
21
models/flux/requirements.txt
Normal file
21
models/flux/requirements.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
# Flux.1 Image Generation Service Dependencies
|
||||
|
||||
# Diffusers library (for Flux.1 pipeline)
|
||||
diffusers==0.30.0
|
||||
|
||||
# PyTorch (required by diffusers)
|
||||
torch==2.1.0
|
||||
torchvision==0.16.0
|
||||
|
||||
# Transformers (for model components)
|
||||
transformers==4.36.0
|
||||
|
||||
# Image processing
|
||||
Pillow==10.1.0
|
||||
|
||||
# Accelerate (for optimizations)
|
||||
accelerate==0.25.0
|
||||
|
||||
# Additional dependencies for Flux
|
||||
sentencepiece==0.1.99
|
||||
protobuf==4.25.1
|
||||
193
models/flux/server.py
Normal file
193
models/flux/server.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Flux.1 Image Generation Service
|
||||
|
||||
OpenAI-compatible image generation using Flux.1 Schnell model.
|
||||
Provides /v1/images/generations endpoint.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from fastapi import HTTPException
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import base service class
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from core.base_service import GPUService
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
"""Image generation request (OpenAI-compatible)"""
|
||||
model: str = Field(default="flux-schnell", description="Model name")
|
||||
prompt: str = Field(..., description="Text description of the image to generate")
|
||||
n: int = Field(default=1, ge=1, le=4, description="Number of images to generate")
|
||||
size: str = Field(default="1024x1024", description="Image size (e.g., 512x512, 1024x1024)")
|
||||
response_format: str = Field(default="b64_json", description="Response format: url or b64_json")
|
||||
quality: str = Field(default="standard", description="Image quality: standard or hd")
|
||||
style: str = Field(default="natural", description="Image style: natural or vivid")
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
"""Image generation response (OpenAI-compatible)"""
|
||||
created: int = Field(..., description="Unix timestamp")
|
||||
data: list = Field(..., description="List of generated images")
|
||||
|
||||
|
||||
class FluxService(GPUService):
|
||||
"""Flux.1 Schnell image generation service"""
|
||||
|
||||
def __init__(self):
|
||||
# Get port from environment or use default
|
||||
port = int(os.getenv("PORT", "8002"))
|
||||
super().__init__(name="flux-schnell", port=port)
|
||||
|
||||
# Service-specific attributes
|
||||
self.pipeline: Optional[FluxPipeline] = None
|
||||
self.model_name = os.getenv("MODEL_NAME", "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize Flux.1 pipeline"""
|
||||
await super().initialize()
|
||||
|
||||
self.logger.info(f"Loading Flux.1 pipeline: {self.model_name}")
|
||||
|
||||
# Load pipeline
|
||||
self.pipeline = FluxPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
cache_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache")
|
||||
)
|
||||
|
||||
# Move to GPU
|
||||
if torch.cuda.is_available():
|
||||
self.pipeline = self.pipeline.to("cuda")
|
||||
self.logger.info("Flux.1 pipeline loaded on GPU")
|
||||
else:
|
||||
self.logger.warning("GPU not available, running on CPU (very slow)")
|
||||
|
||||
# Enable memory optimizations
|
||||
if hasattr(self.pipeline, 'enable_model_cpu_offload'):
|
||||
# This moves models to GPU only when needed, saving VRAM
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
self.logger.info("Flux.1 pipeline initialized successfully")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
await super().cleanup()
|
||||
if self.pipeline:
|
||||
self.logger.info("Flux.1 pipeline cleanup")
|
||||
self.pipeline = None
|
||||
|
||||
def parse_size(self, size_str: str) -> tuple[int, int]:
|
||||
"""Parse size string like '1024x1024' into (width, height)"""
|
||||
try:
|
||||
parts = size_str.lower().split('x')
|
||||
if len(parts) != 2:
|
||||
return (1024, 1024)
|
||||
width = int(parts[0])
|
||||
height = int(parts[1])
|
||||
return (width, height)
|
||||
except:
|
||||
return (1024, 1024)
|
||||
|
||||
def image_to_base64(self, image: Image.Image) -> str:
|
||||
"""Convert PIL Image to base64 string"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_bytes = buffered.getvalue()
|
||||
return base64.b64encode(img_bytes).decode('utf-8')
|
||||
|
||||
def create_app(self):
|
||||
"""Create FastAPI routes"""
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "Flux.1 Schnell Image Generation",
|
||||
"model": self.model_name,
|
||||
"max_images": 4
|
||||
}
|
||||
|
||||
@self.app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""List available models (OpenAI-compatible)"""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "flux-schnell",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "black-forest-labs",
|
||||
"permission": [],
|
||||
"root": self.model_name,
|
||||
"parent": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@self.app.post("/v1/images/generations")
|
||||
async def generate_image(request: ImageGenerationRequest) -> ImageGenerationResponse:
|
||||
"""Generate images from text prompt (OpenAI-compatible)"""
|
||||
if not self.pipeline:
|
||||
raise HTTPException(status_code=503, detail="Model not initialized")
|
||||
|
||||
self.logger.info(f"Generating {request.n} image(s): {request.prompt[:100]}...")
|
||||
|
||||
try:
|
||||
# Parse image size
|
||||
width, height = self.parse_size(request.size)
|
||||
self.logger.info(f"Size: {width}x{height}")
|
||||
|
||||
# Generate images
|
||||
images = []
|
||||
for i in range(request.n):
|
||||
self.logger.info(f"Generating image {i+1}/{request.n}")
|
||||
|
||||
# Flux.1 Schnell uses 4 inference steps for speed
|
||||
image = self.pipeline(
|
||||
prompt=request.prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=4, # Schnell is optimized for 4 steps
|
||||
guidance_scale=0.0, # Schnell doesn't use guidance
|
||||
).images[0]
|
||||
|
||||
# Convert to base64
|
||||
if request.response_format == "b64_json":
|
||||
image_data = {
|
||||
"b64_json": self.image_to_base64(image)
|
||||
}
|
||||
else:
|
||||
# For URL format, we'd need to save and serve the file
|
||||
# For now, we'll return base64 anyway
|
||||
image_data = {
|
||||
"b64_json": self.image_to_base64(image)
|
||||
}
|
||||
|
||||
images.append(image_data)
|
||||
|
||||
self.logger.info(f"Generated {request.n} image(s) successfully")
|
||||
|
||||
return ImageGenerationResponse(
|
||||
created=1234567890,
|
||||
data=images
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error generating image: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = FluxService()
|
||||
service.run()
|
||||
11
models/musicgen/requirements.txt
Normal file
11
models/musicgen/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
# MusicGen Music Generation Service Dependencies
|
||||
|
||||
# AudioCraft (contains MusicGen)
|
||||
audiocraft==1.3.0
|
||||
|
||||
# PyTorch (required by AudioCraft)
|
||||
torch==2.1.0
|
||||
torchaudio==2.1.0
|
||||
|
||||
# Additional dependencies
|
||||
transformers==4.36.0
|
||||
172
models/musicgen/server.py
Normal file
172
models/musicgen/server.py
Normal file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MusicGen Music Generation Service
|
||||
|
||||
OpenAI-compatible music generation using Meta's MusicGen Medium model.
|
||||
Provides /v1/audio/generations endpoint.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Import base service class
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from core.base_service import GPUService
|
||||
|
||||
|
||||
class AudioGenerationRequest(BaseModel):
|
||||
"""Music generation request"""
|
||||
model: str = Field(default="musicgen-medium", description="Model name")
|
||||
prompt: str = Field(..., description="Text description of the music to generate")
|
||||
duration: float = Field(default=30.0, ge=1.0, le=30.0, description="Duration in seconds")
|
||||
temperature: float = Field(default=1.0, ge=0.1, le=2.0, description="Sampling temperature")
|
||||
top_k: int = Field(default=250, ge=0, le=500, description="Top-k sampling")
|
||||
top_p: float = Field(default=0.0, ge=0.0, le=1.0, description="Top-p (nucleus) sampling")
|
||||
cfg_coef: float = Field(default=3.0, ge=1.0, le=15.0, description="Classifier-free guidance coefficient")
|
||||
response_format: str = Field(default="wav", description="Audio format (wav or mp3)")
|
||||
|
||||
|
||||
class AudioGenerationResponse(BaseModel):
|
||||
"""Music generation response"""
|
||||
audio: str = Field(..., description="Base64-encoded audio data")
|
||||
format: str = Field(..., description="Audio format (wav or mp3)")
|
||||
duration: float = Field(..., description="Duration in seconds")
|
||||
sample_rate: int = Field(..., description="Sample rate in Hz")
|
||||
|
||||
|
||||
class MusicGenService(GPUService):
|
||||
"""MusicGen music generation service"""
|
||||
|
||||
def __init__(self):
|
||||
# Get port from environment or use default
|
||||
port = int(os.getenv("PORT", "8003"))
|
||||
super().__init__(name="musicgen-medium", port=port)
|
||||
|
||||
# Service-specific attributes
|
||||
self.model: Optional[MusicGen] = None
|
||||
self.model_name = os.getenv("MODEL_NAME", "facebook/musicgen-medium")
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize MusicGen model"""
|
||||
await super().initialize()
|
||||
|
||||
self.logger.info(f"Loading MusicGen model: {self.model_name}")
|
||||
|
||||
# Load model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = MusicGen.get_pretrained(self.model_name, device=device)
|
||||
|
||||
self.logger.info(f"MusicGen model loaded successfully")
|
||||
self.logger.info(f"Max duration: 30 seconds at {self.model.sample_rate}Hz")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
await super().cleanup()
|
||||
if self.model:
|
||||
self.logger.info("MusicGen model cleanup")
|
||||
self.model = None
|
||||
|
||||
def create_app(self):
|
||||
"""Create FastAPI routes"""
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"service": "MusicGen API Server",
|
||||
"model": self.model_name,
|
||||
"max_duration": 30.0,
|
||||
"sample_rate": self.model.sample_rate if self.model else 32000
|
||||
}
|
||||
|
||||
@self.app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""List available models (OpenAI-compatible)"""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "musicgen-medium",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "meta",
|
||||
"permission": [],
|
||||
"root": self.model_name,
|
||||
"parent": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@self.app.post("/v1/audio/generations")
|
||||
async def generate_audio(request: AudioGenerationRequest) -> AudioGenerationResponse:
|
||||
"""Generate music from text prompt"""
|
||||
if not self.model:
|
||||
raise HTTPException(status_code=503, detail="Model not initialized")
|
||||
|
||||
self.logger.info(f"Generating music: {request.prompt[:100]}...")
|
||||
self.logger.info(f"Duration: {request.duration}s, Temperature: {request.temperature}")
|
||||
|
||||
try:
|
||||
# Set generation parameters
|
||||
self.model.set_generation_params(
|
||||
duration=request.duration,
|
||||
temperature=request.temperature,
|
||||
top_k=request.top_k,
|
||||
top_p=request.top_p,
|
||||
cfg_coef=request.cfg_coef,
|
||||
)
|
||||
|
||||
# Generate audio
|
||||
descriptions = [request.prompt]
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch_size, channels, samples]
|
||||
# Extract first batch item
|
||||
audio_data = wav[0].cpu() # [channels, samples]
|
||||
|
||||
# Get sample rate
|
||||
sample_rate = self.model.sample_rate
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
torchaudio.save(temp_path, audio_data, sample_rate)
|
||||
|
||||
# Read audio file and encode to base64
|
||||
with open(temp_path, 'rb') as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
# Clean up temporary file
|
||||
os.unlink(temp_path)
|
||||
|
||||
# Encode to base64
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
|
||||
|
||||
self.logger.info(f"Generated {request.duration}s of audio")
|
||||
|
||||
return AudioGenerationResponse(
|
||||
audio=audio_base64,
|
||||
format="wav",
|
||||
duration=request.duration,
|
||||
sample_rate=sample_rate
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error generating audio: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = MusicGenService()
|
||||
service.run()
|
||||
13
models/vllm/requirements.txt
Normal file
13
models/vllm/requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# vLLM Text Generation Service Dependencies
|
||||
|
||||
# vLLM engine
|
||||
vllm==0.6.4.post1
|
||||
|
||||
# PyTorch (required by vLLM)
|
||||
torch==2.1.0
|
||||
|
||||
# Transformers (for model loading)
|
||||
transformers==4.36.0
|
||||
|
||||
# Additional dependencies
|
||||
accelerate==0.25.0
|
||||
297
models/vllm/server.py
Normal file
297
models/vllm/server.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vLLM Text Generation Service
|
||||
|
||||
OpenAI-compatible text generation using vLLM and Qwen 2.5 7B Instruct model.
|
||||
Provides /v1/completions and /v1/chat/completions endpoints.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# Import base service class
|
||||
import sys
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||||
from core.base_service import GPUService
|
||||
|
||||
|
||||
# Request/Response models
|
||||
class CompletionRequest(BaseModel):
|
||||
"""OpenAI-compatible completion request"""
|
||||
model: str = Field(default="qwen-2.5-7b")
|
||||
prompt: str | List[str] = Field(..., description="Text prompt(s)")
|
||||
max_tokens: int = Field(default=512, ge=1, le=4096)
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
n: int = Field(default=1, ge=1, le=10)
|
||||
stream: bool = Field(default=False)
|
||||
stop: Optional[str | List[str]] = None
|
||||
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message format"""
|
||||
role: str = Field(..., description="Role: system, user, or assistant")
|
||||
content: str = Field(..., description="Message content")
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI-compatible chat completion request"""
|
||||
model: str = Field(default="qwen-2.5-7b")
|
||||
messages: List[ChatMessage] = Field(..., description="Chat messages")
|
||||
max_tokens: int = Field(default=512, ge=1, le=4096)
|
||||
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||||
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||||
n: int = Field(default=1, ge=1, le=10)
|
||||
stream: bool = Field(default=False)
|
||||
stop: Optional[str | List[str]] = None
|
||||
|
||||
|
||||
class VLLMService(GPUService):
|
||||
"""vLLM text generation service"""
|
||||
|
||||
def __init__(self):
|
||||
# Get port from environment or use default
|
||||
port = int(os.getenv("PORT", "8001"))
|
||||
super().__init__(name="vllm-qwen", port=port)
|
||||
|
||||
# Service-specific attributes
|
||||
self.engine: Optional[AsyncLLMEngine] = None
|
||||
self.model_name = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize vLLM engine"""
|
||||
await super().initialize()
|
||||
|
||||
self.logger.info(f"Initializing vLLM AsyncLLMEngine with model: {self.model_name}")
|
||||
|
||||
# Configure engine
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=self.model_name,
|
||||
tensor_parallel_size=1, # Single GPU
|
||||
gpu_memory_utilization=0.85, # Use 85% of GPU memory
|
||||
max_model_len=4096, # Context length
|
||||
dtype="auto", # Auto-detect dtype
|
||||
download_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache"),
|
||||
trust_remote_code=True, # Some models require this
|
||||
enforce_eager=False, # Use CUDA graphs for better performance
|
||||
)
|
||||
|
||||
# Create async engine
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
self.logger.info("vLLM AsyncLLMEngine initialized successfully")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources"""
|
||||
await super().cleanup()
|
||||
if self.engine:
|
||||
# vLLM doesn't have an explicit shutdown method
|
||||
self.logger.info("vLLM engine cleanup")
|
||||
self.engine = None
|
||||
|
||||
def messages_to_prompt(self, messages: List[ChatMessage]) -> str:
|
||||
"""Convert chat messages to Qwen 2.5 prompt format"""
|
||||
prompt_parts = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.role
|
||||
content = msg.content
|
||||
|
||||
if role == "system":
|
||||
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||||
elif role == "user":
|
||||
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
|
||||
# Add final assistant prompt
|
||||
prompt_parts.append("<|im_start|>assistant\n")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def create_app(self):
|
||||
"""Create FastAPI routes"""
|
||||
|
||||
@self.app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {"status": "ok", "model": self.model_name}
|
||||
|
||||
@self.app.get("/v1/models")
|
||||
async def list_models():
|
||||
"""OpenAI-compatible models endpoint"""
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "qwen-2.5-7b",
|
||||
"object": "model",
|
||||
"created": 1234567890,
|
||||
"owned_by": "pivoine-gpu",
|
||||
"permission": [],
|
||||
"root": self.model_name,
|
||||
"parent": None,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@self.app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest):
|
||||
"""OpenAI-compatible completion endpoint"""
|
||||
if not self.engine:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "Engine not initialized"}
|
||||
)
|
||||
|
||||
# Handle both single prompt and batch prompts
|
||||
prompts = [request.prompt] if isinstance(request.prompt, str) else request.prompt
|
||||
|
||||
# Configure sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens,
|
||||
n=request.n,
|
||||
stop=request.stop if request.stop else [],
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
)
|
||||
|
||||
# Generate completions
|
||||
results = []
|
||||
for prompt in prompts:
|
||||
request_id = random_uuid()
|
||||
|
||||
if request.stream:
|
||||
# Streaming response
|
||||
async def generate_stream():
|
||||
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||||
chunk = {
|
||||
"id": request_id,
|
||||
"object": "text_completion",
|
||||
"created": 1234567890,
|
||||
"model": request.model,
|
||||
"choices": [
|
||||
{
|
||||
"text": output.outputs[0].text,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
# Non-streaming response
|
||||
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||||
final_output = output
|
||||
|
||||
results.append({
|
||||
"text": final_output.outputs[0].text,
|
||||
"index": len(results),
|
||||
"logprobs": None,
|
||||
"finish_reason": final_output.outputs[0].finish_reason,
|
||||
})
|
||||
|
||||
return {
|
||||
"id": random_uuid(),
|
||||
"object": "text_completion",
|
||||
"created": 1234567890,
|
||||
"model": request.model,
|
||||
"choices": results,
|
||||
"usage": {
|
||||
"prompt_tokens": 0, # vLLM doesn't expose this easily
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
}
|
||||
|
||||
@self.app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
"""OpenAI-compatible chat completion endpoint"""
|
||||
if not self.engine:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={"error": "Engine not initialized"}
|
||||
)
|
||||
|
||||
# Convert messages to prompt
|
||||
prompt = self.messages_to_prompt(request.messages)
|
||||
|
||||
# Configure sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_tokens=request.max_tokens,
|
||||
n=request.n,
|
||||
stop=request.stop if request.stop else ["<|im_end|>"],
|
||||
)
|
||||
|
||||
request_id = random_uuid()
|
||||
|
||||
if request.stream:
|
||||
# Streaming response
|
||||
async def generate_stream():
|
||||
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||||
chunk = {
|
||||
"id": request_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
"model": request.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": output.outputs[0].text},
|
||||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
# Non-streaming response
|
||||
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||||
final_output = output
|
||||
|
||||
return {
|
||||
"id": request_id,
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": request.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": final_output.outputs[0].text,
|
||||
},
|
||||
"finish_reason": final_output.outputs[0].finish_reason,
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
service = VLLMService()
|
||||
service.run()
|
||||
Reference in New Issue
Block a user