fix: chat completions API now also passes tools along (#1167)
Prior to this PR, there were two big misses in `chat_completions.rs`: 1. The loop in `stream_chat_completions()` was only including items of type `ResponseItem::Message` when building up the `"messages"` JSON for the `POST` request to the `chat/completions` endpoint. This fixes things by ensuring other variants (`FunctionCall`, `LocalShellCall`, and `FunctionCallOutput`) are included, as well. 2. In `process_chat_sse()`, we were not recording tool calls and were only emitting items of type `ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream. Now we introduce `FunctionCallState`, which is used to accumulate the `delta`s of type `tool_calls`, so we can ultimately emit a `ResponseItem::FunctionCall`, when appropriate. While function calling now appears to work for chat completions with my local testing, I believe that there are still edge cases that are not covered and that this codepath would benefit from a battery of integration tests. (As part of that further cleanup, we should also work to support streaming responses in the UI.) The other important part of this PR is some cleanup in `core/src/codex.rs`. In particular, it was hard to reason about how `run_task()` was building up the list of messages to include in a request across the various cases: - Responses API - Chat Completions API - Responses API used in concert with ZDR I like to think things are a bit cleaner now where: - `zdr_transcript` (if present) contains all messages in the history of the conversation, which includes function call outputs that have not been sent back to the model yet - `pending_input` includes any messages the user has submitted while the turn is in flight that need to be injected as part of the next `POST` to the model - `input_for_next_turn` includes the tool call outputs that have not been sent back to the model yet
This commit is contained in:
@@ -20,6 +20,7 @@ use codex_apply_patch::MaybeApplyPatchVerified;
|
||||
use codex_apply_patch::maybe_parse_apply_patch_verified;
|
||||
use codex_apply_patch::print_summary;
|
||||
use futures::prelude::*;
|
||||
use mcp_types::CallToolResult;
|
||||
use serde::Serialize;
|
||||
use serde_json;
|
||||
use tokio::sync::Notify;
|
||||
@@ -295,6 +296,17 @@ impl Session {
|
||||
state.approved_commands.insert(cmd);
|
||||
}
|
||||
|
||||
/// Records items to both the rollout and the chat completions/ZDR
|
||||
/// transcript, if enabled.
|
||||
async fn record_conversation_items(&self, items: &[ResponseItem]) {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
self.record_rollout_items(items).await;
|
||||
|
||||
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(items);
|
||||
}
|
||||
}
|
||||
|
||||
/// Append the given items to the session's rollout transcript (if enabled)
|
||||
/// and persist them to disk.
|
||||
async fn record_rollout_items(&self, items: &[ResponseItem]) {
|
||||
@@ -388,7 +400,7 @@ impl Session {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
) -> anyhow::Result<mcp_types::CallToolResult> {
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
self.mcp_connection_manager
|
||||
.call_tool(server, tool, arguments, timeout)
|
||||
.await
|
||||
@@ -760,6 +772,19 @@ async fn submission_loop(
|
||||
debug!("Agent loop exited");
|
||||
}
|
||||
|
||||
/// Takes a user message as input and runs a loop where, at each turn, the model
|
||||
/// replies with either:
|
||||
///
|
||||
/// - requested function calls
|
||||
/// - an assistant message
|
||||
///
|
||||
/// While it is possible for the model to return multiple of these items in a
|
||||
/// single turn, in practice, we generally one item per turn:
|
||||
///
|
||||
/// - If the model requests a function call, we execute it and send the output
|
||||
/// back to the model in the next turn.
|
||||
/// - If the model sends only an assistant message, we record it in the
|
||||
/// conversation history and consider the task complete.
|
||||
async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
if input.is_empty() {
|
||||
return;
|
||||
@@ -772,10 +797,14 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut pending_response_input: Vec<ResponseInputItem> = vec![ResponseInputItem::from(input)];
|
||||
let initial_input_for_turn = ResponseInputItem::from(input);
|
||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||
.await;
|
||||
|
||||
let mut input_for_next_turn: Vec<ResponseInputItem> = vec![initial_input_for_turn];
|
||||
let last_agent_message: Option<String>;
|
||||
loop {
|
||||
let mut net_new_turn_input = pending_response_input
|
||||
let mut net_new_turn_input = input_for_next_turn
|
||||
.drain(..)
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<_>>();
|
||||
@@ -783,11 +812,12 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
// Note that pending_input would be something like a message the user
|
||||
// submitted through the UI while the model was running. Though the UI
|
||||
// may support this, the model might not.
|
||||
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from);
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
// Persist only the net-new items of this turn to the rollout.
|
||||
sess.record_rollout_items(&net_new_turn_input).await;
|
||||
let pending_input = sess
|
||||
.get_pending_input()
|
||||
.into_iter()
|
||||
.map(ResponseItem::from)
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
sess.record_conversation_items(&pending_input).await;
|
||||
|
||||
// Construct the input that we will send to the model. When using the
|
||||
// Chat completions API (or ZDR clients), the model needs the full
|
||||
@@ -796,20 +826,24 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
// represents an append-only log without duplicates.
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
// If we are using Chat/ZDR, we need to send the transcript with every turn.
|
||||
|
||||
// 1. Build up the conversation history for the next turn.
|
||||
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat();
|
||||
|
||||
// 2. Update the in-memory transcript so that future turns
|
||||
// include these items as part of the history.
|
||||
transcript.record_items(&net_new_turn_input);
|
||||
|
||||
// Note that `transcript.record_items()` does some filtering
|
||||
// such that `full_transcript` may include items that were
|
||||
// excluded from `transcript`.
|
||||
full_transcript
|
||||
// If we are using Chat/ZDR, we need to send the transcript with
|
||||
// every turn. By induction, `transcript` already contains:
|
||||
// - The `input` that kicked off this task.
|
||||
// - Each `ResponseItem` that was recorded in the previous turn.
|
||||
// - Each response to a `ResponseItem` (in practice, the only
|
||||
// response type we seem to have is `FunctionCallOutput`).
|
||||
//
|
||||
// The only thing the `transcript` does not contain is the
|
||||
// `pending_input` that was injected while the model was
|
||||
// running. We need to add that to the conversation history
|
||||
// so that the model can see it in the next turn.
|
||||
[transcript.contents(), pending_input].concat()
|
||||
} else {
|
||||
// In practice, net_new_turn_input should contain only:
|
||||
// - User messages
|
||||
// - Outputs for function calls requested by the model
|
||||
net_new_turn_input.extend(pending_input);
|
||||
|
||||
// Responses API path – we can just send the new items and
|
||||
// record the same.
|
||||
net_new_turn_input
|
||||
@@ -830,29 +864,86 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
.collect();
|
||||
match run_turn(&sess, sub_id.clone(), turn_input).await {
|
||||
Ok(turn_output) => {
|
||||
let (items, responses): (Vec<_>, Vec<_>) = turn_output
|
||||
.into_iter()
|
||||
.map(|p| (p.item, p.response))
|
||||
.unzip();
|
||||
let responses = responses
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.collect::<Vec<ResponseInputItem>>();
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
for processed_response_item in turn_output {
|
||||
let ProcessedResponseItem { item, response } = processed_response_item;
|
||||
match (&item, &response) {
|
||||
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
||||
// If the model returned a message, we need to record it.
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
}
|
||||
(
|
||||
ResponseItem::LocalShellCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let (content, success): (String, Option<bool>) = match result {
|
||||
Ok(CallToolResult { content, is_error }) => {
|
||||
match serde_json::to_string(content) {
|
||||
Ok(content) => (content, *is_error),
|
||||
Err(e) => {
|
||||
warn!("Failed to serialize MCP tool call output: {e}");
|
||||
(e.to_string(), Some(true))
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => (e.clone(), Some(true)),
|
||||
};
|
||||
items_to_record_in_conversation_history.push(
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload { content, success },
|
||||
},
|
||||
);
|
||||
}
|
||||
(ResponseItem::Reasoning { .. }, None) => {
|
||||
// Omit from conversation history.
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||
}
|
||||
};
|
||||
if let Some(response) = response {
|
||||
responses.push(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items.is_empty() {
|
||||
// First persist model-generated output to the rollout file – this only borrows.
|
||||
sess.record_rollout_items(&items).await;
|
||||
|
||||
// For ZDR we also need to keep a transcript clone.
|
||||
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
|
||||
transcript.record_items(&items);
|
||||
}
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
sess.record_conversation_items(&items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
|
||||
if responses.is_empty() {
|
||||
debug!("Turn completed");
|
||||
last_agent_message = get_last_assistant_message_from_turn(&items);
|
||||
last_agent_message = get_last_assistant_message_from_turn(
|
||||
&items_to_record_in_conversation_history,
|
||||
);
|
||||
sess.maybe_notify(UserNotification::AgentTurnComplete {
|
||||
turn_id: sub_id.clone(),
|
||||
input_messages: turn_input_messages,
|
||||
@@ -861,7 +952,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
|
||||
break;
|
||||
}
|
||||
|
||||
pending_response_input = responses;
|
||||
input_for_next_turn = responses;
|
||||
}
|
||||
Err(e) => {
|
||||
info!("Turn error: {e:#}");
|
||||
|
||||
Reference in New Issue
Block a user