diff --git a/ai/simple_vllm_server.py b/ai/simple_vllm_server.py new file mode 100644 index 0000000..0075bd2 --- /dev/null +++ b/ai/simple_vllm_server.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 +""" +Simple vLLM server using AsyncLLMEngine directly +Bypasses the multiprocessing issues we hit with the default vLLM API server +OpenAI-compatible endpoints: /v1/models and /v1/completions +""" + +import asyncio +import json +import logging +import os +from typing import AsyncIterator, Dict, List, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel, Field +from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams +from vllm.utils import random_uuid + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +# FastAPI app +app = FastAPI(title="Simple vLLM Server", version="1.0.0") + +# Global engine instance +engine: Optional[AsyncLLMEngine] = None +model_name: str = "Qwen/Qwen2.5-7B-Instruct" + +# Request/Response models +class CompletionRequest(BaseModel): + """OpenAI-compatible completion request""" + model: str = Field(default="qwen-2.5-7b") + prompt: str | List[str] = Field(..., description="Text prompt(s)") + max_tokens: int = Field(default=512, ge=1, le=4096) + temperature: float = Field(default=0.7, ge=0.0, le=2.0) + top_p: float = Field(default=1.0, ge=0.0, le=1.0) + n: int = Field(default=1, ge=1, le=10) + stream: bool = Field(default=False) + stop: Optional[str | List[str]] = None + presence_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + frequency_penalty: float = Field(default=0.0, ge=-2.0, le=2.0) + +class ChatMessage(BaseModel): + """Chat message format""" + role: str = Field(..., description="Role: system, user, or assistant") + content: str = Field(..., description="Message content") + +class ChatCompletionRequest(BaseModel): + """OpenAI-compatible chat completion request""" + model: str = Field(default="qwen-2.5-7b") + messages: List[ChatMessage] = Field(..., description="Chat messages") + max_tokens: int = Field(default=512, ge=1, le=4096) + temperature: float = Field(default=0.7, ge=0.0, le=2.0) + top_p: float = Field(default=1.0, ge=0.0, le=1.0) + n: int = Field(default=1, ge=1, le=10) + stream: bool = Field(default=False) + stop: Optional[str | List[str]] = None + +@app.on_event("startup") +async def startup_event(): + """Initialize vLLM engine on startup""" + global engine, model_name + + logger.info(f"Initializing vLLM AsyncLLMEngine with model: {model_name}") + + # Configure engine + engine_args = AsyncEngineArgs( + model=model_name, + tensor_parallel_size=1, # Single GPU + gpu_memory_utilization=0.85, # Use 85% of GPU memory + max_model_len=4096, # Context length + dtype="auto", # Auto-detect dtype + download_dir="/workspace/huggingface_cache", # Large disk + trust_remote_code=True, # Some models require this + enforce_eager=False, # Use CUDA graphs for better performance + ) + + # Create async engine + engine = AsyncLLMEngine.from_engine_args(engine_args) + + logger.info("vLLM AsyncLLMEngine initialized successfully") + +@app.get("/") +async def root(): + """Health check endpoint""" + return {"status": "ok", "model": model_name} + +@app.get("/health") +async def health(): + """Detailed health check""" + return { + "status": "healthy" if engine else "initializing", + "model": model_name, + "ready": engine is not None + } + +@app.get("/v1/models") +async def list_models(): + """OpenAI-compatible models endpoint""" + return { + "object": "list", + "data": [ + { + "id": "qwen-2.5-7b", + "object": "model", + "created": 1234567890, + "owned_by": "pivoine-gpu", + "permission": [], + "root": model_name, + "parent": None, + } + ] + } + +def messages_to_prompt(messages: List[ChatMessage]) -> str: + """Convert chat messages to a single prompt string""" + # Qwen 2.5 chat template format + prompt_parts = [] + + for msg in messages: + role = msg.role + content = msg.content + + if role == "system": + prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>") + elif role == "user": + prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>") + elif role == "assistant": + prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>") + + # Add final assistant prompt + prompt_parts.append("<|im_start|>assistant\n") + + return "\n".join(prompt_parts) + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + """OpenAI-compatible completion endpoint""" + if not engine: + return JSONResponse( + status_code=503, + content={"error": "Engine not initialized"} + ) + + # Handle both single prompt and batch prompts + prompts = [request.prompt] if isinstance(request.prompt, str) else request.prompt + + # Configure sampling parameters + sampling_params = SamplingParams( + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + n=request.n, + stop=request.stop if request.stop else [], + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, + ) + + # Generate completions + results = [] + for prompt in prompts: + request_id = random_uuid() + + if request.stream: + # Streaming response + async def generate_stream(): + async for output in engine.generate(prompt, sampling_params, request_id): + chunk = { + "id": request_id, + "object": "text_completion", + "created": 1234567890, + "model": request.model, + "choices": [ + { + "text": output.outputs[0].text, + "index": 0, + "logprobs": None, + "finish_reason": output.outputs[0].finish_reason, + } + ] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(generate_stream(), media_type="text/event-stream") + else: + # Non-streaming response + async for output in engine.generate(prompt, sampling_params, request_id): + final_output = output + + results.append({ + "text": final_output.outputs[0].text, + "index": len(results), + "logprobs": None, + "finish_reason": final_output.outputs[0].finish_reason, + }) + + return { + "id": random_uuid(), + "object": "text_completion", + "created": 1234567890, + "model": request.model, + "choices": results, + "usage": { + "prompt_tokens": 0, # vLLM doesn't expose this easily + "completion_tokens": 0, + "total_tokens": 0, + } + } + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest): + """OpenAI-compatible chat completion endpoint""" + if not engine: + return JSONResponse( + status_code=503, + content={"error": "Engine not initialized"} + ) + + # Convert messages to prompt + prompt = messages_to_prompt(request.messages) + + # Configure sampling parameters + sampling_params = SamplingParams( + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens, + n=request.n, + stop=request.stop if request.stop else ["<|im_end|>"], + ) + + request_id = random_uuid() + + if request.stream: + # Streaming response + async def generate_stream(): + async for output in engine.generate(prompt, sampling_params, request_id): + chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": 1234567890, + "model": request.model, + "choices": [ + { + "index": 0, + "delta": {"content": output.outputs[0].text}, + "finish_reason": output.outputs[0].finish_reason, + } + ] + } + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(generate_stream(), media_type="text/event-stream") + else: + # Non-streaming response + async for output in engine.generate(prompt, sampling_params, request_id): + final_output = output + + return { + "id": request_id, + "object": "chat.completion", + "created": 1234567890, + "model": request.model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": final_output.outputs[0].text, + }, + "finish_reason": final_output.outputs[0].finish_reason, + } + ], + "usage": { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + } + +if __name__ == "__main__": + import uvicorn + + # Get configuration from environment + host = os.getenv("VLLM_HOST", "0.0.0.0") + port = int(os.getenv("VLLM_PORT", "8000")) + + logger.info(f"Starting vLLM server on {host}:{port}") + + uvicorn.run( + app, + host=host, + port=port, + log_level="info", + access_log=True, + )