This is nice to have for debugging. While here, also cleaned up a bunch of unnecessary noise in `write_server_info()`.
209 lines
6.5 KiB
Rust
209 lines
6.5 KiB
Rust
use std::fs::File;
|
|
use std::fs::{self};
|
|
use std::io::Write;
|
|
use std::net::SocketAddr;
|
|
use std::net::TcpListener;
|
|
use std::path::Path;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Context;
|
|
use anyhow::Result;
|
|
use anyhow::anyhow;
|
|
use clap::Parser;
|
|
use reqwest::blocking::Client;
|
|
use reqwest::header::AUTHORIZATION;
|
|
use reqwest::header::HOST;
|
|
use reqwest::header::HeaderMap;
|
|
use reqwest::header::HeaderName;
|
|
use reqwest::header::HeaderValue;
|
|
use serde::Serialize;
|
|
use tiny_http::Header;
|
|
use tiny_http::Method;
|
|
use tiny_http::Request;
|
|
use tiny_http::Response;
|
|
use tiny_http::Server;
|
|
use tiny_http::StatusCode;
|
|
|
|
mod read_api_key;
|
|
use read_api_key::read_auth_header_from_stdin;
|
|
|
|
/// CLI arguments for the proxy.
|
|
#[derive(Debug, Clone, Parser)]
|
|
#[command(name = "responses-api-proxy", about = "Minimal OpenAI responses proxy")]
|
|
pub struct Args {
|
|
/// Port to listen on. If not set, an ephemeral port is used.
|
|
#[arg(long)]
|
|
pub port: Option<u16>,
|
|
|
|
/// Path to a JSON file to write startup info (single line). Includes {"port": <u16>}.
|
|
#[arg(long, value_name = "FILE")]
|
|
pub server_info: Option<PathBuf>,
|
|
|
|
/// Enable HTTP shutdown endpoint at GET /shutdown
|
|
#[arg(long)]
|
|
pub http_shutdown: bool,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ServerInfo {
|
|
port: u16,
|
|
pid: u32,
|
|
}
|
|
|
|
/// Entry point for the library main, for parity with other crates.
|
|
pub fn run_main(args: Args) -> Result<()> {
|
|
let auth_header = read_auth_header_from_stdin()?;
|
|
|
|
let (listener, bound_addr) = bind_listener(args.port)?;
|
|
if let Some(path) = args.server_info.as_ref() {
|
|
write_server_info(path, bound_addr.port())?;
|
|
}
|
|
let server = Server::from_listener(listener, None)
|
|
.map_err(|err| anyhow!("creating HTTP server: {err}"))?;
|
|
let client = Arc::new(
|
|
Client::builder()
|
|
// Disable reqwest's 30s default so long-lived response streams keep flowing.
|
|
.timeout(None::<Duration>)
|
|
.build()
|
|
.context("building reqwest client")?,
|
|
);
|
|
|
|
eprintln!("responses-api-proxy listening on {bound_addr}");
|
|
|
|
let http_shutdown = args.http_shutdown;
|
|
for request in server.incoming_requests() {
|
|
let client = client.clone();
|
|
std::thread::spawn(move || {
|
|
if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" {
|
|
let _ = request.respond(Response::new_empty(StatusCode(200)));
|
|
std::process::exit(0);
|
|
}
|
|
|
|
if let Err(e) = forward_request(&client, auth_header, request) {
|
|
eprintln!("forwarding error: {e}");
|
|
}
|
|
});
|
|
}
|
|
|
|
Err(anyhow!("server stopped unexpectedly"))
|
|
}
|
|
|
|
fn bind_listener(port: Option<u16>) -> Result<(TcpListener, SocketAddr)> {
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], port.unwrap_or(0)));
|
|
let listener = TcpListener::bind(addr).with_context(|| format!("failed to bind {addr}"))?;
|
|
let bound = listener.local_addr().context("failed to read local_addr")?;
|
|
Ok((listener, bound))
|
|
}
|
|
|
|
fn write_server_info(path: &Path, port: u16) -> Result<()> {
|
|
if let Some(parent) = path.parent()
|
|
&& !parent.as_os_str().is_empty()
|
|
{
|
|
fs::create_dir_all(parent)?;
|
|
}
|
|
|
|
let info = ServerInfo {
|
|
port,
|
|
pid: std::process::id(),
|
|
};
|
|
let mut data = serde_json::to_string(&info)?;
|
|
data.push('\n');
|
|
let mut f = File::create(path)?;
|
|
f.write_all(data.as_bytes())?;
|
|
Ok(())
|
|
}
|
|
|
|
fn forward_request(client: &Client, auth_header: &'static str, mut req: Request) -> Result<()> {
|
|
// Only allow POST /v1/responses exactly, no query string.
|
|
let method = req.method().clone();
|
|
let url_path = req.url().to_string();
|
|
let allow = method == Method::Post && url_path == "/v1/responses";
|
|
|
|
if !allow {
|
|
let resp = Response::new_empty(StatusCode(403));
|
|
let _ = req.respond(resp);
|
|
return Ok(());
|
|
}
|
|
|
|
// Read request body
|
|
let mut body = Vec::new();
|
|
let mut reader = req.as_reader();
|
|
std::io::Read::read_to_end(&mut reader, &mut body)?;
|
|
|
|
// Build headers for upstream, forwarding everything from the incoming
|
|
// request except Authorization (we replace it below).
|
|
let mut headers = HeaderMap::new();
|
|
for header in req.headers() {
|
|
let name_ascii = header.field.as_str();
|
|
let lower = name_ascii.to_ascii_lowercase();
|
|
if lower.as_str() == "authorization" || lower.as_str() == "host" {
|
|
continue;
|
|
}
|
|
|
|
let header_name = match HeaderName::from_bytes(lower.as_bytes()) {
|
|
Ok(name) => name,
|
|
Err(_) => continue,
|
|
};
|
|
if let Ok(value) = HeaderValue::from_bytes(header.value.as_bytes()) {
|
|
headers.append(header_name, value);
|
|
}
|
|
}
|
|
|
|
// As part of our effort to to keep `auth_header` secret, we use a
|
|
// combination of `from_static()` and `set_sensitive(true)`.
|
|
let mut auth_header_value = HeaderValue::from_static(auth_header);
|
|
auth_header_value.set_sensitive(true);
|
|
headers.insert(AUTHORIZATION, auth_header_value);
|
|
|
|
headers.insert(HOST, HeaderValue::from_static("api.openai.com"));
|
|
|
|
let upstream = "https://api.openai.com/v1/responses";
|
|
let upstream_resp = client
|
|
.post(upstream)
|
|
.headers(headers)
|
|
.body(body)
|
|
.send()
|
|
.context("forwarding request to upstream")?;
|
|
|
|
// We have to create an adapter between a `reqwest::blocking::Response`
|
|
// and a `tiny_http::Response`. Fortunately, `reqwest::blocking::Response`
|
|
// implements `Read`, so we can use it directly as the body of the
|
|
// `tiny_http::Response`.
|
|
let status = upstream_resp.status();
|
|
let mut response_headers = Vec::new();
|
|
for (name, value) in upstream_resp.headers().iter() {
|
|
// Skip headers that tiny_http manages itself.
|
|
if matches!(
|
|
name.as_str(),
|
|
"content-length" | "transfer-encoding" | "connection" | "trailer" | "upgrade"
|
|
) {
|
|
continue;
|
|
}
|
|
|
|
if let Ok(header) = Header::from_bytes(name.as_str().as_bytes(), value.as_bytes()) {
|
|
response_headers.push(header);
|
|
}
|
|
}
|
|
|
|
let content_length = upstream_resp.content_length().and_then(|len| {
|
|
if len <= usize::MAX as u64 {
|
|
Some(len as usize)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
let response = Response::new(
|
|
StatusCode(status.as_u16()),
|
|
response_headers,
|
|
upstream_resp,
|
|
content_length,
|
|
None,
|
|
);
|
|
|
|
let _ = req.respond(response);
|
|
Ok(())
|
|
}
|