diff --git a/codex-rs/core/src/shell.rs b/codex-rs/core/src/shell.rs index 3a69874e..7ef79076 100644 --- a/codex-rs/core/src/shell.rs +++ b/codex-rs/core/src/shell.rs @@ -9,6 +9,12 @@ pub struct ZshShell { zshrc_path: String, } +#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +pub struct BashShell { + shell_path: String, + bashrc_path: String, +} + #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub struct PowerShellConfig { exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe". @@ -18,6 +24,7 @@ pub struct PowerShellConfig { #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] pub enum Shell { Zsh(ZshShell), + Bash(BashShell), PowerShell(PowerShellConfig), Unknown, } @@ -26,22 +33,10 @@ impl Shell { pub fn format_default_shell_invocation(&self, command: Vec) -> Option> { match self { Shell::Zsh(zsh) => { - if !std::path::Path::new(&zsh.zshrc_path).exists() { - return None; - } - - let mut result = vec![zsh.shell_path.clone()]; - result.push("-lc".to_string()); - - let joined = strip_bash_lc(&command) - .or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok()); - - if let Some(joined) = joined { - result.push(format!("source {} && ({joined})", zsh.zshrc_path)); - } else { - return None; - } - Some(result) + format_shell_invocation_with_rc(&command, &zsh.shell_path, &zsh.zshrc_path) + } + Shell::Bash(bash) => { + format_shell_invocation_with_rc(&command, &bash.shell_path, &bash.bashrc_path) } Shell::PowerShell(ps) => { // If model generated a bash command, prefer a detected bash fallback @@ -97,12 +92,32 @@ impl Shell { Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path) .file_name() .map(|s| s.to_string_lossy().to_string()), + Shell::Bash(bash) => std::path::Path::new(&bash.shell_path) + .file_name() + .map(|s| s.to_string_lossy().to_string()), Shell::PowerShell(ps) => Some(ps.exe.clone()), Shell::Unknown => None, } } } +fn format_shell_invocation_with_rc( + command: &Vec, + shell_path: &str, + rc_path: &str, +) -> Option> { + let joined = strip_bash_lc(command) + .or_else(|| shlex::try_join(command.iter().map(|s| s.as_str())).ok())?; + + let rc_command = if std::path::Path::new(rc_path).exists() { + format!("source {rc_path} && ({joined})") + } else { + joined + }; + + Some(vec![shell_path.to_string(), "-lc".to_string(), rc_command]) +} + fn strip_bash_lc(command: &Vec) -> Option { match command.as_slice() { // exactly three items @@ -116,44 +131,43 @@ fn strip_bash_lc(command: &Vec) -> Option { } } -#[cfg(target_os = "macos")] -pub async fn default_user_shell() -> Shell { - use tokio::process::Command; - use whoami; +#[cfg(unix)] +fn detect_default_user_shell() -> Shell { + use libc::getpwuid; + use libc::getuid; + use std::ffi::CStr; - let user = whoami::username(); - let home = format!("/Users/{user}"); - let output = Command::new("dscl") - .args([".", "-read", &home, "UserShell"]) - .output() - .await - .ok(); - match output { - Some(o) => { - if !o.status.success() { - return Shell::Unknown; - } - let stdout = String::from_utf8_lossy(&o.stdout); - for line in stdout.lines() { - if let Some(shell_path) = line.strip_prefix("UserShell: ") - && shell_path.ends_with("/zsh") - { - return Shell::Zsh(ZshShell { - shell_path: shell_path.to_string(), - zshrc_path: format!("{home}/.zshrc"), - }); - } + unsafe { + let uid = getuid(); + let pw = getpwuid(uid); + + if !pw.is_null() { + let shell_path = CStr::from_ptr((*pw).pw_shell) + .to_string_lossy() + .into_owned(); + let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned(); + + if shell_path.ends_with("/zsh") { + return Shell::Zsh(ZshShell { + shell_path, + zshrc_path: format!("{home_path}/.zshrc"), + }); } - Shell::Unknown + if shell_path.ends_with("/bash") { + return Shell::Bash(BashShell { + shell_path, + bashrc_path: format!("{home_path}/.bashrc"), + }); + } } - _ => Shell::Unknown, } + Shell::Unknown } -#[cfg(all(not(target_os = "macos"), not(target_os = "windows")))] +#[cfg(unix)] pub async fn default_user_shell() -> Shell { - Shell::Unknown + detect_default_user_shell() } #[cfg(target_os = "windows")] @@ -196,8 +210,13 @@ pub async fn default_user_shell() -> Shell { } } +#[cfg(all(not(target_os = "windows"), not(unix)))] +pub async fn default_user_shell() -> Shell { + Shell::Unknown +} + #[cfg(test)] -#[cfg(target_os = "macos")] +#[cfg(unix)] mod tests { use super::*; use std::process::Command; @@ -230,9 +249,127 @@ mod tests { zshrc_path: "/does/not/exist/.zshrc".to_string(), }); let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]); - assert_eq!(actual_cmd, None); + assert_eq!( + actual_cmd, + Some(vec![ + "/bin/zsh".to_string(), + "-lc".to_string(), + "myecho".to_string() + ]) + ); } + #[tokio::test] + async fn test_run_with_profile_bashrc_not_exists() { + let shell = Shell::Bash(BashShell { + shell_path: "/bin/bash".to_string(), + bashrc_path: "/does/not/exist/.bashrc".to_string(), + }); + let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]); + assert_eq!( + actual_cmd, + Some(vec![ + "/bin/bash".to_string(), + "-lc".to_string(), + "myecho".to_string() + ]) + ); + } + + #[tokio::test] + async fn test_run_with_profile_bash_escaping_and_execution() { + let shell_path = "/bin/bash"; + + let cases = vec![ + ( + vec!["myecho"], + vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"], + Some("It works!\n"), + ), + ( + vec!["bash", "-lc", "echo 'single' \"double\""], + vec![ + shell_path, + "-lc", + "source BASHRC_PATH && (echo 'single' \"double\")", + ], + Some("single double\n"), + ), + ]; + + for (input, expected_cmd, expected_output) in cases { + use std::collections::HashMap; + + use crate::exec::ExecParams; + use crate::exec::SandboxType; + use crate::exec::process_exec_tool_call; + use crate::protocol::SandboxPolicy; + + let temp_home = tempfile::tempdir().unwrap(); + let bashrc_path = temp_home.path().join(".bashrc"); + std::fs::write( + &bashrc_path, + r#" + set -x + function myecho { + echo 'It works!' + } + "#, + ) + .unwrap(); + let shell = Shell::Bash(BashShell { + shell_path: shell_path.to_string(), + bashrc_path: bashrc_path.to_str().unwrap().to_string(), + }); + + let actual_cmd = shell + .format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect()); + let expected_cmd = expected_cmd + .iter() + .map(|s| { + s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()) + .to_string() + }) + .collect(); + + assert_eq!(actual_cmd, Some(expected_cmd)); + + let output = process_exec_tool_call( + ExecParams { + command: actual_cmd.unwrap(), + cwd: PathBuf::from(temp_home.path()), + timeout_ms: None, + env: HashMap::from([( + "HOME".to_string(), + temp_home.path().to_str().unwrap().to_string(), + )]), + with_escalated_permissions: None, + justification: None, + }, + SandboxType::None, + &SandboxPolicy::DangerFullAccess, + &None, + None, + ) + .await + .unwrap(); + + assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}"); + if let Some(expected) = expected_output { + assert_eq!( + output.stdout.text, expected, + "input: {input:?} output: {output:?}" + ); + } + } + } +} + +#[cfg(test)] +#[cfg(target_os = "macos")] +mod macos_tests { + use super::*; + #[tokio::test] async fn test_run_with_profile_escaping_and_execution() { let shell_path = "/bin/zsh";