Files
runpod/models/vllm/server.py

298 lines
11 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
vLLM Text Generation Service
OpenAI-compatible text generation using vLLM and Qwen 2.5 7B Instruct model.
Provides /v1/completions and /v1/chat/completions endpoints.
"""
import asyncio
import json
import os
from typing import AsyncIterator, Dict, List, Optional
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
from vllm.utils import random_uuid
# Import base service class
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
from core.base_service import GPUService
# Request/Response models
class CompletionRequest(BaseModel):
"""OpenAI-compatible completion request"""
model: str = Field(default="qwen-2.5-7b")
prompt: str | List[str] = Field(..., description="Text prompt(s)")
max_tokens: int = Field(default=512, ge=1, le=4096)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
n: int = Field(default=1, ge=1, le=10)
stream: bool = Field(default=False)
stop: Optional[str | List[str]] = None
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
class ChatMessage(BaseModel):
"""Chat message format"""
role: str = Field(..., description="Role: system, user, or assistant")
content: str = Field(..., description="Message content")
class ChatCompletionRequest(BaseModel):
"""OpenAI-compatible chat completion request"""
model: str = Field(default="qwen-2.5-7b")
messages: List[ChatMessage] = Field(..., description="Chat messages")
max_tokens: int = Field(default=512, ge=1, le=4096)
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
n: int = Field(default=1, ge=1, le=10)
stream: bool = Field(default=False)
stop: Optional[str | List[str]] = None
class VLLMService(GPUService):
"""vLLM text generation service"""
def __init__(self):
# Get port from environment or use default
port = int(os.getenv("PORT", "8001"))
super().__init__(name="vllm-qwen", port=port)
# Service-specific attributes
self.engine: Optional[AsyncLLMEngine] = None
self.model_name = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
async def initialize(self):
"""Initialize vLLM engine"""
await super().initialize()
self.logger.info(f"Initializing vLLM AsyncLLMEngine with model: {self.model_name}")
# Configure engine
engine_args = AsyncEngineArgs(
model=self.model_name,
tensor_parallel_size=1, # Single GPU
gpu_memory_utilization=0.85, # Use 85% of GPU memory
max_model_len=4096, # Context length
dtype="auto", # Auto-detect dtype
download_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache"),
trust_remote_code=True, # Some models require this
enforce_eager=False, # Use CUDA graphs for better performance
)
# Create async engine
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.logger.info("vLLM AsyncLLMEngine initialized successfully")
async def cleanup(self):
"""Cleanup resources"""
await super().cleanup()
if self.engine:
# vLLM doesn't have an explicit shutdown method
self.logger.info("vLLM engine cleanup")
self.engine = None
def messages_to_prompt(self, messages: List[ChatMessage]) -> str:
"""Convert chat messages to Qwen 2.5 prompt format"""
prompt_parts = []
for msg in messages:
role = msg.role
content = msg.content
if role == "system":
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
# Add final assistant prompt
prompt_parts.append("<|im_start|>assistant\n")
return "\n".join(prompt_parts)
def create_app(self):
"""Create FastAPI routes"""
@self.app.get("/")
async def root():
"""Root endpoint"""
return {"status": "ok", "model": self.model_name}
@self.app.get("/v1/models")
async def list_models():
"""OpenAI-compatible models endpoint"""
return {
"object": "list",
"data": [
{
"id": "qwen-2.5-7b",
"object": "model",
"created": 1234567890,
"owned_by": "pivoine-gpu",
"permission": [],
"root": self.model_name,
"parent": None,
}
]
}
@self.app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
"""OpenAI-compatible completion endpoint"""
if not self.engine:
return JSONResponse(
status_code=503,
content={"error": "Engine not initialized"}
)
# Handle both single prompt and batch prompts
prompts = [request.prompt] if isinstance(request.prompt, str) else request.prompt
# Configure sampling parameters
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
n=request.n,
stop=request.stop if request.stop else [],
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
)
# Generate completions
results = []
for prompt in prompts:
request_id = random_uuid()
if request.stream:
# Streaming response
async def generate_stream():
async for output in self.engine.generate(prompt, sampling_params, request_id):
chunk = {
"id": request_id,
"object": "text_completion",
"created": 1234567890,
"model": request.model,
"choices": [
{
"text": output.outputs[0].text,
"index": 0,
"logprobs": None,
"finish_reason": output.outputs[0].finish_reason,
}
]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
# Non-streaming response
async for output in self.engine.generate(prompt, sampling_params, request_id):
final_output = output
results.append({
"text": final_output.outputs[0].text,
"index": len(results),
"logprobs": None,
"finish_reason": final_output.outputs[0].finish_reason,
})
return {
"id": random_uuid(),
"object": "text_completion",
"created": 1234567890,
"model": request.model,
"choices": results,
"usage": {
"prompt_tokens": 0, # vLLM doesn't expose this easily
"completion_tokens": 0,
"total_tokens": 0,
}
}
@self.app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint"""
if not self.engine:
return JSONResponse(
status_code=503,
content={"error": "Engine not initialized"}
)
# Convert messages to prompt
prompt = self.messages_to_prompt(request.messages)
# Configure sampling parameters
sampling_params = SamplingParams(
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
n=request.n,
stop=request.stop if request.stop else ["<|im_end|>"],
)
request_id = random_uuid()
if request.stream:
# Streaming response
async def generate_stream():
async for output in self.engine.generate(prompt, sampling_params, request_id):
chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": 1234567890,
"model": request.model,
"choices": [
{
"index": 0,
"delta": {"content": output.outputs[0].text},
"finish_reason": output.outputs[0].finish_reason,
}
]
}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
# Non-streaming response
async for output in self.engine.generate(prompt, sampling_params, request_id):
final_output = output
return {
"id": request_id,
"object": "chat.completion",
"created": 1234567890,
"model": request.model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": final_output.outputs[0].text,
},
"finish_reason": final_output.outputs[0].finish_reason,
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
}
if __name__ == "__main__":
service = VLLMService()
service.run()