298 lines
11 KiB
Python
298 lines
11 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
vLLM Text Generation Service
|
||
|
|
|
||
|
|
OpenAI-compatible text generation using vLLM and Qwen 2.5 7B Instruct model.
|
||
|
|
Provides /v1/completions and /v1/chat/completions endpoints.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
from typing import AsyncIterator, Dict, List, Optional
|
||
|
|
|
||
|
|
from fastapi import Request
|
||
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams
|
||
|
|
from vllm.utils import random_uuid
|
||
|
|
|
||
|
|
# Import base service class
|
||
|
|
import sys
|
||
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
||
|
|
from core.base_service import GPUService
|
||
|
|
|
||
|
|
|
||
|
|
# Request/Response models
|
||
|
|
class CompletionRequest(BaseModel):
|
||
|
|
"""OpenAI-compatible completion request"""
|
||
|
|
model: str = Field(default="qwen-2.5-7b")
|
||
|
|
prompt: str | List[str] = Field(..., description="Text prompt(s)")
|
||
|
|
max_tokens: int = Field(default=512, ge=1, le=4096)
|
||
|
|
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||
|
|
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||
|
|
n: int = Field(default=1, ge=1, le=10)
|
||
|
|
stream: bool = Field(default=False)
|
||
|
|
stop: Optional[str | List[str]] = None
|
||
|
|
presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
||
|
|
frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0)
|
||
|
|
|
||
|
|
|
||
|
|
class ChatMessage(BaseModel):
|
||
|
|
"""Chat message format"""
|
||
|
|
role: str = Field(..., description="Role: system, user, or assistant")
|
||
|
|
content: str = Field(..., description="Message content")
|
||
|
|
|
||
|
|
|
||
|
|
class ChatCompletionRequest(BaseModel):
|
||
|
|
"""OpenAI-compatible chat completion request"""
|
||
|
|
model: str = Field(default="qwen-2.5-7b")
|
||
|
|
messages: List[ChatMessage] = Field(..., description="Chat messages")
|
||
|
|
max_tokens: int = Field(default=512, ge=1, le=4096)
|
||
|
|
temperature: float = Field(default=0.7, ge=0.0, le=2.0)
|
||
|
|
top_p: float = Field(default=1.0, ge=0.0, le=1.0)
|
||
|
|
n: int = Field(default=1, ge=1, le=10)
|
||
|
|
stream: bool = Field(default=False)
|
||
|
|
stop: Optional[str | List[str]] = None
|
||
|
|
|
||
|
|
|
||
|
|
class VLLMService(GPUService):
|
||
|
|
"""vLLM text generation service"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
# Get port from environment or use default
|
||
|
|
port = int(os.getenv("PORT", "8001"))
|
||
|
|
super().__init__(name="vllm-qwen", port=port)
|
||
|
|
|
||
|
|
# Service-specific attributes
|
||
|
|
self.engine: Optional[AsyncLLMEngine] = None
|
||
|
|
self.model_name = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
|
||
|
|
|
||
|
|
async def initialize(self):
|
||
|
|
"""Initialize vLLM engine"""
|
||
|
|
await super().initialize()
|
||
|
|
|
||
|
|
self.logger.info(f"Initializing vLLM AsyncLLMEngine with model: {self.model_name}")
|
||
|
|
|
||
|
|
# Configure engine
|
||
|
|
engine_args = AsyncEngineArgs(
|
||
|
|
model=self.model_name,
|
||
|
|
tensor_parallel_size=1, # Single GPU
|
||
|
|
gpu_memory_utilization=0.85, # Use 85% of GPU memory
|
||
|
|
max_model_len=4096, # Context length
|
||
|
|
dtype="auto", # Auto-detect dtype
|
||
|
|
download_dir=os.getenv("HF_CACHE_DIR", "/workspace/huggingface_cache"),
|
||
|
|
trust_remote_code=True, # Some models require this
|
||
|
|
enforce_eager=False, # Use CUDA graphs for better performance
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create async engine
|
||
|
|
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||
|
|
|
||
|
|
self.logger.info("vLLM AsyncLLMEngine initialized successfully")
|
||
|
|
|
||
|
|
async def cleanup(self):
|
||
|
|
"""Cleanup resources"""
|
||
|
|
await super().cleanup()
|
||
|
|
if self.engine:
|
||
|
|
# vLLM doesn't have an explicit shutdown method
|
||
|
|
self.logger.info("vLLM engine cleanup")
|
||
|
|
self.engine = None
|
||
|
|
|
||
|
|
def messages_to_prompt(self, messages: List[ChatMessage]) -> str:
|
||
|
|
"""Convert chat messages to Qwen 2.5 prompt format"""
|
||
|
|
prompt_parts = []
|
||
|
|
|
||
|
|
for msg in messages:
|
||
|
|
role = msg.role
|
||
|
|
content = msg.content
|
||
|
|
|
||
|
|
if role == "system":
|
||
|
|
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||
|
|
elif role == "user":
|
||
|
|
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||
|
|
elif role == "assistant":
|
||
|
|
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||
|
|
|
||
|
|
# Add final assistant prompt
|
||
|
|
prompt_parts.append("<|im_start|>assistant\n")
|
||
|
|
|
||
|
|
return "\n".join(prompt_parts)
|
||
|
|
|
||
|
|
def create_app(self):
|
||
|
|
"""Create FastAPI routes"""
|
||
|
|
|
||
|
|
@self.app.get("/")
|
||
|
|
async def root():
|
||
|
|
"""Root endpoint"""
|
||
|
|
return {"status": "ok", "model": self.model_name}
|
||
|
|
|
||
|
|
@self.app.get("/v1/models")
|
||
|
|
async def list_models():
|
||
|
|
"""OpenAI-compatible models endpoint"""
|
||
|
|
return {
|
||
|
|
"object": "list",
|
||
|
|
"data": [
|
||
|
|
{
|
||
|
|
"id": "qwen-2.5-7b",
|
||
|
|
"object": "model",
|
||
|
|
"created": 1234567890,
|
||
|
|
"owned_by": "pivoine-gpu",
|
||
|
|
"permission": [],
|
||
|
|
"root": self.model_name,
|
||
|
|
"parent": None,
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
@self.app.post("/v1/completions")
|
||
|
|
async def create_completion(request: CompletionRequest):
|
||
|
|
"""OpenAI-compatible completion endpoint"""
|
||
|
|
if not self.engine:
|
||
|
|
return JSONResponse(
|
||
|
|
status_code=503,
|
||
|
|
content={"error": "Engine not initialized"}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Handle both single prompt and batch prompts
|
||
|
|
prompts = [request.prompt] if isinstance(request.prompt, str) else request.prompt
|
||
|
|
|
||
|
|
# Configure sampling parameters
|
||
|
|
sampling_params = SamplingParams(
|
||
|
|
temperature=request.temperature,
|
||
|
|
top_p=request.top_p,
|
||
|
|
max_tokens=request.max_tokens,
|
||
|
|
n=request.n,
|
||
|
|
stop=request.stop if request.stop else [],
|
||
|
|
presence_penalty=request.presence_penalty,
|
||
|
|
frequency_penalty=request.frequency_penalty,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Generate completions
|
||
|
|
results = []
|
||
|
|
for prompt in prompts:
|
||
|
|
request_id = random_uuid()
|
||
|
|
|
||
|
|
if request.stream:
|
||
|
|
# Streaming response
|
||
|
|
async def generate_stream():
|
||
|
|
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||
|
|
chunk = {
|
||
|
|
"id": request_id,
|
||
|
|
"object": "text_completion",
|
||
|
|
"created": 1234567890,
|
||
|
|
"model": request.model,
|
||
|
|
"choices": [
|
||
|
|
{
|
||
|
|
"text": output.outputs[0].text,
|
||
|
|
"index": 0,
|
||
|
|
"logprobs": None,
|
||
|
|
"finish_reason": output.outputs[0].finish_reason,
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||
|
|
yield "data: [DONE]\n\n"
|
||
|
|
|
||
|
|
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||
|
|
else:
|
||
|
|
# Non-streaming response
|
||
|
|
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||
|
|
final_output = output
|
||
|
|
|
||
|
|
results.append({
|
||
|
|
"text": final_output.outputs[0].text,
|
||
|
|
"index": len(results),
|
||
|
|
"logprobs": None,
|
||
|
|
"finish_reason": final_output.outputs[0].finish_reason,
|
||
|
|
})
|
||
|
|
|
||
|
|
return {
|
||
|
|
"id": random_uuid(),
|
||
|
|
"object": "text_completion",
|
||
|
|
"created": 1234567890,
|
||
|
|
"model": request.model,
|
||
|
|
"choices": results,
|
||
|
|
"usage": {
|
||
|
|
"prompt_tokens": 0, # vLLM doesn't expose this easily
|
||
|
|
"completion_tokens": 0,
|
||
|
|
"total_tokens": 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
@self.app.post("/v1/chat/completions")
|
||
|
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||
|
|
"""OpenAI-compatible chat completion endpoint"""
|
||
|
|
if not self.engine:
|
||
|
|
return JSONResponse(
|
||
|
|
status_code=503,
|
||
|
|
content={"error": "Engine not initialized"}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Convert messages to prompt
|
||
|
|
prompt = self.messages_to_prompt(request.messages)
|
||
|
|
|
||
|
|
# Configure sampling parameters
|
||
|
|
sampling_params = SamplingParams(
|
||
|
|
temperature=request.temperature,
|
||
|
|
top_p=request.top_p,
|
||
|
|
max_tokens=request.max_tokens,
|
||
|
|
n=request.n,
|
||
|
|
stop=request.stop if request.stop else ["<|im_end|>"],
|
||
|
|
)
|
||
|
|
|
||
|
|
request_id = random_uuid()
|
||
|
|
|
||
|
|
if request.stream:
|
||
|
|
# Streaming response
|
||
|
|
async def generate_stream():
|
||
|
|
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||
|
|
chunk = {
|
||
|
|
"id": request_id,
|
||
|
|
"object": "chat.completion.chunk",
|
||
|
|
"created": 1234567890,
|
||
|
|
"model": request.model,
|
||
|
|
"choices": [
|
||
|
|
{
|
||
|
|
"index": 0,
|
||
|
|
"delta": {"content": output.outputs[0].text},
|
||
|
|
"finish_reason": output.outputs[0].finish_reason,
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||
|
|
yield "data: [DONE]\n\n"
|
||
|
|
|
||
|
|
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||
|
|
else:
|
||
|
|
# Non-streaming response
|
||
|
|
async for output in self.engine.generate(prompt, sampling_params, request_id):
|
||
|
|
final_output = output
|
||
|
|
|
||
|
|
return {
|
||
|
|
"id": request_id,
|
||
|
|
"object": "chat.completion",
|
||
|
|
"created": 1234567890,
|
||
|
|
"model": request.model,
|
||
|
|
"choices": [
|
||
|
|
{
|
||
|
|
"index": 0,
|
||
|
|
"message": {
|
||
|
|
"role": "assistant",
|
||
|
|
"content": final_output.outputs[0].text,
|
||
|
|
},
|
||
|
|
"finish_reason": final_output.outputs[0].finish_reason,
|
||
|
|
}
|
||
|
|
],
|
||
|
|
"usage": {
|
||
|
|
"prompt_tokens": 0,
|
||
|
|
"completion_tokens": 0,
|
||
|
|
"total_tokens": 0,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
service = VLLMService()
|
||
|
|
service.run()
|