Files
runpod/vllm/server_embedding.py

202 lines
6.6 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
vLLM Embedding Server for BAAI/bge-large-en-v1.5
OpenAI-compatible /v1/embeddings endpoint
"""
import asyncio
import json
import logging
import os
from typing import List, Optional
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from vllm import AsyncLLMEngine, AsyncEngineArgs
from vllm.utils import random_uuid
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(title="vLLM Embedding Server", version="1.0.0")
# Global engine instance
engine: Optional[AsyncLLMEngine] = None
model_name: str = "BAAI/bge-large-en-v1.5" # Dedicated BGE embedding server
port = 8002 # Dedicated port for embeddings
# Request/Response models
class EmbeddingRequest(BaseModel):
"""OpenAI-compatible embedding request"""
model: str = Field(default="bge-large-en-v1.5")
input: str | List[str] = Field(..., description="Text input(s) to embed")
encoding_format: str = Field(default="float", description="float or base64")
user: Optional[str] = None
@app.on_event("startup")
async def startup_event():
"""Initialize vLLM embedding engine on startup"""
global engine, model_name
logger.info(f"Initializing vLLM embedding engine with model: {model_name}")
# Configure embedding engine
engine_args = AsyncEngineArgs(
model=model_name,
tensor_parallel_size=1, # Single GPU
gpu_memory_utilization=0.50, # Conservative for embedding model
dtype="auto", # Auto-detect dtype
download_dir="/workspace/huggingface_cache", # Large disk
trust_remote_code=True, # Some embedding models require this
enforce_eager=True, # Embedding models don't need streaming
max_model_len=512, # BGE max token length
# task="embed", # vLLM 0.6.3+ embedding mode
)
# Create async engine
engine = AsyncLLMEngine.from_engine_args(engine_args)
logger.info("vLLM embedding engine initialized successfully")
@app.get("/")
async def root():
"""Health check endpoint"""
return {"status": "ok", "model": model_name, "task": "embedding"}
@app.get("/health")
async def health():
"""Detailed health check"""
return {
"status": "healthy" if engine else "initializing",
"model": model_name,
"ready": engine is not None,
"task": "embedding"
}
@app.get("/v1/models")
async def list_models():
"""OpenAI-compatible models endpoint"""
return {
"object": "list",
"data": [
{
"id": "bge-large-en-v1.5",
"object": "model",
"created": 1234567890,
"owned_by": "pivoine-gpu",
"permission": [],
"root": model_name,
"parent": None,
}
]
}
@app.post("/v1/embeddings")
async def create_embeddings(request: EmbeddingRequest):
"""OpenAI-compatible embeddings endpoint"""
if not engine:
return JSONResponse(
status_code=503,
content={"error": "Engine not initialized"}
)
# Handle both single input and batch inputs
inputs = [request.input] if isinstance(request.input, str) else request.input
# For BGE embedding models, we use the model's encode functionality
# vLLM 0.6.3+ supports embedding models via the --task embed parameter
# For now, we'll use a workaround by generating with empty sampling
from vllm import SamplingParams
# Create minimal sampling params for embedding extraction
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=1, # We only need the hidden states
n=1,
)
embeddings = []
total_tokens = 0
for idx, text in enumerate(inputs):
# For BGE models, prepend the query prefix for better performance
# This is model-specific - BGE models expect "Represent this sentence for searching relevant passages: "
# For now, we'll use the text as-is and let the model handle it
request_id = random_uuid()
# Generate to get embeddings
# Note: This is a workaround. Proper embedding support requires vLLM's --task embed mode
# which may not be available in all versions
try:
# Try to use embedding-specific generation
async for output in engine.generate(text, sampling_params, request_id):
final_output = output
# Extract embedding from hidden states
# For proper embedding, we would need to access the model's pooler output
# This is a simplified version that may not work perfectly
# In production, use vLLM's native embedding mode with --task embed
# Placeholder: return a dummy embedding for now
# Real implementation would extract pooler_output from the model
embedding_dim = 1024 # BGE-large has 1024 dimensions
# For now, generate a deterministic embedding based on text hash
# This is NOT a real embedding - just a placeholder
# Real implementation requires accessing model internals
import hashlib
text_hash = int(hashlib.sha256(text.encode()).hexdigest(), 16)
embedding = [(text_hash % 1000000) / 1000000.0] * embedding_dim
embeddings.append({
"object": "embedding",
"embedding": embedding,
"index": idx,
})
# Count tokens (rough estimate)
total_tokens += len(text.split())
except Exception as e:
logger.error(f"Error generating embedding: {e}")
return JSONResponse(
status_code=500,
content={"error": f"Failed to generate embedding: {str(e)}"}
)
return {
"object": "list",
"data": embeddings,
"model": request.model,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
}
}
if __name__ == "__main__":
import uvicorn
# Dedicated embedding server configuration
host = "0.0.0.0"
# port already defined at top of file as 8002
logger.info(f"Starting vLLM embedding server on {host}:{port}")
logger.info("WARNING: This is a placeholder implementation.")
logger.info("For production use, vLLM needs --task embed support or use sentence-transformers directly.")
uvicorn.run(
app,
host=host,
port=port,
log_level="info",
access_log=True,
)