diff --git a/model-orchestrator/orchestrator_subprocess.py b/model-orchestrator/orchestrator_subprocess.py index 3245d4d..23ae90f 100644 --- a/model-orchestrator/orchestrator_subprocess.py +++ b/model-orchestrator/orchestrator_subprocess.py @@ -213,6 +213,27 @@ async def health_check(): } +@app.get("/v1/models") +async def list_models_openai(): + """OpenAI-compatible models listing endpoint""" + models_list = [] + for model_name, model_info in model_registry.items(): + models_list.append({ + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "pivoine-gpu", + "permission": [], + "root": model_name, + "parent": None, + }) + + return { + "object": "list", + "data": models_list + } + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"]) async def proxy_request(request: Request, path: str): """Proxy requests to appropriate model service""" @@ -235,19 +256,52 @@ async def proxy_request(request: Request, path: str): model_config = model_registry[target_model] target_url = f"http://localhost:{model_config['port']}/{path}" - try: - async with httpx.AsyncClient(timeout=300.0) as client: - response = await client.request( - method=request.method, - url=target_url, - headers=dict(request.headers), - content=await request.body() - ) + # Get request details + method = request.method + headers = dict(request.headers) + headers.pop('host', None) # Remove host header + body = await request.body() - return JSONResponse( - content=response.json() if response.headers.get('content-type') == 'application/json' else response.text, - status_code=response.status_code + # Check if this is a streaming request + is_streaming = False + if method == "POST" and body: + try: + import json + body_json = json.loads(body) + is_streaming = body_json.get('stream', False) + except: + pass + + logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})") + + try: + if is_streaming: + # For streaming requests, use httpx streaming and yield chunks + async def stream_response(): + async with httpx.AsyncClient(timeout=300.0) as client: + async with client.stream(method, target_url, content=body, headers=headers) as response: + async for chunk in response.aiter_bytes(): + yield chunk + + return StreamingResponse( + stream_response(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) + else: + # For non-streaming requests, use the original behavior + async with httpx.AsyncClient(timeout=300.0) as client: + response = await client.request( + method=method, + url=target_url, + headers=headers, + content=body + ) + + return JSONResponse( + content=response.json() if response.headers.get('content-type') == 'application/json' else response.text, + status_code=response.status_code + ) except Exception as e: logger.error(f"Error proxying request: {e}") raise HTTPException(status_code=502, detail=str(e))