From d7245cbbc9d8ff5446da45e5951761103492476d Mon Sep 17 00:00:00 2001 From: Michael Bolin Date: Mon, 2 Jun 2025 13:47:51 -0700 Subject: [PATCH] fix: chat completions API now also passes tools along (#1167) Prior to this PR, there were two big misses in `chat_completions.rs`: 1. The loop in `stream_chat_completions()` was only including items of type `ResponseItem::Message` when building up the `"messages"` JSON for the `POST` request to the `chat/completions` endpoint. This fixes things by ensuring other variants (`FunctionCall`, `LocalShellCall`, and `FunctionCallOutput`) are included, as well. 2. In `process_chat_sse()`, we were not recording tool calls and were only emitting items of type `ResponseEvent::OutputItemDone(ResponseItem::Message)` to the stream. Now we introduce `FunctionCallState`, which is used to accumulate the `delta`s of type `tool_calls`, so we can ultimately emit a `ResponseItem::FunctionCall`, when appropriate. While function calling now appears to work for chat completions with my local testing, I believe that there are still edge cases that are not covered and that this codepath would benefit from a battery of integration tests. (As part of that further cleanup, we should also work to support streaming responses in the UI.) The other important part of this PR is some cleanup in `core/src/codex.rs`. In particular, it was hard to reason about how `run_task()` was building up the list of messages to include in a request across the various cases: - Responses API - Chat Completions API - Responses API used in concert with ZDR I like to think things are a bit cleaner now where: - `zdr_transcript` (if present) contains all messages in the history of the conversation, which includes function call outputs that have not been sent back to the model yet - `pending_input` includes any messages the user has submitted while the turn is in flight that need to be injected as part of the next `POST` to the model - `input_for_next_turn` includes the tool call outputs that have not been sent back to the model yet --- codex-rs/core/src/chat_completions.rs | 201 ++++++++++++++++++++++---- codex-rs/core/src/codex.rs | 169 +++++++++++++++++----- 2 files changed, 299 insertions(+), 71 deletions(-) diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index f55512e5..416baafc 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -28,8 +28,7 @@ use crate::models::ResponseItem; use crate::openai_tools::create_tools_json_for_chat_completions_api; use crate::util::backoff; -/// Implementation for the classic Chat Completions API. This is intentionally -/// minimal: we only stream back plain assistant text. +/// Implementation for the classic Chat Completions API. pub(crate) async fn stream_chat_completions( prompt: &Prompt, model: &str, @@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions( messages.push(json!({"role": "system", "content": full_instructions})); for item in &prompt.input { - if let ResponseItem::Message { role, content } = item { - let mut text = String::new(); - for c in content { - match c { - ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => { - text.push_str(t); + match item { + ResponseItem::Message { role, content } => { + let mut text = String::new(); + for c in content { + match c { + ContentItem::InputText { text: t } + | ContentItem::OutputText { text: t } => { + text.push_str(t); + } + _ => {} } - _ => {} } + messages.push(json!({"role": role, "content": text})); + } + ResponseItem::FunctionCall { + name, + arguments, + call_id, + } => { + messages.push(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + } + }] + })); + } + ResponseItem::LocalShellCall { + id, + call_id: _, + status, + action, + } => { + // Confirm with API team. + messages.push(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": id.clone().unwrap_or_else(|| "".to_string()), + "type": "local_shell_call", + "status": status, + "action": action, + }] + })); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": output.content, + })); + } + ResponseItem::Reasoning { .. } | ResponseItem::Other => { + // Omit these items from the conversation history. + continue; } - messages.push(json!({"role": role, "content": text})); } } @@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions( let base_url = provider.base_url.trim_end_matches('/'); let url = format!("{}/chat/completions", base_url); - debug!(url, "POST (chat)"); - trace!( - "request payload: {}", + debug!( + "POST to {url}: {}", serde_json::to_string_pretty(&payload).unwrap_or_default() ); @@ -140,6 +188,21 @@ where let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; + // State to accumulate a function call across streaming chunks. + // OpenAI may split the `arguments` string over multiple `delta` events + // until the chunk whose `finish_reason` is `tool_calls` is emitted. We + // keep collecting the pieces here and forward a single + // `ResponseItem::FunctionCall` once the call is complete. + #[derive(Default)] + struct FunctionCallState { + name: Option, + arguments: String, + call_id: Option, + active: bool, + } + + let mut fn_call_state = FunctionCallState::default(); + loop { let sse = match timeout(idle_timeout, stream.next()).await { Ok(Some(Ok(ev))) => ev, @@ -179,23 +242,89 @@ where Ok(v) => v, Err(_) => continue, }; + trace!("chat_completions received SSE chunk: {chunk:?}"); - let content_opt = chunk - .get("choices") - .and_then(|c| c.get(0)) - .and_then(|c| c.get("delta")) - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()); + let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); - if let Some(content) = content_opt { - let item = ResponseItem::Message { - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: content.to_string(), - }], - }; + if let Some(choice) = choice_opt { + // Handle assistant content tokens. + if let Some(content) = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + { + let item = ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: content.to_string(), + }], + }; - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + + // Handle streaming function / tool calls. + if let Some(tool_calls) = choice + .get("delta") + .and_then(|d| d.get("tool_calls")) + .and_then(|tc| tc.as_array()) + { + if let Some(tool_call) = tool_calls.first() { + // Mark that we have an active function call in progress. + fn_call_state.active = true; + + // Extract call_id if present. + if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) { + fn_call_state.call_id.get_or_insert_with(|| id.to_string()); + } + + // Extract function details if present. + if let Some(function) = tool_call.get("function") { + if let Some(name) = function.get("name").and_then(|n| n.as_str()) { + fn_call_state.name.get_or_insert_with(|| name.to_string()); + } + + if let Some(args_fragment) = + function.get("arguments").and_then(|a| a.as_str()) + { + fn_call_state.arguments.push_str(args_fragment); + } + } + } + } + + // Emit end-of-turn when finish_reason signals completion. + if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) { + match finish_reason { + "tool_calls" if fn_call_state.active => { + // Build the FunctionCall response item. + let item = ResponseItem::FunctionCall { + name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), + arguments: fn_call_state.arguments.clone(), + call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), + }; + + // Emit it downstream. + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + "stop" => { + // Regular turn without tool-call. + } + _ => {} + } + + // Emit Completed regardless of reason so the agent can advance. + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + })) + .await; + + // Prepare for potential next turn (should not happen in same stream). + // fn_call_state = FunctionCallState::default(); + + return; // End processing for this SSE stream. + } } } } @@ -242,9 +371,14 @@ where Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // Accumulate *assistant* text but do not emit yet. - if let crate::models::ResponseItem::Message { role, content } = &item { - if role == "assistant" { + // If this is an incremental assistant message chunk, accumulate but + // do NOT emit yet. Forward any other item (e.g. FunctionCall) right + // away so downstream consumers see it. + + let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant"); + + if is_assistant_delta { + if let crate::models::ResponseItem::Message { content, .. } = &item { if let Some(text) = content.iter().find_map(|c| match c { crate::models::ContentItem::OutputText { text } => Some(text), _ => None, @@ -252,10 +386,13 @@ where this.cumulative.push_str(text); } } + + // Swallow partial assistant chunk; keep polling. + continue; } - // Swallow partial event; keep polling. - continue; + // Not an assistant message – forward immediately. + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); } Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { if !this.cumulative.is_empty() { diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2699a9ce..01ff459f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -20,6 +20,7 @@ use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_apply_patch::print_summary; use futures::prelude::*; +use mcp_types::CallToolResult; use serde::Serialize; use serde_json; use tokio::sync::Notify; @@ -295,6 +296,17 @@ impl Session { state.approved_commands.insert(cmd); } + /// Records items to both the rollout and the chat completions/ZDR + /// transcript, if enabled. + async fn record_conversation_items(&self, items: &[ResponseItem]) { + debug!("Recording items for conversation: {items:?}"); + self.record_rollout_items(items).await; + + if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() { + transcript.record_items(items); + } + } + /// Append the given items to the session's rollout transcript (if enabled) /// and persist them to disk. async fn record_rollout_items(&self, items: &[ResponseItem]) { @@ -388,7 +400,7 @@ impl Session { tool: &str, arguments: Option, timeout: Option, - ) -> anyhow::Result { + ) -> anyhow::Result { self.mcp_connection_manager .call_tool(server, tool, arguments, timeout) .await @@ -760,6 +772,19 @@ async fn submission_loop( debug!("Agent loop exited"); } +/// Takes a user message as input and runs a loop where, at each turn, the model +/// replies with either: +/// +/// - requested function calls +/// - an assistant message +/// +/// While it is possible for the model to return multiple of these items in a +/// single turn, in practice, we generally one item per turn: +/// +/// - If the model requests a function call, we execute it and send the output +/// back to the model in the next turn. +/// - If the model sends only an assistant message, we record it in the +/// conversation history and consider the task complete. async fn run_task(sess: Arc, sub_id: String, input: Vec) { if input.is_empty() { return; @@ -772,10 +797,14 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { return; } - let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; + let initial_input_for_turn = ResponseInputItem::from(input); + sess.record_conversation_items(&[initial_input_for_turn.clone().into()]) + .await; + + let mut input_for_next_turn: Vec = vec![initial_input_for_turn]; let last_agent_message: Option; loop { - let mut net_new_turn_input = pending_response_input + let mut net_new_turn_input = input_for_next_turn .drain(..) .map(ResponseItem::from) .collect::>(); @@ -783,11 +812,12 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // Note that pending_input would be something like a message the user // submitted through the UI while the model was running. Though the UI // may support this, the model might not. - let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from); - net_new_turn_input.extend(pending_input); - - // Persist only the net-new items of this turn to the rollout. - sess.record_rollout_items(&net_new_turn_input).await; + let pending_input = sess + .get_pending_input() + .into_iter() + .map(ResponseItem::from) + .collect::>(); + sess.record_conversation_items(&pending_input).await; // Construct the input that we will send to the model. When using the // Chat completions API (or ZDR clients), the model needs the full @@ -796,20 +826,24 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // represents an append-only log without duplicates. let turn_input: Vec = if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - // If we are using Chat/ZDR, we need to send the transcript with every turn. - - // 1. Build up the conversation history for the next turn. - let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat(); - - // 2. Update the in-memory transcript so that future turns - // include these items as part of the history. - transcript.record_items(&net_new_turn_input); - - // Note that `transcript.record_items()` does some filtering - // such that `full_transcript` may include items that were - // excluded from `transcript`. - full_transcript + // If we are using Chat/ZDR, we need to send the transcript with + // every turn. By induction, `transcript` already contains: + // - The `input` that kicked off this task. + // - Each `ResponseItem` that was recorded in the previous turn. + // - Each response to a `ResponseItem` (in practice, the only + // response type we seem to have is `FunctionCallOutput`). + // + // The only thing the `transcript` does not contain is the + // `pending_input` that was injected while the model was + // running. We need to add that to the conversation history + // so that the model can see it in the next turn. + [transcript.contents(), pending_input].concat() } else { + // In practice, net_new_turn_input should contain only: + // - User messages + // - Outputs for function calls requested by the model + net_new_turn_input.extend(pending_input); + // Responses API path – we can just send the new items and // record the same. net_new_turn_input @@ -830,29 +864,86 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .collect(); match run_turn(&sess, sub_id.clone(), turn_input).await { Ok(turn_output) => { - let (items, responses): (Vec<_>, Vec<_>) = turn_output - .into_iter() - .map(|p| (p.item, p.response)) - .unzip(); - let responses = responses - .into_iter() - .flatten() - .collect::>(); + let mut items_to_record_in_conversation_history = Vec::::new(); + let mut responses = Vec::::new(); + for processed_response_item in turn_output { + let ProcessedResponseItem { item, response } = processed_response_item; + match (&item, &response) { + (ResponseItem::Message { role, .. }, None) if role == "assistant" => { + // If the model returned a message, we need to record it. + items_to_record_in_conversation_history.push(item); + } + ( + ResponseItem::LocalShellCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::McpToolCallOutput { call_id, result }), + ) => { + items_to_record_in_conversation_history.push(item); + let (content, success): (String, Option) = match result { + Ok(CallToolResult { content, is_error }) => { + match serde_json::to_string(content) { + Ok(content) => (content, *is_error), + Err(e) => { + warn!("Failed to serialize MCP tool call output: {e}"); + (e.to_string(), Some(true)) + } + } + } + Err(e) => (e.clone(), Some(true)), + }; + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: FunctionCallOutputPayload { content, success }, + }, + ); + } + (ResponseItem::Reasoning { .. }, None) => { + // Omit from conversation history. + } + _ => { + warn!("Unexpected response item: {item:?} with response: {response:?}"); + } + }; + if let Some(response) = response { + responses.push(response); + } + } // Only attempt to take the lock if there is something to record. - if !items.is_empty() { - // First persist model-generated output to the rollout file – this only borrows. - sess.record_rollout_items(&items).await; - - // For ZDR we also need to keep a transcript clone. - if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - transcript.record_items(&items); - } + if !items_to_record_in_conversation_history.is_empty() { + sess.record_conversation_items(&items_to_record_in_conversation_history) + .await; } if responses.is_empty() { debug!("Turn completed"); - last_agent_message = get_last_assistant_message_from_turn(&items); + last_agent_message = get_last_assistant_message_from_turn( + &items_to_record_in_conversation_history, + ); sess.maybe_notify(UserNotification::AgentTurnComplete { turn_id: sub_id.clone(), input_messages: turn_input_messages, @@ -861,7 +952,7 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { break; } - pending_response_input = responses; + input_for_next_turn = responses; } Err(e) => { info!("Turn error: {e:#}");