diff --git a/model-orchestrator/orchestrator.py b/model-orchestrator/orchestrator.py index 9091537..822d61b 100644 --- a/model-orchestrator/orchestrator.py +++ b/model-orchestrator/orchestrator.py @@ -197,24 +197,48 @@ async def proxy_request(model_name: str, request: Request): # Build target URL target_url = f"http://localhost:{port}{path}" - logger.info(f"Proxying {method} request to {target_url}") + # Check if this is a streaming request + body = await request.body() + is_streaming = False + if method == "POST" and body: + try: + import json + body_json = json.loads(body) + is_streaming = body_json.get('stream', False) + except: + pass - async with httpx.AsyncClient(timeout=300.0) as client: - # Handle different request types - if method == "GET": - response = await client.get(target_url, headers=headers) - elif method == "POST": - body = await request.body() - response = await client.post(target_url, content=body, headers=headers) - else: - raise HTTPException(status_code=405, detail=f"Method {method} not supported") + logger.info(f"Proxying {method} request to {target_url} (streaming: {is_streaming})") - # Return response - return JSONResponse( - content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text, - status_code=response.status_code, - headers=dict(response.headers) + if is_streaming: + # For streaming requests, use httpx streaming and yield chunks + async def stream_response(): + async with httpx.AsyncClient(timeout=300.0) as client: + async with client.stream(method, target_url, content=body, headers=headers) as response: + async for chunk in response.aiter_bytes(): + yield chunk + + return StreamingResponse( + stream_response(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) + else: + # For non-streaming requests, use the original behavior + async with httpx.AsyncClient(timeout=300.0) as client: + if method == "GET": + response = await client.get(target_url, headers=headers) + elif method == "POST": + response = await client.post(target_url, content=body, headers=headers) + else: + raise HTTPException(status_code=405, detail=f"Method {method} not supported") + + # Return response + return JSONResponse( + content=response.json() if response.headers.get('content-type', '').startswith('application/json') else response.text, + status_code=response.status_code, + headers=dict(response.headers) + ) @app.on_event("startup")