feat: make cwd a required field of Config so we stop assuming std::env::current_dir() in a session (#800)
In order to expose Codex via an MCP server, I realized that we should be taking `cwd` as a parameter rather than assuming `std::env::current_dir()` as the `cwd`. Specifically, the user may want to start a session in a directory other than the one where the MCP server has been started. This PR makes `cwd: PathBuf` a required field of `Session` and threads it all the way through, though I think there is still an issue with not honoring `workdir` for `apply_patch`, which is something we also had to fix in the TypeScript version: https://github.com/openai/codex/pull/556. This also adds `-C`/`--cd` to change the cwd via the command line. To test, I ran: ``` cargo run --bin codex -- exec -C /tmp 'show the output of ls' ``` and verified it showed the contents of my `/tmp` folder instead of `$PWD`.
This commit is contained in:
@@ -22,6 +22,7 @@ use tokio::sync::oneshot;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::AbortHandle;
|
||||
use tracing::debug;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
@@ -40,6 +41,7 @@ use crate::models::ContentItem;
|
||||
use crate::models::FunctionCallOutputPayload;
|
||||
use crate::models::ResponseInputItem;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::models::ShellToolCallParams;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
@@ -190,6 +192,10 @@ struct Session {
|
||||
tx_event: Sender<Event>,
|
||||
ctrl_c: Arc<Notify>,
|
||||
|
||||
/// The session's current working directory. All relative paths provided by
|
||||
/// the model as well as sandbox policies are resolved against this path
|
||||
/// instead of `std::env::current_dir()`.
|
||||
cwd: PathBuf,
|
||||
instructions: Option<String>,
|
||||
approval_policy: AskForApproval,
|
||||
sandbox_policy: SandboxPolicy,
|
||||
@@ -198,10 +204,17 @@ struct Session {
|
||||
/// External notifier command (will be passed as args to exec()). When
|
||||
/// `None` this feature is disabled.
|
||||
notify: Option<Vec<String>>,
|
||||
|
||||
state: Mutex<State>,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
||||
path.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state of the agent
|
||||
#[derive(Default)]
|
||||
struct State {
|
||||
@@ -296,15 +309,8 @@ impl Session {
|
||||
sub_id: &str,
|
||||
call_id: &str,
|
||||
command: Vec<String>,
|
||||
cwd: Option<String>,
|
||||
cwd: PathBuf,
|
||||
) {
|
||||
let cwd = cwd
|
||||
.or_else(|| {
|
||||
std::env::current_dir()
|
||||
.ok()
|
||||
.map(|p| p.to_string_lossy().to_string())
|
||||
})
|
||||
.unwrap_or_else(|| "<unknown cwd>".to_string());
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::ExecCommandBegin {
|
||||
@@ -518,8 +524,22 @@ async fn submission_loop(
|
||||
sandbox_policy,
|
||||
disable_response_storage,
|
||||
notify,
|
||||
cwd,
|
||||
} => {
|
||||
info!(model, "Configuring session");
|
||||
if !cwd.is_absolute() {
|
||||
let message = format!("cwd is not absolute: {cwd:?}");
|
||||
error!(message);
|
||||
let event = Event {
|
||||
id: sub.id,
|
||||
msg: EventMsg::Error { message },
|
||||
};
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
error!("failed to send error message: {e:?}");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let client = ModelClient::new(model.clone());
|
||||
|
||||
// abort any current running session and clone its state
|
||||
@@ -538,7 +558,7 @@ async fn submission_loop(
|
||||
},
|
||||
};
|
||||
|
||||
// update session
|
||||
let writable_roots = Mutex::new(get_writable_roots(&cwd));
|
||||
sess = Some(Arc::new(Session {
|
||||
client,
|
||||
tx_event: tx_event.clone(),
|
||||
@@ -546,7 +566,8 @@ async fn submission_loop(
|
||||
instructions,
|
||||
approval_policy,
|
||||
sandbox_policy,
|
||||
writable_roots: Mutex::new(get_writable_roots()),
|
||||
cwd,
|
||||
writable_roots,
|
||||
notify,
|
||||
state: Mutex::new(state),
|
||||
}));
|
||||
@@ -865,7 +886,7 @@ async fn handle_function_call(
|
||||
match name.as_str() {
|
||||
"container.exec" | "shell" => {
|
||||
// parse command
|
||||
let params = match serde_json::from_str::<ExecParams>(&arguments) {
|
||||
let params = match serde_json::from_str::<ShellToolCallParams>(&arguments) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
// allow model to re-sample
|
||||
@@ -904,12 +925,7 @@ async fn handle_function_call(
|
||||
}
|
||||
|
||||
// this was not a valid patch, execute command
|
||||
let repo_root = std::env::current_dir().expect("no current dir");
|
||||
let workdir: PathBuf = params
|
||||
.workdir
|
||||
.as_ref()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or(repo_root.clone());
|
||||
let workdir = sess.resolve_path(params.workdir.clone());
|
||||
|
||||
// safety checks
|
||||
let safety = {
|
||||
@@ -968,12 +984,16 @@ async fn handle_function_call(
|
||||
&sub_id,
|
||||
&call_id,
|
||||
params.command.clone(),
|
||||
params.workdir.clone(),
|
||||
workdir.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output_result = process_exec_tool_call(
|
||||
params.clone(),
|
||||
ExecParams {
|
||||
command: params.command.clone(),
|
||||
cwd: workdir.clone(),
|
||||
timeout_ms: params.timeout_ms,
|
||||
},
|
||||
sandbox_type,
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
@@ -1051,18 +1071,23 @@ async fn handle_function_call(
|
||||
|
||||
// Emit a fresh Begin event so progress bars reset.
|
||||
let retry_call_id = format!("{call_id}-retry");
|
||||
let cwd = sess.resolve_path(params.workdir.clone());
|
||||
sess.notify_exec_command_begin(
|
||||
&sub_id,
|
||||
&retry_call_id,
|
||||
params.command.clone(),
|
||||
params.workdir.clone(),
|
||||
cwd.clone(),
|
||||
)
|
||||
.await;
|
||||
|
||||
// This is an escalated retry; the policy will not be
|
||||
// examined and the sandbox has been set to `None`.
|
||||
let retry_output_result = process_exec_tool_call(
|
||||
params.clone(),
|
||||
ExecParams {
|
||||
command: params.command.clone(),
|
||||
cwd: cwd.clone(),
|
||||
timeout_ms: params.timeout_ms,
|
||||
},
|
||||
SandboxType::None,
|
||||
sess.ctrl_c.clone(),
|
||||
&sess.sandbox_policy,
|
||||
@@ -1162,43 +1187,47 @@ async fn apply_patch(
|
||||
guard.clone()
|
||||
};
|
||||
|
||||
let auto_approved =
|
||||
match assess_patch_safety(&changes, sess.approval_policy, &writable_roots_snapshot) {
|
||||
SafetyCheck::AutoApprove { .. } => true,
|
||||
SafetyCheck::AskUser => {
|
||||
// Compute a readable summary of path changes to include in the
|
||||
// approval request so the user can make an informed decision.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(sub_id.clone(), &changes, None, None)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false,
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "patch rejected by user".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
let auto_approved = match assess_patch_safety(
|
||||
&changes,
|
||||
sess.approval_policy,
|
||||
&writable_roots_snapshot,
|
||||
&sess.cwd,
|
||||
) {
|
||||
SafetyCheck::AutoApprove { .. } => true,
|
||||
SafetyCheck::AskUser => {
|
||||
// Compute a readable summary of path changes to include in the
|
||||
// approval request so the user can make an informed decision.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(sub_id.clone(), &changes, None, None)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => false,
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "patch rejected by user".to_string(),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
SafetyCheck::Reject { reason } => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("patch rejected: {reason}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
SafetyCheck::Reject { reason } => {
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("patch rejected: {reason}"),
|
||||
success: Some(false),
|
||||
},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Verify write permissions before touching the filesystem.
|
||||
let writable_snapshot = { sess.writable_roots.lock().unwrap().clone() };
|
||||
|
||||
if let Some(offending) = first_offending_path(&changes, &writable_snapshot) {
|
||||
if let Some(offending) = first_offending_path(&changes, &writable_snapshot, &sess.cwd) {
|
||||
let root = offending.parent().unwrap_or(&offending).to_path_buf();
|
||||
|
||||
let reason = Some(format!(
|
||||
@@ -1255,11 +1284,13 @@ async fn apply_patch(
|
||||
ApplyPatchFileChange::Update { .. } => path,
|
||||
};
|
||||
|
||||
// Reuse safety normalisation logic: treat absolute path.
|
||||
// Reuse safety normalization logic: treat absolute path.
|
||||
let abs = if path_ref.is_absolute() {
|
||||
path_ref.clone()
|
||||
} else {
|
||||
std::env::current_dir().unwrap_or_default().join(path_ref)
|
||||
// TODO(mbolin): If workdir was supplied with apply_patch call,
|
||||
// relative paths should be resolved against it.
|
||||
sess.cwd.join(path_ref)
|
||||
};
|
||||
|
||||
let writable = {
|
||||
@@ -1345,9 +1376,8 @@ async fn apply_patch(
|
||||
fn first_offending_path(
|
||||
changes: &HashMap<PathBuf, ApplyPatchFileChange>,
|
||||
writable_roots: &[PathBuf],
|
||||
cwd: &Path,
|
||||
) -> Option<PathBuf> {
|
||||
let cwd = std::env::current_dir().unwrap_or_default();
|
||||
|
||||
for (path, change) in changes {
|
||||
let candidate = match change {
|
||||
ApplyPatchFileChange::Add { .. } => path,
|
||||
@@ -1485,7 +1515,7 @@ fn apply_changes_from_apply_patch(
|
||||
})
|
||||
}
|
||||
|
||||
fn get_writable_roots() -> Vec<PathBuf> {
|
||||
fn get_writable_roots(cwd: &Path) -> Vec<std::path::PathBuf> {
|
||||
let mut writable_roots = Vec::new();
|
||||
if cfg!(target_os = "macos") {
|
||||
// On macOS, $TMPDIR is private to the user.
|
||||
@@ -1507,9 +1537,7 @@ fn get_writable_roots() -> Vec<PathBuf> {
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(cwd) = std::env::current_dir() {
|
||||
writable_roots.push(cwd);
|
||||
}
|
||||
writable_roots.push(cwd.to_path_buf());
|
||||
|
||||
writable_roots
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user