Use anyhow::Result in tests for error propagation (#4105)

This commit is contained in:
pakrym-oai
2025-09-23 13:31:36 -07:00
committed by GitHub
parent c6e8671b2a
commit 0f9a796617
11 changed files with 96 additions and 68 deletions

2
codex-rs/Cargo.lock generated
View File

@@ -819,6 +819,7 @@ dependencies = [
name = "codex-login"
version = "0.0.0"
dependencies = [
"anyhow",
"base64",
"chrono",
"codex-core",
@@ -899,6 +900,7 @@ dependencies = [
name = "codex-protocol"
version = "0.0.0"
dependencies = [
"anyhow",
"base64",
"icu_decimal",
"icu_locale_core",

View File

@@ -62,9 +62,10 @@ pub(crate) enum UserNotification {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_user_notification() {
fn test_user_notification() -> Result<()> {
let notification = UserNotification::AgentTurnComplete {
turn_id: "12345".to_string(),
input_messages: vec!["Rename `foo` to `bar` and update the callsites.".to_string()],
@@ -72,10 +73,11 @@ mod tests {
"Rename complete and verified `cargo build` succeeds.".to_string(),
),
};
let serialized = serde_json::to_string(&notification).unwrap();
let serialized = serde_json::to_string(&notification)?;
assert_eq!(
serialized,
r#"{"type":"agent-turn-complete","turn-id":"12345","input-messages":["Rename `foo` to `bar` and update the callsites."],"last-assistant-message":"Rename complete and verified `cargo build` succeeds."}"#
);
Ok(())
}
}

View File

@@ -146,6 +146,8 @@ mod tests {
use super::*;
use crate::MatchedArg;
use crate::PolicyParser;
use anyhow::Result;
use anyhow::anyhow;
fn setup(fake_cp: &Path) -> ExecvChecker {
let source = format!(
@@ -164,7 +166,7 @@ system_path=[{fake_cp:?}]
#[test]
fn test_check_valid_input_files() -> Result<()> {
let temp_dir = TempDir::new().unwrap();
let temp_dir = TempDir::new()?;
// Create an executable file that can be used with the system_path arg.
let fake_cp = temp_dir.path().join("cp");
@@ -172,14 +174,14 @@ system_path=[{fake_cp:?}]
{
use std::os::unix::fs::PermissionsExt;
let fake_cp_file = std::fs::File::create(&fake_cp).unwrap();
let mut permissions = fake_cp_file.metadata().unwrap().permissions();
let fake_cp_file = std::fs::File::create(&fake_cp)?;
let mut permissions = fake_cp_file.metadata()?.permissions();
permissions.set_mode(0o755);
std::fs::set_permissions(&fake_cp, permissions).unwrap();
std::fs::set_permissions(&fake_cp, permissions)?;
}
#[cfg(windows)]
{
std::fs::File::create(&fake_cp).unwrap();
std::fs::File::create(&fake_cp)?;
}
// Create root_path and reference to files under the root.
@@ -199,7 +201,7 @@ system_path=[{fake_cp:?}]
program: "cp".into(),
args: vec![source, dest.clone()],
};
let valid_exec = match checker.r#match(&exec_call)? {
let valid_exec = match checker.r#match(&exec_call).map_err(|e| anyhow!("{e:?}"))? {
MatchedExec::Match { exec } => exec,
unexpected => panic!("Expected a safe exec but got {unexpected:?}"),
};
@@ -244,7 +246,10 @@ system_path=[{fake_cp:?}]
program: "cp".into(),
args: vec![root.clone(), root],
};
let valid_exec_call_folders_as_args = match checker.r#match(&exec_call_folders_as_args)? {
let valid_exec_call_folders_as_args = match checker
.r#match(&exec_call_folders_as_args)
.map_err(|e| anyhow!("{e:?}"))?
{
MatchedExec::Match { exec } => exec,
_ => panic!("Expected a safe exec"),
};
@@ -266,8 +271,9 @@ system_path=[{fake_cp:?}]
0,
ArgType::ReadableFile,
root_path.parent().unwrap().to_str().unwrap(),
)?,
MatchedArg::new(1, ArgType::WriteableFile, &dest)?,
)
.map_err(|e| anyhow!("{e:?}"))?,
MatchedArg::new(1, ArgType::WriteableFile, &dest).map_err(|e| anyhow!("{e:?}"))?,
],
..Default::default()
};

View File

@@ -30,5 +30,6 @@ urlencoding = { workspace = true }
webbrowser = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
core_test_support = { workspace = true }
tempfile = { workspace = true }

View File

@@ -5,6 +5,7 @@ use std::net::TcpListener;
use std::thread;
use std::time::Duration;
use anyhow::Result;
use base64::Engine;
use codex_login::ServerOptions;
use codex_login::run_login_server;
@@ -76,13 +77,13 @@ fn start_mock_issuer() -> (SocketAddr, thread::JoinHandle<()>) {
}
#[tokio::test]
async fn end_to_end_login_flow_persists_auth_json() {
non_sandbox_test!();
async fn end_to_end_login_flow_persists_auth_json() -> Result<()> {
non_sandbox_test!(result);
let (issuer_addr, issuer_handle) = start_mock_issuer();
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
let tmp = tempdir().unwrap();
let tmp = tempdir()?;
let codex_home = tmp.path().to_path_buf();
// Seed auth.json with stale API key + tokens that should be overwritten.
@@ -97,9 +98,8 @@ async fn end_to_end_login_flow_persists_auth_json() {
});
std::fs::write(
codex_home.join("auth.json"),
serde_json::to_string_pretty(&stale_auth).unwrap(),
)
.unwrap();
serde_json::to_string_pretty(&stale_auth)?,
)?;
let state = "test_state_123".to_string();
@@ -114,25 +114,24 @@ async fn end_to_end_login_flow_persists_auth_json() {
open_browser: false,
force_state: Some(state),
};
let server = run_login_server(opts).unwrap();
let server = run_login_server(opts)?;
let login_port = server.actual_port;
// Simulate browser callback, and follow redirect to /success
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::limited(5))
.build()
.unwrap();
.build()?;
let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=test_state_123");
let resp = client.get(&url).send().await.unwrap();
let resp = client.get(&url).send().await?;
assert!(resp.status().is_success());
// Wait for server shutdown
server.block_until_done().await.unwrap();
server.block_until_done().await?;
// Validate auth.json
let auth_path = codex_home.join("auth.json");
let data = std::fs::read_to_string(&auth_path).unwrap();
let json: serde_json::Value = serde_json::from_str(&data).unwrap();
let data = std::fs::read_to_string(&auth_path)?;
let json: serde_json::Value = serde_json::from_str(&data)?;
// The following assert is here because of the old oauth flow that exchanges tokens for an
// API key. See obtain_api_key in server.rs for details. Once we remove this old mechanism
// from the code, this test should be updated to expect that the API key is no longer present.
@@ -143,16 +142,17 @@ async fn end_to_end_login_flow_persists_auth_json() {
// Stop mock issuer
drop(issuer_handle);
Ok(())
}
#[tokio::test]
async fn creates_missing_codex_home_dir() {
non_sandbox_test!();
async fn creates_missing_codex_home_dir() -> Result<()> {
non_sandbox_test!(result);
let (issuer_addr, _issuer_handle) = start_mock_issuer();
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
let tmp = tempdir().unwrap();
let tmp = tempdir()?;
let codex_home = tmp.path().join("missing-subdir"); // does not exist
let state = "state2".to_string();
@@ -167,31 +167,32 @@ async fn creates_missing_codex_home_dir() {
open_browser: false,
force_state: Some(state),
};
let server = run_login_server(opts).unwrap();
let server = run_login_server(opts)?;
let login_port = server.actual_port;
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{login_port}/auth/callback?code=abc&state=state2");
let resp = client.get(&url).send().await.unwrap();
let resp = client.get(&url).send().await?;
assert!(resp.status().is_success());
server.block_until_done().await.unwrap();
server.block_until_done().await?;
let auth_path = codex_home.join("auth.json");
assert!(
auth_path.exists(),
"auth.json should be created even if parent dir was missing"
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn cancels_previous_login_server_when_port_is_in_use() {
non_sandbox_test!();
async fn cancels_previous_login_server_when_port_is_in_use() -> Result<()> {
non_sandbox_test!(result);
let (issuer_addr, _issuer_handle) = start_mock_issuer();
let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port());
let first_tmp = tempdir().unwrap();
let first_tmp = tempdir()?;
let first_codex_home = first_tmp.path().to_path_buf();
let first_opts = ServerOptions {
@@ -203,13 +204,13 @@ async fn cancels_previous_login_server_when_port_is_in_use() {
force_state: Some("cancel_state".to_string()),
};
let first_server = run_login_server(first_opts).unwrap();
let first_server = run_login_server(first_opts)?;
let login_port = first_server.actual_port;
let first_server_task = tokio::spawn(async move { first_server.block_until_done().await });
tokio::time::sleep(Duration::from_millis(100)).await;
let second_tmp = tempdir().unwrap();
let second_tmp = tempdir()?;
let second_codex_home = second_tmp.path().to_path_buf();
let second_opts = ServerOptions {
@@ -221,7 +222,7 @@ async fn cancels_previous_login_server_when_port_is_in_use() {
force_state: Some("cancel_state_2".to_string()),
};
let second_server = run_login_server(second_opts).unwrap();
let second_server = run_login_server(second_opts)?;
assert_eq!(second_server.actual_port, login_port);
let cancel_result = first_server_task
@@ -232,11 +233,12 @@ async fn cancels_previous_login_server_when_port_is_in_use() {
let client = reqwest::Client::new();
let cancel_url = format!("http://127.0.0.1:{login_port}/cancel");
let resp = client.get(cancel_url).send().await.unwrap();
let resp = client.get(cancel_url).send().await?;
assert!(resp.status().is_success());
second_server
.block_until_done()
.await
.expect_err("second login server should report cancellation");
Ok(())
}

View File

@@ -1410,13 +1410,13 @@ fn extract_conversation_summary(
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn extract_conversation_summary_prefers_plain_user_messages() {
let conversation_id =
ConversationId::from_string("3f941c35-29b3-493b-b0a4-e25800d9aeb0").unwrap();
fn extract_conversation_summary_prefers_plain_user_messages() -> Result<()> {
let conversation_id = ConversationId::from_string("3f941c35-29b3-493b-b0a4-e25800d9aeb0")?;
let timestamp = Some("2025-09-05T16:53:11.850Z".to_string());
let path = PathBuf::from("rollout.jsonl");
@@ -1456,5 +1456,6 @@ mod tests {
);
assert_eq!(summary.path, path);
assert_eq!(summary.preview, "Count to 5");
Ok(())
}
}

View File

@@ -256,6 +256,7 @@ pub(crate) struct OutgoingError {
#[cfg(test)]
mod tests {
use anyhow::Result;
use codex_core::protocol::EventMsg;
use codex_core::protocol::SessionConfiguredEvent;
use codex_protocol::config_types::ReasoningEffort;
@@ -269,12 +270,12 @@ mod tests {
use super::*;
#[tokio::test]
async fn test_send_event_as_notification() {
async fn test_send_event_as_notification() -> Result<()> {
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
let conversation_id = ConversationId::new();
let rollout_file = NamedTempFile::new().unwrap();
let rollout_file = NamedTempFile::new()?;
let event = Event {
id: "1".to_string(),
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
@@ -302,15 +303,16 @@ mod tests {
panic!("Event must serialize");
};
assert_eq!(params, Some(expected_params));
Ok(())
}
#[tokio::test]
async fn test_send_event_as_notification_with_meta() {
async fn test_send_event_as_notification_with_meta() -> Result<()> {
let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded_channel::<OutgoingMessage>();
let outgoing_message_sender = OutgoingMessageSender::new(outgoing_tx);
let conversation_id = ConversationId::new();
let rollout_file = NamedTempFile::new().unwrap();
let rollout_file = NamedTempFile::new()?;
let session_configured_event = SessionConfiguredEvent {
session_id: conversation_id,
model: "gpt-4o".to_string(),
@@ -353,6 +355,7 @@ mod tests {
}
});
assert_eq!(params.unwrap(), expected_params);
Ok(())
}
#[test]

View File

@@ -31,6 +31,7 @@ ts-rs = { workspace = true, features = [
uuid = { workspace = true, features = ["serde", "v7"] }
[dev-dependencies]
anyhow = { workspace = true }
pretty_assertions = { workspace = true }
tempfile = { workspace = true }

View File

@@ -702,11 +702,12 @@ impl ServerNotification {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn serialize_new_conversation() {
fn serialize_new_conversation() -> Result<()> {
let request = ClientRequest::NewConversation {
request_id: RequestId::Integer(42),
params: NewConversationParams {
@@ -730,8 +731,9 @@ mod tests {
"approvalPolicy": "on-request"
}
}),
serde_json::to_value(&request).unwrap(),
serde_json::to_value(&request)?,
);
Ok(())
}
#[test]
@@ -741,23 +743,25 @@ mod tests {
}
#[test]
fn conversation_id_serializes_as_plain_string() {
let id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap();
fn conversation_id_serializes_as_plain_string() -> Result<()> {
let id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?;
assert_eq!(
json!("67e55044-10b1-426f-9247-bb680e5fe0c8"),
serde_json::to_value(id).unwrap()
serde_json::to_value(id)?
);
Ok(())
}
#[test]
fn conversation_id_deserializes_from_plain_string() {
fn conversation_id_deserializes_from_plain_string() -> Result<()> {
let id: ConversationId =
serde_json::from_value(json!("67e55044-10b1-426f-9247-bb680e5fe0c8")).unwrap();
serde_json::from_value(json!("67e55044-10b1-426f-9247-bb680e5fe0c8"))?;
assert_eq!(
ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap(),
ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?,
id,
);
Ok(())
}
}

View File

@@ -318,9 +318,10 @@ impl std::ops::Deref for FunctionCallOutputPayload {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn serializes_success_as_plain_string() {
fn serializes_success_as_plain_string() -> Result<()> {
let item = ResponseInputItem::FunctionCallOutput {
call_id: "call1".into(),
output: FunctionCallOutputPayload {
@@ -329,15 +330,16 @@ mod tests {
},
};
let json = serde_json::to_string(&item).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
let json = serde_json::to_string(&item)?;
let v: serde_json::Value = serde_json::from_str(&json)?;
// Success case -> output should be a plain string
assert_eq!(v.get("output").unwrap().as_str().unwrap(), "ok");
Ok(())
}
#[test]
fn serializes_failure_as_string() {
fn serializes_failure_as_string() -> Result<()> {
let item = ResponseInputItem::FunctionCallOutput {
call_id: "call1".into(),
output: FunctionCallOutputPayload {
@@ -346,21 +348,22 @@ mod tests {
},
};
let json = serde_json::to_string(&item).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
let json = serde_json::to_string(&item)?;
let v: serde_json::Value = serde_json::from_str(&json)?;
assert_eq!(v.get("output").unwrap().as_str().unwrap(), "bad");
Ok(())
}
#[test]
fn deserialize_shell_tool_call_params() {
fn deserialize_shell_tool_call_params() -> Result<()> {
let json = r#"{
"command": ["ls", "-l"],
"workdir": "/tmp",
"timeout": 1000
}"#;
let params: ShellToolCallParams = serde_json::from_str(json).unwrap();
let params: ShellToolCallParams = serde_json::from_str(json)?;
assert_eq!(
ShellToolCallParams {
command: vec!["ls".to_string(), "-l".to_string()],
@@ -371,5 +374,6 @@ mod tests {
},
params
);
Ok(())
}
}

View File

@@ -1261,16 +1261,16 @@ pub enum TurnAbortReason {
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use serde_json::json;
use tempfile::NamedTempFile;
/// Serialize Event to verify that its JSON representation has the expected
/// amount of nesting.
#[test]
fn serialize_event() {
let conversation_id =
ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8").unwrap();
let rollout_file = NamedTempFile::new().unwrap();
fn serialize_event() -> Result<()> {
let conversation_id = ConversationId::from_string("67e55044-10b1-426f-9247-bb680e5fe0c8")?;
let rollout_file = NamedTempFile::new()?;
let event = Event {
id: "1234".to_string(),
msg: EventMsg::SessionConfigured(SessionConfiguredEvent {
@@ -1296,23 +1296,25 @@ mod tests {
"rollout_path": format!("{}", rollout_file.path().display()),
}
});
assert_eq!(expected, serde_json::to_value(&event).unwrap());
assert_eq!(expected, serde_json::to_value(&event)?);
Ok(())
}
#[test]
fn vec_u8_as_base64_serialization_and_deserialization() {
fn vec_u8_as_base64_serialization_and_deserialization() -> Result<()> {
let event = ExecCommandOutputDeltaEvent {
call_id: "call21".to_string(),
stream: ExecOutputStream::Stdout,
chunk: vec![1, 2, 3, 4, 5],
};
let serialized = serde_json::to_string(&event).unwrap();
let serialized = serde_json::to_string(&event)?;
assert_eq!(
r#"{"call_id":"call21","stream":"stdout","chunk":"AQIDBAU="}"#,
serialized,
);
let deserialized: ExecCommandOutputDeltaEvent = serde_json::from_str(&serialized).unwrap();
let deserialized: ExecCommandOutputDeltaEvent = serde_json::from_str(&serialized)?;
assert_eq!(deserialized, event);
Ok(())
}
}