fix: chat completions API now also passes tools along (#1167)

Prior to this PR, there were two big misses in `chat_completions.rs`:

1. The loop in `stream_chat_completions()` was only including items of
type `ResponseItem::Message` when building up the `"messages"` JSON for
the `POST` request to the `chat/completions` endpoint. This fixes things
by ensuring other variants (`FunctionCall`, `LocalShellCall`, and
`FunctionCallOutput`) are included, as well.
2. In `process_chat_sse()`, we were not recording tool calls and were
only emitting items of type
`ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream.
Now we introduce `FunctionCallState`, which is used to accumulate the
`delta`s of type `tool_calls`, so we can ultimately emit a
`ResponseItem::FunctionCall`, when appropriate.

While function calling now appears to work for chat completions with my
local testing, I believe that there are still edge cases that are not
covered and that this codepath would benefit from a battery of
integration tests. (As part of that further cleanup, we should also work
to support streaming responses in the UI.)

The other important part of this PR is some cleanup in
`core/src/codex.rs`. In particular, it was hard to reason about how
`run_task()` was building up the list of messages to include in a
request across the various cases:

- Responses API
- Chat Completions API
- Responses API used in concert with ZDR

I like to think things are a bit cleaner now where:

- `zdr_transcript` (if present) contains all messages in the history of
the conversation, which includes function call outputs that have not
been sent back to the model yet
- `pending_input` includes any messages the user has submitted while the
turn is in flight that need to be injected as part of the next `POST` to
the model
- `input_for_next_turn` includes the tool call outputs that have not
been sent back to the model yet
This commit is contained in:
Michael Bolin
2025-06-02 13:47:51 -07:00
committed by GitHub
parent e40f86b446
commit d7245cbbc9
2 changed files with 299 additions and 71 deletions

View File

@@ -28,8 +28,7 @@ use crate::models::ResponseItem;
use crate::openai_tools::create_tools_json_for_chat_completions_api; use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff; use crate::util::backoff;
/// Implementation for the classic Chat Completions API. This is intentionally /// Implementation for the classic Chat Completions API.
/// minimal: we only stream back plain assistant text.
pub(crate) async fn stream_chat_completions( pub(crate) async fn stream_chat_completions(
prompt: &Prompt, prompt: &Prompt,
model: &str, model: &str,
@@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions(
messages.push(json!({"role": "system", "content": full_instructions})); messages.push(json!({"role": "system", "content": full_instructions}));
for item in &prompt.input { for item in &prompt.input {
if let ResponseItem::Message { role, content } = item { match item {
let mut text = String::new(); ResponseItem::Message { role, content } => {
for c in content { let mut text = String::new();
match c { for c in content {
ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => { match c {
text.push_str(t); ContentItem::InputText { text: t }
| ContentItem::OutputText { text: t } => {
text.push_str(t);
}
_ => {}
} }
_ => {}
} }
messages.push(json!({"role": role, "content": text}));
}
ResponseItem::FunctionCall {
name,
arguments,
call_id,
} => {
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": call_id,
"type": "function",
"function": {
"name": name,
"arguments": arguments,
}
}]
}));
}
ResponseItem::LocalShellCall {
id,
call_id: _,
status,
action,
} => {
// Confirm with API team.
messages.push(json!({
"role": "assistant",
"content": null,
"tool_calls": [{
"id": id.clone().unwrap_or_else(|| "".to_string()),
"type": "local_shell_call",
"status": status,
"action": action,
}]
}));
}
ResponseItem::FunctionCallOutput { call_id, output } => {
messages.push(json!({
"role": "tool",
"tool_call_id": call_id,
"content": output.content,
}));
}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// Omit these items from the conversation history.
continue;
} }
messages.push(json!({"role": role, "content": text}));
} }
} }
@@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions(
let base_url = provider.base_url.trim_end_matches('/'); let base_url = provider.base_url.trim_end_matches('/');
let url = format!("{}/chat/completions", base_url); let url = format!("{}/chat/completions", base_url);
debug!(url, "POST (chat)"); debug!(
trace!( "POST to {url}: {}",
"request payload: {}",
serde_json::to_string_pretty(&payload).unwrap_or_default() serde_json::to_string_pretty(&payload).unwrap_or_default()
); );
@@ -140,6 +188,21 @@ where
let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS;
// State to accumulate a function call across streaming chunks.
// OpenAI may split the `arguments` string over multiple `delta` events
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
// keep collecting the pieces here and forward a single
// `ResponseItem::FunctionCall` once the call is complete.
#[derive(Default)]
struct FunctionCallState {
name: Option<String>,
arguments: String,
call_id: Option<String>,
active: bool,
}
let mut fn_call_state = FunctionCallState::default();
loop { loop {
let sse = match timeout(idle_timeout, stream.next()).await { let sse = match timeout(idle_timeout, stream.next()).await {
Ok(Some(Ok(ev))) => ev, Ok(Some(Ok(ev))) => ev,
@@ -179,23 +242,89 @@ where
Ok(v) => v, Ok(v) => v,
Err(_) => continue, Err(_) => continue,
}; };
trace!("chat_completions received SSE chunk: {chunk:?}");
let content_opt = chunk let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
.get("choices")
.and_then(|c| c.get(0))
.and_then(|c| c.get("delta"))
.and_then(|d| d.get("content"))
.and_then(|c| c.as_str());
if let Some(content) = content_opt { if let Some(choice) = choice_opt {
let item = ResponseItem::Message { // Handle assistant content tokens.
role: "assistant".to_string(), if let Some(content) = choice
content: vec![ContentItem::OutputText { .get("delta")
text: content.to_string(), .and_then(|d| d.get("content"))
}], .and_then(|c| c.as_str())
}; {
let item = ResponseItem::Message {
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: content.to_string(),
}],
};
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
// Handle streaming function / tool calls.
if let Some(tool_calls) = choice
.get("delta")
.and_then(|d| d.get("tool_calls"))
.and_then(|tc| tc.as_array())
{
if let Some(tool_call) = tool_calls.first() {
// Mark that we have an active function call in progress.
fn_call_state.active = true;
// Extract call_id if present.
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
}
// Extract function details if present.
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
fn_call_state.name.get_or_insert_with(|| name.to_string());
}
if let Some(args_fragment) =
function.get("arguments").and_then(|a| a.as_str())
{
fn_call_state.arguments.push_str(args_fragment);
}
}
}
}
// Emit end-of-turn when finish_reason signals completion.
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
match finish_reason {
"tool_calls" if fn_call_state.active => {
// Build the FunctionCall response item.
let item = ResponseItem::FunctionCall {
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
arguments: fn_call_state.arguments.clone(),
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
};
// Emit it downstream.
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
}
"stop" => {
// Regular turn without tool-call.
}
_ => {}
}
// Emit Completed regardless of reason so the agent can advance.
let _ = tx_event
.send(Ok(ResponseEvent::Completed {
response_id: String::new(),
}))
.await;
// Prepare for potential next turn (should not happen in same stream).
// fn_call_state = FunctionCallState::default();
return; // End processing for this SSE stream.
}
} }
} }
} }
@@ -242,9 +371,14 @@ where
Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
// Accumulate *assistant* text but do not emit yet. // If this is an incremental assistant message chunk, accumulate but
if let crate::models::ResponseItem::Message { role, content } = &item { // do NOT emit yet. Forward any other item (e.g. FunctionCall) right
if role == "assistant" { // away so downstream consumers see it.
let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant");
if is_assistant_delta {
if let crate::models::ResponseItem::Message { content, .. } = &item {
if let Some(text) = content.iter().find_map(|c| match c { if let Some(text) = content.iter().find_map(|c| match c {
crate::models::ContentItem::OutputText { text } => Some(text), crate::models::ContentItem::OutputText { text } => Some(text),
_ => None, _ => None,
@@ -252,10 +386,13 @@ where
this.cumulative.push_str(text); this.cumulative.push_str(text);
} }
} }
// Swallow partial assistant chunk; keep polling.
continue;
} }
// Swallow partial event; keep polling. // Not an assistant message forward immediately.
continue; return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
} }
Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => {
if !this.cumulative.is_empty() { if !this.cumulative.is_empty() {

View File

@@ -20,6 +20,7 @@ use codex_apply_patch::MaybeApplyPatchVerified;
use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_apply_patch::maybe_parse_apply_patch_verified;
use codex_apply_patch::print_summary; use codex_apply_patch::print_summary;
use futures::prelude::*; use futures::prelude::*;
use mcp_types::CallToolResult;
use serde::Serialize; use serde::Serialize;
use serde_json; use serde_json;
use tokio::sync::Notify; use tokio::sync::Notify;
@@ -295,6 +296,17 @@ impl Session {
state.approved_commands.insert(cmd); state.approved_commands.insert(cmd);
} }
/// Records items to both the rollout and the chat completions/ZDR
/// transcript, if enabled.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
debug!("Recording items for conversation: {items:?}");
self.record_rollout_items(items).await;
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
transcript.record_items(items);
}
}
/// Append the given items to the session's rollout transcript (if enabled) /// Append the given items to the session's rollout transcript (if enabled)
/// and persist them to disk. /// and persist them to disk.
async fn record_rollout_items(&self, items: &[ResponseItem]) { async fn record_rollout_items(&self, items: &[ResponseItem]) {
@@ -388,7 +400,7 @@ impl Session {
tool: &str, tool: &str,
arguments: Option<serde_json::Value>, arguments: Option<serde_json::Value>,
timeout: Option<Duration>, timeout: Option<Duration>,
) -> anyhow::Result<mcp_types::CallToolResult> { ) -> anyhow::Result<CallToolResult> {
self.mcp_connection_manager self.mcp_connection_manager
.call_tool(server, tool, arguments, timeout) .call_tool(server, tool, arguments, timeout)
.await .await
@@ -760,6 +772,19 @@ async fn submission_loop(
debug!("Agent loop exited"); debug!("Agent loop exited");
} }
/// Takes a user message as input and runs a loop where, at each turn, the model
/// replies with either:
///
/// - requested function calls
/// - an assistant message
///
/// While it is possible for the model to return multiple of these items in a
/// single turn, in practice, we generally one item per turn:
///
/// - If the model requests a function call, we execute it and send the output
/// back to the model in the next turn.
/// - If the model sends only an assistant message, we record it in the
/// conversation history and consider the task complete.
async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) { async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
if input.is_empty() { if input.is_empty() {
return; return;
@@ -772,10 +797,14 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
return; return;
} }
let mut pending_response_input: Vec<ResponseInputItem> = vec![ResponseInputItem::from(input)]; let initial_input_for_turn = ResponseInputItem::from(input);
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
.await;
let mut input_for_next_turn: Vec<ResponseInputItem> = vec![initial_input_for_turn];
let last_agent_message: Option<String>; let last_agent_message: Option<String>;
loop { loop {
let mut net_new_turn_input = pending_response_input let mut net_new_turn_input = input_for_next_turn
.drain(..) .drain(..)
.map(ResponseItem::from) .map(ResponseItem::from)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@@ -783,11 +812,12 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
// Note that pending_input would be something like a message the user // Note that pending_input would be something like a message the user
// submitted through the UI while the model was running. Though the UI // submitted through the UI while the model was running. Though the UI
// may support this, the model might not. // may support this, the model might not.
let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from); let pending_input = sess
net_new_turn_input.extend(pending_input); .get_pending_input()
.into_iter()
// Persist only the net-new items of this turn to the rollout. .map(ResponseItem::from)
sess.record_rollout_items(&net_new_turn_input).await; .collect::<Vec<ResponseItem>>();
sess.record_conversation_items(&pending_input).await;
// Construct the input that we will send to the model. When using the // Construct the input that we will send to the model. When using the
// Chat completions API (or ZDR clients), the model needs the full // Chat completions API (or ZDR clients), the model needs the full
@@ -796,20 +826,24 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
// represents an append-only log without duplicates. // represents an append-only log without duplicates.
let turn_input: Vec<ResponseItem> = let turn_input: Vec<ResponseItem> =
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
// If we are using Chat/ZDR, we need to send the transcript with every turn. // If we are using Chat/ZDR, we need to send the transcript with
// every turn. By induction, `transcript` already contains:
// 1. Build up the conversation history for the next turn. // - The `input` that kicked off this task.
let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat(); // - Each `ResponseItem` that was recorded in the previous turn.
// - Each response to a `ResponseItem` (in practice, the only
// 2. Update the in-memory transcript so that future turns // response type we seem to have is `FunctionCallOutput`).
// include these items as part of the history. //
transcript.record_items(&net_new_turn_input); // The only thing the `transcript` does not contain is the
// `pending_input` that was injected while the model was
// Note that `transcript.record_items()` does some filtering // running. We need to add that to the conversation history
// such that `full_transcript` may include items that were // so that the model can see it in the next turn.
// excluded from `transcript`. [transcript.contents(), pending_input].concat()
full_transcript
} else { } else {
// In practice, net_new_turn_input should contain only:
// - User messages
// - Outputs for function calls requested by the model
net_new_turn_input.extend(pending_input);
// Responses API path we can just send the new items and // Responses API path we can just send the new items and
// record the same. // record the same.
net_new_turn_input net_new_turn_input
@@ -830,29 +864,86 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
.collect(); .collect();
match run_turn(&sess, sub_id.clone(), turn_input).await { match run_turn(&sess, sub_id.clone(), turn_input).await {
Ok(turn_output) => { Ok(turn_output) => {
let (items, responses): (Vec<_>, Vec<_>) = turn_output let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
.into_iter() let mut responses = Vec::<ResponseInputItem>::new();
.map(|p| (p.item, p.response)) for processed_response_item in turn_output {
.unzip(); let ProcessedResponseItem { item, response } = processed_response_item;
let responses = responses match (&item, &response) {
.into_iter() (ResponseItem::Message { role, .. }, None) if role == "assistant" => {
.flatten() // If the model returned a message, we need to record it.
.collect::<Vec<ResponseInputItem>>(); items_to_record_in_conversation_history.push(item);
}
(
ResponseItem::LocalShellCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
},
);
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
},
);
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
) => {
items_to_record_in_conversation_history.push(item);
let (content, success): (String, Option<bool>) = match result {
Ok(CallToolResult { content, is_error }) => {
match serde_json::to_string(content) {
Ok(content) => (content, *is_error),
Err(e) => {
warn!("Failed to serialize MCP tool call output: {e}");
(e.to_string(), Some(true))
}
}
}
Err(e) => (e.clone(), Some(true)),
};
items_to_record_in_conversation_history.push(
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload { content, success },
},
);
}
(ResponseItem::Reasoning { .. }, None) => {
// Omit from conversation history.
}
_ => {
warn!("Unexpected response item: {item:?} with response: {response:?}");
}
};
if let Some(response) = response {
responses.push(response);
}
}
// Only attempt to take the lock if there is something to record. // Only attempt to take the lock if there is something to record.
if !items.is_empty() { if !items_to_record_in_conversation_history.is_empty() {
// First persist model-generated output to the rollout file this only borrows. sess.record_conversation_items(&items_to_record_in_conversation_history)
sess.record_rollout_items(&items).await; .await;
// For ZDR we also need to keep a transcript clone.
if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() {
transcript.record_items(&items);
}
} }
if responses.is_empty() { if responses.is_empty() {
debug!("Turn completed"); debug!("Turn completed");
last_agent_message = get_last_assistant_message_from_turn(&items); last_agent_message = get_last_assistant_message_from_turn(
&items_to_record_in_conversation_history,
);
sess.maybe_notify(UserNotification::AgentTurnComplete { sess.maybe_notify(UserNotification::AgentTurnComplete {
turn_id: sub_id.clone(), turn_id: sub_id.clone(),
input_messages: turn_input_messages, input_messages: turn_input_messages,
@@ -861,7 +952,7 @@ async fn run_task(sess: Arc<Session>, sub_id: String, input: Vec<InputItem>) {
break; break;
} }
pending_response_input = responses; input_for_next_turn = responses;
} }
Err(e) => { Err(e) => {
info!("Turn error: {e:#}"); info!("Turn error: {e:#}");