diff --git a/codex-rs/core/src/chat_completions.rs b/codex-rs/core/src/chat_completions.rs index f55512e5..416baafc 100644 --- a/codex-rs/core/src/chat_completions.rs +++ b/codex-rs/core/src/chat_completions.rs @@ -28,8 +28,7 @@ use crate::models::ResponseItem; use crate::openai_tools::create_tools_json_for_chat_completions_api; use crate::util::backoff; -/// Implementation for the classic Chat Completions API. This is intentionally -/// minimal: we only stream back plain assistant text. +/// Implementation for the classic Chat Completions API. pub(crate) async fn stream_chat_completions( prompt: &Prompt, model: &str, @@ -43,17 +42,67 @@ pub(crate) async fn stream_chat_completions( messages.push(json!({"role": "system", "content": full_instructions})); for item in &prompt.input { - if let ResponseItem::Message { role, content } = item { - let mut text = String::new(); - for c in content { - match c { - ContentItem::InputText { text: t } | ContentItem::OutputText { text: t } => { - text.push_str(t); + match item { + ResponseItem::Message { role, content } => { + let mut text = String::new(); + for c in content { + match c { + ContentItem::InputText { text: t } + | ContentItem::OutputText { text: t } => { + text.push_str(t); + } + _ => {} } - _ => {} } + messages.push(json!({"role": role, "content": text})); + } + ResponseItem::FunctionCall { + name, + arguments, + call_id, + } => { + messages.push(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": arguments, + } + }] + })); + } + ResponseItem::LocalShellCall { + id, + call_id: _, + status, + action, + } => { + // Confirm with API team. + messages.push(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": id.clone().unwrap_or_else(|| "".to_string()), + "type": "local_shell_call", + "status": status, + "action": action, + }] + })); + } + ResponseItem::FunctionCallOutput { call_id, output } => { + messages.push(json!({ + "role": "tool", + "tool_call_id": call_id, + "content": output.content, + })); + } + ResponseItem::Reasoning { .. } | ResponseItem::Other => { + // Omit these items from the conversation history. + continue; } - messages.push(json!({"role": role, "content": text})); } } @@ -68,9 +117,8 @@ pub(crate) async fn stream_chat_completions( let base_url = provider.base_url.trim_end_matches('/'); let url = format!("{}/chat/completions", base_url); - debug!(url, "POST (chat)"); - trace!( - "request payload: {}", + debug!( + "POST to {url}: {}", serde_json::to_string_pretty(&payload).unwrap_or_default() ); @@ -140,6 +188,21 @@ where let idle_timeout = *OPENAI_STREAM_IDLE_TIMEOUT_MS; + // State to accumulate a function call across streaming chunks. + // OpenAI may split the `arguments` string over multiple `delta` events + // until the chunk whose `finish_reason` is `tool_calls` is emitted. We + // keep collecting the pieces here and forward a single + // `ResponseItem::FunctionCall` once the call is complete. + #[derive(Default)] + struct FunctionCallState { + name: Option, + arguments: String, + call_id: Option, + active: bool, + } + + let mut fn_call_state = FunctionCallState::default(); + loop { let sse = match timeout(idle_timeout, stream.next()).await { Ok(Some(Ok(ev))) => ev, @@ -179,23 +242,89 @@ where Ok(v) => v, Err(_) => continue, }; + trace!("chat_completions received SSE chunk: {chunk:?}"); - let content_opt = chunk - .get("choices") - .and_then(|c| c.get(0)) - .and_then(|c| c.get("delta")) - .and_then(|d| d.get("content")) - .and_then(|c| c.as_str()); + let choice_opt = chunk.get("choices").and_then(|c| c.get(0)); - if let Some(content) = content_opt { - let item = ResponseItem::Message { - role: "assistant".to_string(), - content: vec![ContentItem::OutputText { - text: content.to_string(), - }], - }; + if let Some(choice) = choice_opt { + // Handle assistant content tokens. + if let Some(content) = choice + .get("delta") + .and_then(|d| d.get("content")) + .and_then(|c| c.as_str()) + { + let item = ResponseItem::Message { + role: "assistant".to_string(), + content: vec![ContentItem::OutputText { + text: content.to_string(), + }], + }; - let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + + // Handle streaming function / tool calls. + if let Some(tool_calls) = choice + .get("delta") + .and_then(|d| d.get("tool_calls")) + .and_then(|tc| tc.as_array()) + { + if let Some(tool_call) = tool_calls.first() { + // Mark that we have an active function call in progress. + fn_call_state.active = true; + + // Extract call_id if present. + if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) { + fn_call_state.call_id.get_or_insert_with(|| id.to_string()); + } + + // Extract function details if present. + if let Some(function) = tool_call.get("function") { + if let Some(name) = function.get("name").and_then(|n| n.as_str()) { + fn_call_state.name.get_or_insert_with(|| name.to_string()); + } + + if let Some(args_fragment) = + function.get("arguments").and_then(|a| a.as_str()) + { + fn_call_state.arguments.push_str(args_fragment); + } + } + } + } + + // Emit end-of-turn when finish_reason signals completion. + if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) { + match finish_reason { + "tool_calls" if fn_call_state.active => { + // Build the FunctionCall response item. + let item = ResponseItem::FunctionCall { + name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()), + arguments: fn_call_state.arguments.clone(), + call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new), + }; + + // Emit it downstream. + let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await; + } + "stop" => { + // Regular turn without tool-call. + } + _ => {} + } + + // Emit Completed regardless of reason so the agent can advance. + let _ = tx_event + .send(Ok(ResponseEvent::Completed { + response_id: String::new(), + })) + .await; + + // Prepare for potential next turn (should not happen in same stream). + // fn_call_state = FunctionCallState::default(); + + return; // End processing for this SSE stream. + } } } } @@ -242,9 +371,14 @@ where Poll::Ready(None) => return Poll::Ready(None), Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => { - // Accumulate *assistant* text but do not emit yet. - if let crate::models::ResponseItem::Message { role, content } = &item { - if role == "assistant" { + // If this is an incremental assistant message chunk, accumulate but + // do NOT emit yet. Forward any other item (e.g. FunctionCall) right + // away so downstream consumers see it. + + let is_assistant_delta = matches!(&item, crate::models::ResponseItem::Message { role, .. } if role == "assistant"); + + if is_assistant_delta { + if let crate::models::ResponseItem::Message { content, .. } = &item { if let Some(text) = content.iter().find_map(|c| match c { crate::models::ContentItem::OutputText { text } => Some(text), _ => None, @@ -252,10 +386,13 @@ where this.cumulative.push_str(text); } } + + // Swallow partial assistant chunk; keep polling. + continue; } - // Swallow partial event; keep polling. - continue; + // Not an assistant message – forward immediately. + return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))); } Poll::Ready(Some(Ok(ResponseEvent::Completed { response_id }))) => { if !this.cumulative.is_empty() { diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 2699a9ce..01ff459f 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -20,6 +20,7 @@ use codex_apply_patch::MaybeApplyPatchVerified; use codex_apply_patch::maybe_parse_apply_patch_verified; use codex_apply_patch::print_summary; use futures::prelude::*; +use mcp_types::CallToolResult; use serde::Serialize; use serde_json; use tokio::sync::Notify; @@ -295,6 +296,17 @@ impl Session { state.approved_commands.insert(cmd); } + /// Records items to both the rollout and the chat completions/ZDR + /// transcript, if enabled. + async fn record_conversation_items(&self, items: &[ResponseItem]) { + debug!("Recording items for conversation: {items:?}"); + self.record_rollout_items(items).await; + + if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() { + transcript.record_items(items); + } + } + /// Append the given items to the session's rollout transcript (if enabled) /// and persist them to disk. async fn record_rollout_items(&self, items: &[ResponseItem]) { @@ -388,7 +400,7 @@ impl Session { tool: &str, arguments: Option, timeout: Option, - ) -> anyhow::Result { + ) -> anyhow::Result { self.mcp_connection_manager .call_tool(server, tool, arguments, timeout) .await @@ -760,6 +772,19 @@ async fn submission_loop( debug!("Agent loop exited"); } +/// Takes a user message as input and runs a loop where, at each turn, the model +/// replies with either: +/// +/// - requested function calls +/// - an assistant message +/// +/// While it is possible for the model to return multiple of these items in a +/// single turn, in practice, we generally one item per turn: +/// +/// - If the model requests a function call, we execute it and send the output +/// back to the model in the next turn. +/// - If the model sends only an assistant message, we record it in the +/// conversation history and consider the task complete. async fn run_task(sess: Arc, sub_id: String, input: Vec) { if input.is_empty() { return; @@ -772,10 +797,14 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { return; } - let mut pending_response_input: Vec = vec![ResponseInputItem::from(input)]; + let initial_input_for_turn = ResponseInputItem::from(input); + sess.record_conversation_items(&[initial_input_for_turn.clone().into()]) + .await; + + let mut input_for_next_turn: Vec = vec![initial_input_for_turn]; let last_agent_message: Option; loop { - let mut net_new_turn_input = pending_response_input + let mut net_new_turn_input = input_for_next_turn .drain(..) .map(ResponseItem::from) .collect::>(); @@ -783,11 +812,12 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // Note that pending_input would be something like a message the user // submitted through the UI while the model was running. Though the UI // may support this, the model might not. - let pending_input = sess.get_pending_input().into_iter().map(ResponseItem::from); - net_new_turn_input.extend(pending_input); - - // Persist only the net-new items of this turn to the rollout. - sess.record_rollout_items(&net_new_turn_input).await; + let pending_input = sess + .get_pending_input() + .into_iter() + .map(ResponseItem::from) + .collect::>(); + sess.record_conversation_items(&pending_input).await; // Construct the input that we will send to the model. When using the // Chat completions API (or ZDR clients), the model needs the full @@ -796,20 +826,24 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { // represents an append-only log without duplicates. let turn_input: Vec = if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - // If we are using Chat/ZDR, we need to send the transcript with every turn. - - // 1. Build up the conversation history for the next turn. - let full_transcript = [transcript.contents(), net_new_turn_input.clone()].concat(); - - // 2. Update the in-memory transcript so that future turns - // include these items as part of the history. - transcript.record_items(&net_new_turn_input); - - // Note that `transcript.record_items()` does some filtering - // such that `full_transcript` may include items that were - // excluded from `transcript`. - full_transcript + // If we are using Chat/ZDR, we need to send the transcript with + // every turn. By induction, `transcript` already contains: + // - The `input` that kicked off this task. + // - Each `ResponseItem` that was recorded in the previous turn. + // - Each response to a `ResponseItem` (in practice, the only + // response type we seem to have is `FunctionCallOutput`). + // + // The only thing the `transcript` does not contain is the + // `pending_input` that was injected while the model was + // running. We need to add that to the conversation history + // so that the model can see it in the next turn. + [transcript.contents(), pending_input].concat() } else { + // In practice, net_new_turn_input should contain only: + // - User messages + // - Outputs for function calls requested by the model + net_new_turn_input.extend(pending_input); + // Responses API path – we can just send the new items and // record the same. net_new_turn_input @@ -830,29 +864,86 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { .collect(); match run_turn(&sess, sub_id.clone(), turn_input).await { Ok(turn_output) => { - let (items, responses): (Vec<_>, Vec<_>) = turn_output - .into_iter() - .map(|p| (p.item, p.response)) - .unzip(); - let responses = responses - .into_iter() - .flatten() - .collect::>(); + let mut items_to_record_in_conversation_history = Vec::::new(); + let mut responses = Vec::::new(); + for processed_response_item in turn_output { + let ProcessedResponseItem { item, response } = processed_response_item; + match (&item, &response) { + (ResponseItem::Message { role, .. }, None) if role == "assistant" => { + // If the model returned a message, we need to record it. + items_to_record_in_conversation_history.push(item); + } + ( + ResponseItem::LocalShellCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::FunctionCallOutput { call_id, output }), + ) => { + items_to_record_in_conversation_history.push(item); + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: output.clone(), + }, + ); + } + ( + ResponseItem::FunctionCall { .. }, + Some(ResponseInputItem::McpToolCallOutput { call_id, result }), + ) => { + items_to_record_in_conversation_history.push(item); + let (content, success): (String, Option) = match result { + Ok(CallToolResult { content, is_error }) => { + match serde_json::to_string(content) { + Ok(content) => (content, *is_error), + Err(e) => { + warn!("Failed to serialize MCP tool call output: {e}"); + (e.to_string(), Some(true)) + } + } + } + Err(e) => (e.clone(), Some(true)), + }; + items_to_record_in_conversation_history.push( + ResponseItem::FunctionCallOutput { + call_id: call_id.clone(), + output: FunctionCallOutputPayload { content, success }, + }, + ); + } + (ResponseItem::Reasoning { .. }, None) => { + // Omit from conversation history. + } + _ => { + warn!("Unexpected response item: {item:?} with response: {response:?}"); + } + }; + if let Some(response) = response { + responses.push(response); + } + } // Only attempt to take the lock if there is something to record. - if !items.is_empty() { - // First persist model-generated output to the rollout file – this only borrows. - sess.record_rollout_items(&items).await; - - // For ZDR we also need to keep a transcript clone. - if let Some(transcript) = sess.state.lock().unwrap().zdr_transcript.as_mut() { - transcript.record_items(&items); - } + if !items_to_record_in_conversation_history.is_empty() { + sess.record_conversation_items(&items_to_record_in_conversation_history) + .await; } if responses.is_empty() { debug!("Turn completed"); - last_agent_message = get_last_assistant_message_from_turn(&items); + last_agent_message = get_last_assistant_message_from_turn( + &items_to_record_in_conversation_history, + ); sess.maybe_notify(UserNotification::AgentTurnComplete { turn_id: sub_id.clone(), input_messages: turn_input_messages, @@ -861,7 +952,7 @@ async fn run_task(sess: Arc, sub_id: String, input: Vec) { break; } - pending_response_input = responses; + input_for_next_turn = responses; } Err(e) => { info!("Turn error: {e:#}");