From 77148a5c612f161e7cc03f3df591ed0ca56a1b04 Mon Sep 17 00:00:00 2001 From: Gabriel Peal Date: Tue, 19 Aug 2025 19:50:28 -0700 Subject: [PATCH] Diff command (#2476) --- codex-rs/core/src/git_info.rs | 499 +++++++++++++++++- .../mcp-server/src/codex_message_processor.rs | 26 + codex-rs/protocol-ts/src/lib.rs | 1 + codex-rs/protocol/src/mcp_protocol.rs | 28 + 4 files changed, 553 insertions(+), 1 deletion(-) diff --git a/codex-rs/core/src/git_info.rs b/codex-rs/core/src/git_info.rs index ccb43ae5..5f25d8fe 100644 --- a/codex-rs/core/src/git_info.rs +++ b/codex-rs/core/src/git_info.rs @@ -1,11 +1,16 @@ +use std::collections::HashSet; use std::path::Path; +use codex_protocol::mcp_protocol::GitSha; +use futures::future::join_all; use serde::Deserialize; use serde::Serialize; use tokio::process::Command; use tokio::time::Duration as TokioDuration; use tokio::time::timeout; +use crate::util::is_inside_git_repo; + /// Timeout for git commands to prevent freezing on large repositories const GIT_COMMAND_TIMEOUT: TokioDuration = TokioDuration::from_secs(5); @@ -22,6 +27,12 @@ pub struct GitInfo { pub repository_url: Option, } +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct GitDiffToRemote { + pub sha: GitSha, + pub diff: String, +} + /// Collect git repository information from the given working directory using command-line git. /// Returns None if no git repository is found or if git operations fail. /// Uses timeouts to prevent freezing on large repositories. @@ -80,6 +91,23 @@ pub async fn collect_git_info(cwd: &Path) -> Option { Some(git_info) } +/// Returns the closest git sha to HEAD that is on a remote as well as the diff to that sha. +pub async fn git_diff_to_remote(cwd: &Path) -> Option { + if !is_inside_git_repo(cwd) { + return None; + } + + let remotes = get_git_remotes(cwd).await?; + let branches = branch_ancestry(cwd).await?; + let base_sha = find_closest_sha(cwd, &branches, &remotes).await?; + let diff = diff_against_sha(cwd, &base_sha).await?; + + Some(GitDiffToRemote { + sha: base_sha, + diff, + }) +} + /// Run a git command with a timeout to prevent blocking on large repositories async fn run_git_command_with_timeout(args: &[&str], cwd: &Path) -> Option { let result = timeout( @@ -94,6 +122,309 @@ async fn run_git_command_with_timeout(args: &[&str], cwd: &Path) -> Option Option> { + let output = run_git_command_with_timeout(&["remote"], cwd).await?; + if !output.status.success() { + return None; + } + let mut remotes: Vec = String::from_utf8(output.stdout) + .ok()? + .lines() + .map(|s| s.to_string()) + .collect(); + if let Some(pos) = remotes.iter().position(|r| r == "origin") { + let origin = remotes.remove(pos); + remotes.insert(0, origin); + } + Some(remotes) +} + +/// Attempt to determine the repository's default branch name. +/// +/// Preference order: +/// 1) The symbolic ref at `refs/remotes//HEAD` for the first remote (origin prioritized) +/// 2) `git remote show ` parsed for "HEAD branch: " +/// 3) Local fallback to existing `main` or `master` if present +async fn get_default_branch(cwd: &Path) -> Option { + // Prefer the first remote (with origin prioritized) + let remotes = get_git_remotes(cwd).await.unwrap_or_default(); + for remote in remotes { + // Try symbolic-ref, which returns something like: refs/remotes/origin/main + if let Some(symref_output) = run_git_command_with_timeout( + &[ + "symbolic-ref", + "--quiet", + &format!("refs/remotes/{remote}/HEAD"), + ], + cwd, + ) + .await + && symref_output.status.success() + && let Ok(sym) = String::from_utf8(symref_output.stdout) + { + let trimmed = sym.trim(); + if let Some((_, name)) = trimmed.rsplit_once('/') { + return Some(name.to_string()); + } + } + + // Fall back to parsing `git remote show ` output + if let Some(show_output) = + run_git_command_with_timeout(&["remote", "show", &remote], cwd).await + && show_output.status.success() + && let Ok(text) = String::from_utf8(show_output.stdout) + { + for line in text.lines() { + let line = line.trim(); + if let Some(rest) = line.strip_prefix("HEAD branch:") { + let name = rest.trim(); + if !name.is_empty() { + return Some(name.to_string()); + } + } + } + } + } + + // No remote-derived default; try common local defaults if they exist + for candidate in ["main", "master"] { + if let Some(verify) = run_git_command_with_timeout( + &[ + "rev-parse", + "--verify", + "--quiet", + &format!("refs/heads/{candidate}"), + ], + cwd, + ) + .await + && verify.status.success() + { + return Some(candidate.to_string()); + } + } + + None +} + +/// Build an ancestry of branches starting at the current branch and ending at the +/// repository's default branch (if determinable).. +async fn branch_ancestry(cwd: &Path) -> Option> { + // Discover current branch (ignore detached HEAD by treating it as None) + let current_branch = run_git_command_with_timeout(&["rev-parse", "--abbrev-ref", "HEAD"], cwd) + .await + .and_then(|o| { + if o.status.success() { + String::from_utf8(o.stdout).ok() + } else { + None + } + }) + .map(|s| s.trim().to_string()) + .filter(|s| s != "HEAD"); + + // Discover default branch + let default_branch = get_default_branch(cwd).await; + + let mut ancestry: Vec = Vec::new(); + let mut seen: HashSet = HashSet::new(); + if let Some(cb) = current_branch.clone() { + seen.insert(cb.clone()); + ancestry.push(cb); + } + if let Some(db) = default_branch + && !seen.contains(&db) + { + seen.insert(db.clone()); + ancestry.push(db); + } + + // Expand candidates: include any remote branches that already contain HEAD. + // This addresses cases where we're on a new local-only branch forked from a + // remote branch that isn't the repository default. We prioritize remotes in + // the order returned by get_git_remotes (origin first). + let remotes = get_git_remotes(cwd).await.unwrap_or_default(); + for remote in remotes { + if let Some(output) = run_git_command_with_timeout( + &[ + "for-each-ref", + "--format=%(refname:short)", + "--contains=HEAD", + &format!("refs/remotes/{remote}"), + ], + cwd, + ) + .await + && output.status.success() + && let Ok(text) = String::from_utf8(output.stdout) + { + for line in text.lines() { + let short = line.trim(); + // Expect format like: "origin/feature"; extract the branch path after "remote/" + if let Some(stripped) = short.strip_prefix(&format!("{remote}/")) + && !stripped.is_empty() + && !seen.contains(stripped) + { + seen.insert(stripped.to_string()); + ancestry.push(stripped.to_string()); + } + } + } + } + + // Ensure we return Some vector, even if empty, to allow caller logic to proceed + Some(ancestry) +} + +// Helper for a single branch: return the remote SHA if present on any remote +// and the distance (commits ahead of HEAD) for that branch. The first item is +// None if the branch is not present on any remote. Returns None if distance +// could not be computed due to git errors/timeouts. +async fn branch_remote_and_distance( + cwd: &Path, + branch: &str, + remotes: &[String], +) -> Option<(Option, usize)> { + // Try to find the first remote ref that exists for this branch (origin prioritized by caller). + let mut found_remote_sha: Option = None; + let mut found_remote_ref: Option = None; + for remote in remotes { + let remote_ref = format!("refs/remotes/{remote}/{branch}"); + let Some(verify_output) = + run_git_command_with_timeout(&["rev-parse", "--verify", "--quiet", &remote_ref], cwd) + .await + else { + // Mirror previous behavior: if the verify call times out/fails at the process level, + // treat the entire branch as unusable. + return None; + }; + if !verify_output.status.success() { + continue; + } + let Ok(sha) = String::from_utf8(verify_output.stdout) else { + // Mirror previous behavior and skip the entire branch on parse failure. + return None; + }; + found_remote_sha = Some(GitSha::new(sha.trim())); + found_remote_ref = Some(remote_ref); + break; + } + + // Compute distance as the number of commits HEAD is ahead of the branch. + // Prefer local branch name if it exists; otherwise fall back to the remote ref (if any). + let count_output = if let Some(local_count) = + run_git_command_with_timeout(&["rev-list", "--count", &format!("{branch}..HEAD")], cwd) + .await + { + if local_count.status.success() { + local_count + } else if let Some(remote_ref) = &found_remote_ref { + match run_git_command_with_timeout( + &["rev-list", "--count", &format!("{remote_ref}..HEAD")], + cwd, + ) + .await + { + Some(remote_count) => remote_count, + None => return None, + } + } else { + return None; + } + } else if let Some(remote_ref) = &found_remote_ref { + match run_git_command_with_timeout( + &["rev-list", "--count", &format!("{remote_ref}..HEAD")], + cwd, + ) + .await + { + Some(remote_count) => remote_count, + None => return None, + } + } else { + return None; + }; + + if !count_output.status.success() { + return None; + } + let Ok(distance_str) = String::from_utf8(count_output.stdout) else { + return None; + }; + let Ok(distance) = distance_str.trim().parse::() else { + return None; + }; + + Some((found_remote_sha, distance)) +} + +// Finds the closest sha that exist on any of branches and also exists on any of the remotes. +async fn find_closest_sha(cwd: &Path, branches: &[String], remotes: &[String]) -> Option { + // A sha and how many commits away from HEAD it is. + let mut closest_sha: Option<(GitSha, usize)> = None; + for branch in branches { + let Some((maybe_remote_sha, distance)) = + branch_remote_and_distance(cwd, branch, remotes).await + else { + continue; + }; + let Some(remote_sha) = maybe_remote_sha else { + // Preserve existing behavior: skip branches that are not present on a remote. + continue; + }; + match &closest_sha { + None => closest_sha = Some((remote_sha, distance)), + Some((_, best_distance)) if distance < *best_distance => { + closest_sha = Some((remote_sha, distance)); + } + _ => {} + } + } + closest_sha.map(|(sha, _)| sha) +} + +async fn diff_against_sha(cwd: &Path, sha: &GitSha) -> Option { + let output = run_git_command_with_timeout(&["diff", &sha.0], cwd).await?; + // 0 is success and no diff. + // 1 is success but there is a diff. + let exit_ok = output.status.code().is_some_and(|c| c == 0 || c == 1); + if !exit_ok { + return None; + } + let mut diff = String::from_utf8(output.stdout).ok()?; + + if let Some(untracked_output) = + run_git_command_with_timeout(&["ls-files", "--others", "--exclude-standard"], cwd).await + && untracked_output.status.success() + { + let untracked: Vec = String::from_utf8(untracked_output.stdout) + .ok()? + .lines() + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()) + .collect(); + + if !untracked.is_empty() { + let futures_iter = untracked.into_iter().map(|file| async move { + let file_owned = file; + let args_vec: Vec<&str> = + vec!["diff", "--binary", "--no-index", "/dev/null", &file_owned]; + run_git_command_with_timeout(&args_vec, cwd).await + }); + let results = join_all(futures_iter).await; + for extra in results.into_iter().flatten() { + if extra.status.code().is_some_and(|c| c == 0 || c == 1) + && let Ok(s) = String::from_utf8(extra.stdout) + { + diff.push_str(&s); + } + } + } + } + + Some(diff) +} + #[cfg(test)] mod tests { use super::*; @@ -104,7 +435,8 @@ mod tests { // Helper function to create a test git repository async fn create_test_git_repo(temp_dir: &TempDir) -> PathBuf { - let repo_path = temp_dir.path().to_path_buf(); + let repo_path = temp_dir.path().join("repo"); + fs::create_dir(&repo_path).expect("Failed to create repo dir"); let envs = vec![ ("GIT_CONFIG_GLOBAL", "/dev/null"), ("GIT_CONFIG_NOSYSTEM", "1"), @@ -159,6 +491,41 @@ mod tests { repo_path } + async fn create_test_git_repo_with_remote(temp_dir: &TempDir) -> (PathBuf, String) { + let repo_path = create_test_git_repo(temp_dir).await; + let remote_path = temp_dir.path().join("remote.git"); + + Command::new("git") + .args(["init", "--bare", remote_path.to_str().unwrap()]) + .output() + .await + .expect("Failed to init bare remote"); + + Command::new("git") + .args(["remote", "add", "origin", remote_path.to_str().unwrap()]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add remote"); + + let output = Command::new("git") + .args(["rev-parse", "--abbrev-ref", "HEAD"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to get branch"); + let branch = String::from_utf8(output.stdout).unwrap().trim().to_string(); + + Command::new("git") + .args(["push", "-u", "origin", &branch]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to push initial commit"); + + (repo_path, branch) + } + #[tokio::test] async fn test_collect_git_info_non_git_directory() { let temp_dir = TempDir::new().expect("Failed to create temp dir"); @@ -272,6 +639,136 @@ mod tests { assert_eq!(git_info.branch, Some("feature-branch".to_string())); } + #[tokio::test] + async fn test_get_git_working_tree_state_clean_repo() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.is_empty()); + } + + #[tokio::test] + async fn test_get_git_working_tree_state_with_changes() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let tracked = repo_path.join("test.txt"); + fs::write(&tracked, "modified").unwrap(); + fs::write(repo_path.join("untracked.txt"), "new").unwrap(); + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.contains("test.txt")); + assert!(state.diff.contains("untracked.txt")); + } + + #[tokio::test] + async fn test_get_git_working_tree_state_branch_fallback() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, _branch) = create_test_git_repo_with_remote(&temp_dir).await; + + Command::new("git") + .args(["checkout", "-b", "feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create feature branch"); + Command::new("git") + .args(["push", "-u", "origin", "feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to push feature branch"); + + Command::new("git") + .args(["checkout", "-b", "local-branch"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to create local branch"); + + let remote_sha = Command::new("git") + .args(["rev-parse", "origin/feature"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + } + + #[tokio::test] + async fn test_get_git_working_tree_state_unpushed_commit() { + let temp_dir = TempDir::new().expect("Failed to create temp dir"); + let (repo_path, branch) = create_test_git_repo_with_remote(&temp_dir).await; + + let remote_sha = Command::new("git") + .args(["rev-parse", &format!("origin/{branch}")]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to rev-parse remote"); + let remote_sha = String::from_utf8(remote_sha.stdout) + .unwrap() + .trim() + .to_string(); + + fs::write(repo_path.join("test.txt"), "updated").unwrap(); + Command::new("git") + .args(["add", "test.txt"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to add file"); + Command::new("git") + .args(["commit", "-m", "local change"]) + .current_dir(&repo_path) + .output() + .await + .expect("Failed to commit"); + + let state = git_diff_to_remote(&repo_path) + .await + .expect("Should collect working tree state"); + assert_eq!(state.sha, GitSha::new(&remote_sha)); + assert!(state.diff.contains("updated")); + } + #[test] fn test_git_info_serialization() { let git_info = GitInfo { diff --git a/codex-rs/mcp-server/src/codex_message_processor.rs b/codex-rs/mcp-server/src/codex_message_processor.rs index 1decf11d..07e06d66 100644 --- a/codex-rs/mcp-server/src/codex_message_processor.rs +++ b/codex-rs/mcp-server/src/codex_message_processor.rs @@ -8,11 +8,13 @@ use codex_core::ConversationManager; use codex_core::NewConversation; use codex_core::config::Config; use codex_core::config::ConfigOverrides; +use codex_core::git_info::git_diff_to_remote; use codex_core::protocol::ApplyPatchApprovalRequestEvent; use codex_core::protocol::Event; use codex_core::protocol::EventMsg; use codex_core::protocol::ExecApprovalRequestEvent; use codex_core::protocol::ReviewDecision; +use codex_protocol::mcp_protocol::GitDiffToRemoteResponse; use mcp_types::JSONRPCErrorError; use mcp_types::RequestId; use tokio::sync::Mutex; @@ -126,6 +128,9 @@ impl CodexMessageProcessor { ClientRequest::CancelLoginChatGpt { request_id, params } => { self.cancel_login_chatgpt(request_id, params.login_id).await; } + ClientRequest::GitDiffToRemote { request_id, params } => { + self.git_diff_to_origin(request_id, params.cwd).await; + } } } @@ -514,6 +519,27 @@ impl CodexMessageProcessor { } } } + + async fn git_diff_to_origin(&self, request_id: RequestId, cwd: PathBuf) { + let diff = git_diff_to_remote(&cwd).await; + match diff { + Some(value) => { + let response = GitDiffToRemoteResponse { + sha: value.sha, + diff: value.diff, + }; + self.outgoing.send_response(request_id, response).await; + } + None => { + let error = JSONRPCErrorError { + code: INVALID_REQUEST_ERROR_CODE, + message: format!("failed to compute git diff to remote for cwd: {cwd:?}"), + data: None, + }; + self.outgoing.send_error(request_id, error).await; + } + } + } } async fn apply_bespoke_event_handling( diff --git a/codex-rs/protocol-ts/src/lib.rs b/codex-rs/protocol-ts/src/lib.rs index c2b196da..6bbc9269 100644 --- a/codex-rs/protocol-ts/src/lib.rs +++ b/codex-rs/protocol-ts/src/lib.rs @@ -36,6 +36,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> { codex_protocol::mcp_protocol::LoginChatGptCompleteNotification::export_all_to(out_dir)?; codex_protocol::mcp_protocol::CancelLoginChatGptParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::CancelLoginChatGptResponse::export_all_to(out_dir)?; + codex_protocol::mcp_protocol::GitDiffToRemoteParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ApplyPatchApprovalParams::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ApplyPatchApprovalResponse::export_all_to(out_dir)?; codex_protocol::mcp_protocol::ExecCommandApprovalParams::export_all_to(out_dir)?; diff --git a/codex-rs/protocol/src/mcp_protocol.rs b/codex-rs/protocol/src/mcp_protocol.rs index 383b2033..68f5c01d 100644 --- a/codex-rs/protocol/src/mcp_protocol.rs +++ b/codex-rs/protocol/src/mcp_protocol.rs @@ -26,6 +26,16 @@ impl Display for ConversationId { } } +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, TS)] +#[ts(type = "string")] +pub struct GitSha(pub String); + +impl GitSha { + pub fn new(sha: &str) -> Self { + Self(sha.to_string()) + } +} + /// Request from the client to the server. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(tag = "method", rename_all = "camelCase")] @@ -69,6 +79,11 @@ pub enum ClientRequest { request_id: RequestId, params: CancelLoginChatGptParams, }, + GitDiffToRemote { + #[serde(rename = "id")] + request_id: RequestId, + params: GitDiffToRemoteParams, + }, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default, TS)] @@ -139,6 +154,13 @@ pub struct LoginChatGptResponse { pub auth_url: String, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct GitDiffToRemoteResponse { + pub sha: GitSha, + pub diff: String, +} + // Event name for notifying client of login completion or failure. pub const LOGIN_CHATGPT_COMPLETE_EVENT: &str = "codex/event/login_chatgpt_complete"; @@ -157,6 +179,12 @@ pub struct CancelLoginChatGptParams { pub login_id: Uuid, } +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] +#[serde(rename_all = "camelCase")] +pub struct GitDiffToRemoteParams { + pub cwd: PathBuf, +} + #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)] #[serde(rename_all = "camelCase")] pub struct CancelLoginChatGptResponse {}