This PR introduces an `--upstream-url` option to the proxy CLI that
determines the URL that Responses API requests should be forwarded to.
To preserve existing behavior, the default value is
`"https://api.openai.com/v1/responses"`.
The motivation for this change is that the [Codex GitHub
Action](https://github.com/openai/codex-action) should support those who
use the OpenAI Responses API via Azure. Relevant issues:
- https://github.com/openai/codex-action/issues/28
- https://github.com/openai/codex-action/issues/38
- https://github.com/openai/codex-action/pull/44
Though rather than introduce a bunch of new Azure-specific logic in the
action as https://github.com/openai/codex-action/pull/44 proposes, we
should leverage our Responses API proxy to get the _hardening_ benefits
it provides:
d5853d9c47/codex-rs/responses-api-proxy/README.md (hardening-details)
This PR should make this straightforward to incorporate in the action.
To see how the updated version of the action would consume these new
options, see https://github.com/openai/codex-action/pull/47.
238 lines
7.3 KiB
Rust
238 lines
7.3 KiB
Rust
use std::fs::File;
|
|
use std::fs::{self};
|
|
use std::io::Write;
|
|
use std::net::SocketAddr;
|
|
use std::net::TcpListener;
|
|
use std::path::Path;
|
|
use std::path::PathBuf;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use anyhow::Context;
|
|
use anyhow::Result;
|
|
use anyhow::anyhow;
|
|
use clap::Parser;
|
|
use reqwest::Url;
|
|
use reqwest::blocking::Client;
|
|
use reqwest::header::AUTHORIZATION;
|
|
use reqwest::header::HOST;
|
|
use reqwest::header::HeaderMap;
|
|
use reqwest::header::HeaderName;
|
|
use reqwest::header::HeaderValue;
|
|
use serde::Serialize;
|
|
use tiny_http::Header;
|
|
use tiny_http::Method;
|
|
use tiny_http::Request;
|
|
use tiny_http::Response;
|
|
use tiny_http::Server;
|
|
use tiny_http::StatusCode;
|
|
|
|
mod read_api_key;
|
|
use read_api_key::read_auth_header_from_stdin;
|
|
|
|
/// CLI arguments for the proxy.
|
|
#[derive(Debug, Clone, Parser)]
|
|
#[command(name = "responses-api-proxy", about = "Minimal OpenAI responses proxy")]
|
|
pub struct Args {
|
|
/// Port to listen on. If not set, an ephemeral port is used.
|
|
#[arg(long)]
|
|
pub port: Option<u16>,
|
|
|
|
/// Path to a JSON file to write startup info (single line). Includes {"port": <u16>}.
|
|
#[arg(long, value_name = "FILE")]
|
|
pub server_info: Option<PathBuf>,
|
|
|
|
/// Enable HTTP shutdown endpoint at GET /shutdown
|
|
#[arg(long)]
|
|
pub http_shutdown: bool,
|
|
|
|
/// Absolute URL the proxy should forward requests to (defaults to OpenAI).
|
|
#[arg(long, default_value = "https://api.openai.com/v1/responses")]
|
|
pub upstream_url: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ServerInfo {
|
|
port: u16,
|
|
pid: u32,
|
|
}
|
|
|
|
struct ForwardConfig {
|
|
upstream_url: Url,
|
|
host_header: HeaderValue,
|
|
}
|
|
|
|
/// Entry point for the library main, for parity with other crates.
|
|
pub fn run_main(args: Args) -> Result<()> {
|
|
let auth_header = read_auth_header_from_stdin()?;
|
|
|
|
let upstream_url = Url::parse(&args.upstream_url).context("parsing --upstream-url")?;
|
|
let host = match (upstream_url.host_str(), upstream_url.port()) {
|
|
(Some(host), Some(port)) => format!("{host}:{port}"),
|
|
(Some(host), None) => host.to_string(),
|
|
_ => return Err(anyhow!("upstream URL must include a host")),
|
|
};
|
|
let host_header =
|
|
HeaderValue::from_str(&host).context("constructing Host header from upstream URL")?;
|
|
|
|
let forward_config = Arc::new(ForwardConfig {
|
|
upstream_url,
|
|
host_header,
|
|
});
|
|
|
|
let (listener, bound_addr) = bind_listener(args.port)?;
|
|
if let Some(path) = args.server_info.as_ref() {
|
|
write_server_info(path, bound_addr.port())?;
|
|
}
|
|
let server = Server::from_listener(listener, None)
|
|
.map_err(|err| anyhow!("creating HTTP server: {err}"))?;
|
|
let client = Arc::new(
|
|
Client::builder()
|
|
// Disable reqwest's 30s default so long-lived response streams keep flowing.
|
|
.timeout(None::<Duration>)
|
|
.build()
|
|
.context("building reqwest client")?,
|
|
);
|
|
|
|
eprintln!("responses-api-proxy listening on {bound_addr}");
|
|
|
|
let http_shutdown = args.http_shutdown;
|
|
for request in server.incoming_requests() {
|
|
let client = client.clone();
|
|
let forward_config = forward_config.clone();
|
|
std::thread::spawn(move || {
|
|
if http_shutdown && request.method() == &Method::Get && request.url() == "/shutdown" {
|
|
let _ = request.respond(Response::new_empty(StatusCode(200)));
|
|
std::process::exit(0);
|
|
}
|
|
|
|
if let Err(e) = forward_request(&client, auth_header, &forward_config, request) {
|
|
eprintln!("forwarding error: {e}");
|
|
}
|
|
});
|
|
}
|
|
|
|
Err(anyhow!("server stopped unexpectedly"))
|
|
}
|
|
|
|
fn bind_listener(port: Option<u16>) -> Result<(TcpListener, SocketAddr)> {
|
|
let addr = SocketAddr::from(([127, 0, 0, 1], port.unwrap_or(0)));
|
|
let listener = TcpListener::bind(addr).with_context(|| format!("failed to bind {addr}"))?;
|
|
let bound = listener.local_addr().context("failed to read local_addr")?;
|
|
Ok((listener, bound))
|
|
}
|
|
|
|
fn write_server_info(path: &Path, port: u16) -> Result<()> {
|
|
if let Some(parent) = path.parent()
|
|
&& !parent.as_os_str().is_empty()
|
|
{
|
|
fs::create_dir_all(parent)?;
|
|
}
|
|
|
|
let info = ServerInfo {
|
|
port,
|
|
pid: std::process::id(),
|
|
};
|
|
let mut data = serde_json::to_string(&info)?;
|
|
data.push('\n');
|
|
let mut f = File::create(path)?;
|
|
f.write_all(data.as_bytes())?;
|
|
Ok(())
|
|
}
|
|
|
|
fn forward_request(
|
|
client: &Client,
|
|
auth_header: &'static str,
|
|
config: &ForwardConfig,
|
|
mut req: Request,
|
|
) -> Result<()> {
|
|
// Only allow POST /v1/responses exactly, no query string.
|
|
let method = req.method().clone();
|
|
let url_path = req.url().to_string();
|
|
let allow = method == Method::Post && url_path == "/v1/responses";
|
|
|
|
if !allow {
|
|
let resp = Response::new_empty(StatusCode(403));
|
|
let _ = req.respond(resp);
|
|
return Ok(());
|
|
}
|
|
|
|
// Read request body
|
|
let mut body = Vec::new();
|
|
let mut reader = req.as_reader();
|
|
std::io::Read::read_to_end(&mut reader, &mut body)?;
|
|
|
|
// Build headers for upstream, forwarding everything from the incoming
|
|
// request except Authorization (we replace it below).
|
|
let mut headers = HeaderMap::new();
|
|
for header in req.headers() {
|
|
let name_ascii = header.field.as_str();
|
|
let lower = name_ascii.to_ascii_lowercase();
|
|
if lower.as_str() == "authorization" || lower.as_str() == "host" {
|
|
continue;
|
|
}
|
|
|
|
let header_name = match HeaderName::from_bytes(lower.as_bytes()) {
|
|
Ok(name) => name,
|
|
Err(_) => continue,
|
|
};
|
|
if let Ok(value) = HeaderValue::from_bytes(header.value.as_bytes()) {
|
|
headers.append(header_name, value);
|
|
}
|
|
}
|
|
|
|
// As part of our effort to to keep `auth_header` secret, we use a
|
|
// combination of `from_static()` and `set_sensitive(true)`.
|
|
let mut auth_header_value = HeaderValue::from_static(auth_header);
|
|
auth_header_value.set_sensitive(true);
|
|
headers.insert(AUTHORIZATION, auth_header_value);
|
|
|
|
headers.insert(HOST, config.host_header.clone());
|
|
|
|
let upstream_resp = client
|
|
.post(config.upstream_url.clone())
|
|
.headers(headers)
|
|
.body(body)
|
|
.send()
|
|
.context("forwarding request to upstream")?;
|
|
|
|
// We have to create an adapter between a `reqwest::blocking::Response`
|
|
// and a `tiny_http::Response`. Fortunately, `reqwest::blocking::Response`
|
|
// implements `Read`, so we can use it directly as the body of the
|
|
// `tiny_http::Response`.
|
|
let status = upstream_resp.status();
|
|
let mut response_headers = Vec::new();
|
|
for (name, value) in upstream_resp.headers().iter() {
|
|
// Skip headers that tiny_http manages itself.
|
|
if matches!(
|
|
name.as_str(),
|
|
"content-length" | "transfer-encoding" | "connection" | "trailer" | "upgrade"
|
|
) {
|
|
continue;
|
|
}
|
|
|
|
if let Ok(header) = Header::from_bytes(name.as_str().as_bytes(), value.as_bytes()) {
|
|
response_headers.push(header);
|
|
}
|
|
}
|
|
|
|
let content_length = upstream_resp.content_length().and_then(|len| {
|
|
if len <= usize::MAX as u64 {
|
|
Some(len as usize)
|
|
} else {
|
|
None
|
|
}
|
|
});
|
|
|
|
let response = Response::new(
|
|
StatusCode(status.as_u16()),
|
|
response_headers,
|
|
upstream_resp,
|
|
content_length,
|
|
None,
|
|
);
|
|
|
|
let _ = req.respond(response);
|
|
Ok(())
|
|
}
|