feat: support the chat completions API in the Rust CLI (#862)
This is a substantial PR to add support for the chat completions API, which in turn makes it possible to use non-OpenAI model providers (just like in the TypeScript CLI): * It moves a number of structs from `client.rs` to `client_common.rs` so they can be shared. * It introduces support for the chat completions API in `chat_completions.rs`. * It updates `ModelProviderInfo` so that `env_key` is `Option<String>` instead of `String` (for e.g., ollama) and adds a `wire_api` field * It updates `client.rs` to choose between `stream_responses()` and `stream_chat_completions()` based on the `wire_api` for the `ModelProviderInfo` * It updates the `exec` and TUI CLIs to no longer fail if the `OPENAI_API_KEY` environment variable is not set * It updates the TUI so that `EventMsg::Error` is displayed more prominently when it occurs, particularly now that it is important to alert users to the `CodexErr::EnvVar` variant. * `CodexErr::EnvVar` was updated to include an optional `instructions` field so we can preserve the behavior where we direct users to https://platform.openai.com if `OPENAI_API_KEY` is not set. * Cleaned up the "welcome message" in the TUI to ensure the model provider is displayed. * Updated the docs in `codex-rs/README.md`. To exercise the chat completions API from OpenAI models, I added the following to my `config.toml`: ```toml model = "gpt-4o" model_provider = "openai-chat-completions" [model_providers.openai-chat-completions] name = "OpenAI using Chat Completions" base_url = "https://api.openai.com/v1" env_key = "OPENAI_API_KEY" wire_api = "chat" ``` Though to test a non-OpenAI provider, I installed ollama with mistral locally on my Mac because ChatGPT said that would be a good match for my hardware: ```shell brew install ollama ollama serve ollama pull mistral ``` Then I added the following to my `~/.codex/config.toml`: ```toml model = "mistral" model_provider = "ollama" ``` Note this code could certainly use more test coverage, but I want to get this in so folks can start playing with it. For reference, I believe https://github.com/openai/codex/pull/247 was roughly the comparable PR on the TypeScript side.
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashMap;
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::sync::LazyLock;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
@@ -23,66 +19,22 @@ use tracing::debug;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::chat_completions::stream_chat_completions;
|
||||
use crate::client_common::Payload;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::Reasoning;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::flags::OPENAI_REQUEST_MAX_RETRIES;
|
||||
use crate::flags::OPENAI_STREAM_IDLE_TIMEOUT_MS;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::util::backoff;
|
||||
|
||||
/// API request payload for a single model turn.
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
/// Optional previous response ID (when storage is enabled).
|
||||
pub prev_id: Option<String>,
|
||||
/// Optional initial instructions (only sent on first turn).
|
||||
pub instructions: Option<String>,
|
||||
/// Whether to store response on server side (disable_response_storage = !store).
|
||||
pub store: bool,
|
||||
|
||||
/// Additional tools sourced from external MCP servers. Note each key is
|
||||
/// the "fully qualified" tool name (i.e., prefixed with the server name),
|
||||
/// which should be reported to the model in place of Tool::name.
|
||||
pub extra_tools: HashMap<String, mcp_types::Tool>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
OutputItemDone(ResponseItem),
|
||||
Completed { response_id: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Payload<'a> {
|
||||
model: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
instructions: Option<&'a String>,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
input: &'a Vec<ResponseItem>,
|
||||
tools: &'a [serde_json::Value],
|
||||
tool_choice: &'static str,
|
||||
parallel_tool_calls: bool,
|
||||
reasoning: Option<Reasoning>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
previous_response_id: Option<String>,
|
||||
/// true when using the Responses API.
|
||||
store: bool,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Reasoning {
|
||||
effort: &'static str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
generate_summary: Option<bool>,
|
||||
}
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
@@ -152,7 +104,20 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
/// Dispatches to either the Responses or Chat implementation depending on
|
||||
/// the provider config. Public callers always invoke `stream()` – the
|
||||
/// specialised helpers are private to avoid accidental misuse.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses(prompt).await,
|
||||
WireApi::Chat => {
|
||||
stream_chat_completions(prompt, &self.model, &self.client, &self.provider).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation for the OpenAI *Responses* experimental API.
|
||||
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
// short circuit for tests
|
||||
warn!(path, "Streaming from fixture");
|
||||
@@ -202,8 +167,8 @@ impl ModelClient {
|
||||
|
||||
let api_key = self
|
||||
.provider
|
||||
.api_key()
|
||||
.ok_or_else(|| crate::error::CodexErr::EnvVar("API_KEY"))?;
|
||||
.api_key()?
|
||||
.expect("Repsones API requires an API key");
|
||||
let res = self
|
||||
.client
|
||||
.post(&url)
|
||||
@@ -396,18 +361,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// used in tests to stream from a text SSE file
|
||||
async fn stream_from_fixture(path: impl AsRef<Path>) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(16);
|
||||
|
||||
Reference in New Issue
Block a user