feat: support the chat completions API in the Rust CLI (#862)
This is a substantial PR to add support for the chat completions API, which in turn makes it possible to use non-OpenAI model providers (just like in the TypeScript CLI): * It moves a number of structs from `client.rs` to `client_common.rs` so they can be shared. * It introduces support for the chat completions API in `chat_completions.rs`. * It updates `ModelProviderInfo` so that `env_key` is `Option<String>` instead of `String` (for e.g., ollama) and adds a `wire_api` field * It updates `client.rs` to choose between `stream_responses()` and `stream_chat_completions()` based on the `wire_api` for the `ModelProviderInfo` * It updates the `exec` and TUI CLIs to no longer fail if the `OPENAI_API_KEY` environment variable is not set * It updates the TUI so that `EventMsg::Error` is displayed more prominently when it occurs, particularly now that it is important to alert users to the `CodexErr::EnvVar` variant. * `CodexErr::EnvVar` was updated to include an optional `instructions` field so we can preserve the behavior where we direct users to https://platform.openai.com if `OPENAI_API_KEY` is not set. * Cleaned up the "welcome message" in the TUI to ensure the model provider is displayed. * Updated the docs in `codex-rs/README.md`. To exercise the chat completions API from OpenAI models, I added the following to my `config.toml`: ```toml model = "gpt-4o" model_provider = "openai-chat-completions" [model_providers.openai-chat-completions] name = "OpenAI using Chat Completions" base_url = "https://api.openai.com/v1" env_key = "OPENAI_API_KEY" wire_api = "chat" ``` Though to test a non-OpenAI provider, I installed ollama with mistral locally on my Mac because ChatGPT said that would be a good match for my hardware: ```shell brew install ollama ollama serve ollama pull mistral ``` Then I added the following to my `~/.codex/config.toml`: ```toml model = "mistral" model_provider = "ollama" ``` Note this code could certainly use more test coverage, but I want to get this in so folks can start playing with it. For reference, I believe https://github.com/openai/codex/pull/247 was roughly the comparable PR on the TypeScript side.
This commit is contained in:
@@ -31,10 +31,13 @@ use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::WireApi;
|
||||
use crate::chat_completions::AggregateStreamExt;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client::Prompt;
|
||||
use crate::client::ResponseEvent;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::exec::ExecParams;
|
||||
@@ -65,7 +68,6 @@ use crate::safety::assess_command_safety;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use crate::user_notification::UserNotification;
|
||||
use crate::util::backoff;
|
||||
use crate::zdr_transcript::ZdrTranscript;
|
||||
|
||||
/// The high-level interface to the Codex system.
|
||||
/// It operates as a queue pair where you send submissions and receive events.
|
||||
@@ -181,7 +183,7 @@ struct State {
|
||||
previous_response_id: Option<String>,
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
zdr_transcript: Option<ZdrTranscript>,
|
||||
zdr_transcript: Option<ConversationHistory>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
@@ -416,11 +418,15 @@ impl Drop for Session {
|
||||
}
|
||||
|
||||
impl State {
|
||||
pub fn partial_clone(&self) -> Self {
|
||||
pub fn partial_clone(&self, retain_zdr_transcript: bool) -> Self {
|
||||
Self {
|
||||
approved_commands: self.approved_commands.clone(),
|
||||
previous_response_id: self.previous_response_id.clone(),
|
||||
zdr_transcript: self.zdr_transcript.clone(),
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
self.zdr_transcript.clone()
|
||||
} else {
|
||||
None
|
||||
},
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
@@ -534,14 +540,19 @@ async fn submission_loop(
|
||||
let client = ModelClient::new(model.clone(), provider.clone());
|
||||
|
||||
// abort any current running session and clone its state
|
||||
let retain_zdr_transcript =
|
||||
record_conversation_history(disable_response_storage, provider.wire_api);
|
||||
let state = match sess.take() {
|
||||
Some(sess) => {
|
||||
sess.abort();
|
||||
sess.state.lock().unwrap().partial_clone()
|
||||
sess.state
|
||||
.lock()
|
||||
.unwrap()
|
||||
.partial_clone(retain_zdr_transcript)
|
||||
}
|
||||
None => State {
|
||||
zdr_transcript: if disable_response_storage {
|
||||
Some(ZdrTranscript::new())
|
||||
zdr_transcript: if retain_zdr_transcript {
|
||||
Some(ConversationHistory::new())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
@@ -670,21 +681,35 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
// Persist only the net-new items of this turn to the rollout.
|
||||
sess.record_rollout_items(&net_new_turn_input).await;
|
||||
|
||||
// Construct the input that we will send to the model. When using the
|
||||
// Chat completions API (or ZDR clients), the model needs the full
|
||||
// conversation history on each turn. The rollout file, however, should
|
||||
// only record the new items that originated in this turn so that it
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
// If we are using ZDR, we need to send the transcript with every turn.
|
||||
let mut full_transcript = transcript.contents();
|
||||
full_transcript.extend(net_new_turn_input.clone());
|
||||
// If we are using Chat/ZDR, we need to send the transcript with every turn.
|
||||
|
||||
// 1. Build up the conversation history for the next turn.
|
||||
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat();
|
||||
|
||||
// 2. Update the in-memory transcript so that future turns
|
||||
// include these items as part of the history.
|
||||
transcript.record_items(net_new_turn_input);
|
||||
|
||||
// Note that `transcript.record_items()` does some filtering
|
||||
// such that `full_transcript` may include items that were
|
||||
// excluded from `transcript`.
|
||||
full_transcript
|
||||
} else {
|
||||
// Responses API path – we can just send the new items and
|
||||
// record the same.
|
||||
net_new_turn_input
|
||||
};
|
||||
|
||||
// Persist the input part of the turn to the rollout (user messages /
|
||||
// function_call_output from previous step).
|
||||
sess.record_rollout_items(&turn_input).await;
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
@@ -794,6 +819,7 @@ async fn run_turn(
|
||||
match try_run_turn(sess, &sub_id, &prompt).await {
|
||||
Ok(output) => return Ok(output),
|
||||
Err(CodexErr::Interrupted) => return Err(CodexErr::Interrupted),
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e) => {
|
||||
if retries < *OPENAI_STREAM_MAX_RETRIES {
|
||||
retries += 1;
|
||||
@@ -838,7 +864,7 @@ async fn try_run_turn(
|
||||
sub_id: &str,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<Vec<ProcessedResponseItem>> {
|
||||
let mut stream = sess.client.clone().stream(prompt).await?;
|
||||
let mut stream = sess.client.clone().stream(prompt).await?.aggregate();
|
||||
|
||||
// Buffer all the incoming messages from the stream first, then execute them.
|
||||
// If we execute a function call in the middle of handling the stream, it can time out.
|
||||
@@ -1612,3 +1638,15 @@ fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -> Option<St
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// See [`ConversationHistory`] for details.
|
||||
fn record_conversation_history(disable_response_storage: bool, wire_api: WireApi) -> bool {
|
||||
if disable_response_storage {
|
||||
return true;
|
||||
}
|
||||
|
||||
match wire_api {
|
||||
WireApi::Responses => false,
|
||||
WireApi::Chat => true,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user