Optionally run using user profile (#1678)
This commit is contained in:
19
codex-rs/Cargo.lock
generated
19
codex-rs/Cargo.lock
generated
@@ -683,6 +683,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sha1",
|
"sha1",
|
||||||
|
"shlex",
|
||||||
"strum_macros 0.27.2",
|
"strum_macros 0.27.2",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
@@ -696,6 +697,7 @@ dependencies = [
|
|||||||
"tree-sitter-bash",
|
"tree-sitter-bash",
|
||||||
"uuid",
|
"uuid",
|
||||||
"walkdir",
|
"walkdir",
|
||||||
|
"whoami",
|
||||||
"wildmatch",
|
"wildmatch",
|
||||||
"wiremock",
|
"wiremock",
|
||||||
]
|
]
|
||||||
@@ -5128,6 +5130,12 @@ dependencies = [
|
|||||||
"wit-bindgen-rt",
|
"wit-bindgen-rt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasite"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasm-bindgen"
|
name = "wasm-bindgen"
|
||||||
version = "0.2.100"
|
version = "0.2.100"
|
||||||
@@ -5228,6 +5236,17 @@ version = "0.1.10"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
checksum = "a751b3277700db47d3e574514de2eced5e54dc8a5436a3bf7a0b248b2cee16f3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "whoami"
|
||||||
|
version = "1.6.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6994d13118ab492c3c80c1f81928718159254c53c472bf9ce36f8dae4add02a7"
|
||||||
|
dependencies = [
|
||||||
|
"redox_syscall",
|
||||||
|
"wasite",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wildmatch"
|
name = "wildmatch"
|
||||||
version = "2.4.0"
|
version = "2.4.0"
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ reqwest = { version = "0.12", features = ["json", "stream"] }
|
|||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
sha1 = "0.10.6"
|
sha1 = "0.10.6"
|
||||||
|
shlex = "1.3.0"
|
||||||
strum_macros = "0.27.2"
|
strum_macros = "0.27.2"
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
time = { version = "0.3", features = ["formatting", "local-offset", "macros"] }
|
||||||
@@ -47,6 +48,8 @@ tree-sitter = "0.25.8"
|
|||||||
tree-sitter-bash = "0.25.0"
|
tree-sitter-bash = "0.25.0"
|
||||||
uuid = { version = "1", features = ["serde", "v4"] }
|
uuid = { version = "1", features = ["serde", "v4"] }
|
||||||
wildmatch = "2.4.0"
|
wildmatch = "2.4.0"
|
||||||
|
whoami = "1.6.0"
|
||||||
|
|
||||||
|
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
landlock = "0.4.1"
|
landlock = "0.4.1"
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ use crate::rollout::RolloutRecorder;
|
|||||||
use crate::safety::SafetyCheck;
|
use crate::safety::SafetyCheck;
|
||||||
use crate::safety::assess_command_safety;
|
use crate::safety::assess_command_safety;
|
||||||
use crate::safety::assess_patch_safety;
|
use crate::safety::assess_patch_safety;
|
||||||
|
use crate::shell;
|
||||||
use crate::user_notification::UserNotification;
|
use crate::user_notification::UserNotification;
|
||||||
use crate::util::backoff;
|
use crate::util::backoff;
|
||||||
|
|
||||||
@@ -204,6 +205,7 @@ pub(crate) struct Session {
|
|||||||
rollout: Mutex<Option<RolloutRecorder>>,
|
rollout: Mutex<Option<RolloutRecorder>>,
|
||||||
state: Mutex<State>,
|
state: Mutex<State>,
|
||||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||||
|
user_shell: shell::Shell,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Session {
|
impl Session {
|
||||||
@@ -676,6 +678,7 @@ async fn submission_loop(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
let default_shell = shell::default_user_shell().await;
|
||||||
sess = Some(Arc::new(Session {
|
sess = Some(Arc::new(Session {
|
||||||
client,
|
client,
|
||||||
tx_event: tx_event.clone(),
|
tx_event: tx_event.clone(),
|
||||||
@@ -693,6 +696,7 @@ async fn submission_loop(
|
|||||||
rollout: Mutex::new(rollout_recorder),
|
rollout: Mutex::new(rollout_recorder),
|
||||||
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
|
||||||
disable_response_storage,
|
disable_response_storage,
|
||||||
|
user_shell: default_shell,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Patch restored state into the newly created session.
|
// Patch restored state into the newly created session.
|
||||||
@@ -1383,6 +1387,18 @@ fn parse_container_exec_arguments(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn maybe_run_with_user_profile(params: ExecParams, sess: &Session) -> ExecParams {
|
||||||
|
if sess.shell_environment_policy.use_profile {
|
||||||
|
let command = sess
|
||||||
|
.user_shell
|
||||||
|
.format_default_shell_invocation(params.command.clone());
|
||||||
|
if let Some(command) = command {
|
||||||
|
return ExecParams { command, ..params };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
params
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_container_exec_with_params(
|
async fn handle_container_exec_with_params(
|
||||||
params: ExecParams,
|
params: ExecParams,
|
||||||
sess: &Session,
|
sess: &Session,
|
||||||
@@ -1469,6 +1485,7 @@ async fn handle_container_exec_with_params(
|
|||||||
sess.notify_exec_command_begin(&sub_id, &call_id, ¶ms)
|
sess.notify_exec_command_begin(&sub_id, &call_id, ¶ms)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
|
let params = maybe_run_with_user_profile(params, sess);
|
||||||
let output_result = process_exec_tool_call(
|
let output_result = process_exec_tool_call(
|
||||||
params.clone(),
|
params.clone(),
|
||||||
sandbox_type,
|
sandbox_type,
|
||||||
|
|||||||
@@ -130,6 +130,8 @@ pub struct ShellEnvironmentPolicyToml {
|
|||||||
|
|
||||||
/// List of regular expressions.
|
/// List of regular expressions.
|
||||||
pub include_only: Option<Vec<String>>,
|
pub include_only: Option<Vec<String>>,
|
||||||
|
|
||||||
|
pub experimental_use_profile: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
|
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
|
||||||
@@ -158,6 +160,9 @@ pub struct ShellEnvironmentPolicy {
|
|||||||
|
|
||||||
/// Environment variable names to retain in the environment.
|
/// Environment variable names to retain in the environment.
|
||||||
pub include_only: Vec<EnvironmentVariablePattern>,
|
pub include_only: Vec<EnvironmentVariablePattern>,
|
||||||
|
|
||||||
|
/// If true, the shell profile will be used to run the command.
|
||||||
|
pub use_profile: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||||
@@ -177,6 +182,7 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
|||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
|
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
|
||||||
.collect();
|
.collect();
|
||||||
|
let use_profile = toml.experimental_use_profile.unwrap_or(false);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
inherit,
|
inherit,
|
||||||
@@ -184,6 +190,7 @@ impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
|||||||
exclude,
|
exclude,
|
||||||
r#set,
|
r#set,
|
||||||
include_only,
|
include_only,
|
||||||
|
use_profile,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ use tokio::io::BufReader;
|
|||||||
use tokio::process::Child;
|
use tokio::process::Child;
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::sync::Notify;
|
use tokio::sync::Notify;
|
||||||
|
use tracing::trace;
|
||||||
|
|
||||||
use crate::error::CodexErr;
|
use crate::error::CodexErr;
|
||||||
use crate::error::Result;
|
use crate::error::Result;
|
||||||
@@ -82,7 +83,8 @@ pub async fn process_exec_tool_call(
|
|||||||
) -> Result<ExecToolCallOutput> {
|
) -> Result<ExecToolCallOutput> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let raw_output_result = match sandbox_type {
|
let raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr> = match sandbox_type
|
||||||
|
{
|
||||||
SandboxType::None => exec(params, sandbox_policy, ctrl_c).await,
|
SandboxType::None => exec(params, sandbox_policy, ctrl_c).await,
|
||||||
SandboxType::MacosSeatbelt => {
|
SandboxType::MacosSeatbelt => {
|
||||||
let ExecParams {
|
let ExecParams {
|
||||||
@@ -372,6 +374,10 @@ async fn spawn_child_async(
|
|||||||
stdio_policy: StdioPolicy,
|
stdio_policy: StdioPolicy,
|
||||||
env: HashMap<String, String>,
|
env: HashMap<String, String>,
|
||||||
) -> std::io::Result<Child> {
|
) -> std::io::Result<Child> {
|
||||||
|
trace!(
|
||||||
|
"spawn_child_async: {program:?} {args:?} {arg0:?} {cwd:?} {sandbox_policy:?} {stdio_policy:?} {env:?}"
|
||||||
|
);
|
||||||
|
|
||||||
let mut cmd = Command::new(&program);
|
let mut cmd = Command::new(&program);
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from));
|
cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from));
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ mod project_doc;
|
|||||||
pub mod protocol;
|
pub mod protocol;
|
||||||
mod rollout;
|
mod rollout;
|
||||||
mod safety;
|
mod safety;
|
||||||
|
pub mod shell;
|
||||||
mod user_notification;
|
mod user_notification;
|
||||||
pub mod util;
|
pub mod util;
|
||||||
|
|
||||||
|
|||||||
204
codex-rs/core/src/shell.rs
Normal file
204
codex-rs/core/src/shell.rs
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
use shlex;
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub struct ZshShell {
|
||||||
|
shell_path: String,
|
||||||
|
zshrc_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub enum Shell {
|
||||||
|
Zsh(ZshShell),
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Shell {
|
||||||
|
pub fn format_default_shell_invocation(&self, command: Vec<String>) -> Option<Vec<String>> {
|
||||||
|
match self {
|
||||||
|
Shell::Zsh(zsh) => {
|
||||||
|
if !std::path::Path::new(&zsh.zshrc_path).exists() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = vec![zsh.shell_path.clone(), "-c".to_string()];
|
||||||
|
if let Ok(joined) = shlex::try_join(command.iter().map(|s| s.as_str())) {
|
||||||
|
result.push(format!("source {} && ({joined})", zsh.zshrc_path));
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
Shell::Unknown => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
pub async fn default_user_shell() -> Shell {
|
||||||
|
use tokio::process::Command;
|
||||||
|
use whoami;
|
||||||
|
|
||||||
|
let user = whoami::username();
|
||||||
|
let home = format!("/Users/{user}");
|
||||||
|
let output = Command::new("dscl")
|
||||||
|
.args([".", "-read", &home, "UserShell"])
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
match output {
|
||||||
|
Some(o) => {
|
||||||
|
if !o.status.success() {
|
||||||
|
return Shell::Unknown;
|
||||||
|
}
|
||||||
|
let stdout = String::from_utf8_lossy(&o.stdout);
|
||||||
|
for line in stdout.lines() {
|
||||||
|
if let Some(shell_path) = line.strip_prefix("UserShell: ") {
|
||||||
|
if shell_path.ends_with("/zsh") {
|
||||||
|
return Shell::Zsh(ZshShell {
|
||||||
|
shell_path: shell_path.to_string(),
|
||||||
|
zshrc_path: format!("{home}/.zshrc"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Shell::Unknown
|
||||||
|
}
|
||||||
|
_ => Shell::Unknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(target_os = "macos"))]
|
||||||
|
pub async fn default_user_shell() -> Shell {
|
||||||
|
Shell::Unknown
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[cfg(target_os = "macos")]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
async fn test_current_shell_detects_zsh() {
|
||||||
|
let shell = Command::new("sh")
|
||||||
|
.arg("-c")
|
||||||
|
.arg("echo $SHELL")
|
||||||
|
.output()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let home = std::env::var("HOME").unwrap();
|
||||||
|
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
|
||||||
|
if shell_path.ends_with("/zsh") {
|
||||||
|
assert_eq!(
|
||||||
|
default_user_shell().await,
|
||||||
|
Shell::Zsh(ZshShell {
|
||||||
|
shell_path: shell_path.to_string(),
|
||||||
|
zshrc_path: format!("{home}/.zshrc",),
|
||||||
|
})
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_with_profile_zshrc_not_exists() {
|
||||||
|
let shell = Shell::Zsh(ZshShell {
|
||||||
|
shell_path: "/bin/zsh".to_string(),
|
||||||
|
zshrc_path: "/does/not/exist/.zshrc".to_string(),
|
||||||
|
});
|
||||||
|
let actual_cmd = shell.format_default_shell_invocation(vec!["myecho".to_string()]);
|
||||||
|
assert_eq!(actual_cmd, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[expect(clippy::unwrap_used)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_run_with_profile_escaping_and_execution() {
|
||||||
|
let shell_path = "/bin/zsh";
|
||||||
|
|
||||||
|
let cases = vec![
|
||||||
|
(
|
||||||
|
vec!["myecho"],
|
||||||
|
vec![shell_path, "-c", "source ZSHRC_PATH && (myecho)"],
|
||||||
|
Some("It works!\n"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||||
|
vec![
|
||||||
|
shell_path,
|
||||||
|
"-c",
|
||||||
|
"source ZSHRC_PATH && (bash -lc \"echo 'single' \\\"double\\\"\")",
|
||||||
|
],
|
||||||
|
Some("single double\n"),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
for (input, expected_cmd, expected_output) in cases {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use tokio::sync::Notify;
|
||||||
|
|
||||||
|
use crate::exec::ExecParams;
|
||||||
|
use crate::exec::SandboxType;
|
||||||
|
use crate::exec::process_exec_tool_call;
|
||||||
|
use crate::protocol::SandboxPolicy;
|
||||||
|
|
||||||
|
// create a temp directory with a zshrc file in it
|
||||||
|
let temp_home = tempfile::tempdir().unwrap();
|
||||||
|
let zshrc_path = temp_home.path().join(".zshrc");
|
||||||
|
std::fs::write(
|
||||||
|
&zshrc_path,
|
||||||
|
r#"
|
||||||
|
set -x
|
||||||
|
function myecho {
|
||||||
|
echo 'It works!'
|
||||||
|
}
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let shell = Shell::Zsh(ZshShell {
|
||||||
|
shell_path: shell_path.to_string(),
|
||||||
|
zshrc_path: zshrc_path.to_str().unwrap().to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let actual_cmd = shell
|
||||||
|
.format_default_shell_invocation(input.iter().map(|s| s.to_string()).collect());
|
||||||
|
let expected_cmd = expected_cmd
|
||||||
|
.iter()
|
||||||
|
.map(|s| {
|
||||||
|
s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap())
|
||||||
|
.to_string()
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(actual_cmd, Some(expected_cmd));
|
||||||
|
// Actually run the command and check output/exit code
|
||||||
|
let output = process_exec_tool_call(
|
||||||
|
ExecParams {
|
||||||
|
command: actual_cmd.unwrap(),
|
||||||
|
cwd: PathBuf::from(temp_home.path()),
|
||||||
|
timeout_ms: None,
|
||||||
|
env: HashMap::from([(
|
||||||
|
"HOME".to_string(),
|
||||||
|
temp_home.path().to_str().unwrap().to_string(),
|
||||||
|
)]),
|
||||||
|
},
|
||||||
|
SandboxType::None,
|
||||||
|
Arc::new(Notify::new()),
|
||||||
|
&SandboxPolicy::DangerFullAccess,
|
||||||
|
&None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||||
|
if let Some(expected) = expected_output {
|
||||||
|
assert_eq!(
|
||||||
|
output.stdout, expected,
|
||||||
|
"input: {input:?} output: {output:?}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ use tokio::sync::mpsc;
|
|||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
mod codex_tool_config;
|
mod codex_tool_config;
|
||||||
mod codex_tool_runner;
|
mod codex_tool_runner;
|
||||||
@@ -43,6 +44,7 @@ pub async fn run_main(codex_linux_sandbox_exe: Option<PathBuf>) -> IoResult<()>
|
|||||||
// control the log level with `RUST_LOG`.
|
// control the log level with `RUST_LOG`.
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt()
|
||||||
.with_writer(std::io::stderr)
|
.with_writer(std::io::stderr)
|
||||||
|
.with_env_filter(EnvFilter::from_default_env())
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
// Set up channels.
|
// Set up channels.
|
||||||
|
|||||||
Reference in New Issue
Block a user