feat: add header for task kind (#5142)

Add a header in the responses API request for the task kind (compact,
review, ...) for observability purpose
The header name is `codex-task-type`
This commit is contained in:
jif-oai
2025-10-14 16:17:00 +01:00
committed by GitHub
parent 5346cc422d
commit 268a10f917
7 changed files with 161 additions and 7 deletions

View File

@@ -47,6 +47,7 @@ use crate::openai_tools::create_tools_json_for_responses_api;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::RateLimitWindow;
use crate::protocol::TokenUsage;
use crate::state::TaskKind;
use crate::token_data::PlanType;
use crate::util::backoff;
use codex_otel::otel_event_manager::OtelEventManager;
@@ -123,8 +124,16 @@ impl ModelClient {
/// the provider config. Public callers always invoke `stream()` the
/// specialised helpers are private to avoid accidental misuse.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
self.stream_with_task_kind(prompt, TaskKind::Regular).await
}
pub(crate) async fn stream_with_task_kind(
&self,
prompt: &Prompt,
task_kind: TaskKind,
) -> Result<ResponseStream> {
match self.provider.wire_api {
WireApi::Responses => self.stream_responses(prompt).await,
WireApi::Responses => self.stream_responses(prompt, task_kind).await,
WireApi::Chat => {
// Create the raw streaming connection first.
let response_stream = stream_chat_completions(
@@ -165,7 +174,11 @@ impl ModelClient {
}
/// Implementation for the OpenAI *Responses* experimental API.
async fn stream_responses(&self, prompt: &Prompt) -> Result<ResponseStream> {
async fn stream_responses(
&self,
prompt: &Prompt,
task_kind: TaskKind,
) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
// short circuit for tests
warn!(path, "Streaming from fixture");
@@ -244,7 +257,7 @@ impl ModelClient {
let max_attempts = self.provider.request_max_retries();
for attempt in 0..=max_attempts {
match self
.attempt_stream_responses(attempt, &payload_json, &auth_manager)
.attempt_stream_responses(attempt, &payload_json, &auth_manager, task_kind)
.await
{
Ok(stream) => {
@@ -272,6 +285,7 @@ impl ModelClient {
attempt: u64,
payload_json: &Value,
auth_manager: &Option<Arc<AuthManager>>,
task_kind: TaskKind,
) -> std::result::Result<ResponseStream, StreamAttemptError> {
// Always fetch the latest auth in case a prior attempt refreshed the token.
let auth = auth_manager.as_ref().and_then(|m| m.auth());
@@ -294,6 +308,7 @@ impl ModelClient {
.header("conversation_id", self.conversation_id.to_string())
.header("session_id", self.conversation_id.to_string())
.header(reqwest::header::ACCEPT, "text/event-stream")
.header("Codex-Task-Type", task_kind.header_value())
.json(payload_json);
if let Some(auth) = auth.as_ref()

View File

@@ -99,6 +99,7 @@ use crate::rollout::RolloutRecorderParams;
use crate::shell;
use crate::state::ActiveTurn;
use crate::state::SessionServices;
use crate::state::TaskKind;
use crate::tasks::CompactTask;
use crate::tasks::RegularTask;
use crate::tasks::ReviewTask;
@@ -1634,6 +1635,7 @@ pub(crate) async fn run_task(
turn_context: Arc<TurnContext>,
sub_id: String,
input: Vec<InputItem>,
task_kind: TaskKind,
) -> Option<String> {
if input.is_empty() {
return None;
@@ -1717,6 +1719,7 @@ pub(crate) async fn run_task(
Arc::clone(&turn_diff_tracker),
sub_id.clone(),
turn_input,
task_kind,
)
.await
{
@@ -1942,6 +1945,7 @@ async fn run_turn(
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: String,
input: Vec<ResponseItem>,
task_kind: TaskKind,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
let router = Arc::new(ToolRouter::from_config(
@@ -1971,6 +1975,7 @@ async fn run_turn(
Arc::clone(&turn_diff_tracker),
&sub_id,
&prompt,
task_kind,
)
.await
{
@@ -2044,6 +2049,7 @@ async fn try_run_turn(
turn_diff_tracker: SharedTurnDiffTracker,
sub_id: &str,
prompt: &Prompt,
task_kind: TaskKind,
) -> CodexResult<TurnRunResult> {
// call_ids that are part of this response.
let completed_call_ids = prompt
@@ -2109,7 +2115,11 @@ async fn try_run_turn(
summary: turn_context.client.get_reasoning_summary(),
});
sess.persist_rollout_items(&[rollout_item]).await;
let mut stream = turn_context.client.clone().stream(&prompt).await?;
let mut stream = turn_context
.client
.clone()
.stream_with_task_kind(prompt.as_ref(), task_kind)
.await?;
let tool_runtime = ToolCallRuntime::new(
Arc::clone(&router),

View File

@@ -16,6 +16,7 @@ use crate::protocol::InputItem;
use crate::protocol::InputMessageKind;
use crate::protocol::TaskStartedEvent;
use crate::protocol::TurnContextItem;
use crate::state::TaskKind;
use crate::truncate::truncate_middle;
use crate::util::backoff;
use askama::Template;
@@ -258,7 +259,11 @@ async fn drain_to_completed(
sub_id: &str,
prompt: &Prompt,
) -> CodexResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
let mut stream = turn_context
.client
.clone()
.stream_with_task_kind(prompt, TaskKind::Compact)
.await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {

View File

@@ -34,6 +34,16 @@ pub(crate) enum TaskKind {
Compact,
}
impl TaskKind {
pub(crate) fn header_value(self) -> &'static str {
match self {
TaskKind::Regular => "standard",
TaskKind::Review => "review",
TaskKind::Compact => "compact",
}
}
}
#[derive(Clone)]
pub(crate) struct RunningTask {
pub(crate) handle: AbortHandle,
@@ -113,3 +123,15 @@ impl ActiveTurn {
}
}
}
#[cfg(test)]
mod tests {
use super::TaskKind;
#[test]
fn header_value_matches_expected_labels() {
assert_eq!(TaskKind::Regular.header_value(), "standard");
assert_eq!(TaskKind::Review.header_value(), "review");
assert_eq!(TaskKind::Compact.header_value(), "compact");
}
}

View File

@@ -27,6 +27,6 @@ impl SessionTask for RegularTask {
input: Vec<InputItem>,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input).await
run_task(sess, ctx, sub_id, input, TaskKind::Regular).await
}
}

View File

@@ -28,7 +28,7 @@ impl SessionTask for ReviewTask {
input: Vec<InputItem>,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, sub_id, input).await
run_task(sess, ctx, sub_id, input, TaskKind::Review).await
}
async fn abort(&self, session: Arc<SessionTaskContext>, sub_id: &str) {

View File

@@ -0,0 +1,102 @@
use std::sync::Arc;
use codex_app_server_protocol::AuthMode;
use codex_core::ContentItem;
use codex_core::ModelClient;
use codex_core::ModelProviderInfo;
use codex_core::Prompt;
use codex_core::ResponseEvent;
use codex_core::ResponseItem;
use codex_core::WireApi;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use core_test_support::load_default_config_for_test;
use core_test_support::responses;
use futures::StreamExt;
use tempfile::TempDir;
use wiremock::matchers::header;
#[tokio::test]
async fn responses_stream_includes_task_type_header() {
core_test_support::skip_if_no_network!();
let server = responses::start_mock_server().await;
let response_body = responses::sse(vec![
responses::ev_response_created("resp-1"),
responses::ev_completed("resp-1"),
]);
let request_recorder = responses::mount_sse_once_match(
&server,
header("Codex-Task-Type", "standard"),
response_body,
)
.await;
let provider = ModelProviderInfo {
name: "mock".into(),
base_url: Some(format!("{}/v1", server.uri())),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: Some(0),
stream_max_retries: Some(0),
stream_idle_timeout_ms: Some(5_000),
requires_openai_auth: false,
};
let codex_home = TempDir::new().expect("failed to create TempDir");
let mut config = load_default_config_for_test(&codex_home);
config.model_provider_id = provider.name.clone();
config.model_provider = provider.clone();
let effort = config.model_reasoning_effort;
let summary = config.model_reasoning_summary;
let config = Arc::new(config);
let conversation_id = ConversationId::new();
let otel_event_manager = OtelEventManager::new(
conversation_id,
config.model.as_str(),
config.model_family.slug.as_str(),
None,
Some(AuthMode::ChatGPT),
false,
"test".to_string(),
);
let client = ModelClient::new(
Arc::clone(&config),
None,
otel_event_manager,
provider,
effort,
summary,
conversation_id,
);
let mut prompt = Prompt::default();
prompt.input = vec![ResponseItem::Message {
id: None,
role: "user".into(),
content: vec![ContentItem::InputText {
text: "hello".into(),
}],
}];
let mut stream = client.stream(&prompt).await.expect("stream failed");
while let Some(event) = stream.next().await {
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
break;
}
}
let request = request_recorder.single_request();
assert_eq!(
request.header("Codex-Task-Type").as_deref(),
Some("standard")
);
}