Phase 1: Repository & Infrastructure Setup
- Renamed directories: codex-rs -> llmx-rs, codex-cli -> llmx-cli
- Updated package.json files:
- Root: llmx-monorepo
- CLI: @llmx/llmx
- SDK: @llmx/llmx-sdk
- Updated pnpm workspace configuration
- Renamed binary: codex.js -> llmx.js
- Updated environment variables: CODEX_* -> LLMX_*
- Changed repository URLs to valknar/llmx
🤖 Generated with Claude Code
This commit is contained in:
142
llmx-rs/core/src/apply_patch.rs
Normal file
142
llmx-rs/core/src/apply_patch.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::FileChange;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::safety::SafetyCheck;
|
||||
use crate::safety::assess_patch_safety;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub const CODEX_APPLY_PATCH_ARG1: &str = "--codex-run-as-apply-patch";
|
||||
|
||||
pub(crate) enum InternalApplyPatchInvocation {
|
||||
/// The `apply_patch` call was handled programmatically, without any sort
|
||||
/// of sandbox, because the user explicitly approved it. This is the
|
||||
/// result to use with the `shell` function call that contained `apply_patch`.
|
||||
Output(Result<String, FunctionCallError>),
|
||||
|
||||
/// The `apply_patch` call was approved, either automatically because it
|
||||
/// appears that it should be allowed based on the user's sandbox policy
|
||||
/// *or* because the user explicitly approved it. In either case, we use
|
||||
/// exec with [`CODEX_APPLY_PATCH_ARG1`] to realize the `apply_patch` call,
|
||||
/// but [`ApplyPatchExec::auto_approved`] is used to determine the sandbox
|
||||
/// used with the `exec()`.
|
||||
DelegateToExec(ApplyPatchExec),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ApplyPatchExec {
|
||||
pub(crate) action: ApplyPatchAction,
|
||||
pub(crate) user_explicitly_approved_this_action: bool,
|
||||
}
|
||||
|
||||
pub(crate) async fn apply_patch(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
call_id: &str,
|
||||
action: ApplyPatchAction,
|
||||
) -> InternalApplyPatchInvocation {
|
||||
match assess_patch_safety(
|
||||
&action,
|
||||
turn_context.approval_policy,
|
||||
&turn_context.sandbox_policy,
|
||||
&turn_context.cwd,
|
||||
) {
|
||||
SafetyCheck::AutoApprove {
|
||||
user_explicitly_approved,
|
||||
..
|
||||
} => InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action: user_explicitly_approved,
|
||||
}),
|
||||
SafetyCheck::AskUser => {
|
||||
// Compute a readable summary of path changes to include in the
|
||||
// approval request so the user can make an informed decision.
|
||||
//
|
||||
// Note that it might be worth expanding this approval request to
|
||||
// give the user the option to expand the set of writable roots so
|
||||
// that similar patches can be auto-approved in the future during
|
||||
// this session.
|
||||
let rx_approve = sess
|
||||
.request_patch_approval(
|
||||
turn_context,
|
||||
call_id.to_owned(),
|
||||
convert_apply_patch_to_protocol(&action),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
match rx_approve.await.unwrap_or_default() {
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
|
||||
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
|
||||
action,
|
||||
user_explicitly_approved_this_action: true,
|
||||
})
|
||||
}
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
InternalApplyPatchInvocation::Output(Err(FunctionCallError::RespondToModel(
|
||||
"patch rejected by user".to_string(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
SafetyCheck::Reject { reason } => InternalApplyPatchInvocation::Output(Err(
|
||||
FunctionCallError::RespondToModel(format!("patch rejected: {reason}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn convert_apply_patch_to_protocol(
|
||||
action: &ApplyPatchAction,
|
||||
) -> HashMap<PathBuf, FileChange> {
|
||||
let changes = action.changes();
|
||||
let mut result = HashMap::with_capacity(changes.len());
|
||||
for (path, change) in changes {
|
||||
let protocol_change = match change {
|
||||
ApplyPatchFileChange::Add { content } => FileChange::Add {
|
||||
content: content.clone(),
|
||||
},
|
||||
ApplyPatchFileChange::Delete { content } => FileChange::Delete {
|
||||
content: content.clone(),
|
||||
},
|
||||
ApplyPatchFileChange::Update {
|
||||
unified_diff,
|
||||
move_path,
|
||||
new_content: _new_content,
|
||||
} => FileChange::Update {
|
||||
unified_diff: unified_diff.clone(),
|
||||
move_path: move_path.clone(),
|
||||
},
|
||||
};
|
||||
result.insert(path.clone(), protocol_change);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn convert_apply_patch_maps_add_variant() {
|
||||
let tmp = tempdir().expect("tmp");
|
||||
let p = tmp.path().join("a.txt");
|
||||
// Create an action with a single Add change
|
||||
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
|
||||
|
||||
let got = convert_apply_patch_to_protocol(&action);
|
||||
|
||||
assert_eq!(
|
||||
got.get(&p),
|
||||
Some(&FileChange::Add {
|
||||
content: "hello".to_string()
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
1193
llmx-rs/core/src/auth.rs
Normal file
1193
llmx-rs/core/src/auth.rs
Normal file
File diff suppressed because it is too large
Load Diff
672
llmx-rs/core/src/auth/storage.rs
Normal file
672
llmx-rs/core/src/auth/storage.rs
Normal file
@@ -0,0 +1,672 @@
|
||||
use chrono::DateTime;
|
||||
use chrono::Utc;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use sha2::Digest;
|
||||
use sha2::Sha256;
|
||||
use std::fmt::Debug;
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Read;
|
||||
use std::io::Write;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::token_data::TokenData;
|
||||
use codex_keyring_store::DefaultKeyringStore;
|
||||
use codex_keyring_store::KeyringStore;
|
||||
|
||||
/// Determine where Codex should store CLI auth credentials.
|
||||
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum AuthCredentialsStoreMode {
|
||||
#[default]
|
||||
/// Persist credentials in CODEX_HOME/auth.json.
|
||||
File,
|
||||
/// Persist credentials in the keyring. Fail if unavailable.
|
||||
Keyring,
|
||||
/// Use keyring when available; otherwise, fall back to a file in CODEX_HOME.
|
||||
Auto,
|
||||
}
|
||||
|
||||
/// Expected structure for $CODEX_HOME/auth.json.
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
pub struct AuthDotJson {
|
||||
#[serde(rename = "OPENAI_API_KEY")]
|
||||
pub openai_api_key: Option<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tokens: Option<TokenData>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub last_refresh: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
pub(super) fn get_auth_file(codex_home: &Path) -> PathBuf {
|
||||
codex_home.join("auth.json")
|
||||
}
|
||||
|
||||
pub(super) fn delete_file_if_exists(codex_home: &Path) -> std::io::Result<bool> {
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
match std::fs::remove_file(&auth_file) {
|
||||
Ok(()) => Ok(true),
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) trait AuthStorageBackend: Debug + Send + Sync {
|
||||
fn load(&self) -> std::io::Result<Option<AuthDotJson>>;
|
||||
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()>;
|
||||
fn delete(&self) -> std::io::Result<bool>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(super) struct FileAuthStorage {
|
||||
codex_home: PathBuf,
|
||||
}
|
||||
|
||||
impl FileAuthStorage {
|
||||
pub(super) fn new(codex_home: PathBuf) -> Self {
|
||||
Self { codex_home }
|
||||
}
|
||||
|
||||
/// Attempt to read and refresh the `auth.json` file in the given `CODEX_HOME` directory.
|
||||
/// Returns the full AuthDotJson structure after refreshing if necessary.
|
||||
pub(super) fn try_read_auth_json(&self, auth_file: &Path) -> std::io::Result<AuthDotJson> {
|
||||
let mut file = File::open(auth_file)?;
|
||||
let mut contents = String::new();
|
||||
file.read_to_string(&mut contents)?;
|
||||
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
|
||||
|
||||
Ok(auth_dot_json)
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthStorageBackend for FileAuthStorage {
|
||||
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
|
||||
let auth_file = get_auth_file(&self.codex_home);
|
||||
let auth_dot_json = match self.try_read_auth_json(&auth_file) {
|
||||
Ok(auth) => auth,
|
||||
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
|
||||
Err(err) => return Err(err),
|
||||
};
|
||||
Ok(Some(auth_dot_json))
|
||||
}
|
||||
|
||||
fn save(&self, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
|
||||
let auth_file = get_auth_file(&self.codex_home);
|
||||
|
||||
if let Some(parent) = auth_file.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
|
||||
let mut options = OpenOptions::new();
|
||||
options.truncate(true).write(true).create(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
options.mode(0o600);
|
||||
}
|
||||
let mut file = options.open(auth_file)?;
|
||||
file.write_all(json_data.as_bytes())?;
|
||||
file.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete(&self) -> std::io::Result<bool> {
|
||||
delete_file_if_exists(&self.codex_home)
|
||||
}
|
||||
}
|
||||
|
||||
const KEYRING_SERVICE: &str = "Codex Auth";
|
||||
|
||||
// turns codex_home path into a stable, short key string
|
||||
fn compute_store_key(codex_home: &Path) -> std::io::Result<String> {
|
||||
let canonical = codex_home
|
||||
.canonicalize()
|
||||
.unwrap_or_else(|_| codex_home.to_path_buf());
|
||||
let path_str = canonical.to_string_lossy();
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(path_str.as_bytes());
|
||||
let digest = hasher.finalize();
|
||||
let hex = format!("{digest:x}");
|
||||
let truncated = hex.get(..16).unwrap_or(&hex);
|
||||
Ok(format!("cli|{truncated}"))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct KeyringAuthStorage {
|
||||
codex_home: PathBuf,
|
||||
keyring_store: Arc<dyn KeyringStore>,
|
||||
}
|
||||
|
||||
impl KeyringAuthStorage {
|
||||
fn new(codex_home: PathBuf, keyring_store: Arc<dyn KeyringStore>) -> Self {
|
||||
Self {
|
||||
codex_home,
|
||||
keyring_store,
|
||||
}
|
||||
}
|
||||
|
||||
fn load_from_keyring(&self, key: &str) -> std::io::Result<Option<AuthDotJson>> {
|
||||
match self.keyring_store.load(KEYRING_SERVICE, key) {
|
||||
Ok(Some(serialized)) => serde_json::from_str(&serialized).map(Some).map_err(|err| {
|
||||
std::io::Error::other(format!(
|
||||
"failed to deserialize CLI auth from keyring: {err}"
|
||||
))
|
||||
}),
|
||||
Ok(None) => Ok(None),
|
||||
Err(error) => Err(std::io::Error::other(format!(
|
||||
"failed to load CLI auth from keyring: {}",
|
||||
error.message()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn save_to_keyring(&self, key: &str, value: &str) -> std::io::Result<()> {
|
||||
match self.keyring_store.save(KEYRING_SERVICE, key, value) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(error) => {
|
||||
let message = format!(
|
||||
"failed to write OAuth tokens to keyring: {}",
|
||||
error.message()
|
||||
);
|
||||
warn!("{message}");
|
||||
Err(std::io::Error::other(message))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthStorageBackend for KeyringAuthStorage {
|
||||
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
|
||||
let key = compute_store_key(&self.codex_home)?;
|
||||
self.load_from_keyring(&key)
|
||||
}
|
||||
|
||||
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> {
|
||||
let key = compute_store_key(&self.codex_home)?;
|
||||
// Simpler error mapping per style: prefer method reference over closure
|
||||
let serialized = serde_json::to_string(auth).map_err(std::io::Error::other)?;
|
||||
self.save_to_keyring(&key, &serialized)?;
|
||||
if let Err(err) = delete_file_if_exists(&self.codex_home) {
|
||||
warn!("failed to remove CLI auth fallback file: {err}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete(&self) -> std::io::Result<bool> {
|
||||
let key = compute_store_key(&self.codex_home)?;
|
||||
let keyring_removed = self
|
||||
.keyring_store
|
||||
.delete(KEYRING_SERVICE, &key)
|
||||
.map_err(|err| {
|
||||
std::io::Error::other(format!("failed to delete auth from keyring: {err}"))
|
||||
})?;
|
||||
let file_removed = delete_file_if_exists(&self.codex_home)?;
|
||||
Ok(keyring_removed || file_removed)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AutoAuthStorage {
|
||||
keyring_storage: Arc<KeyringAuthStorage>,
|
||||
file_storage: Arc<FileAuthStorage>,
|
||||
}
|
||||
|
||||
impl AutoAuthStorage {
|
||||
fn new(codex_home: PathBuf, keyring_store: Arc<dyn KeyringStore>) -> Self {
|
||||
Self {
|
||||
keyring_storage: Arc::new(KeyringAuthStorage::new(codex_home.clone(), keyring_store)),
|
||||
file_storage: Arc::new(FileAuthStorage::new(codex_home)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AuthStorageBackend for AutoAuthStorage {
|
||||
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
|
||||
match self.keyring_storage.load() {
|
||||
Ok(Some(auth)) => Ok(Some(auth)),
|
||||
Ok(None) => self.file_storage.load(),
|
||||
Err(err) => {
|
||||
warn!("failed to load CLI auth from keyring, falling back to file storage: {err}");
|
||||
self.file_storage.load()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> {
|
||||
match self.keyring_storage.save(auth) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(err) => {
|
||||
warn!("failed to save auth to keyring, falling back to file storage: {err}");
|
||||
self.file_storage.save(auth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn delete(&self) -> std::io::Result<bool> {
|
||||
// Keyring storage will delete from disk as well
|
||||
self.keyring_storage.delete()
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn create_auth_storage(
|
||||
codex_home: PathBuf,
|
||||
mode: AuthCredentialsStoreMode,
|
||||
) -> Arc<dyn AuthStorageBackend> {
|
||||
let keyring_store: Arc<dyn KeyringStore> = Arc::new(DefaultKeyringStore);
|
||||
create_auth_storage_with_keyring_store(codex_home, mode, keyring_store)
|
||||
}
|
||||
|
||||
fn create_auth_storage_with_keyring_store(
|
||||
codex_home: PathBuf,
|
||||
mode: AuthCredentialsStoreMode,
|
||||
keyring_store: Arc<dyn KeyringStore>,
|
||||
) -> Arc<dyn AuthStorageBackend> {
|
||||
match mode {
|
||||
AuthCredentialsStoreMode::File => Arc::new(FileAuthStorage::new(codex_home)),
|
||||
AuthCredentialsStoreMode::Keyring => {
|
||||
Arc::new(KeyringAuthStorage::new(codex_home, keyring_store))
|
||||
}
|
||||
AuthCredentialsStoreMode::Auto => Arc::new(AutoAuthStorage::new(codex_home, keyring_store)),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::token_data::IdTokenInfo;
|
||||
use anyhow::Context;
|
||||
use base64::Engine;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tempfile::tempdir;
|
||||
|
||||
use codex_keyring_store::tests::MockKeyringStore;
|
||||
use keyring::Error as KeyringError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let storage = FileAuthStorage::new(codex_home.path().to_path_buf());
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some("test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
storage
|
||||
.save(&auth_dot_json)
|
||||
.context("failed to save auth file")?;
|
||||
|
||||
let loaded = storage.load().context("failed to load auth file")?;
|
||||
assert_eq!(Some(auth_dot_json), loaded);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn file_storage_save_persists_auth_dot_json() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let storage = FileAuthStorage::new(codex_home.path().to_path_buf());
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some("test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
let file = get_auth_file(codex_home.path());
|
||||
storage
|
||||
.save(&auth_dot_json)
|
||||
.context("failed to save auth file")?;
|
||||
|
||||
let same_auth_dot_json = storage
|
||||
.try_read_auth_json(&file)
|
||||
.context("failed to read auth file after save")?;
|
||||
assert_eq!(auth_dot_json, same_auth_dot_json);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn file_storage_delete_removes_auth_file() -> anyhow::Result<()> {
|
||||
let dir = tempdir()?;
|
||||
let auth_dot_json = AuthDotJson {
|
||||
openai_api_key: Some("sk-test-key".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
let storage = create_auth_storage(dir.path().to_path_buf(), AuthCredentialsStoreMode::File);
|
||||
storage.save(&auth_dot_json)?;
|
||||
assert!(dir.path().join("auth.json").exists());
|
||||
let storage = FileAuthStorage::new(dir.path().to_path_buf());
|
||||
let removed = storage.delete()?;
|
||||
assert!(removed);
|
||||
assert!(!dir.path().join("auth.json").exists());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn seed_keyring_and_fallback_auth_file_for_delete<F>(
|
||||
mock_keyring: &MockKeyringStore,
|
||||
codex_home: &Path,
|
||||
compute_key: F,
|
||||
) -> anyhow::Result<(String, PathBuf)>
|
||||
where
|
||||
F: FnOnce() -> std::io::Result<String>,
|
||||
{
|
||||
let key = compute_key()?;
|
||||
mock_keyring.save(KEYRING_SERVICE, &key, "{}")?;
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
std::fs::write(&auth_file, "stale")?;
|
||||
Ok((key, auth_file))
|
||||
}
|
||||
|
||||
fn seed_keyring_with_auth<F>(
|
||||
mock_keyring: &MockKeyringStore,
|
||||
compute_key: F,
|
||||
auth: &AuthDotJson,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
F: FnOnce() -> std::io::Result<String>,
|
||||
{
|
||||
let key = compute_key()?;
|
||||
let serialized = serde_json::to_string(auth)?;
|
||||
mock_keyring.save(KEYRING_SERVICE, &key, &serialized)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn assert_keyring_saved_auth_and_removed_fallback(
|
||||
mock_keyring: &MockKeyringStore,
|
||||
key: &str,
|
||||
codex_home: &Path,
|
||||
expected: &AuthDotJson,
|
||||
) {
|
||||
let saved_value = mock_keyring
|
||||
.saved_value(key)
|
||||
.expect("keyring entry should exist");
|
||||
let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth");
|
||||
assert_eq!(saved_value, expected_serialized);
|
||||
let auth_file = get_auth_file(codex_home);
|
||||
assert!(
|
||||
!auth_file.exists(),
|
||||
"fallback auth.json should be removed after keyring save"
|
||||
);
|
||||
}
|
||||
|
||||
fn id_token_with_prefix(prefix: &str) -> IdTokenInfo {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = json!({
|
||||
"email": format!("{prefix}@example.com"),
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_account_id": format!("{prefix}-account"),
|
||||
},
|
||||
});
|
||||
let encode = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
|
||||
let header_b64 = encode(&serde_json::to_vec(&header).expect("serialize header"));
|
||||
let payload_b64 = encode(&serde_json::to_vec(&payload).expect("serialize payload"));
|
||||
let signature_b64 = encode(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
crate::token_data::parse_id_token(&fake_jwt).expect("fake JWT should parse")
|
||||
}
|
||||
|
||||
fn auth_with_prefix(prefix: &str) -> AuthDotJson {
|
||||
AuthDotJson {
|
||||
openai_api_key: Some(format!("{prefix}-api-key")),
|
||||
tokens: Some(TokenData {
|
||||
id_token: id_token_with_prefix(prefix),
|
||||
access_token: format!("{prefix}-access"),
|
||||
refresh_token: format!("{prefix}-refresh"),
|
||||
account_id: Some(format!("{prefix}-account-id")),
|
||||
}),
|
||||
last_refresh: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = KeyringAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let expected = AuthDotJson {
|
||||
openai_api_key: Some("sk-test".to_string()),
|
||||
tokens: None,
|
||||
last_refresh: None,
|
||||
};
|
||||
seed_keyring_with_auth(
|
||||
&mock_keyring,
|
||||
|| compute_store_key(codex_home.path()),
|
||||
&expected,
|
||||
)?;
|
||||
|
||||
let loaded = storage.load()?;
|
||||
assert_eq!(Some(expected), loaded);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> {
|
||||
let codex_home = PathBuf::from("~/.codex");
|
||||
|
||||
let key = compute_store_key(codex_home.as_path())?;
|
||||
|
||||
assert_eq!(key, "cli|940db7b1d0e4eb40");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyring_auth_storage_save_persists_and_removes_fallback_file() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = KeyringAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let auth_file = get_auth_file(codex_home.path());
|
||||
std::fs::write(&auth_file, "stale")?;
|
||||
let auth = AuthDotJson {
|
||||
openai_api_key: None,
|
||||
tokens: Some(TokenData {
|
||||
id_token: Default::default(),
|
||||
access_token: "access".to_string(),
|
||||
refresh_token: "refresh".to_string(),
|
||||
account_id: Some("account".to_string()),
|
||||
}),
|
||||
last_refresh: Some(Utc::now()),
|
||||
};
|
||||
|
||||
storage.save(&auth)?;
|
||||
|
||||
let key = compute_store_key(codex_home.path())?;
|
||||
assert_keyring_saved_auth_and_removed_fallback(
|
||||
&mock_keyring,
|
||||
&key,
|
||||
codex_home.path(),
|
||||
&auth,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = KeyringAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete(
|
||||
&mock_keyring,
|
||||
codex_home.path(),
|
||||
|| compute_store_key(codex_home.path()),
|
||||
)?;
|
||||
|
||||
let removed = storage.delete()?;
|
||||
|
||||
assert!(removed, "delete should report removal");
|
||||
assert!(
|
||||
!mock_keyring.contains(&key),
|
||||
"keyring entry should be removed"
|
||||
);
|
||||
assert!(
|
||||
!auth_file.exists(),
|
||||
"fallback auth.json should be removed after keyring delete"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_load_prefers_keyring_value() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let keyring_auth = auth_with_prefix("keyring");
|
||||
seed_keyring_with_auth(
|
||||
&mock_keyring,
|
||||
|| compute_store_key(codex_home.path()),
|
||||
&keyring_auth,
|
||||
)?;
|
||||
|
||||
let file_auth = auth_with_prefix("file");
|
||||
storage.file_storage.save(&file_auth)?;
|
||||
|
||||
let loaded = storage.load()?;
|
||||
assert_eq!(loaded, Some(keyring_auth));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(codex_home.path().to_path_buf(), Arc::new(mock_keyring));
|
||||
|
||||
let expected = auth_with_prefix("file-only");
|
||||
storage.file_storage.save(&expected)?;
|
||||
|
||||
let loaded = storage.load()?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let key = compute_store_key(codex_home.path())?;
|
||||
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
|
||||
|
||||
let expected = auth_with_prefix("fallback");
|
||||
storage.file_storage.save(&expected)?;
|
||||
|
||||
let loaded = storage.load()?;
|
||||
assert_eq!(loaded, Some(expected));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let key = compute_store_key(codex_home.path())?;
|
||||
|
||||
let stale = auth_with_prefix("stale");
|
||||
storage.file_storage.save(&stale)?;
|
||||
|
||||
let expected = auth_with_prefix("to-save");
|
||||
storage.save(&expected)?;
|
||||
|
||||
assert_keyring_saved_auth_and_removed_fallback(
|
||||
&mock_keyring,
|
||||
&key,
|
||||
codex_home.path(),
|
||||
&expected,
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let key = compute_store_key(codex_home.path())?;
|
||||
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
|
||||
|
||||
let auth = auth_with_prefix("fallback");
|
||||
storage.save(&auth)?;
|
||||
|
||||
let auth_file = get_auth_file(codex_home.path());
|
||||
assert!(
|
||||
auth_file.exists(),
|
||||
"fallback auth.json should be created when keyring save fails"
|
||||
);
|
||||
let saved = storage
|
||||
.file_storage
|
||||
.load()?
|
||||
.context("fallback auth should exist")?;
|
||||
assert_eq!(saved, auth);
|
||||
assert!(
|
||||
mock_keyring.saved_value(&key).is_none(),
|
||||
"keyring should not contain value when save fails"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> {
|
||||
let codex_home = tempdir()?;
|
||||
let mock_keyring = MockKeyringStore::default();
|
||||
let storage = AutoAuthStorage::new(
|
||||
codex_home.path().to_path_buf(),
|
||||
Arc::new(mock_keyring.clone()),
|
||||
);
|
||||
let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete(
|
||||
&mock_keyring,
|
||||
codex_home.path(),
|
||||
|| compute_store_key(codex_home.path()),
|
||||
)?;
|
||||
|
||||
let removed = storage.delete()?;
|
||||
|
||||
assert!(removed, "delete should report removal");
|
||||
assert!(
|
||||
!mock_keyring.contains(&key),
|
||||
"keyring entry should be removed"
|
||||
);
|
||||
assert!(
|
||||
!auth_file.exists(),
|
||||
"fallback auth.json should be removed after delete"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
261
llmx-rs/core/src/bash.rs
Normal file
261
llmx-rs/core/src/bash.rs
Normal file
@@ -0,0 +1,261 @@
|
||||
use tree_sitter::Node;
|
||||
use tree_sitter::Parser;
|
||||
use tree_sitter::Tree;
|
||||
use tree_sitter_bash::LANGUAGE as BASH;
|
||||
|
||||
/// Parse the provided bash source using tree-sitter-bash, returning a Tree on
|
||||
/// success or None if parsing failed.
|
||||
pub fn try_parse_shell(shell_lc_arg: &str) -> Option<Tree> {
|
||||
let lang = BASH.into();
|
||||
let mut parser = Parser::new();
|
||||
#[expect(clippy::expect_used)]
|
||||
parser.set_language(&lang).expect("load bash grammar");
|
||||
let old_tree: Option<&Tree> = None;
|
||||
parser.parse(shell_lc_arg, old_tree)
|
||||
}
|
||||
|
||||
/// Parse a script which may contain multiple simple commands joined only by
|
||||
/// the safe logical/pipe/sequencing operators: `&&`, `||`, `;`, `|`.
|
||||
///
|
||||
/// Returns `Some(Vec<command_words>)` if every command is a plain word‑only
|
||||
/// command and the parse tree does not contain disallowed constructs
|
||||
/// (parentheses, redirections, substitutions, control flow, etc.). Otherwise
|
||||
/// returns `None`.
|
||||
pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option<Vec<Vec<String>>> {
|
||||
if tree.root_node().has_error() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// List of allowed (named) node kinds for a "word only commands sequence".
|
||||
// If we encounter a named node that is not in this list we reject.
|
||||
const ALLOWED_KINDS: &[&str] = &[
|
||||
// top level containers
|
||||
"program",
|
||||
"list",
|
||||
"pipeline",
|
||||
// commands & words
|
||||
"command",
|
||||
"command_name",
|
||||
"word",
|
||||
"string",
|
||||
"string_content",
|
||||
"raw_string",
|
||||
"number",
|
||||
];
|
||||
// Allow only safe punctuation / operator tokens; anything else causes reject.
|
||||
const ALLOWED_PUNCT_TOKENS: &[&str] = &["&&", "||", ";", "|", "\"", "'"];
|
||||
|
||||
let root = tree.root_node();
|
||||
let mut cursor = root.walk();
|
||||
let mut stack = vec![root];
|
||||
let mut command_nodes = Vec::new();
|
||||
while let Some(node) = stack.pop() {
|
||||
let kind = node.kind();
|
||||
if node.is_named() {
|
||||
if !ALLOWED_KINDS.contains(&kind) {
|
||||
return None;
|
||||
}
|
||||
if kind == "command" {
|
||||
command_nodes.push(node);
|
||||
}
|
||||
} else {
|
||||
// Reject any punctuation / operator tokens that are not explicitly allowed.
|
||||
if kind.chars().any(|c| "&;|".contains(c)) && !ALLOWED_PUNCT_TOKENS.contains(&kind) {
|
||||
return None;
|
||||
}
|
||||
if !(ALLOWED_PUNCT_TOKENS.contains(&kind) || kind.trim().is_empty()) {
|
||||
// If it's a quote token or operator it's allowed above; we also allow whitespace tokens.
|
||||
// Any other punctuation like parentheses, braces, redirects, backticks, etc are rejected.
|
||||
return None;
|
||||
}
|
||||
}
|
||||
for child in node.children(&mut cursor) {
|
||||
stack.push(child);
|
||||
}
|
||||
}
|
||||
|
||||
// Walk uses a stack (LIFO), so re-sort by position to restore source order.
|
||||
command_nodes.sort_by_key(Node::start_byte);
|
||||
|
||||
let mut commands = Vec::new();
|
||||
for node in command_nodes {
|
||||
if let Some(words) = parse_plain_command_from_node(node, src) {
|
||||
commands.push(words);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(commands)
|
||||
}
|
||||
|
||||
pub fn is_well_known_sh_shell(shell: &str) -> bool {
|
||||
if shell == "bash" || shell == "zsh" {
|
||||
return true;
|
||||
}
|
||||
|
||||
let shell_name = std::path::Path::new(shell)
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or(shell);
|
||||
matches!(shell_name, "bash" | "zsh")
|
||||
}
|
||||
|
||||
pub fn extract_bash_command(command: &[String]) -> Option<(&str, &str)> {
|
||||
let [shell, flag, script] = command else {
|
||||
return None;
|
||||
};
|
||||
if flag != "-lc" || !is_well_known_sh_shell(shell) {
|
||||
return None;
|
||||
}
|
||||
Some((shell, script))
|
||||
}
|
||||
|
||||
/// Returns the sequence of plain commands within a `bash -lc "..."` or
|
||||
/// `zsh -lc "..."` invocation when the script only contains word-only commands
|
||||
/// joined by safe operators.
|
||||
pub fn parse_shell_lc_plain_commands(command: &[String]) -> Option<Vec<Vec<String>>> {
|
||||
let (_, script) = extract_bash_command(command)?;
|
||||
|
||||
let tree = try_parse_shell(script)?;
|
||||
try_parse_word_only_commands_sequence(&tree, script)
|
||||
}
|
||||
|
||||
fn parse_plain_command_from_node(cmd: tree_sitter::Node, src: &str) -> Option<Vec<String>> {
|
||||
if cmd.kind() != "command" {
|
||||
return None;
|
||||
}
|
||||
let mut words = Vec::new();
|
||||
let mut cursor = cmd.walk();
|
||||
for child in cmd.named_children(&mut cursor) {
|
||||
match child.kind() {
|
||||
"command_name" => {
|
||||
let word_node = child.named_child(0)?;
|
||||
if word_node.kind() != "word" {
|
||||
return None;
|
||||
}
|
||||
words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
"word" | "number" => {
|
||||
words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
}
|
||||
"string" => {
|
||||
if child.child_count() == 3
|
||||
&& child.child(0)?.kind() == "\""
|
||||
&& child.child(1)?.kind() == "string_content"
|
||||
&& child.child(2)?.kind() == "\""
|
||||
{
|
||||
words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned());
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
"raw_string" => {
|
||||
let raw_string = child.utf8_text(src.as_bytes()).ok()?;
|
||||
let stripped = raw_string
|
||||
.strip_prefix('\'')
|
||||
.and_then(|s| s.strip_suffix('\''));
|
||||
if let Some(s) = stripped {
|
||||
words.push(s.to_owned());
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
_ => return None,
|
||||
}
|
||||
}
|
||||
Some(words)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn parse_seq(src: &str) -> Option<Vec<Vec<String>>> {
|
||||
let tree = try_parse_shell(src)?;
|
||||
try_parse_word_only_commands_sequence(&tree, src)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_single_simple_command() {
|
||||
let cmds = parse_seq("ls -1").unwrap();
|
||||
assert_eq!(cmds, vec![vec!["ls".to_string(), "-1".to_string()]]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_multiple_commands_with_allowed_operators() {
|
||||
let src = "ls && pwd; echo 'hi there' | wc -l";
|
||||
let cmds = parse_seq(src).unwrap();
|
||||
let expected: Vec<Vec<String>> = vec![
|
||||
vec!["ls".to_string()],
|
||||
vec!["pwd".to_string()],
|
||||
vec!["echo".to_string(), "hi there".to_string()],
|
||||
vec!["wc".to_string(), "-l".to_string()],
|
||||
];
|
||||
assert_eq!(cmds, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extracts_double_and_single_quoted_strings() {
|
||||
let cmds = parse_seq("echo \"hello world\"").unwrap();
|
||||
assert_eq!(
|
||||
cmds,
|
||||
vec![vec!["echo".to_string(), "hello world".to_string()]]
|
||||
);
|
||||
|
||||
let cmds2 = parse_seq("echo 'hi there'").unwrap();
|
||||
assert_eq!(
|
||||
cmds2,
|
||||
vec![vec!["echo".to_string(), "hi there".to_string()]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accepts_numbers_as_words() {
|
||||
let cmds = parse_seq("echo 123 456").unwrap();
|
||||
assert_eq!(
|
||||
cmds,
|
||||
vec![vec![
|
||||
"echo".to_string(),
|
||||
"123".to_string(),
|
||||
"456".to_string()
|
||||
]]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_parentheses_and_subshells() {
|
||||
assert!(parse_seq("(ls)").is_none());
|
||||
assert!(parse_seq("ls || (pwd && echo hi)").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_redirections_and_unsupported_operators() {
|
||||
assert!(parse_seq("ls > out.txt").is_none());
|
||||
assert!(parse_seq("echo hi & echo bye").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_command_and_process_substitutions_and_expansions() {
|
||||
assert!(parse_seq("echo $(pwd)").is_none());
|
||||
assert!(parse_seq("echo `pwd`").is_none());
|
||||
assert!(parse_seq("echo $HOME").is_none());
|
||||
assert!(parse_seq("echo \"hi $USER\"").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_variable_assignment_prefix() {
|
||||
assert!(parse_seq("FOO=bar ls").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_trailing_operator_parse_error() {
|
||||
assert!(parse_seq("ls &&").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_zsh_lc_plain_commands() {
|
||||
let command = vec!["zsh".to_string(), "-lc".to_string(), "ls".to_string()];
|
||||
let parsed = parse_shell_lc_plain_commands(&command).unwrap();
|
||||
assert_eq!(parsed, vec![vec!["ls".to_string()]]);
|
||||
}
|
||||
}
|
||||
967
llmx-rs/core/src/chat_completions.rs
Normal file
967
llmx-rs/core/src/chat_completions.rs
Normal file
@@ -0,0 +1,967 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model_family: &ModelFamily,
|
||||
client: &CodexHttpClient,
|
||||
provider: &ModelProviderInfo,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
session_source: &SessionSource,
|
||||
) -> Result<ResponseStream> {
|
||||
if prompt.output_schema.is_some() {
|
||||
return Err(CodexErr::UnsupportedOperation(
|
||||
"output_schema is not supported for Chat Completions API".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model_family);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => text.push_str(segment),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Build content either as a plain string (typical for assistant text)
|
||||
// or as an array of content items when images are present (user/tool multimodal).
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
// For assistant messages, always send a plain string for compatibility.
|
||||
// For user messages, if an image is present, send an array of content items.
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_else(|| "".to_string()),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
// Prefer structured content items when available (e.g., images)
|
||||
// otherwise fall back to the legacy plain-string content.
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<serde_json::Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"POST to {}: {}",
|
||||
provider.get_full_url(&None),
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = provider.create_request_builder(client, &None).await?;
|
||||
|
||||
// Include subagent header only for subagent sessions.
|
||||
if let SessionSource::SubAgent(sub) = session_source.clone() {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label
|
||||
} else {
|
||||
serde_json::to_value(&sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
let res = otel_event_manager
|
||||
.log_request(attempt, || {
|
||||
req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
})
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager.clone(),
|
||||
));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
// keep collecting the pieces here and forward a single
|
||||
// `ResponseItem::FunctionCall` once the call is complete.
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
|
||||
loop {
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(e.to_string(), None)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
// Emit any finalized items before closing so downstream consumers receive
|
||||
// terminal events for both assistant content and raw reasoning.
|
||||
if let Some(item) = assistant_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
if let Some(item) = reasoning_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
trace!("chat_completions received SSE chunk: {chunk:?}");
|
||||
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens as streaming deltas.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
|
||||
}
|
||||
|
||||
// Forward any reasoning/thinking deltas if present.
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|tc| tc.as_array())
|
||||
&& let Some(tool_call) = tool_calls.first()
|
||||
{
|
||||
// Mark that we have an active function call in progress.
|
||||
fn_call_state.active = true;
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
}
|
||||
|
||||
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit end-of-turn when finish_reason signals completion.
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// First, flush the terminal raw reasoning so UIs can finalize
|
||||
// the reasoning stream before any exec/tool events begin.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Then emit the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call. Emit the final assistant message
|
||||
// as a single OutputItemDone so non-delta consumers see the result.
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Prepare for potential next turn (should not happen in same stream).
|
||||
// fn_call_state = FunctionCallState::default();
|
||||
|
||||
return; // End processing for this SSE stream.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning =
|
||||
codex_protocol::models::ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![
|
||||
codex_protocol::models::ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
// Return the first pending event now.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let agg_stream = client.stream(&prompt).await?.aggregate();
|
||||
/// while let Some(event) = agg_stream.next().await {
|
||||
/// // event now contains cumulative text
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
1523
llmx-rs/core/src/client.rs
Normal file
1523
llmx-rs/core/src/client.rs
Normal file
File diff suppressed because it is too large
Load Diff
553
llmx-rs/core/src/client_common.rs
Normal file
553
llmx-rs/core/src/client_common.rs
Normal file
@@ -0,0 +1,553 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
|
||||
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
|
||||
|
||||
// Centralized templates for review-related user messages
|
||||
pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exit_success.xml");
|
||||
pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str =
|
||||
include_str!("../templates/review/exit_interrupted.xml");
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
|
||||
/// Optional the output schema for the model's response.
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(model.base_instructions.deref());
|
||||
// When there are no custom instructions, add apply_patch_tool_instructions if:
|
||||
// - the model needs special instructions (4.1)
|
||||
// AND
|
||||
// - there is no apply_patch tool present
|
||||
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Function(f) => f.name == "apply_patch",
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if self.base_instructions_override.is_none()
|
||||
&& model.needs_special_apply_patch_instructions
|
||||
&& !is_apply_patch_tool_present
|
||||
{
|
||||
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
|
||||
} else {
|
||||
Cow::Borrowed(base)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
ReasoningSummaryPartAdded,
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextFormat {
|
||||
pub(crate) r#type: TextFormatType,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) schema: Value,
|
||||
pub(crate) name: String,
|
||||
}
|
||||
|
||||
/// Controls under the `text` field in the Responses API for GPT-5.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request object that is serialized as JSON and POST'ed when using the
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) model: &'a str,
|
||||
pub(crate) instructions: &'a str,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub(crate) mod tools {
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum ToolSpec {
|
||||
#[serde(rename = "function")]
|
||||
Function(ResponsesApiTool),
|
||||
#[serde(rename = "local_shell")]
|
||||
LocalShell {},
|
||||
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) format: FreeformToolFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformToolFormat {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) syntax: String,
|
||||
pub(crate) definition: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
pub struct ResponsesApiTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
/// TODO: Validation. When strict is set to true, the JSON schema,
|
||||
/// `required` and `additional_properties` must be present. All fields in
|
||||
/// `properties` must be present in `required`.
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_reasoning_param_for_request(
|
||||
model_family: &ModelFamily,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
) -> Option<Reasoning> {
|
||||
if !model_family.supports_reasoning_summaries {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Reasoning {
|
||||
effort,
|
||||
summary: Some(summary),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
) -> Option<TextControls> {
|
||||
if verbosity.is_none() && output_schema.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
schema: schema.clone(),
|
||||
name: "codex_output_schema".to_string(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct InstructionsTestCase {
|
||||
pub slug: &'static str,
|
||||
pub expects_apply_patch_instructions: bool,
|
||||
}
|
||||
#[test]
|
||||
fn get_full_instructions_no_user_content() {
|
||||
let prompt = Prompt {
|
||||
..Default::default()
|
||||
};
|
||||
let test_cases = vec![
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-3.5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4.1",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4o",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "codex-mini-latest",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-oss:120b",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5-codex",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
];
|
||||
for test_case in test_cases {
|
||||
let model_family = find_family_for_model(test_case.slug).expect("known model slug");
|
||||
let expected = if test_case.expects_apply_patch_instructions {
|
||||
format!(
|
||||
"{}\n{}",
|
||||
model_family.clone().base_instructions,
|
||||
APPLY_PATCH_TOOL_INSTRUCTIONS
|
||||
)
|
||||
} else {
|
||||
model_family.clone().base_instructions
|
||||
};
|
||||
|
||||
let full = prompt.get_full_instructions(&model_family);
|
||||
assert_eq!(full, expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serializes_text_verbosity_when_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(TextControls {
|
||||
verbosity: Some(OpenAiVerbosity::Low),
|
||||
format: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert_eq!(
|
||||
v.get("text")
|
||||
.and_then(|t| t.get("verbosity"))
|
||||
.and_then(|s| s.as_str()),
|
||||
Some("low")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serializes_text_schema_with_strict_format() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": {"type": "string"}
|
||||
},
|
||||
"required": ["answer"],
|
||||
});
|
||||
let text_controls =
|
||||
create_text_param_for_request(None, &Some(schema.clone())).expect("text controls");
|
||||
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(text_controls),
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
let text = v.get("text").expect("text field");
|
||||
assert!(text.get("verbosity").is_none());
|
||||
let format = text.get("format").expect("format field");
|
||||
|
||||
assert_eq!(
|
||||
format.get("name"),
|
||||
Some(&serde_json::Value::String("codex_output_schema".into()))
|
||||
);
|
||||
assert_eq!(
|
||||
format.get("type"),
|
||||
Some(&serde_json::Value::String("json_schema".into()))
|
||||
);
|
||||
assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true)));
|
||||
assert_eq!(format.get("schema"), Some(&schema));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn omits_text_when_not_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: None,
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert!(v.get("text").is_none());
|
||||
}
|
||||
}
|
||||
3149
llmx-rs/core/src/codex.rs
Normal file
3149
llmx-rs/core/src/codex.rs
Normal file
File diff suppressed because it is too large
Load Diff
39
llmx-rs/core/src/codex_conversation.rs
Normal file
39
llmx-rs/core/src/codex_conversation.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use crate::codex::Codex;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::Submission;
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub struct CodexConversation {
|
||||
codex: Codex,
|
||||
rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
/// Conduit for the bidirectional stream of messages that compose a conversation
|
||||
/// in Codex.
|
||||
impl CodexConversation {
|
||||
pub(crate) fn new(codex: Codex, rollout_path: PathBuf) -> Self {
|
||||
Self {
|
||||
codex,
|
||||
rollout_path,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn submit(&self, op: Op) -> CodexResult<String> {
|
||||
self.codex.submit(op).await
|
||||
}
|
||||
|
||||
/// Use sparingly: this is intended to be removed soon.
|
||||
pub async fn submit_with_id(&self, sub: Submission) -> CodexResult<()> {
|
||||
self.codex.submit_with_id(sub).await
|
||||
}
|
||||
|
||||
pub async fn next_event(&self) -> CodexResult<Event> {
|
||||
self.codex.next_event().await
|
||||
}
|
||||
|
||||
pub fn rollout_path(&self) -> PathBuf {
|
||||
self.rollout_path.clone()
|
||||
}
|
||||
}
|
||||
300
llmx-rs/core/src/codex_delegate.rs
Normal file
300
llmx-rs/core/src/codex_delegate.rs
Normal file
@@ -0,0 +1,300 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_protocol::protocol::ApplyPatchApprovalRequestEvent;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ExecApprovalRequestEvent;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::Submission;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::CodexSpawnOk;
|
||||
use crate::codex::SUBMISSION_CHANNEL_CAPACITY;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::config::Config;
|
||||
use crate::error::CodexErr;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
|
||||
/// Start an interactive sub-Codex conversation and return IO channels.
|
||||
///
|
||||
/// The returned `events_rx` yields non-approval events emitted by the sub-agent.
|
||||
/// Approval requests are handled via `parent_session` and are not surfaced.
|
||||
/// The returned `ops_tx` allows the caller to submit additional `Op`s to the sub-agent.
|
||||
pub(crate) async fn run_codex_conversation_interactive(
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
parent_session: Arc<Session>,
|
||||
parent_ctx: Arc<TurnContext>,
|
||||
cancel_token: CancellationToken,
|
||||
initial_history: Option<InitialHistory>,
|
||||
) -> Result<Codex, CodexErr> {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let (tx_ops, rx_ops) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
|
||||
let CodexSpawnOk { codex, .. } = Codex::spawn(
|
||||
config,
|
||||
auth_manager,
|
||||
initial_history.unwrap_or(InitialHistory::New),
|
||||
SessionSource::SubAgent(SubAgentSource::Review),
|
||||
)
|
||||
.await?;
|
||||
let codex = Arc::new(codex);
|
||||
|
||||
// Use a child token so parent cancel cascades but we can scope it to this task
|
||||
let cancel_token_events = cancel_token.child_token();
|
||||
let cancel_token_ops = cancel_token.child_token();
|
||||
|
||||
// Forward events from the sub-agent to the consumer, filtering approvals and
|
||||
// routing them to the parent session for decisions.
|
||||
let parent_session_clone = Arc::clone(&parent_session);
|
||||
let parent_ctx_clone = Arc::clone(&parent_ctx);
|
||||
let codex_for_events = Arc::clone(&codex);
|
||||
tokio::spawn(async move {
|
||||
let _ = forward_events(
|
||||
codex_for_events,
|
||||
tx_sub,
|
||||
parent_session_clone,
|
||||
parent_ctx_clone,
|
||||
cancel_token_events.clone(),
|
||||
)
|
||||
.or_cancel(&cancel_token_events)
|
||||
.await;
|
||||
});
|
||||
|
||||
// Forward ops from the caller to the sub-agent.
|
||||
let codex_for_ops = Arc::clone(&codex);
|
||||
tokio::spawn(async move {
|
||||
forward_ops(codex_for_ops, rx_ops, cancel_token_ops).await;
|
||||
});
|
||||
|
||||
Ok(Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub: tx_ops,
|
||||
rx_event: rx_sub,
|
||||
})
|
||||
}
|
||||
|
||||
/// Convenience wrapper for one-time use with an initial prompt.
|
||||
///
|
||||
/// Internally calls the interactive variant, then immediately submits the provided input.
|
||||
pub(crate) async fn run_codex_conversation_one_shot(
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
input: Vec<UserInput>,
|
||||
parent_session: Arc<Session>,
|
||||
parent_ctx: Arc<TurnContext>,
|
||||
cancel_token: CancellationToken,
|
||||
initial_history: Option<InitialHistory>,
|
||||
) -> Result<Codex, CodexErr> {
|
||||
// Use a child token so we can stop the delegate after completion without
|
||||
// requiring the caller to cancel the parent token.
|
||||
let child_cancel = cancel_token.child_token();
|
||||
let io = run_codex_conversation_interactive(
|
||||
config,
|
||||
auth_manager,
|
||||
parent_session,
|
||||
parent_ctx,
|
||||
child_cancel.clone(),
|
||||
initial_history,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Send the initial input to kick off the one-shot turn.
|
||||
io.submit(Op::UserInput { items: input }).await?;
|
||||
|
||||
// Bridge events so we can observe completion and shut down automatically.
|
||||
let (tx_bridge, rx_bridge) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
let ops_tx = io.tx_sub.clone();
|
||||
let io_for_bridge = io;
|
||||
tokio::spawn(async move {
|
||||
while let Ok(event) = io_for_bridge.next_event().await {
|
||||
let should_shutdown = matches!(
|
||||
event.msg,
|
||||
EventMsg::TaskComplete(_) | EventMsg::TurnAborted(_)
|
||||
);
|
||||
let _ = tx_bridge.send(event).await;
|
||||
if should_shutdown {
|
||||
let _ = ops_tx
|
||||
.send(Submission {
|
||||
id: "shutdown".to_string(),
|
||||
op: Op::Shutdown {},
|
||||
})
|
||||
.await;
|
||||
child_cancel.cancel();
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// For one-shot usage, return a closed `tx_sub` so callers cannot submit
|
||||
// additional ops after the initial request. Create a channel and drop the
|
||||
// receiver to close it immediately.
|
||||
let (tx_closed, rx_closed) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
|
||||
drop(rx_closed);
|
||||
|
||||
Ok(Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
rx_event: rx_bridge,
|
||||
tx_sub: tx_closed,
|
||||
})
|
||||
}
|
||||
|
||||
async fn forward_events(
|
||||
codex: Arc<Codex>,
|
||||
tx_sub: Sender<Event>,
|
||||
parent_session: Arc<Session>,
|
||||
parent_ctx: Arc<TurnContext>,
|
||||
cancel_token: CancellationToken,
|
||||
) {
|
||||
while let Ok(event) = codex.next_event().await {
|
||||
match event {
|
||||
// ignore all legacy delta events
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id: _,
|
||||
msg: EventMsg::SessionConfigured(_),
|
||||
} => continue,
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ExecApprovalRequest(event),
|
||||
} => {
|
||||
// Initiate approval via parent session; do not surface to consumer.
|
||||
handle_exec_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::ApplyPatchApprovalRequest(event),
|
||||
} => {
|
||||
handle_patch_approval(
|
||||
&codex,
|
||||
id,
|
||||
&parent_session,
|
||||
&parent_ctx,
|
||||
event,
|
||||
&cancel_token,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
other => {
|
||||
let _ = tx_sub.send(other).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Forward ops from a caller to a sub-agent, respecting cancellation.
|
||||
async fn forward_ops(
|
||||
codex: Arc<Codex>,
|
||||
rx_ops: Receiver<Submission>,
|
||||
cancel_token_ops: CancellationToken,
|
||||
) {
|
||||
loop {
|
||||
let op: Op = match rx_ops.recv().or_cancel(&cancel_token_ops).await {
|
||||
Ok(Ok(Submission { id: _, op })) => op,
|
||||
Ok(Err(_)) | Err(_) => break,
|
||||
};
|
||||
let _ = codex.submit(op).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle an ExecApprovalRequest by consulting the parent session and replying.
|
||||
async fn handle_exec_approval(
|
||||
codex: &Codex,
|
||||
id: String,
|
||||
parent_session: &Session,
|
||||
parent_ctx: &TurnContext,
|
||||
event: ExecApprovalRequestEvent,
|
||||
cancel_token: &CancellationToken,
|
||||
) {
|
||||
// Race approval with cancellation and timeout to avoid hangs.
|
||||
let approval_fut = parent_session.request_command_approval(
|
||||
parent_ctx,
|
||||
parent_ctx.sub_id.clone(),
|
||||
event.command,
|
||||
event.cwd,
|
||||
event.reason,
|
||||
event.risk,
|
||||
);
|
||||
let decision = await_approval_with_cancel(
|
||||
approval_fut,
|
||||
parent_session,
|
||||
&parent_ctx.sub_id,
|
||||
cancel_token,
|
||||
)
|
||||
.await;
|
||||
|
||||
let _ = codex.submit(Op::ExecApproval { id, decision }).await;
|
||||
}
|
||||
|
||||
/// Handle an ApplyPatchApprovalRequest by consulting the parent session and replying.
|
||||
async fn handle_patch_approval(
|
||||
codex: &Codex,
|
||||
id: String,
|
||||
parent_session: &Session,
|
||||
parent_ctx: &TurnContext,
|
||||
event: ApplyPatchApprovalRequestEvent,
|
||||
cancel_token: &CancellationToken,
|
||||
) {
|
||||
let decision_rx = parent_session
|
||||
.request_patch_approval(
|
||||
parent_ctx,
|
||||
parent_ctx.sub_id.clone(),
|
||||
event.changes,
|
||||
event.reason,
|
||||
event.grant_root,
|
||||
)
|
||||
.await;
|
||||
let decision = await_approval_with_cancel(
|
||||
async move { decision_rx.await.unwrap_or_default() },
|
||||
parent_session,
|
||||
&parent_ctx.sub_id,
|
||||
cancel_token,
|
||||
)
|
||||
.await;
|
||||
let _ = codex.submit(Op::PatchApproval { id, decision }).await;
|
||||
}
|
||||
|
||||
/// Await an approval decision, aborting on cancellation.
|
||||
async fn await_approval_with_cancel<F>(
|
||||
fut: F,
|
||||
parent_session: &Session,
|
||||
sub_id: &str,
|
||||
cancel_token: &CancellationToken,
|
||||
) -> codex_protocol::protocol::ReviewDecision
|
||||
where
|
||||
F: core::future::Future<Output = codex_protocol::protocol::ReviewDecision>,
|
||||
{
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = cancel_token.cancelled() => {
|
||||
parent_session
|
||||
.notify_approval(sub_id, codex_protocol::protocol::ReviewDecision::Abort)
|
||||
.await;
|
||||
codex_protocol::protocol::ReviewDecision::Abort
|
||||
}
|
||||
decision = fut => {
|
||||
decision
|
||||
}
|
||||
}
|
||||
}
|
||||
142
llmx-rs/core/src/command_safety/is_dangerous_command.rs
Normal file
142
llmx-rs/core/src/command_safety/is_dangerous_command.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
|
||||
use crate::bash::parse_shell_lc_plain_commands;
|
||||
use crate::is_safe_command::is_known_safe_command;
|
||||
|
||||
pub fn requires_initial_appoval(
|
||||
policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
command: &[String],
|
||||
with_escalated_permissions: bool,
|
||||
) -> bool {
|
||||
if is_known_safe_command(command) {
|
||||
return false;
|
||||
}
|
||||
match policy {
|
||||
AskForApproval::Never | AskForApproval::OnFailure => false,
|
||||
AskForApproval::OnRequest => {
|
||||
// In DangerFullAccess, only prompt if the command looks dangerous.
|
||||
if matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) {
|
||||
return command_might_be_dangerous(command);
|
||||
}
|
||||
|
||||
// In restricted sandboxes (ReadOnly/WorkspaceWrite), do not prompt for
|
||||
// non‑escalated, non‑dangerous commands — let the sandbox enforce
|
||||
// restrictions (e.g., block network/write) without a user prompt.
|
||||
let wants_escalation: bool = with_escalated_permissions;
|
||||
if wants_escalation {
|
||||
return true;
|
||||
}
|
||||
command_might_be_dangerous(command)
|
||||
}
|
||||
AskForApproval::UnlessTrusted => !is_known_safe_command(command),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn command_might_be_dangerous(command: &[String]) -> bool {
|
||||
if is_dangerous_to_call_with_exec(command) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Support `bash -lc "<script>"` where the any part of the script might contain a dangerous command.
|
||||
if let Some(all_commands) = parse_shell_lc_plain_commands(command)
|
||||
&& all_commands
|
||||
.iter()
|
||||
.any(|cmd| is_dangerous_to_call_with_exec(cmd))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn is_dangerous_to_call_with_exec(command: &[String]) -> bool {
|
||||
let cmd0 = command.first().map(String::as_str);
|
||||
|
||||
match cmd0 {
|
||||
Some(cmd) if cmd.ends_with("git") || cmd.ends_with("/git") => {
|
||||
matches!(command.get(1).map(String::as_str), Some("reset" | "rm"))
|
||||
}
|
||||
|
||||
Some("rm") => matches!(command.get(1).map(String::as_str), Some("-f" | "-rf")),
|
||||
|
||||
// for sudo <cmd> simply do the check for <cmd>
|
||||
Some("sudo") => is_dangerous_to_call_with_exec(&command[1..]),
|
||||
|
||||
// ── anything else ─────────────────────────────────────────────────
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn vec_str(items: &[&str]) -> Vec<String> {
|
||||
items.iter().map(std::string::ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["git", "reset"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"git reset --hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zsh_git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"zsh",
|
||||
"-lc",
|
||||
"git reset --hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn git_status_is_not_dangerous() {
|
||||
assert!(!command_might_be_dangerous(&vec_str(&["git", "status"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_git_status_is_not_dangerous() {
|
||||
assert!(!command_might_be_dangerous(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"git status"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sudo_git_reset_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"sudo", "git", "reset", "--hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usr_bin_git_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&[
|
||||
"/usr/bin/git",
|
||||
"reset",
|
||||
"--hard"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rm_rf_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["rm", "-rf", "/"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rm_f_is_dangerous() {
|
||||
assert!(command_might_be_dangerous(&vec_str(&["rm", "-f", "/"])));
|
||||
}
|
||||
}
|
||||
366
llmx-rs/core/src/command_safety/is_safe_command.rs
Normal file
366
llmx-rs/core/src/command_safety/is_safe_command.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
use crate::bash::parse_shell_lc_plain_commands;
|
||||
|
||||
pub fn is_known_safe_command(command: &[String]) -> bool {
|
||||
let command: Vec<String> = command
|
||||
.iter()
|
||||
.map(|s| {
|
||||
if s == "zsh" {
|
||||
"bash".to_string()
|
||||
} else {
|
||||
s.clone()
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
use super::windows_safe_commands::is_safe_command_windows;
|
||||
if is_safe_command_windows(&command) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if is_safe_to_call_with_exec(&command) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Support `bash -lc "..."` where the script consists solely of one or
|
||||
// more "plain" commands (only bare words / quoted strings) combined with
|
||||
// a conservative allow‑list of shell operators that themselves do not
|
||||
// introduce side effects ( "&&", "||", ";", and "|" ). If every
|
||||
// individual command in the script is itself a known‑safe command, then
|
||||
// the composite expression is considered safe.
|
||||
if let Some(all_commands) = parse_shell_lc_plain_commands(&command)
|
||||
&& !all_commands.is_empty()
|
||||
&& all_commands
|
||||
.iter()
|
||||
.all(|cmd| is_safe_to_call_with_exec(cmd))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
fn is_safe_to_call_with_exec(command: &[String]) -> bool {
|
||||
let Some(cmd0) = command.first().map(String::as_str) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
match std::path::Path::new(&cmd0)
|
||||
.file_name()
|
||||
.and_then(|osstr| osstr.to_str())
|
||||
{
|
||||
#[rustfmt::skip]
|
||||
Some(
|
||||
"cat" |
|
||||
"cd" |
|
||||
"echo" |
|
||||
"false" |
|
||||
"grep" |
|
||||
"head" |
|
||||
"ls" |
|
||||
"nl" |
|
||||
"pwd" |
|
||||
"tail" |
|
||||
"true" |
|
||||
"wc" |
|
||||
"which") => {
|
||||
true
|
||||
},
|
||||
|
||||
Some("find") => {
|
||||
// Certain options to `find` can delete files, write to files, or
|
||||
// execute arbitrary commands, so we cannot auto-approve the
|
||||
// invocation of `find` in such cases.
|
||||
#[rustfmt::skip]
|
||||
const UNSAFE_FIND_OPTIONS: &[&str] = &[
|
||||
// Options that can execute arbitrary commands.
|
||||
"-exec", "-execdir", "-ok", "-okdir",
|
||||
// Option that deletes matching files.
|
||||
"-delete",
|
||||
// Options that write pathnames to a file.
|
||||
"-fls", "-fprint", "-fprint0", "-fprintf",
|
||||
];
|
||||
|
||||
!command
|
||||
.iter()
|
||||
.any(|arg| UNSAFE_FIND_OPTIONS.contains(&arg.as_str()))
|
||||
}
|
||||
|
||||
// Ripgrep
|
||||
Some("rg") => {
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &[
|
||||
// Takes an arbitrary command that is executed for each match.
|
||||
"--pre",
|
||||
// Takes a command that can be used to obtain the local hostname.
|
||||
"--hostname-bin",
|
||||
];
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &[
|
||||
// Calls out to other decompression tools, so do not auto-approve
|
||||
// out of an abundance of caution.
|
||||
"--search-zip",
|
||||
"-z",
|
||||
];
|
||||
|
||||
!command.iter().any(|arg| {
|
||||
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg.as_str())
|
||||
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
|
||||
.iter()
|
||||
.any(|&opt| arg == opt || arg.starts_with(&format!("{opt}=")))
|
||||
})
|
||||
}
|
||||
|
||||
// Git
|
||||
Some("git") => matches!(
|
||||
command.get(1).map(String::as_str),
|
||||
Some("branch" | "status" | "log" | "diff" | "show")
|
||||
),
|
||||
|
||||
// Rust
|
||||
Some("cargo") if command.get(1).map(String::as_str) == Some("check") => true,
|
||||
|
||||
// Special-case `sed -n {N|M,N}p`
|
||||
Some("sed")
|
||||
if {
|
||||
command.len() <= 4
|
||||
&& command.get(1).map(String::as_str) == Some("-n")
|
||||
&& is_valid_sed_n_arg(command.get(2).map(String::as_str))
|
||||
} =>
|
||||
{
|
||||
true
|
||||
}
|
||||
|
||||
// ── anything else ─────────────────────────────────────────────────
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
// (bash parsing helpers implemented in crate::bash)
|
||||
|
||||
/* ----------------------------------------------------------
|
||||
Example
|
||||
---------------------------------------------------------- */
|
||||
|
||||
/// Returns true if `arg` matches /^(\d+,)?\d+p$/
|
||||
fn is_valid_sed_n_arg(arg: Option<&str>) -> bool {
|
||||
// unwrap or bail
|
||||
let s = match arg {
|
||||
Some(s) => s,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
// must end with 'p', strip it
|
||||
let core = match s.strip_suffix('p') {
|
||||
Some(rest) => rest,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
// split on ',' and ensure 1 or 2 numeric parts
|
||||
let parts: Vec<&str> = core.split(',').collect();
|
||||
match parts.as_slice() {
|
||||
// single number, e.g. "10"
|
||||
[num] => !num.is_empty() && num.chars().all(|c| c.is_ascii_digit()),
|
||||
|
||||
// two numbers, e.g. "1,5"
|
||||
[a, b] => {
|
||||
!a.is_empty()
|
||||
&& !b.is_empty()
|
||||
&& a.chars().all(|c| c.is_ascii_digit())
|
||||
&& b.chars().all(|c| c.is_ascii_digit())
|
||||
}
|
||||
|
||||
// anything else (more than one comma) is invalid
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::string::ToString;
|
||||
|
||||
fn vec_str(args: &[&str]) -> Vec<String> {
|
||||
args.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn known_safe_examples() {
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&["ls"])));
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&["git", "status"])));
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"sed", "-n", "1,5p", "file.txt"
|
||||
])));
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"nl",
|
||||
"-nrz",
|
||||
"Cargo.toml"
|
||||
])));
|
||||
|
||||
// Safe `find` command (no unsafe options).
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"find", ".", "-name", "file.txt"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zsh_lc_safe_command_sequence() {
|
||||
assert!(is_known_safe_command(&vec_str(&["zsh", "-lc", "ls"])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_or_partial() {
|
||||
assert!(!is_safe_to_call_with_exec(&vec_str(&["foo"])));
|
||||
assert!(!is_safe_to_call_with_exec(&vec_str(&["git", "fetch"])));
|
||||
assert!(!is_safe_to_call_with_exec(&vec_str(&[
|
||||
"sed", "-n", "xp", "file.txt"
|
||||
])));
|
||||
|
||||
// Unsafe `find` commands.
|
||||
for args in [
|
||||
vec_str(&["find", ".", "-name", "file.txt", "-exec", "rm", "{}", ";"]),
|
||||
vec_str(&[
|
||||
"find", ".", "-name", "*.py", "-execdir", "python3", "{}", ";",
|
||||
]),
|
||||
vec_str(&["find", ".", "-name", "file.txt", "-ok", "rm", "{}", ";"]),
|
||||
vec_str(&["find", ".", "-name", "*.py", "-okdir", "python3", "{}", ";"]),
|
||||
vec_str(&["find", ".", "-delete", "-name", "file.txt"]),
|
||||
vec_str(&["find", ".", "-fls", "/etc/passwd"]),
|
||||
vec_str(&["find", ".", "-fprint", "/etc/passwd"]),
|
||||
vec_str(&["find", ".", "-fprint0", "/etc/passwd"]),
|
||||
vec_str(&["find", ".", "-fprintf", "/root/suid.txt", "%#m %u %p\n"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be unsafe"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ripgrep_rules() {
|
||||
// Safe ripgrep invocations – none of the unsafe flags are present.
|
||||
assert!(is_safe_to_call_with_exec(&vec_str(&[
|
||||
"rg",
|
||||
"Cargo.toml",
|
||||
"-n"
|
||||
])));
|
||||
|
||||
// Unsafe flags that do not take an argument (present verbatim).
|
||||
for args in [
|
||||
vec_str(&["rg", "--search-zip", "files"]),
|
||||
vec_str(&["rg", "-z", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to zip-search flag",
|
||||
);
|
||||
}
|
||||
|
||||
// Unsafe flags that expect a value, provided in both split and = forms.
|
||||
for args in [
|
||||
vec_str(&["rg", "--pre", "pwned", "files"]),
|
||||
vec_str(&["rg", "--pre=pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin", "pwned", "files"]),
|
||||
vec_str(&["rg", "--hostname-bin=pwned", "files"]),
|
||||
] {
|
||||
assert!(
|
||||
!is_safe_to_call_with_exec(&args),
|
||||
"expected {args:?} to be considered unsafe due to external-command flag",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_safe_examples() {
|
||||
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"])));
|
||||
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls -1"])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"git status"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"grep -R \"Cargo.toml\" -n"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"sed -n 1,5p file.txt"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"sed -n '1,5p' file.txt"
|
||||
])));
|
||||
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"find . -name file.txt"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_safe_examples_with_operators() {
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"grep -R \"Cargo.toml\" -n || true"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"ls && pwd"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"echo 'hi' ; ls"
|
||||
])));
|
||||
assert!(is_known_safe_command(&vec_str(&[
|
||||
"bash",
|
||||
"-lc",
|
||||
"ls | wc -l"
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_lc_unsafe_examples() {
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "git", "status"])),
|
||||
"Four arg version is not known to be safe."
|
||||
);
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "'git status'"])),
|
||||
"The extra quoting around 'git status' makes it a program named 'git status' and is therefore unsafe."
|
||||
);
|
||||
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "find . -name file.txt -delete"])),
|
||||
"Unsafe find option should not be auto-approved."
|
||||
);
|
||||
|
||||
// Disallowed because of unsafe command in sequence.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls && rm -rf /"])),
|
||||
"Sequence containing unsafe command must be rejected"
|
||||
);
|
||||
|
||||
// Disallowed because of parentheses / subshell.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "(ls)"])),
|
||||
"Parentheses (subshell) are not provably safe with the current parser"
|
||||
);
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls || (pwd && echo hi)"])),
|
||||
"Nested parentheses are not provably safe with the current parser"
|
||||
);
|
||||
|
||||
// Disallowed redirection.
|
||||
assert!(
|
||||
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls > out.txt"])),
|
||||
"> redirection should be rejected"
|
||||
);
|
||||
}
|
||||
}
|
||||
4
llmx-rs/core/src/command_safety/mod.rs
Normal file
4
llmx-rs/core/src/command_safety/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub mod is_dangerous_command;
|
||||
pub mod is_safe_command;
|
||||
#[cfg(target_os = "windows")]
|
||||
pub mod windows_safe_commands;
|
||||
431
llmx-rs/core/src/command_safety/windows_safe_commands.rs
Normal file
431
llmx-rs/core/src/command_safety/windows_safe_commands.rs
Normal file
@@ -0,0 +1,431 @@
|
||||
use shlex::split as shlex_split;
|
||||
|
||||
/// On Windows, we conservatively allow only clearly read-only PowerShell invocations
|
||||
/// that match a small safelist. Anything else (including direct CMD commands) is unsafe.
|
||||
pub fn is_safe_command_windows(command: &[String]) -> bool {
|
||||
if let Some(commands) = try_parse_powershell_command_sequence(command) {
|
||||
return commands
|
||||
.iter()
|
||||
.all(|cmd| is_safe_powershell_command(cmd.as_slice()));
|
||||
}
|
||||
// Only PowerShell invocations are allowed on Windows for now; anything else is unsafe.
|
||||
false
|
||||
}
|
||||
|
||||
/// Returns each command sequence if the invocation starts with a PowerShell binary.
|
||||
/// For example, the tokens from `pwsh Get-ChildItem | Measure-Object` become two sequences.
|
||||
fn try_parse_powershell_command_sequence(command: &[String]) -> Option<Vec<Vec<String>>> {
|
||||
let (exe, rest) = command.split_first()?;
|
||||
if !is_powershell_executable(exe) {
|
||||
return None;
|
||||
}
|
||||
parse_powershell_invocation(rest)
|
||||
}
|
||||
|
||||
/// Parses a PowerShell invocation into discrete command vectors, rejecting unsafe patterns.
|
||||
fn parse_powershell_invocation(args: &[String]) -> Option<Vec<Vec<String>>> {
|
||||
if args.is_empty() {
|
||||
// Examples rejected here: "pwsh" and "powershell.exe" with no additional arguments.
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut idx = 0;
|
||||
while idx < args.len() {
|
||||
let arg = &args[idx];
|
||||
let lower = arg.to_ascii_lowercase();
|
||||
match lower.as_str() {
|
||||
"-command" | "/command" | "-c" => {
|
||||
let script = args.get(idx + 1)?;
|
||||
if idx + 2 != args.len() {
|
||||
// Reject if there is more than one token representing the actual command.
|
||||
// Examples rejected here: "pwsh -Command foo bar" and "powershell -c ls extra".
|
||||
return None;
|
||||
}
|
||||
return parse_powershell_script(script);
|
||||
}
|
||||
_ if lower.starts_with("-command:") || lower.starts_with("/command:") => {
|
||||
if idx + 1 != args.len() {
|
||||
// Reject if there are more tokens after the command itself.
|
||||
// Examples rejected here: "pwsh -Command:dir C:\\" and "powershell /Command:dir C:\\" with trailing args.
|
||||
return None;
|
||||
}
|
||||
let script = arg.split_once(':')?.1;
|
||||
return parse_powershell_script(script);
|
||||
}
|
||||
|
||||
// Benign, no-arg flags we tolerate.
|
||||
"-nologo" | "-noprofile" | "-noninteractive" | "-mta" | "-sta" => {
|
||||
idx += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Explicitly forbidden/opaque or unnecessary for read-only operations.
|
||||
"-encodedcommand" | "-ec" | "-file" | "/file" | "-windowstyle" | "-executionpolicy"
|
||||
| "-workingdirectory" => {
|
||||
// Examples rejected here: "pwsh -EncodedCommand ..." and "powershell -File script.ps1".
|
||||
return None;
|
||||
}
|
||||
|
||||
// Unknown switch → bail conservatively.
|
||||
_ if lower.starts_with('-') => {
|
||||
// Examples rejected here: "pwsh -UnknownFlag" and "powershell -foo bar".
|
||||
return None;
|
||||
}
|
||||
|
||||
// If we hit non-flag tokens, treat the remainder as a command sequence.
|
||||
// This happens if powershell is invoked without -Command, e.g.
|
||||
// ["pwsh", "-NoLogo", "git", "-c", "core.pager=cat", "status"]
|
||||
_ => {
|
||||
return split_into_commands(args[idx..].to_vec());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Examples rejected here: "pwsh" and "powershell.exe -NoLogo" without a script.
|
||||
None
|
||||
}
|
||||
|
||||
/// Tokenizes an inline PowerShell script and delegates to the command splitter.
|
||||
/// Examples of when this is called: pwsh.exe -Command '<script>' or pwsh.exe -Command:<script>
|
||||
fn parse_powershell_script(script: &str) -> Option<Vec<Vec<String>>> {
|
||||
let tokens = shlex_split(script)?;
|
||||
split_into_commands(tokens)
|
||||
}
|
||||
|
||||
/// Splits tokens into pipeline segments while ensuring no unsafe separators slip through.
|
||||
/// e.g. Get-ChildItem | Measure-Object -> [['Get-ChildItem'], ['Measure-Object']]
|
||||
fn split_into_commands(tokens: Vec<String>) -> Option<Vec<Vec<String>>> {
|
||||
if tokens.is_empty() {
|
||||
// Examples rejected here: "pwsh -Command ''" and "powershell -Command \"\"".
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut commands = Vec::new();
|
||||
let mut current = Vec::new();
|
||||
for token in tokens.into_iter() {
|
||||
match token.as_str() {
|
||||
"|" | "||" | "&&" | ";" => {
|
||||
if current.is_empty() {
|
||||
// Examples rejected here: "pwsh -Command '| Get-ChildItem'" and "pwsh -Command '; dir'".
|
||||
return None;
|
||||
}
|
||||
commands.push(current);
|
||||
current = Vec::new();
|
||||
}
|
||||
// Reject if any token embeds separators, redirection, or call operator characters.
|
||||
_ if token.contains(['|', ';', '>', '<', '&']) || token.contains("$(") => {
|
||||
// Examples rejected here: "pwsh -Command 'dir|select'" and "pwsh -Command 'echo hi > out.txt'".
|
||||
return None;
|
||||
}
|
||||
_ => current.push(token),
|
||||
}
|
||||
}
|
||||
|
||||
if current.is_empty() {
|
||||
// Examples rejected here: "pwsh -Command 'dir |'" and "pwsh -Command 'Get-ChildItem ;'".
|
||||
return None;
|
||||
}
|
||||
commands.push(current);
|
||||
Some(commands)
|
||||
}
|
||||
|
||||
/// Returns true when the executable name is one of the supported PowerShell binaries.
|
||||
fn is_powershell_executable(exe: &str) -> bool {
|
||||
matches!(
|
||||
exe.to_ascii_lowercase().as_str(),
|
||||
"powershell" | "powershell.exe" | "pwsh" | "pwsh.exe"
|
||||
)
|
||||
}
|
||||
|
||||
/// Validates that a parsed PowerShell command stays within our read-only safelist.
|
||||
/// Everything before this is parsing, and rejecting things that make us feel uncomfortable.
|
||||
fn is_safe_powershell_command(words: &[String]) -> bool {
|
||||
if words.is_empty() {
|
||||
// Examples rejected here: "pwsh -Command ''" and "pwsh -Command \"\"".
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reject nested unsafe cmdlets inside parentheses or arguments
|
||||
for w in words.iter() {
|
||||
let inner = w
|
||||
.trim_matches(|c| c == '(' || c == ')')
|
||||
.trim_start_matches('-')
|
||||
.to_ascii_lowercase();
|
||||
if matches!(
|
||||
inner.as_str(),
|
||||
"set-content"
|
||||
| "add-content"
|
||||
| "out-file"
|
||||
| "new-item"
|
||||
| "remove-item"
|
||||
| "move-item"
|
||||
| "copy-item"
|
||||
| "rename-item"
|
||||
| "start-process"
|
||||
| "stop-process"
|
||||
) {
|
||||
// Examples rejected here: "Write-Output (Set-Content foo6.txt 'abc')" and "Get-Content (New-Item bar.txt)".
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Block PowerShell call operator or any redirection explicitly.
|
||||
if words.iter().any(|w| {
|
||||
matches!(
|
||||
w.as_str(),
|
||||
"&" | ">" | ">>" | "1>" | "2>" | "2>&1" | "*>" | "<" | "<<"
|
||||
)
|
||||
}) {
|
||||
// Examples rejected here: "pwsh -Command '& Remove-Item foo'" and "pwsh -Command 'Get-Content foo > bar'".
|
||||
return false;
|
||||
}
|
||||
|
||||
let command = words[0]
|
||||
.trim_matches(|c| c == '(' || c == ')')
|
||||
.trim_start_matches('-')
|
||||
.to_ascii_lowercase();
|
||||
match command.as_str() {
|
||||
"echo" | "write-output" | "write-host" => true, // (no redirection allowed)
|
||||
"dir" | "ls" | "get-childitem" | "gci" => true,
|
||||
"cat" | "type" | "gc" | "get-content" => true,
|
||||
"select-string" | "sls" | "findstr" => true,
|
||||
"measure-object" | "measure" => true,
|
||||
"get-location" | "gl" | "pwd" => true,
|
||||
"test-path" | "tp" => true,
|
||||
"resolve-path" | "rvpa" => true,
|
||||
"select-object" | "select" => true,
|
||||
"get-item" => true,
|
||||
|
||||
"git" => is_safe_git_command(words),
|
||||
|
||||
"rg" => is_safe_ripgrep(words),
|
||||
|
||||
// Extra safety: explicitly prohibit common side-effecting cmdlets regardless of args.
|
||||
"set-content" | "add-content" | "out-file" | "new-item" | "remove-item" | "move-item"
|
||||
| "copy-item" | "rename-item" | "start-process" | "stop-process" => {
|
||||
// Examples rejected here: "pwsh -Command 'Set-Content notes.txt data'" and "pwsh -Command 'Remove-Item temp.log'".
|
||||
false
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Examples rejected here: "pwsh -Command 'Invoke-WebRequest https://example.com'" and "pwsh -Command 'Start-Service Spooler'".
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks that an `rg` invocation avoids options that can spawn arbitrary executables.
|
||||
fn is_safe_ripgrep(words: &[String]) -> bool {
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &["--pre", "--hostname-bin"];
|
||||
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &["--search-zip", "-z"];
|
||||
|
||||
!words.iter().skip(1).any(|arg| {
|
||||
let arg_lc = arg.to_ascii_lowercase();
|
||||
// Examples rejected here: "pwsh -Command 'rg --pre cat pattern'" and "pwsh -Command 'rg --search-zip pattern'".
|
||||
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg_lc.as_str())
|
||||
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
|
||||
.iter()
|
||||
.any(|opt| arg_lc == *opt || arg_lc.starts_with(&format!("{opt}=")))
|
||||
})
|
||||
}
|
||||
|
||||
/// Ensures a Git command sticks to whitelisted read-only subcommands and flags.
|
||||
fn is_safe_git_command(words: &[String]) -> bool {
|
||||
const SAFE_SUBCOMMANDS: &[&str] = &["status", "log", "show", "diff", "cat-file"];
|
||||
|
||||
let mut iter = words.iter().skip(1);
|
||||
while let Some(arg) = iter.next() {
|
||||
let arg_lc = arg.to_ascii_lowercase();
|
||||
|
||||
if arg.starts_with('-') {
|
||||
if arg.eq_ignore_ascii_case("-c") || arg.eq_ignore_ascii_case("--config") {
|
||||
if iter.next().is_none() {
|
||||
// Examples rejected here: "pwsh -Command 'git -c'" and "pwsh -Command 'git --config'".
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if arg_lc.starts_with("-c=")
|
||||
|| arg_lc.starts_with("--config=")
|
||||
|| arg_lc.starts_with("--git-dir=")
|
||||
|| arg_lc.starts_with("--work-tree=")
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if arg.eq_ignore_ascii_case("--git-dir") || arg.eq_ignore_ascii_case("--work-tree") {
|
||||
if iter.next().is_none() {
|
||||
// Examples rejected here: "pwsh -Command 'git --git-dir'" and "pwsh -Command 'git --work-tree'".
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
return SAFE_SUBCOMMANDS.contains(&arg_lc.as_str());
|
||||
}
|
||||
|
||||
// Examples rejected here: "pwsh -Command 'git'" and "pwsh -Command 'git status --short | Remove-Item foo'".
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::is_safe_command_windows;
|
||||
use std::string::ToString;
|
||||
|
||||
/// Converts a slice of string literals into owned `String`s for the tests.
|
||||
fn vec_str(args: &[&str]) -> Vec<String> {
|
||||
args.iter().map(ToString::to_string).collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recognizes_safe_powershell_wrappers() {
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-NoLogo",
|
||||
"-Command",
|
||||
"Get-ChildItem -Path .",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"git status",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"Get-Content",
|
||||
"Cargo.toml",
|
||||
])));
|
||||
|
||||
// pwsh parity
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"Get-ChildItem",
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allows_read_only_pipelines_and_git_usage() {
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-NoLogo",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"rg --files-with-matches foo | Measure-Object | Select-Object -ExpandProperty Count",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-NoLogo",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"Get-Content foo.rs | Select-Object -Skip 200",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-NoLogo",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"git -c core.pager=cat show HEAD:foo.rs",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-Command",
|
||||
"-git cat-file -p HEAD:foo.rs",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-Command",
|
||||
"(Get-Content foo.rs -Raw)",
|
||||
])));
|
||||
|
||||
assert!(is_safe_command_windows(&vec_str(&[
|
||||
"pwsh",
|
||||
"-Command",
|
||||
"Get-Item foo.rs | Select-Object Length",
|
||||
])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_powershell_commands_with_side_effects() {
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-NoLogo",
|
||||
"-Command",
|
||||
"Remove-Item foo.txt",
|
||||
])));
|
||||
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-NoProfile",
|
||||
"-Command",
|
||||
"rg --pre cat",
|
||||
])));
|
||||
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Set-Content foo.txt 'hello'",
|
||||
])));
|
||||
|
||||
// Redirections are blocked
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"echo hi > out.txt",
|
||||
])));
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Get-Content x | Out-File y",
|
||||
])));
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Write-Output foo 2> err.txt",
|
||||
])));
|
||||
|
||||
// Call operator is blocked
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"& Remove-Item foo",
|
||||
])));
|
||||
|
||||
// Chained safe + unsafe must fail
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Get-ChildItem; Remove-Item foo",
|
||||
])));
|
||||
// Nested unsafe cmdlet inside safe command must fail
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Write-Output (Set-Content foo6.txt 'abc')",
|
||||
])));
|
||||
// Additional nested unsafe cmdlet examples must fail
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Write-Host (Remove-Item foo.txt)",
|
||||
])));
|
||||
assert!(!is_safe_command_windows(&vec_str(&[
|
||||
"powershell.exe",
|
||||
"-Command",
|
||||
"Get-Content (New-Item bar.txt)",
|
||||
])));
|
||||
}
|
||||
}
|
||||
451
llmx-rs/core/src/compact.rs
Normal file
451
llmx-rs/core/src/compact.rs
Normal file
@@ -0,0 +1,451 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::get_last_assistant_message_from_turn;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::AgentMessageEvent;
|
||||
use crate::protocol::CompactedItem;
|
||||
use crate::protocol::ErrorEvent;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::protocol::TurnContextItem;
|
||||
use crate::protocol::WarningEvent;
|
||||
use crate::truncate::truncate_middle;
|
||||
use crate::util::backoff;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use futures::prelude::*;
|
||||
use tracing::error;
|
||||
|
||||
pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md");
|
||||
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
|
||||
|
||||
pub(crate) async fn run_inline_auto_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
) {
|
||||
let prompt = turn_context.compact_prompt().to_string();
|
||||
let input = vec![UserInput::Text { text: prompt }];
|
||||
run_compact_task_inner(sess, turn_context, input).await;
|
||||
}
|
||||
|
||||
pub(crate) async fn run_compact_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
) -> Option<String> {
|
||||
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
sess.send_event(&turn_context, start_event).await;
|
||||
run_compact_task_inner(sess.clone(), turn_context, input).await;
|
||||
None
|
||||
}
|
||||
|
||||
async fn run_compact_task_inner(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
) {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
history.record_items(&[initial_input_for_turn.into()]);
|
||||
|
||||
let mut truncated_count = 0usize;
|
||||
|
||||
let max_retries = turn_context.client.get_provider().stream_max_retries();
|
||||
let mut retries = 0;
|
||||
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
cwd: turn_context.cwd.clone(),
|
||||
approval_policy: turn_context.approval_policy,
|
||||
sandbox_policy: turn_context.sandbox_policy.clone(),
|
||||
model: turn_context.client.get_model(),
|
||||
effort: turn_context.client.get_reasoning_effort(),
|
||||
summary: turn_context.client.get_reasoning_summary(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
loop {
|
||||
let turn_input = history.get_history_for_prompt();
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
if truncated_count > 0 {
|
||||
sess.notify_background_event(
|
||||
turn_context.as_ref(),
|
||||
format!(
|
||||
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
|
||||
),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Err(CodexErr::Interrupted) => {
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
||||
error!(
|
||||
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
||||
);
|
||||
history.remove_first_item();
|
||||
truncated_count += 1;
|
||||
retries = 0;
|
||||
continue;
|
||||
}
|
||||
sess.set_total_tokens_full(turn_context.as_ref()).await;
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
if retries < max_retries {
|
||||
retries += 1;
|
||||
let delay = backoff(retries);
|
||||
sess.notify_stream_error(
|
||||
turn_context.as_ref(),
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
)
|
||||
.await;
|
||||
tokio::time::sleep(delay).await;
|
||||
continue;
|
||||
} else {
|
||||
let event = EventMsg::Error(ErrorEvent {
|
||||
message: e.to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let history_snapshot = sess.clone_history().await.get_history();
|
||||
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
|
||||
let user_messages = collect_user_messages(&history_snapshot);
|
||||
|
||||
let initial_context = sess.build_initial_context(turn_context.as_ref());
|
||||
let mut new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
|
||||
let ghost_snapshots: Vec<ResponseItem> = history_snapshot
|
||||
.iter()
|
||||
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
|
||||
.cloned()
|
||||
.collect();
|
||||
new_history.extend(ghost_snapshots);
|
||||
sess.replace_history(new_history).await;
|
||||
|
||||
let rollout_item = RolloutItem::Compacted(CompactedItem {
|
||||
message: summary_text.clone(),
|
||||
});
|
||||
sess.persist_rollout_items(&[rollout_item]).await;
|
||||
|
||||
let event = EventMsg::AgentMessage(AgentMessageEvent {
|
||||
message: "Compact task completed".to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, event).await;
|
||||
|
||||
let warning = EventMsg::Warning(WarningEvent {
|
||||
message: "Heads up: Long conversations and multiple compactions can cause the model to be less accurate. Start a new conversation when possible to keep conversations small and targeted.".to_string(),
|
||||
});
|
||||
sess.send_event(&turn_context, warning).await;
|
||||
}
|
||||
|
||||
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
|
||||
let mut pieces = Vec::new();
|
||||
for item in content {
|
||||
match item {
|
||||
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
|
||||
if !text.is_empty() {
|
||||
pieces.push(text.as_str());
|
||||
}
|
||||
}
|
||||
ContentItem::InputImage { .. } => {}
|
||||
}
|
||||
}
|
||||
if pieces.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(pieces.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
|
||||
items
|
||||
.iter()
|
||||
.filter_map(|item| match crate::event_mapping::parse_turn_item(item) {
|
||||
Some(TurnItem::UserMessage(user)) => Some(user.message()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) fn build_compacted_history(
|
||||
initial_context: Vec<ResponseItem>,
|
||||
user_messages: &[String],
|
||||
summary_text: &str,
|
||||
) -> Vec<ResponseItem> {
|
||||
build_compacted_history_with_limit(
|
||||
initial_context,
|
||||
user_messages,
|
||||
summary_text,
|
||||
COMPACT_USER_MESSAGE_MAX_TOKENS * 4,
|
||||
)
|
||||
}
|
||||
|
||||
fn build_compacted_history_with_limit(
|
||||
mut history: Vec<ResponseItem>,
|
||||
user_messages: &[String],
|
||||
summary_text: &str,
|
||||
max_bytes: usize,
|
||||
) -> Vec<ResponseItem> {
|
||||
let mut selected_messages: Vec<String> = Vec::new();
|
||||
if max_bytes > 0 {
|
||||
let mut remaining = max_bytes;
|
||||
for message in user_messages.iter().rev() {
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
if message.len() <= remaining {
|
||||
selected_messages.push(message.clone());
|
||||
remaining = remaining.saturating_sub(message.len());
|
||||
} else {
|
||||
let (truncated, _) = truncate_middle(message, remaining);
|
||||
selected_messages.push(truncated);
|
||||
break;
|
||||
}
|
||||
}
|
||||
selected_messages.reverse();
|
||||
}
|
||||
|
||||
for message in &selected_messages {
|
||||
history.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: message.clone(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
|
||||
let summary_text = if summary_text.is_empty() {
|
||||
"(no summary available)".to_string()
|
||||
} else {
|
||||
summary_text.to_string()
|
||||
};
|
||||
|
||||
history.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: summary_text }],
|
||||
});
|
||||
|
||||
history
|
||||
}
|
||||
|
||||
async fn drain_to_completed(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
return Err(CodexErr::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
};
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
sess.record_into_history(std::slice::from_ref(&item)).await;
|
||||
}
|
||||
Ok(ResponseEvent::RateLimits(snapshot)) => {
|
||||
sess.update_rate_limits(turn_context, snapshot).await;
|
||||
}
|
||||
Ok(ResponseEvent::Completed { token_usage, .. }) => {
|
||||
sess.update_token_usage_info(turn_context, token_usage.as_ref())
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
Ok(_) => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn content_items_to_text_joins_non_empty_segments() {
|
||||
let items = vec![
|
||||
ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
},
|
||||
ContentItem::OutputText {
|
||||
text: String::new(),
|
||||
},
|
||||
ContentItem::OutputText {
|
||||
text: "world".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
let joined = content_items_to_text(&items);
|
||||
|
||||
assert_eq!(Some("hello\nworld".to_string()), joined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_items_to_text_ignores_image_only_content() {
|
||||
let items = vec![ContentItem::InputImage {
|
||||
image_url: "file://image.png".to_string(),
|
||||
}];
|
||||
|
||||
let joined = content_items_to_text(&items);
|
||||
|
||||
assert_eq!(None, joined);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_extracts_user_text_only() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: Some("assistant".to_string()),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "ignored".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: Some("user".to_string()),
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "first".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Other,
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(vec!["first".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_user_messages_filters_session_prefix_entries() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "# AGENTS.md instructions for project\n\n<INSTRUCTIONS>\ndo things\n</INSTRUCTIONS>"
|
||||
.to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<ENVIRONMENT_CONTEXT>cwd=/tmp</ENVIRONMENT_CONTEXT>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "real user message".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
let collected = collect_user_messages(&items);
|
||||
|
||||
assert_eq!(vec!["real user message".to_string()], collected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_compacted_history_truncates_overlong_user_messages() {
|
||||
// Use a small truncation limit so the test remains fast while still validating
|
||||
// that oversized user content is truncated.
|
||||
let max_bytes = 128;
|
||||
let big = "X".repeat(max_bytes + 50);
|
||||
let history = super::build_compacted_history_with_limit(
|
||||
Vec::new(),
|
||||
std::slice::from_ref(&big),
|
||||
"SUMMARY",
|
||||
max_bytes,
|
||||
);
|
||||
assert_eq!(history.len(), 2);
|
||||
|
||||
let truncated_message = &history[0];
|
||||
let summary_message = &history[1];
|
||||
|
||||
let truncated_text = match truncated_message {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content_items_to_text(content).unwrap_or_default()
|
||||
}
|
||||
other => panic!("unexpected item in history: {other:?}"),
|
||||
};
|
||||
|
||||
assert!(
|
||||
truncated_text.contains("tokens truncated"),
|
||||
"expected truncation marker in truncated user message"
|
||||
);
|
||||
assert!(
|
||||
!truncated_text.contains(&big),
|
||||
"truncated user message should not include the full oversized user text"
|
||||
);
|
||||
|
||||
let summary_text = match summary_message {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content_items_to_text(content).unwrap_or_default()
|
||||
}
|
||||
other => panic!("unexpected item in history: {other:?}"),
|
||||
};
|
||||
assert_eq!(summary_text, "SUMMARY");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_compacted_history_appends_summary_message() {
|
||||
let initial_context: Vec<ResponseItem> = Vec::new();
|
||||
let user_messages = vec!["first user message".to_string()];
|
||||
let summary_text = "summary text";
|
||||
|
||||
let history = build_compacted_history(initial_context, &user_messages, summary_text);
|
||||
assert!(
|
||||
!history.is_empty(),
|
||||
"expected compacted history to include summary"
|
||||
);
|
||||
|
||||
let last = history.last().expect("history should have a summary entry");
|
||||
let summary = match last {
|
||||
ResponseItem::Message { role, content, .. } if role == "user" => {
|
||||
content_items_to_text(content).unwrap_or_default()
|
||||
}
|
||||
other => panic!("expected summary message, found {other:?}"),
|
||||
};
|
||||
assert_eq!(summary, summary_text);
|
||||
}
|
||||
}
|
||||
1008
llmx-rs/core/src/config/edit.rs
Normal file
1008
llmx-rs/core/src/config/edit.rs
Normal file
File diff suppressed because it is too large
Load Diff
3298
llmx-rs/core/src/config/mod.rs
Normal file
3298
llmx-rs/core/src/config/mod.rs
Normal file
File diff suppressed because it is too large
Load Diff
50
llmx-rs/core/src/config/profile.rs
Normal file
50
llmx-rs/core/src/config/profile.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::protocol::AskForApproval;
|
||||
use codex_protocol::config_types::ReasoningEffort;
|
||||
use codex_protocol::config_types::ReasoningSummary;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::config_types::Verbosity;
|
||||
|
||||
/// Collection of common configuration options that a user can define as a unit
|
||||
/// in `config.toml`.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
|
||||
pub struct ConfigProfile {
|
||||
pub model: Option<String>,
|
||||
/// The key in the `model_providers` map identifying the
|
||||
/// [`ModelProviderInfo`] to use.
|
||||
pub model_provider: Option<String>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub model_reasoning_effort: Option<ReasoningEffort>,
|
||||
pub model_reasoning_summary: Option<ReasoningSummary>,
|
||||
pub model_verbosity: Option<Verbosity>,
|
||||
pub chatgpt_base_url: Option<String>,
|
||||
pub experimental_instructions_file: Option<PathBuf>,
|
||||
pub experimental_compact_prompt_file: Option<PathBuf>,
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
pub experimental_sandbox_command_assessment: Option<bool>,
|
||||
pub tools_web_search: Option<bool>,
|
||||
pub tools_view_image: Option<bool>,
|
||||
/// Optional feature toggles scoped to this profile.
|
||||
#[serde(default)]
|
||||
pub features: Option<crate::features::FeaturesToml>,
|
||||
}
|
||||
|
||||
impl From<ConfigProfile> for codex_app_server_protocol::Profile {
|
||||
fn from(config_profile: ConfigProfile) -> Self {
|
||||
Self {
|
||||
model: config_profile.model,
|
||||
model_provider: config_profile.model_provider,
|
||||
approval_policy: config_profile.approval_policy,
|
||||
model_reasoning_effort: config_profile.model_reasoning_effort,
|
||||
model_reasoning_summary: config_profile.model_reasoning_summary,
|
||||
model_verbosity: config_profile.model_verbosity,
|
||||
chatgpt_base_url: config_profile.chatgpt_base_url,
|
||||
}
|
||||
}
|
||||
}
|
||||
771
llmx-rs/core/src/config/types.rs
Normal file
771
llmx-rs/core/src/config/types.rs
Normal file
@@ -0,0 +1,771 @@
|
||||
//! Types used to define the fields of [`crate::config::Config`].
|
||||
|
||||
// Note this file should generally be restricted to simple struct/enum
|
||||
// definitions that do not contain business logic.
|
||||
|
||||
use serde::Deserializer;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use wildmatch::WildMatchPattern;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde::de::Error as SerdeError;
|
||||
|
||||
pub const DEFAULT_OTEL_ENVIRONMENT: &str = "dev";
|
||||
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
pub struct McpServerConfig {
|
||||
#[serde(flatten)]
|
||||
pub transport: McpServerTransportConfig,
|
||||
|
||||
/// When `false`, Codex skips initializing this MCP server.
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
|
||||
#[serde(
|
||||
default,
|
||||
with = "option_duration_secs",
|
||||
skip_serializing_if = "Option::is_none"
|
||||
)]
|
||||
pub startup_timeout_sec: Option<Duration>,
|
||||
|
||||
/// Default timeout for MCP tool calls initiated via this server.
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
pub tool_timeout_sec: Option<Duration>,
|
||||
|
||||
/// Explicit allow-list of tools exposed from this server. When set, only these tools will be registered.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub enabled_tools: Option<Vec<String>>,
|
||||
|
||||
/// Explicit deny-list of tools. These tools will be removed after applying `enabled_tools`.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub disabled_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for McpServerConfig {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct RawMcpServerConfig {
|
||||
// stdio
|
||||
command: Option<String>,
|
||||
#[serde(default)]
|
||||
args: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_vars: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
cwd: Option<PathBuf>,
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
#[serde(default)]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
// streamable_http
|
||||
url: Option<String>,
|
||||
bearer_token: Option<String>,
|
||||
bearer_token_env_var: Option<String>,
|
||||
|
||||
// shared
|
||||
#[serde(default)]
|
||||
startup_timeout_sec: Option<f64>,
|
||||
#[serde(default)]
|
||||
startup_timeout_ms: Option<u64>,
|
||||
#[serde(default, with = "option_duration_secs")]
|
||||
tool_timeout_sec: Option<Duration>,
|
||||
#[serde(default)]
|
||||
enabled: Option<bool>,
|
||||
#[serde(default)]
|
||||
enabled_tools: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
disabled_tools: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
let mut raw = RawMcpServerConfig::deserialize(deserializer)?;
|
||||
|
||||
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
|
||||
(Some(sec), _) => {
|
||||
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
|
||||
Some(duration)
|
||||
}
|
||||
(None, Some(ms)) => Some(Duration::from_millis(ms)),
|
||||
(None, None) => None,
|
||||
};
|
||||
let tool_timeout_sec = raw.tool_timeout_sec;
|
||||
let enabled = raw.enabled.unwrap_or_else(default_enabled);
|
||||
let enabled_tools = raw.enabled_tools.clone();
|
||||
let disabled_tools = raw.disabled_tools.clone();
|
||||
|
||||
fn throw_if_set<E, T>(transport: &str, field: &str, value: Option<&T>) -> Result<(), E>
|
||||
where
|
||||
E: SerdeError,
|
||||
{
|
||||
if value.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
Err(E::custom(format!(
|
||||
"{field} is not supported for {transport}",
|
||||
)))
|
||||
}
|
||||
|
||||
let transport = if let Some(command) = raw.command.clone() {
|
||||
throw_if_set("stdio", "url", raw.url.as_ref())?;
|
||||
throw_if_set(
|
||||
"stdio",
|
||||
"bearer_token_env_var",
|
||||
raw.bearer_token_env_var.as_ref(),
|
||||
)?;
|
||||
throw_if_set("stdio", "bearer_token", raw.bearer_token.as_ref())?;
|
||||
throw_if_set("stdio", "http_headers", raw.http_headers.as_ref())?;
|
||||
throw_if_set("stdio", "env_http_headers", raw.env_http_headers.as_ref())?;
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args: raw.args.clone().unwrap_or_default(),
|
||||
env: raw.env.clone(),
|
||||
env_vars: raw.env_vars.clone().unwrap_or_default(),
|
||||
cwd: raw.cwd.take(),
|
||||
}
|
||||
} else if let Some(url) = raw.url.clone() {
|
||||
throw_if_set("streamable_http", "args", raw.args.as_ref())?;
|
||||
throw_if_set("streamable_http", "env", raw.env.as_ref())?;
|
||||
throw_if_set("streamable_http", "env_vars", raw.env_vars.as_ref())?;
|
||||
throw_if_set("streamable_http", "cwd", raw.cwd.as_ref())?;
|
||||
throw_if_set("streamable_http", "bearer_token", raw.bearer_token.as_ref())?;
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var: raw.bearer_token_env_var.clone(),
|
||||
http_headers: raw.http_headers.clone(),
|
||||
env_http_headers: raw.env_http_headers.take(),
|
||||
}
|
||||
} else {
|
||||
return Err(SerdeError::custom("invalid transport"));
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
transport,
|
||||
startup_timeout_sec,
|
||||
tool_timeout_sec,
|
||||
enabled,
|
||||
enabled_tools,
|
||||
disabled_tools,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
|
||||
pub enum McpServerTransportConfig {
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio
|
||||
Stdio {
|
||||
command: String,
|
||||
#[serde(default)]
|
||||
args: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env: Option<HashMap<String, String>>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
env_vars: Vec<String>,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
cwd: Option<PathBuf>,
|
||||
},
|
||||
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
|
||||
StreamableHttp {
|
||||
url: String,
|
||||
/// Name of the environment variable to read for an HTTP bearer token.
|
||||
/// When set, requests will include the token via `Authorization: Bearer <token>`.
|
||||
/// The actual secret value must be provided via the environment.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
bearer_token_env_var: Option<String>,
|
||||
/// Additional HTTP headers to include in requests to this server.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
/// HTTP headers where the value is sourced from an environment variable.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
},
|
||||
}
|
||||
|
||||
mod option_duration_secs {
|
||||
use serde::Deserialize;
|
||||
use serde::Deserializer;
|
||||
use serde::Serializer;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
match value {
|
||||
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
|
||||
None => serializer.serialize_none(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let secs = Option::<f64>::deserialize(deserializer)?;
|
||||
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
|
||||
.transpose()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
|
||||
pub enum UriBasedFileOpener {
|
||||
#[serde(rename = "vscode")]
|
||||
VsCode,
|
||||
|
||||
#[serde(rename = "vscode-insiders")]
|
||||
VsCodeInsiders,
|
||||
|
||||
#[serde(rename = "windsurf")]
|
||||
Windsurf,
|
||||
|
||||
#[serde(rename = "cursor")]
|
||||
Cursor,
|
||||
|
||||
/// Option to disable the URI-based file opener.
|
||||
#[serde(rename = "none")]
|
||||
None,
|
||||
}
|
||||
|
||||
impl UriBasedFileOpener {
|
||||
pub fn get_scheme(&self) -> Option<&str> {
|
||||
match self {
|
||||
UriBasedFileOpener::VsCode => Some("vscode"),
|
||||
UriBasedFileOpener::VsCodeInsiders => Some("vscode-insiders"),
|
||||
UriBasedFileOpener::Windsurf => Some("windsurf"),
|
||||
UriBasedFileOpener::Cursor => Some("cursor"),
|
||||
UriBasedFileOpener::None => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Settings that govern if and what will be written to `~/.codex/history.jsonl`.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct History {
|
||||
/// If true, history entries will not be written to disk.
|
||||
pub persistence: HistoryPersistence,
|
||||
|
||||
/// If set, the maximum size of the history file in bytes.
|
||||
/// TODO(mbolin): Not currently honored.
|
||||
pub max_bytes: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Copy, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum HistoryPersistence {
|
||||
/// Save all history entries to disk.
|
||||
#[default]
|
||||
SaveAll,
|
||||
/// Do not write history to disk.
|
||||
None,
|
||||
}
|
||||
|
||||
// ===== OTEL configuration =====
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum OtelHttpProtocol {
|
||||
/// Binary payload
|
||||
Binary,
|
||||
/// JSON payload
|
||||
Json,
|
||||
}
|
||||
|
||||
/// Which OTEL exporter to use.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum OtelExporterKind {
|
||||
None,
|
||||
OtlpHttp {
|
||||
endpoint: String,
|
||||
headers: HashMap<String, String>,
|
||||
protocol: OtelHttpProtocol,
|
||||
},
|
||||
OtlpGrpc {
|
||||
endpoint: String,
|
||||
headers: HashMap<String, String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// OTEL settings loaded from config.toml. Fields are optional so we can apply defaults.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct OtelConfigToml {
|
||||
/// Log user prompt in traces
|
||||
pub log_user_prompt: Option<bool>,
|
||||
|
||||
/// Mark traces with environment (dev, staging, prod, test). Defaults to dev.
|
||||
pub environment: Option<String>,
|
||||
|
||||
/// Exporter to use. Defaults to `otlp-file`.
|
||||
pub exporter: Option<OtelExporterKind>,
|
||||
}
|
||||
|
||||
/// Effective OTEL settings after defaults are applied.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct OtelConfig {
|
||||
pub log_user_prompt: bool,
|
||||
pub environment: String,
|
||||
pub exporter: OtelExporterKind,
|
||||
}
|
||||
|
||||
impl Default for OtelConfig {
|
||||
fn default() -> Self {
|
||||
OtelConfig {
|
||||
log_user_prompt: false,
|
||||
environment: DEFAULT_OTEL_ENVIRONMENT.to_owned(),
|
||||
exporter: OtelExporterKind::None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Notifications {
|
||||
Enabled(bool),
|
||||
Custom(Vec<String>),
|
||||
}
|
||||
|
||||
impl Default for Notifications {
|
||||
fn default() -> Self {
|
||||
Self::Enabled(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Collection of settings that are specific to the TUI.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Tui {
|
||||
/// Enable desktop notifications from the TUI when the terminal is unfocused.
|
||||
/// Defaults to `false`.
|
||||
#[serde(default)]
|
||||
pub notifications: Notifications,
|
||||
}
|
||||
|
||||
/// Settings for notices we display to users via the tui and app-server clients
|
||||
/// (primarily the Codex IDE extension). NOTE: these are different from
|
||||
/// notifications - notices are warnings, NUX screens, acknowledgements, etc.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct Notice {
|
||||
/// Tracks whether the user has acknowledged the full access warning prompt.
|
||||
pub hide_full_access_warning: Option<bool>,
|
||||
/// Tracks whether the user has acknowledged the Windows world-writable directories warning.
|
||||
pub hide_world_writable_warning: Option<bool>,
|
||||
/// Tracks whether the user opted out of the rate limit model switch reminder.
|
||||
pub hide_rate_limit_model_nudge: Option<bool>,
|
||||
}
|
||||
|
||||
impl Notice {
|
||||
/// referenced by config_edit helpers when writing notice flags
|
||||
pub(crate) const TABLE_KEY: &'static str = "notice";
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct SandboxWorkspaceWrite {
|
||||
#[serde(default)]
|
||||
pub writable_roots: Vec<PathBuf>,
|
||||
#[serde(default)]
|
||||
pub network_access: bool,
|
||||
#[serde(default)]
|
||||
pub exclude_tmpdir_env_var: bool,
|
||||
#[serde(default)]
|
||||
pub exclude_slash_tmp: bool,
|
||||
}
|
||||
|
||||
impl From<SandboxWorkspaceWrite> for codex_app_server_protocol::SandboxSettings {
|
||||
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
|
||||
Self {
|
||||
writable_roots: sandbox_workspace_write.writable_roots,
|
||||
network_access: Some(sandbox_workspace_write.network_access),
|
||||
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
|
||||
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ShellEnvironmentPolicyInherit {
|
||||
/// "Core" environment variables for the platform. On UNIX, this would
|
||||
/// include HOME, LOGNAME, PATH, SHELL, and USER, among others.
|
||||
Core,
|
||||
|
||||
/// Inherits the full environment from the parent process.
|
||||
#[default]
|
||||
All,
|
||||
|
||||
/// Do not inherit any environment variables from the parent process.
|
||||
None,
|
||||
}
|
||||
|
||||
/// Policy for building the `env` when spawning a process via either the
|
||||
/// `shell` or `local_shell` tool.
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
|
||||
pub struct ShellEnvironmentPolicyToml {
|
||||
pub inherit: Option<ShellEnvironmentPolicyInherit>,
|
||||
|
||||
pub ignore_default_excludes: Option<bool>,
|
||||
|
||||
/// List of regular expressions.
|
||||
pub exclude: Option<Vec<String>>,
|
||||
|
||||
pub r#set: Option<HashMap<String, String>>,
|
||||
|
||||
/// List of regular expressions.
|
||||
pub include_only: Option<Vec<String>>,
|
||||
|
||||
pub experimental_use_profile: Option<bool>,
|
||||
}
|
||||
|
||||
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
|
||||
|
||||
/// Deriving the `env` based on this policy works as follows:
|
||||
/// 1. Create an initial map based on the `inherit` policy.
|
||||
/// 2. If `ignore_default_excludes` is false, filter the map using the default
|
||||
/// exclude pattern(s), which are: `"*KEY*"` and `"*TOKEN*"`.
|
||||
/// 3. If `exclude` is not empty, filter the map using the provided patterns.
|
||||
/// 4. Insert any entries from `r#set` into the map.
|
||||
/// 5. If non-empty, filter the map using the `include_only` patterns.
|
||||
#[derive(Debug, Clone, PartialEq, Default)]
|
||||
pub struct ShellEnvironmentPolicy {
|
||||
/// Starting point when building the environment.
|
||||
pub inherit: ShellEnvironmentPolicyInherit,
|
||||
|
||||
/// True to skip the check to exclude default environment variables that
|
||||
/// contain "KEY" or "TOKEN" in their name.
|
||||
pub ignore_default_excludes: bool,
|
||||
|
||||
/// Environment variable names to exclude from the environment.
|
||||
pub exclude: Vec<EnvironmentVariablePattern>,
|
||||
|
||||
/// (key, value) pairs to insert in the environment.
|
||||
pub r#set: HashMap<String, String>,
|
||||
|
||||
/// Environment variable names to retain in the environment.
|
||||
pub include_only: Vec<EnvironmentVariablePattern>,
|
||||
|
||||
/// If true, the shell profile will be used to run the command.
|
||||
pub use_profile: bool,
|
||||
}
|
||||
|
||||
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
|
||||
fn from(toml: ShellEnvironmentPolicyToml) -> Self {
|
||||
// Default to inheriting the full environment when not specified.
|
||||
let inherit = toml.inherit.unwrap_or(ShellEnvironmentPolicyInherit::All);
|
||||
let ignore_default_excludes = toml.ignore_default_excludes.unwrap_or(false);
|
||||
let exclude = toml
|
||||
.exclude
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
|
||||
.collect();
|
||||
let r#set = toml.r#set.unwrap_or_default();
|
||||
let include_only = toml
|
||||
.include_only
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
|
||||
.collect();
|
||||
let use_profile = toml.experimental_use_profile.unwrap_or(false);
|
||||
|
||||
Self {
|
||||
inherit,
|
||||
ignore_default_excludes,
|
||||
exclude,
|
||||
r#set,
|
||||
include_only,
|
||||
use_profile,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum ReasoningSummaryFormat {
|
||||
#[default]
|
||||
None,
|
||||
Experimental,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
assert!(cfg.enabled_tools.is_none());
|
||||
assert!(cfg.disabled_tools.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_args() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
args = ["hello", "world"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
args = ["hello", "world"]
|
||||
env = { "FOO" = "BAR" }
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec!["hello".to_string(), "world".to_string()],
|
||||
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])),
|
||||
env_vars: Vec::new(),
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_env_vars() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_vars = ["FOO", "BAR"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with env_vars");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: vec!["FOO".to_string(), "BAR".to_string()],
|
||||
cwd: None,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_stdio_command_server_config_with_cwd() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
cwd = "/tmp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize command config with cwd");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::Stdio {
|
||||
command: "echo".to_string(),
|
||||
args: vec![],
|
||||
env: None,
|
||||
env_vars: Vec::new(),
|
||||
cwd: Some(PathBuf::from("/tmp")),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_disabled_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
enabled = false
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize disabled server config");
|
||||
|
||||
assert!(!cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_env_var() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
bearer_token_env_var = "GITHUB_TOKEN"
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: Some("GITHUB_TOKEN".to_string()),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
}
|
||||
);
|
||||
assert!(cfg.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_streamable_http_server_config_with_headers() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
url = "https://example.com/mcp"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
env_http_headers = { "X-Token" = "TOKEN_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize http config with headers");
|
||||
|
||||
assert_eq!(
|
||||
cfg.transport,
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url: "https://example.com/mcp".to_string(),
|
||||
bearer_token_env_var: None,
|
||||
http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])),
|
||||
env_http_headers: Some(HashMap::from([(
|
||||
"X-Token".to_string(),
|
||||
"TOKEN_ENV".to_string()
|
||||
)])),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_server_config_with_tool_filters() {
|
||||
let cfg: McpServerConfig = toml::from_str(
|
||||
r#"
|
||||
command = "echo"
|
||||
enabled_tools = ["allowed"]
|
||||
disabled_tools = ["blocked"]
|
||||
"#,
|
||||
)
|
||||
.expect("should deserialize tool filters");
|
||||
|
||||
assert_eq!(cfg.enabled_tools, Some(vec!["allowed".to_string()]));
|
||||
assert_eq!(cfg.disabled_tools, Some(vec!["blocked".to_string()]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_command_and_url() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
url = "https://example.com"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject command+url");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_env_for_http_transport() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
url = "https://example.com"
|
||||
env = { "FOO" = "BAR" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject env for http transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_headers_for_stdio() {
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
http_headers = { "X-Foo" = "bar" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject http_headers for stdio transport");
|
||||
|
||||
toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
command = "echo"
|
||||
env_http_headers = { "X-Foo" = "BAR_ENV" }
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject env_http_headers for stdio transport");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deserialize_rejects_inline_bearer_token_field() {
|
||||
let err = toml::from_str::<McpServerConfig>(
|
||||
r#"
|
||||
url = "https://example.com"
|
||||
bearer_token = "secret"
|
||||
"#,
|
||||
)
|
||||
.expect_err("should reject bearer_token field");
|
||||
|
||||
assert!(
|
||||
err.to_string().contains("bearer_token is not supported"),
|
||||
"unexpected error: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
118
llmx-rs/core/src/config_loader/macos.rs
Normal file
118
llmx-rs/core/src/config_loader/macos.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use std::io;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
mod native {
|
||||
use super::*;
|
||||
use base64::Engine;
|
||||
use base64::prelude::BASE64_STANDARD;
|
||||
use core_foundation::base::TCFType;
|
||||
use core_foundation::string::CFString;
|
||||
use core_foundation::string::CFStringRef;
|
||||
use std::ffi::c_void;
|
||||
use tokio::task;
|
||||
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
if let Some(encoded) = override_base64 {
|
||||
let trimmed = encoded.trim();
|
||||
return if trimmed.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
};
|
||||
}
|
||||
|
||||
const LOAD_ERROR: &str = "Failed to load managed preferences configuration";
|
||||
|
||||
match task::spawn_blocking(load_managed_admin_config).await {
|
||||
Ok(result) => result,
|
||||
Err(join_err) => {
|
||||
if join_err.is_cancelled() {
|
||||
tracing::error!("Managed preferences load task was cancelled");
|
||||
} else {
|
||||
tracing::error!("Managed preferences load task failed: {join_err}");
|
||||
}
|
||||
Err(io::Error::other(LOAD_ERROR))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn load_managed_admin_config() -> io::Result<Option<TomlValue>> {
|
||||
#[link(name = "CoreFoundation", kind = "framework")]
|
||||
unsafe extern "C" {
|
||||
fn CFPreferencesCopyAppValue(
|
||||
key: CFStringRef,
|
||||
application_id: CFStringRef,
|
||||
) -> *mut c_void;
|
||||
}
|
||||
|
||||
const MANAGED_PREFERENCES_APPLICATION_ID: &str = "com.openai.codex";
|
||||
const MANAGED_PREFERENCES_CONFIG_KEY: &str = "config_toml_base64";
|
||||
|
||||
let application_id = CFString::new(MANAGED_PREFERENCES_APPLICATION_ID);
|
||||
let key = CFString::new(MANAGED_PREFERENCES_CONFIG_KEY);
|
||||
|
||||
let value_ref = unsafe {
|
||||
CFPreferencesCopyAppValue(
|
||||
key.as_concrete_TypeRef(),
|
||||
application_id.as_concrete_TypeRef(),
|
||||
)
|
||||
};
|
||||
|
||||
if value_ref.is_null() {
|
||||
tracing::debug!(
|
||||
"Managed preferences for {} key {} not found",
|
||||
MANAGED_PREFERENCES_APPLICATION_ID,
|
||||
MANAGED_PREFERENCES_CONFIG_KEY
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = unsafe { CFString::wrap_under_create_rule(value_ref as _) };
|
||||
let contents = value.to_string();
|
||||
let trimmed = contents.trim();
|
||||
|
||||
parse_managed_preferences_base64(trimmed).map(Some)
|
||||
}
|
||||
|
||||
pub(super) fn parse_managed_preferences_base64(encoded: &str) -> io::Result<TomlValue> {
|
||||
let decoded = BASE64_STANDARD.decode(encoded.as_bytes()).map_err(|err| {
|
||||
tracing::error!("Failed to decode managed preferences as base64: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
let decoded_str = String::from_utf8(decoded).map_err(|err| {
|
||||
tracing::error!("Managed preferences base64 contents were not valid UTF-8: {err}");
|
||||
io::Error::new(io::ErrorKind::InvalidData, err)
|
||||
})?;
|
||||
|
||||
match toml::from_str::<TomlValue>(&decoded_str) {
|
||||
Ok(TomlValue::Table(parsed)) => Ok(TomlValue::Table(parsed)),
|
||||
Ok(other) => {
|
||||
tracing::error!(
|
||||
"Managed preferences TOML must have a table at the root, found {other:?}",
|
||||
);
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"managed preferences root must be a table",
|
||||
))
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse managed preferences TOML: {err}");
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
pub(crate) use native::load_managed_admin_config_layer;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
pub(crate) async fn load_managed_admin_config_layer(
|
||||
_override_base64: Option<&str>,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
Ok(None)
|
||||
}
|
||||
311
llmx-rs/core/src/config_loader/mod.rs
Normal file
311
llmx-rs/core/src/config_loader/mod.rs
Normal file
@@ -0,0 +1,311 @@
|
||||
mod macos;
|
||||
|
||||
use crate::config::CONFIG_TOML_FILE;
|
||||
use macos::load_managed_admin_config_layer;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
#[cfg(unix)]
|
||||
const CODEX_MANAGED_CONFIG_SYSTEM_PATH: &str = "/etc/codex/managed_config.toml";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct LoadedConfigLayers {
|
||||
pub base: TomlValue,
|
||||
pub managed_config: Option<TomlValue>,
|
||||
pub managed_preferences: Option<TomlValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct LoaderOverrides {
|
||||
pub managed_config_path: Option<PathBuf>,
|
||||
#[cfg(target_os = "macos")]
|
||||
pub managed_preferences_base64: Option<String>,
|
||||
}
|
||||
|
||||
// Configuration layering pipeline (top overrides bottom):
|
||||
//
|
||||
// +-------------------------+
|
||||
// | Managed preferences (*) |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | managed_config.toml |
|
||||
// +-------------------------+
|
||||
// ^
|
||||
// |
|
||||
// +-------------------------+
|
||||
// | config.toml (base) |
|
||||
// +-------------------------+
|
||||
//
|
||||
// (*) Only available on macOS via managed device profiles.
|
||||
|
||||
pub async fn load_config_as_toml(codex_home: &Path) -> io::Result<TomlValue> {
|
||||
load_config_as_toml_with_overrides(codex_home, LoaderOverrides::default()).await
|
||||
}
|
||||
|
||||
fn default_empty_table() -> TomlValue {
|
||||
TomlValue::Table(Default::default())
|
||||
}
|
||||
|
||||
pub(crate) async fn load_config_layers_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
load_config_layers_internal(codex_home, overrides).await
|
||||
}
|
||||
|
||||
async fn load_config_as_toml_with_overrides(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<TomlValue> {
|
||||
let layers = load_config_layers_internal(codex_home, overrides).await?;
|
||||
Ok(apply_managed_layers(layers))
|
||||
}
|
||||
|
||||
async fn load_config_layers_internal(
|
||||
codex_home: &Path,
|
||||
overrides: LoaderOverrides,
|
||||
) -> io::Result<LoadedConfigLayers> {
|
||||
#[cfg(target_os = "macos")]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
managed_preferences_base64,
|
||||
} = overrides;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let LoaderOverrides {
|
||||
managed_config_path,
|
||||
} = overrides;
|
||||
|
||||
let managed_config_path =
|
||||
managed_config_path.unwrap_or_else(|| managed_config_default_path(codex_home));
|
||||
|
||||
let user_config_path = codex_home.join(CONFIG_TOML_FILE);
|
||||
let user_config = read_config_from_path(&user_config_path, true).await?;
|
||||
let managed_config = read_config_from_path(&managed_config_path, false).await?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
let managed_preferences =
|
||||
load_managed_admin_config_layer(managed_preferences_base64.as_deref()).await?;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let managed_preferences = load_managed_admin_config_layer(None).await?;
|
||||
|
||||
Ok(LoadedConfigLayers {
|
||||
base: user_config.unwrap_or_else(default_empty_table),
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_config_from_path(
|
||||
path: &Path,
|
||||
log_missing_as_info: bool,
|
||||
) -> io::Result<Option<TomlValue>> {
|
||||
match fs::read_to_string(path).await {
|
||||
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
|
||||
Ok(value) => Ok(Some(value)),
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to parse {}: {err}", path.display());
|
||||
Err(io::Error::new(io::ErrorKind::InvalidData, err))
|
||||
}
|
||||
},
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => {
|
||||
if log_missing_as_info {
|
||||
tracing::info!("{} not found, using defaults", path.display());
|
||||
} else {
|
||||
tracing::debug!("{} not found", path.display());
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("Failed to read {}: {err}", path.display());
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge config `overlay` into `base`, giving `overlay` precedence.
|
||||
pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
|
||||
if let TomlValue::Table(overlay_table) = overlay
|
||||
&& let TomlValue::Table(base_table) = base
|
||||
{
|
||||
for (key, value) in overlay_table {
|
||||
if let Some(existing) = base_table.get_mut(key) {
|
||||
merge_toml_values(existing, value);
|
||||
} else {
|
||||
base_table.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*base = overlay.clone();
|
||||
}
|
||||
}
|
||||
|
||||
fn managed_config_default_path(codex_home: &Path) -> PathBuf {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
let _ = codex_home;
|
||||
PathBuf::from(CODEX_MANAGED_CONFIG_SYSTEM_PATH)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
codex_home.join("managed_config.toml")
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_managed_layers(layers: LoadedConfigLayers) -> TomlValue {
|
||||
let LoadedConfigLayers {
|
||||
mut base,
|
||||
managed_config,
|
||||
managed_preferences,
|
||||
} = layers;
|
||||
|
||||
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
|
||||
merge_toml_values(&mut base, &overlay);
|
||||
}
|
||||
|
||||
base
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn merges_managed_config_layer_on_top() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"foo = 1
|
||||
|
||||
[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"foo = 2
|
||||
|
||||
[nested]
|
||||
value = "managed_config"
|
||||
extra = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
|
||||
assert_eq!(table.get("foo"), Some(&TomlValue::Integer(2)));
|
||||
let nested = table
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed_config".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("extra"), Some(&TomlValue::Boolean(true)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_empty_when_all_layers_missing() {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
#[cfg(target_os = "macos")]
|
||||
managed_preferences_base64: None,
|
||||
};
|
||||
|
||||
let layers = load_config_layers_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load layers");
|
||||
let base_table = layers.base.as_table().expect("base table expected");
|
||||
assert!(
|
||||
base_table.is_empty(),
|
||||
"expected empty base layer when configs missing"
|
||||
);
|
||||
assert!(
|
||||
layers.managed_config.is_none(),
|
||||
"managed config layer should be absent when file missing"
|
||||
);
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let loaded = load_config_as_toml(tmp.path()).await.expect("load config");
|
||||
let table = loaded.as_table().expect("top-level table expected");
|
||||
assert!(
|
||||
table.is_empty(),
|
||||
"expected empty table when configs missing"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
#[tokio::test]
|
||||
async fn managed_preferences_take_highest_precedence() {
|
||||
use base64::Engine;
|
||||
|
||||
let managed_payload = r#"
|
||||
[nested]
|
||||
value = "managed"
|
||||
flag = false
|
||||
"#;
|
||||
let encoded = base64::prelude::BASE64_STANDARD.encode(managed_payload.as_bytes());
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let managed_path = tmp.path().join("managed_config.toml");
|
||||
|
||||
std::fs::write(
|
||||
tmp.path().join(CONFIG_TOML_FILE),
|
||||
r#"[nested]
|
||||
value = "base"
|
||||
"#,
|
||||
)
|
||||
.expect("write base");
|
||||
std::fs::write(
|
||||
&managed_path,
|
||||
r#"[nested]
|
||||
value = "managed_config"
|
||||
flag = true
|
||||
"#,
|
||||
)
|
||||
.expect("write managed config");
|
||||
|
||||
let overrides = LoaderOverrides {
|
||||
managed_config_path: Some(managed_path),
|
||||
managed_preferences_base64: Some(encoded),
|
||||
};
|
||||
|
||||
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
|
||||
.await
|
||||
.expect("load config");
|
||||
let nested = loaded
|
||||
.get("nested")
|
||||
.and_then(|v| v.as_table())
|
||||
.expect("nested table");
|
||||
assert_eq!(
|
||||
nested.get("value"),
|
||||
Some(&TomlValue::String("managed".to_string()))
|
||||
);
|
||||
assert_eq!(nested.get("flag"), Some(&TomlValue::Boolean(false)));
|
||||
}
|
||||
}
|
||||
174
llmx-rs/core/src/context_manager/history.rs
Normal file
174
llmx-rs/core/src/context_manager/history.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::TokenUsageInfo;
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::context_manager::normalize;
|
||||
use crate::context_manager::truncate::format_output_for_model_body;
|
||||
use crate::context_manager::truncate::globally_truncate_function_output_items;
|
||||
|
||||
/// Transcript of conversation history
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ContextManager {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
token_info: Option<TokenUsageInfo>,
|
||||
}
|
||||
|
||||
impl ContextManager {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
items: Vec::new(),
|
||||
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn token_info(&self) -> Option<TokenUsageInfo> {
|
||||
self.token_info.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
|
||||
match &mut self.token_info {
|
||||
Some(info) => info.fill_to_context_window(context_window),
|
||||
None => {
|
||||
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `items` is ordered from oldest to newest.
|
||||
pub(crate) fn record_items<I>(&mut self, items: I)
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
for item in items {
|
||||
let item_ref = item.deref();
|
||||
let is_ghost_snapshot = matches!(item_ref, ResponseItem::GhostSnapshot { .. });
|
||||
if !is_api_message(item_ref) && !is_ghost_snapshot {
|
||||
continue;
|
||||
}
|
||||
|
||||
let processed = Self::process_item(&item);
|
||||
self.items.push(processed);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_history(&mut self) -> Vec<ResponseItem> {
|
||||
self.normalize_history();
|
||||
self.contents()
|
||||
}
|
||||
|
||||
// Returns the history prepared for sending to the model.
|
||||
// With extra response items filtered out and GhostCommits removed.
|
||||
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
|
||||
let mut history = self.get_history();
|
||||
Self::remove_ghost_snapshots(&mut history);
|
||||
history
|
||||
}
|
||||
|
||||
pub(crate) fn remove_first_item(&mut self) {
|
||||
if !self.items.is_empty() {
|
||||
// Remove the oldest item (front of the list). Items are ordered from
|
||||
// oldest → newest, so index 0 is the first entry recorded.
|
||||
let removed = self.items.remove(0);
|
||||
// If the removed item participates in a call/output pair, also remove
|
||||
// its corresponding counterpart to keep the invariants intact without
|
||||
// running a full normalization pass.
|
||||
normalize::remove_corresponding_for(&mut self.items, &removed);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
|
||||
self.items = items;
|
||||
}
|
||||
|
||||
pub(crate) fn update_token_info(
|
||||
&mut self,
|
||||
usage: &TokenUsage,
|
||||
model_context_window: Option<i64>,
|
||||
) {
|
||||
self.token_info = TokenUsageInfo::new_or_append(
|
||||
&self.token_info,
|
||||
&Some(usage.clone()),
|
||||
model_context_window,
|
||||
);
|
||||
}
|
||||
|
||||
/// This function enforces a couple of invariants on the in-memory history:
|
||||
/// 1. every call (function/custom) has a corresponding output entry
|
||||
/// 2. every output has a corresponding call entry
|
||||
fn normalize_history(&mut self) {
|
||||
// all function/tool calls must have a corresponding output
|
||||
normalize::ensure_call_outputs_present(&mut self.items);
|
||||
|
||||
// all outputs must have a corresponding function/tool call
|
||||
normalize::remove_orphan_outputs(&mut self.items);
|
||||
}
|
||||
|
||||
/// Returns a clone of the contents in the transcript.
|
||||
fn contents(&self) -> Vec<ResponseItem> {
|
||||
self.items.clone()
|
||||
}
|
||||
|
||||
fn remove_ghost_snapshots(items: &mut Vec<ResponseItem>) {
|
||||
items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
|
||||
}
|
||||
|
||||
fn process_item(item: &ResponseItem) -> ResponseItem {
|
||||
match item {
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let truncated = format_output_for_model_body(output.content.as_str());
|
||||
let truncated_items = output
|
||||
.content_items
|
||||
.as_ref()
|
||||
.map(|items| globally_truncate_function_output_items(items));
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: truncated,
|
||||
content_items: truncated_items,
|
||||
success: output.success,
|
||||
},
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
let truncated = format_output_for_model_body(output);
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: truncated,
|
||||
}
|
||||
}
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. }
|
||||
| ResponseItem::Other => item.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// API messages include every non-system item (user/assistant messages, reasoning,
|
||||
/// tool calls, tool outputs, shell calls, and web-search calls).
|
||||
fn is_api_message(message: &ResponseItem) -> bool {
|
||||
match message {
|
||||
ResponseItem::Message { role, .. } => role.as_str() != "system",
|
||||
ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. } => true,
|
||||
ResponseItem::GhostSnapshot { .. } => false,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "history_tests.rs"]
|
||||
mod tests;
|
||||
841
llmx-rs/core/src/context_manager/history_tests.rs
Normal file
841
llmx-rs/core/src/context_manager/history_tests.rs
Normal file
@@ -0,0 +1,841 @@
|
||||
use super::*;
|
||||
use crate::context_manager::truncate;
|
||||
use codex_git::GhostCommit;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::LocalShellExecAction;
|
||||
use codex_protocol::models::LocalShellStatus;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use pretty_assertions::assert_eq;
|
||||
use regex_lite::Regex;
|
||||
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn create_history_with_items(items: Vec<ResponseItem>) -> ContextManager {
|
||||
let mut h = ContextManager::new();
|
||||
h.record_items(items.iter());
|
||||
h
|
||||
}
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
fn reasoning_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "summary".to_string(),
|
||||
}],
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: text.to_string(),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filters_non_api_messages() {
|
||||
let mut h = ContextManager::default();
|
||||
// System message is not API messages; Other is ignored.
|
||||
let system = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "system".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "ignored".to_string(),
|
||||
}],
|
||||
};
|
||||
let reasoning = reasoning_msg("thinking...");
|
||||
h.record_items([&system, &reasoning, &ResponseItem::Other]);
|
||||
|
||||
// User and assistant should be retained.
|
||||
let u = user_msg("hi");
|
||||
let a = assistant_msg("hello");
|
||||
h.record_items([&u, &a]);
|
||||
|
||||
let items = h.contents();
|
||||
assert_eq!(
|
||||
items,
|
||||
vec![
|
||||
ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "summary".to_string(),
|
||||
}],
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: "thinking...".to_string(),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "hi".to_string()
|
||||
}]
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "hello".to_string()
|
||||
}]
|
||||
}
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_history_for_prompt_drops_ghost_commits() {
|
||||
let items = vec![ResponseItem::GhostSnapshot {
|
||||
ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()),
|
||||
}];
|
||||
let mut history = create_history_with_items(items);
|
||||
let filtered = history.get_history_for_prompt();
|
||||
assert_eq!(filtered, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_removes_matching_output_for_function_call() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_removes_matching_call_for_output() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-2".to_string(),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_handles_local_shell_pair() {
|
||||
let items = vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("call-3".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-3".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_first_item_handles_custom_tool_pair() {
|
||||
let items = vec![
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-1".to_string(),
|
||||
name: "my_tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "tool-1".to_string(),
|
||||
output: "ok".to_string(),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.remove_first_item();
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalization_retains_local_shell_outputs() {
|
||||
let items = vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "shell-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
let mut history = create_history_with_items(items.clone());
|
||||
let normalized = history.get_history();
|
||||
assert_eq!(normalized, items);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_items_truncates_function_call_output_content() {
|
||||
let mut history = ContextManager::new();
|
||||
let long_line = "a very long line to trigger truncation\n";
|
||||
let long_output = long_line.repeat(2_500);
|
||||
let item = ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-100".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: long_output.clone(),
|
||||
success: Some(true),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
history.record_items([&item]);
|
||||
|
||||
assert_eq!(history.items.len(), 1);
|
||||
match &history.items[0] {
|
||||
ResponseItem::FunctionCallOutput { output, .. } => {
|
||||
assert_ne!(output.content, long_output);
|
||||
assert!(
|
||||
output.content.starts_with("Total output lines:"),
|
||||
"expected truncated summary, got {}",
|
||||
output.content
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected history item: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_items_truncates_custom_tool_call_output_content() {
|
||||
let mut history = ContextManager::new();
|
||||
let line = "custom output that is very long\n";
|
||||
let long_output = line.repeat(2_500);
|
||||
let item = ResponseItem::CustomToolCallOutput {
|
||||
call_id: "tool-200".to_string(),
|
||||
output: long_output.clone(),
|
||||
};
|
||||
|
||||
history.record_items([&item]);
|
||||
|
||||
assert_eq!(history.items.len(), 1);
|
||||
match &history.items[0] {
|
||||
ResponseItem::CustomToolCallOutput { output, .. } => {
|
||||
assert_ne!(output, &long_output);
|
||||
assert!(
|
||||
output.starts_with("Total output lines:"),
|
||||
"expected truncated summary, got {output}"
|
||||
);
|
||||
}
|
||||
other => panic!("unexpected history item: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) {
|
||||
let pattern = truncated_message_pattern(line, total_lines);
|
||||
let regex = Regex::new(&pattern).unwrap_or_else(|err| {
|
||||
panic!("failed to compile regex {pattern}: {err}");
|
||||
});
|
||||
let captures = regex
|
||||
.captures(message)
|
||||
.unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {message}"));
|
||||
let body = captures
|
||||
.name("body")
|
||||
.expect("missing body capture")
|
||||
.as_str();
|
||||
assert!(
|
||||
body.len() <= truncate::MODEL_FORMAT_MAX_BYTES,
|
||||
"body exceeds byte limit: {} bytes",
|
||||
body.len()
|
||||
);
|
||||
}
|
||||
|
||||
fn truncated_message_pattern(line: &str, total_lines: usize) -> String {
|
||||
let head_take = truncate::MODEL_FORMAT_HEAD_LINES.min(total_lines);
|
||||
let tail_take = truncate::MODEL_FORMAT_TAIL_LINES.min(total_lines.saturating_sub(head_take));
|
||||
let omitted = total_lines.saturating_sub(head_take + tail_take);
|
||||
let escaped_line = regex_lite::escape(line);
|
||||
if omitted == 0 {
|
||||
return format!(
|
||||
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} output truncated to fit {max_bytes} bytes \.{{3}}]\n\n.*)$",
|
||||
max_bytes = truncate::MODEL_FORMAT_MAX_BYTES,
|
||||
);
|
||||
}
|
||||
format!(
|
||||
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$",
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exec_output_truncates_large_error() {
|
||||
let line = "very long execution error line that should trigger truncation\n";
|
||||
let large_error = line.repeat(2_500); // way beyond both byte and line limits
|
||||
|
||||
let truncated = truncate::format_output_for_model_body(&large_error);
|
||||
|
||||
let total_lines = large_error.lines().count();
|
||||
assert_truncated_message_matches(&truncated, line, total_lines);
|
||||
assert_ne!(truncated, large_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exec_output_marks_byte_truncation_without_omitted_lines() {
|
||||
let long_line = "a".repeat(truncate::MODEL_FORMAT_MAX_BYTES + 50);
|
||||
let truncated = truncate::format_output_for_model_body(&long_line);
|
||||
|
||||
assert_ne!(truncated, long_line);
|
||||
let marker_line = format!(
|
||||
"[... output truncated to fit {} bytes ...]",
|
||||
truncate::MODEL_FORMAT_MAX_BYTES
|
||||
);
|
||||
assert!(
|
||||
truncated.contains(&marker_line),
|
||||
"missing byte truncation marker: {truncated}"
|
||||
);
|
||||
assert!(
|
||||
!truncated.contains("omitted"),
|
||||
"line omission marker should not appear when no lines were dropped: {truncated}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exec_output_returns_original_when_within_limits() {
|
||||
let content = "example output\n".repeat(10);
|
||||
|
||||
assert_eq!(truncate::format_output_for_model_body(&content), content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exec_output_reports_omitted_lines_and_keeps_head_and_tail() {
|
||||
let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 100;
|
||||
let content: String = (0..total_lines)
|
||||
.map(|idx| format!("line-{idx}\n"))
|
||||
.collect();
|
||||
|
||||
let truncated = truncate::format_output_for_model_body(&content);
|
||||
let omitted = total_lines - truncate::MODEL_FORMAT_MAX_LINES;
|
||||
let expected_marker = format!("[... omitted {omitted} of {total_lines} lines ...]");
|
||||
|
||||
assert!(
|
||||
truncated.contains(&expected_marker),
|
||||
"missing omitted marker: {truncated}"
|
||||
);
|
||||
assert!(
|
||||
truncated.contains("line-0\n"),
|
||||
"expected head line to remain: {truncated}"
|
||||
);
|
||||
|
||||
let last_line = format!("line-{}\n", total_lines - 1);
|
||||
assert!(
|
||||
truncated.contains(&last_line),
|
||||
"expected tail line to remain: {truncated}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_exec_output_prefers_line_marker_when_both_limits_exceeded() {
|
||||
let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 42;
|
||||
let long_line = "x".repeat(256);
|
||||
let content: String = (0..total_lines)
|
||||
.map(|idx| format!("line-{idx}-{long_line}\n"))
|
||||
.collect();
|
||||
|
||||
let truncated = truncate::format_output_for_model_body(&content);
|
||||
|
||||
assert!(
|
||||
truncated.contains("[... omitted 42 of 298 lines ...]"),
|
||||
"expected omitted marker when line count exceeds limit: {truncated}"
|
||||
);
|
||||
assert!(
|
||||
!truncated.contains("output truncated to fit"),
|
||||
"line omission marker should take precedence over byte marker: {truncated}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_across_multiple_under_limit_texts_and_reports_omitted() {
|
||||
// Arrange: several text items, none exceeding per-item limit, but total exceeds budget.
|
||||
let budget = truncate::MODEL_FORMAT_MAX_BYTES;
|
||||
let t1_len = (budget / 2).saturating_sub(10);
|
||||
let t2_len = (budget / 2).saturating_sub(10);
|
||||
let remaining_after_t1_t2 = budget.saturating_sub(t1_len + t2_len);
|
||||
let t3_len = 50; // gets truncated to remaining_after_t1_t2
|
||||
let t4_len = 5; // omitted
|
||||
let t5_len = 7; // omitted
|
||||
|
||||
let t1 = "a".repeat(t1_len);
|
||||
let t2 = "b".repeat(t2_len);
|
||||
let t3 = "c".repeat(t3_len);
|
||||
let t4 = "d".repeat(t4_len);
|
||||
let t5 = "e".repeat(t5_len);
|
||||
|
||||
let item = ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-omit".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "irrelevant".to_string(),
|
||||
content_items: Some(vec![
|
||||
FunctionCallOutputContentItem::InputText { text: t1 },
|
||||
FunctionCallOutputContentItem::InputText { text: t2 },
|
||||
FunctionCallOutputContentItem::InputImage {
|
||||
image_url: "img:mid".to_string(),
|
||||
},
|
||||
FunctionCallOutputContentItem::InputText { text: t3 },
|
||||
FunctionCallOutputContentItem::InputText { text: t4 },
|
||||
FunctionCallOutputContentItem::InputText { text: t5 },
|
||||
]),
|
||||
success: Some(true),
|
||||
},
|
||||
};
|
||||
|
||||
let mut history = ContextManager::new();
|
||||
history.record_items([&item]);
|
||||
assert_eq!(history.items.len(), 1);
|
||||
let json = serde_json::to_value(&history.items[0]).expect("serialize to json");
|
||||
|
||||
let output = json
|
||||
.get("output")
|
||||
.expect("output field")
|
||||
.as_array()
|
||||
.expect("array output");
|
||||
|
||||
// Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted.
|
||||
assert_eq!(output.len(), 5);
|
||||
|
||||
let first = output[0].as_object().expect("first obj");
|
||||
assert_eq!(first.get("type").unwrap(), "input_text");
|
||||
let first_text = first.get("text").unwrap().as_str().unwrap();
|
||||
assert_eq!(first_text.len(), t1_len);
|
||||
|
||||
let second = output[1].as_object().expect("second obj");
|
||||
assert_eq!(second.get("type").unwrap(), "input_text");
|
||||
let second_text = second.get("text").unwrap().as_str().unwrap();
|
||||
assert_eq!(second_text.len(), t2_len);
|
||||
|
||||
assert_eq!(
|
||||
output[2],
|
||||
serde_json::json!({"type": "input_image", "image_url": "img:mid"})
|
||||
);
|
||||
|
||||
let fourth = output[3].as_object().expect("fourth obj");
|
||||
assert_eq!(fourth.get("type").unwrap(), "input_text");
|
||||
let fourth_text = fourth.get("text").unwrap().as_str().unwrap();
|
||||
assert_eq!(fourth_text.len(), remaining_after_t1_t2);
|
||||
|
||||
let summary = output[4].as_object().expect("summary obj");
|
||||
assert_eq!(summary.get("type").unwrap(), "input_text");
|
||||
let summary_text = summary.get("text").unwrap().as_str().unwrap();
|
||||
assert!(summary_text.contains("omitted 2 text items"));
|
||||
}
|
||||
|
||||
//TODO(aibrahim): run CI in release mode.
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_function_call() {
|
||||
let items = vec![ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "call-x".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_custom_tool_call() {
|
||||
let items = vec![ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "tool-x".to_string(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_adds_missing_output_for_local_shell_call_with_id() {
|
||||
let items = vec![ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "shell-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_removes_orphan_function_call_output() {
|
||||
let items = vec![ResponseItem::FunctionCallOutput {
|
||||
call_id: "orphan-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_removes_orphan_custom_tool_call_output() {
|
||||
let items = vec![ResponseItem::CustomToolCallOutput {
|
||||
call_id: "orphan-2".to_string(),
|
||||
output: "ok".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(h.contents(), vec![]);
|
||||
}
|
||||
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn normalize_mixed_inserts_and_removals() {
|
||||
let items = vec![
|
||||
// Will get an inserted output
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
// Orphan output that should be removed
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
// Will get an inserted custom tool output
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
// Local shell call also gets an inserted function call output
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
|
||||
h.normalize_history();
|
||||
|
||||
assert_eq!(
|
||||
h.contents(),
|
||||
vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: "t1".to_string(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "s1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// In debug builds we panic on normalization errors instead of silently fixing them.
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_function_call_panics_in_debug() {
|
||||
let items = vec![ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "do_it".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "call-x".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_custom_tool_call_panics_in_debug() {
|
||||
let items = vec![ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "tool-x".to_string(),
|
||||
name: "custom".to_string(),
|
||||
input: "{}".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_adds_missing_output_for_local_shell_call_with_id_panics_in_debug() {
|
||||
let items = vec![ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("shell-1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string(), "hi".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_removes_orphan_function_call_output_panics_in_debug() {
|
||||
let items = vec![ResponseItem::FunctionCallOutput {
|
||||
call_id: "orphan-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() {
|
||||
let items = vec![ResponseItem::CustomToolCallOutput {
|
||||
call_id: "orphan-2".to_string(),
|
||||
output: "ok".to_string(),
|
||||
}];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn normalize_mixed_inserts_and_removals_panics_in_debug() {
|
||||
let items = vec![
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "f1".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: "c2".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "ok".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
ResponseItem::CustomToolCall {
|
||||
id: None,
|
||||
status: None,
|
||||
call_id: "t1".to_string(),
|
||||
name: "tool".to_string(),
|
||||
input: "{}".to_string(),
|
||||
},
|
||||
ResponseItem::LocalShellCall {
|
||||
id: None,
|
||||
call_id: Some("s1".to_string()),
|
||||
status: LocalShellStatus::Completed,
|
||||
action: LocalShellAction::Exec(LocalShellExecAction {
|
||||
command: vec!["echo".to_string()],
|
||||
timeout_ms: None,
|
||||
working_directory: None,
|
||||
env: None,
|
||||
user: None,
|
||||
}),
|
||||
},
|
||||
];
|
||||
let mut h = create_history_with_items(items);
|
||||
h.normalize_history();
|
||||
}
|
||||
6
llmx-rs/core/src/context_manager/mod.rs
Normal file
6
llmx-rs/core/src/context_manager/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
mod history;
|
||||
mod normalize;
|
||||
mod truncate;
|
||||
|
||||
pub(crate) use history::ContextManager;
|
||||
pub(crate) use truncate::format_output_for_model_body;
|
||||
213
llmx-rs/core/src/context_manager/normalize.rs
Normal file
213
llmx-rs/core/src/context_manager/normalize.rs
Normal file
@@ -0,0 +1,213 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::util::error_or_panic;
|
||||
|
||||
pub(crate) fn ensure_call_outputs_present(items: &mut Vec<ResponseItem>) {
|
||||
// Collect synthetic outputs to insert immediately after their calls.
|
||||
// Store the insertion position (index of call) alongside the item so
|
||||
// we can insert in reverse order and avoid index shifting.
|
||||
let mut missing_outputs_to_insert: Vec<(usize, ResponseItem)> = Vec::new();
|
||||
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::FunctionCall { call_id, .. } => {
|
||||
let has_output = items.iter().any(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Function call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall { call_id, .. } => {
|
||||
let has_output = items.iter().any(|i| match i {
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Custom tool call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: "aborted".to_string(),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
// LocalShellCall is represented in upstream streams by a FunctionCallOutput
|
||||
ResponseItem::LocalShellCall { call_id, .. } => {
|
||||
if let Some(call_id) = call_id.as_ref() {
|
||||
let has_output = items.iter().any(|i| match i {
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} => existing == call_id,
|
||||
_ => false,
|
||||
});
|
||||
|
||||
if !has_output {
|
||||
error_or_panic(format!(
|
||||
"Local shell call output is missing for call id: {call_id}"
|
||||
));
|
||||
missing_outputs_to_insert.push((
|
||||
idx,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: "aborted".to_string(),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert synthetic outputs in reverse index order to avoid re-indexing.
|
||||
for (idx, output_item) in missing_outputs_to_insert.into_iter().rev() {
|
||||
items.insert(idx + 1, output_item);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn remove_orphan_outputs(items: &mut Vec<ResponseItem>) {
|
||||
let function_call_ids: HashSet<String> = items
|
||||
.iter()
|
||||
.filter_map(|i| match i {
|
||||
ResponseItem::FunctionCall { call_id, .. } => Some(call_id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let local_shell_call_ids: HashSet<String> = items
|
||||
.iter()
|
||||
.filter_map(|i| match i {
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => Some(call_id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let custom_tool_call_ids: HashSet<String> = items
|
||||
.iter()
|
||||
.filter_map(|i| match i {
|
||||
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
items.retain(|item| match item {
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => {
|
||||
let has_match =
|
||||
function_call_ids.contains(call_id) || local_shell_call_ids.contains(call_id);
|
||||
if !has_match {
|
||||
error_or_panic(format!(
|
||||
"Orphan function call output for call id: {call_id}"
|
||||
));
|
||||
}
|
||||
has_match
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => {
|
||||
let has_match = custom_tool_call_ids.contains(call_id);
|
||||
if !has_match {
|
||||
error_or_panic(format!(
|
||||
"Orphan custom tool call output for call id: {call_id}"
|
||||
));
|
||||
}
|
||||
has_match
|
||||
}
|
||||
_ => true,
|
||||
});
|
||||
}
|
||||
|
||||
pub(crate) fn remove_corresponding_for(items: &mut Vec<ResponseItem>, item: &ResponseItem) {
|
||||
match item {
|
||||
ResponseItem::FunctionCall { call_id, .. } => {
|
||||
remove_first_matching(items, |i| {
|
||||
matches!(
|
||||
i,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} if existing == call_id
|
||||
)
|
||||
});
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, .. } => {
|
||||
if let Some(pos) = items.iter().position(|i| {
|
||||
matches!(i, ResponseItem::FunctionCall { call_id: existing, .. } if existing == call_id)
|
||||
}) {
|
||||
items.remove(pos);
|
||||
} else if let Some(pos) = items.iter().position(|i| {
|
||||
matches!(i, ResponseItem::LocalShellCall { call_id: Some(existing), .. } if existing == call_id)
|
||||
}) {
|
||||
items.remove(pos);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall { call_id, .. } => {
|
||||
remove_first_matching(items, |i| {
|
||||
matches!(
|
||||
i,
|
||||
ResponseItem::CustomToolCallOutput {
|
||||
call_id: existing, ..
|
||||
} if existing == call_id
|
||||
)
|
||||
});
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, .. } => {
|
||||
remove_first_matching(
|
||||
items,
|
||||
|i| matches!(i, ResponseItem::CustomToolCall { call_id: existing, .. } if existing == call_id),
|
||||
);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
call_id: Some(call_id),
|
||||
..
|
||||
} => {
|
||||
remove_first_matching(items, |i| {
|
||||
matches!(
|
||||
i,
|
||||
ResponseItem::FunctionCallOutput {
|
||||
call_id: existing, ..
|
||||
} if existing == call_id
|
||||
)
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_first_matching<F>(items: &mut Vec<ResponseItem>, predicate: F)
|
||||
where
|
||||
F: Fn(&ResponseItem) -> bool,
|
||||
{
|
||||
if let Some(pos) = items.iter().position(predicate) {
|
||||
items.remove(pos);
|
||||
}
|
||||
}
|
||||
128
llmx-rs/core/src/context_manager/truncate.rs
Normal file
128
llmx-rs/core/src/context_manager/truncate.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB
|
||||
pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines
|
||||
pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2;
|
||||
pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128
|
||||
pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2;
|
||||
|
||||
pub(crate) fn globally_truncate_function_output_items(
|
||||
items: &[FunctionCallOutputContentItem],
|
||||
) -> Vec<FunctionCallOutputContentItem> {
|
||||
let mut out: Vec<FunctionCallOutputContentItem> = Vec::with_capacity(items.len());
|
||||
let mut remaining = MODEL_FORMAT_MAX_BYTES;
|
||||
let mut omitted_text_items = 0usize;
|
||||
|
||||
for it in items {
|
||||
match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
if remaining == 0 {
|
||||
omitted_text_items += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
let len = text.len();
|
||||
if len <= remaining {
|
||||
out.push(FunctionCallOutputContentItem::InputText { text: text.clone() });
|
||||
remaining -= len;
|
||||
} else {
|
||||
let slice = take_bytes_at_char_boundary(text, remaining);
|
||||
if !slice.is_empty() {
|
||||
out.push(FunctionCallOutputContentItem::InputText {
|
||||
text: slice.to_string(),
|
||||
});
|
||||
}
|
||||
remaining = 0;
|
||||
}
|
||||
}
|
||||
// todo(aibrahim): handle input images; resize
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
out.push(FunctionCallOutputContentItem::InputImage {
|
||||
image_url: image_url.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if omitted_text_items > 0 {
|
||||
out.push(FunctionCallOutputContentItem::InputText {
|
||||
text: format!("[omitted {omitted_text_items} text items ...]"),
|
||||
});
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
pub(crate) fn format_output_for_model_body(content: &str) -> String {
|
||||
// Head+tail truncation for the model: show the beginning and end with an elision.
|
||||
// Clients still receive full streams; only this formatted summary is capped.
|
||||
let total_lines = content.lines().count();
|
||||
if content.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
|
||||
return content.to_string();
|
||||
}
|
||||
let output = truncate_formatted_exec_output(content, total_lines);
|
||||
format!("Total output lines: {total_lines}\n\n{output}")
|
||||
}
|
||||
|
||||
fn truncate_formatted_exec_output(content: &str, total_lines: usize) -> String {
|
||||
let segments: Vec<&str> = content.split_inclusive('\n').collect();
|
||||
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
|
||||
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
|
||||
let omitted = segments.len().saturating_sub(head_take + tail_take);
|
||||
|
||||
let head_slice_end: usize = segments
|
||||
.iter()
|
||||
.take(head_take)
|
||||
.map(|segment| segment.len())
|
||||
.sum();
|
||||
let tail_slice_start: usize = if tail_take == 0 {
|
||||
content.len()
|
||||
} else {
|
||||
content.len()
|
||||
- segments
|
||||
.iter()
|
||||
.rev()
|
||||
.take(tail_take)
|
||||
.map(|segment| segment.len())
|
||||
.sum::<usize>()
|
||||
};
|
||||
let head_slice = &content[..head_slice_end];
|
||||
let tail_slice = &content[tail_slice_start..];
|
||||
let truncated_by_bytes = content.len() > MODEL_FORMAT_MAX_BYTES;
|
||||
// this is a bit wrong. We are counting metadata lines and not just shell output lines.
|
||||
let marker = if omitted > 0 {
|
||||
Some(format!(
|
||||
"\n[... omitted {omitted} of {total_lines} lines ...]\n\n"
|
||||
))
|
||||
} else if truncated_by_bytes {
|
||||
Some(format!(
|
||||
"\n[... output truncated to fit {MODEL_FORMAT_MAX_BYTES} bytes ...]\n\n"
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let marker_len = marker.as_ref().map_or(0, String::len);
|
||||
let base_head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES);
|
||||
let head_budget = base_head_budget.min(MODEL_FORMAT_MAX_BYTES.saturating_sub(marker_len));
|
||||
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
|
||||
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(content.len()));
|
||||
|
||||
result.push_str(head_part);
|
||||
if let Some(marker_text) = marker.as_ref() {
|
||||
result.push_str(marker_text);
|
||||
}
|
||||
|
||||
let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len());
|
||||
if remaining == 0 {
|
||||
return result;
|
||||
}
|
||||
|
||||
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
|
||||
result.push_str(tail_part);
|
||||
|
||||
result
|
||||
}
|
||||
339
llmx-rs/core/src/conversation_manager.rs
Normal file
339
llmx-rs/core/src/conversation_manager.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
use crate::AuthManager;
|
||||
use crate::CodexAuth;
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::CodexSpawnOk;
|
||||
use crate::codex::INITIAL_SUBMIT_ID;
|
||||
use crate::codex_conversation::CodexConversation;
|
||||
use crate::config::Config;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Represents a newly created Codex conversation, including the first event
|
||||
/// (which is [`EventMsg::SessionConfigured`]).
|
||||
pub struct NewConversation {
|
||||
pub conversation_id: ConversationId,
|
||||
pub conversation: Arc<CodexConversation>,
|
||||
pub session_configured: SessionConfiguredEvent,
|
||||
}
|
||||
|
||||
/// [`ConversationManager`] is responsible for creating conversations and
|
||||
/// maintaining them in memory.
|
||||
pub struct ConversationManager {
|
||||
conversations: Arc<RwLock<HashMap<ConversationId, Arc<CodexConversation>>>>,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
session_source: SessionSource,
|
||||
}
|
||||
|
||||
impl ConversationManager {
|
||||
pub fn new(auth_manager: Arc<AuthManager>, session_source: SessionSource) -> Self {
|
||||
Self {
|
||||
conversations: Arc::new(RwLock::new(HashMap::new())),
|
||||
auth_manager,
|
||||
session_source,
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct with a dummy AuthManager containing the provided CodexAuth.
|
||||
/// Used for integration tests: should not be used by ordinary business logic.
|
||||
pub fn with_auth(auth: CodexAuth) -> Self {
|
||||
Self::new(
|
||||
crate::AuthManager::from_auth_for_testing(auth),
|
||||
SessionSource::Exec,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn new_conversation(&self, config: Config) -> CodexResult<NewConversation> {
|
||||
self.spawn_conversation(config, self.auth_manager.clone())
|
||||
.await
|
||||
}
|
||||
|
||||
async fn spawn_conversation(
|
||||
&self,
|
||||
config: Config,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(
|
||||
config,
|
||||
auth_manager,
|
||||
InitialHistory::New,
|
||||
self.session_source.clone(),
|
||||
)
|
||||
.await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
async fn finalize_spawn(
|
||||
&self,
|
||||
codex: Codex,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
// history.
|
||||
let event = codex.next_event().await?;
|
||||
let session_configured = match event {
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::SessionConfigured(session_configured),
|
||||
} if id == INITIAL_SUBMIT_ID => session_configured,
|
||||
_ => {
|
||||
return Err(CodexErr::SessionConfiguredNotFirstEvent);
|
||||
}
|
||||
};
|
||||
|
||||
let conversation = Arc::new(CodexConversation::new(
|
||||
codex,
|
||||
session_configured.rollout_path.clone(),
|
||||
));
|
||||
self.conversations
|
||||
.write()
|
||||
.await
|
||||
.insert(conversation_id, conversation.clone());
|
||||
|
||||
Ok(NewConversation {
|
||||
conversation_id,
|
||||
conversation,
|
||||
session_configured,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_conversation(
|
||||
&self,
|
||||
conversation_id: ConversationId,
|
||||
) -> CodexResult<Arc<CodexConversation>> {
|
||||
let conversations = self.conversations.read().await;
|
||||
conversations
|
||||
.get(&conversation_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||
}
|
||||
|
||||
pub async fn resume_conversation_from_rollout(
|
||||
&self,
|
||||
config: Config,
|
||||
rollout_path: PathBuf,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
|
||||
self.resume_conversation_with_history(config, initial_history, auth_manager)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn resume_conversation_with_history(
|
||||
&self,
|
||||
config: Config,
|
||||
initial_history: InitialHistory,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
) -> CodexResult<NewConversation> {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(
|
||||
config,
|
||||
auth_manager,
|
||||
initial_history,
|
||||
self.session_source.clone(),
|
||||
)
|
||||
.await?;
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
|
||||
/// Removes the conversation from the manager's internal map, though the
|
||||
/// conversation is stored as `Arc<CodexConversation>`, it is possible that
|
||||
/// other references to it exist elsewhere. Returns the conversation if the
|
||||
/// conversation was found and removed.
|
||||
pub async fn remove_conversation(
|
||||
&self,
|
||||
conversation_id: &ConversationId,
|
||||
) -> Option<Arc<CodexConversation>> {
|
||||
self.conversations.write().await.remove(conversation_id)
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by taking messages up to the given position
|
||||
/// (not including the message at the given position) and starting a new
|
||||
/// conversation with identical configuration (unless overridden by the
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
nth_user_message: usize,
|
||||
config: Config,
|
||||
path: PathBuf,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Compute the prefix up to the cut point.
|
||||
let history = RolloutRecorder::get_rollout_history(&path).await?;
|
||||
let history = truncate_before_nth_user_message(history, nth_user_message);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
conversation_id,
|
||||
} = Codex::spawn(config, auth_manager, history, self.session_source.clone()).await?;
|
||||
|
||||
self.finalize_spawn(codex, conversation_id).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
|
||||
/// (0-based) and all items that follow it.
|
||||
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
|
||||
// Work directly on rollout items, and cut the vector at the nth user message input.
|
||||
let items: Vec<RolloutItem> = history.get_rollout_items();
|
||||
|
||||
// Find indices of user message inputs in rollout order.
|
||||
let mut user_positions: Vec<usize> = Vec::new();
|
||||
for (idx, item) in items.iter().enumerate() {
|
||||
if let RolloutItem::ResponseItem(item @ ResponseItem::Message { .. }) = item
|
||||
&& matches!(
|
||||
crate::event_mapping::parse_turn_item(item),
|
||||
Some(TurnItem::UserMessage(_))
|
||||
)
|
||||
{
|
||||
user_positions.push(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// If fewer than or equal to n user messages exist, treat as empty (out of range).
|
||||
if user_positions.len() <= n {
|
||||
return InitialHistory::New;
|
||||
}
|
||||
|
||||
// Cut strictly before the nth user message (do not keep the nth itself).
|
||||
let cut_idx = user_positions[n];
|
||||
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
|
||||
|
||||
if rolled.is_empty() {
|
||||
InitialHistory::New
|
||||
} else {
|
||||
InitialHistory::Forked(rolled)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codex::make_session_and_context;
|
||||
use assert_matches::assert_matches;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drops_from_last_user_only() {
|
||||
let items = [
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
user_msg("u2"),
|
||||
assistant_msg("a3"),
|
||||
ResponseItem::Reasoning {
|
||||
id: "r1".to_string(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "s".to_string(),
|
||||
}],
|
||||
content: None,
|
||||
encrypted_content: None,
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "tool".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
// Wrap as InitialHistory::Forked with response items only.
|
||||
let initial: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
let expected_items = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected_items).unwrap()
|
||||
);
|
||||
|
||||
let initial2: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
|
||||
assert_matches!(truncated2, InitialHistory::New);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignores_session_prefix_messages_when_truncating() {
|
||||
let (session, turn_context) = make_session_and_context();
|
||||
let mut items = session.build_initial_context(&turn_context);
|
||||
items.push(user_msg("feature request"));
|
||||
items.push(assistant_msg("ack"));
|
||||
items.push(user_msg("second question"));
|
||||
items.push(assistant_msg("answer"));
|
||||
|
||||
let rollout_items: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(RolloutItem::ResponseItem)
|
||||
.collect();
|
||||
|
||||
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1);
|
||||
let got_items = truncated.get_rollout_items();
|
||||
|
||||
let expected: Vec<RolloutItem> = vec![
|
||||
RolloutItem::ResponseItem(items[0].clone()),
|
||||
RolloutItem::ResponseItem(items[1].clone()),
|
||||
RolloutItem::ResponseItem(items[2].clone()),
|
||||
];
|
||||
|
||||
assert_eq!(
|
||||
serde_json::to_value(&got_items).unwrap(),
|
||||
serde_json::to_value(&expected).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
244
llmx-rs/core/src/custom_prompts.rs
Normal file
244
llmx-rs/core/src/custom_prompts.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
use codex_protocol::custom_prompts::CustomPrompt;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::fs;
|
||||
|
||||
/// Return the default prompts directory: `$CODEX_HOME/prompts`.
|
||||
/// If `CODEX_HOME` cannot be resolved, returns `None`.
|
||||
pub fn default_prompts_dir() -> Option<PathBuf> {
|
||||
crate::config::find_codex_home()
|
||||
.ok()
|
||||
.map(|home| home.join("prompts"))
|
||||
}
|
||||
|
||||
/// Discover prompt files in the given directory, returning entries sorted by name.
|
||||
/// Non-files are ignored. If the directory does not exist or cannot be read, returns empty.
|
||||
pub async fn discover_prompts_in(dir: &Path) -> Vec<CustomPrompt> {
|
||||
discover_prompts_in_excluding(dir, &HashSet::new()).await
|
||||
}
|
||||
|
||||
/// Discover prompt files in the given directory, excluding any with names in `exclude`.
|
||||
/// Returns entries sorted by name. Non-files are ignored. Missing/unreadable dir yields empty.
|
||||
pub async fn discover_prompts_in_excluding(
|
||||
dir: &Path,
|
||||
exclude: &HashSet<String>,
|
||||
) -> Vec<CustomPrompt> {
|
||||
let mut out: Vec<CustomPrompt> = Vec::new();
|
||||
let mut entries = match fs::read_dir(dir).await {
|
||||
Ok(entries) => entries,
|
||||
Err(_) => return out,
|
||||
};
|
||||
|
||||
while let Ok(Some(entry)) = entries.next_entry().await {
|
||||
let path = entry.path();
|
||||
let is_file_like = fs::metadata(&path)
|
||||
.await
|
||||
.map(|m| m.is_file())
|
||||
.unwrap_or(false);
|
||||
if !is_file_like {
|
||||
continue;
|
||||
}
|
||||
// Only include Markdown files with a .md extension.
|
||||
let is_md = path
|
||||
.extension()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|ext| ext.eq_ignore_ascii_case("md"))
|
||||
.unwrap_or(false);
|
||||
if !is_md {
|
||||
continue;
|
||||
}
|
||||
let Some(name) = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(str::to_string)
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
if exclude.contains(&name) {
|
||||
continue;
|
||||
}
|
||||
let content = match fs::read_to_string(&path).await {
|
||||
Ok(s) => s,
|
||||
Err(_) => continue,
|
||||
};
|
||||
let (description, argument_hint, body) = parse_frontmatter(&content);
|
||||
out.push(CustomPrompt {
|
||||
name,
|
||||
path,
|
||||
content: body,
|
||||
description,
|
||||
argument_hint,
|
||||
});
|
||||
}
|
||||
out.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
out
|
||||
}
|
||||
|
||||
/// Parse optional YAML-like frontmatter at the beginning of `content`.
|
||||
/// Supported keys:
|
||||
/// - `description`: short description shown in the slash popup
|
||||
/// - `argument-hint` or `argument_hint`: brief hint string shown after the description
|
||||
/// Returns (description, argument_hint, body_without_frontmatter).
|
||||
fn parse_frontmatter(content: &str) -> (Option<String>, Option<String>, String) {
|
||||
let mut segments = content.split_inclusive('\n');
|
||||
let Some(first_segment) = segments.next() else {
|
||||
return (None, None, String::new());
|
||||
};
|
||||
let first_line = first_segment.trim_end_matches(['\r', '\n']);
|
||||
if first_line.trim() != "---" {
|
||||
return (None, None, content.to_string());
|
||||
}
|
||||
|
||||
let mut desc: Option<String> = None;
|
||||
let mut hint: Option<String> = None;
|
||||
let mut frontmatter_closed = false;
|
||||
let mut consumed = first_segment.len();
|
||||
|
||||
for segment in segments {
|
||||
let line = segment.trim_end_matches(['\r', '\n']);
|
||||
let trimmed = line.trim();
|
||||
|
||||
if trimmed == "---" {
|
||||
frontmatter_closed = true;
|
||||
consumed += segment.len();
|
||||
break;
|
||||
}
|
||||
|
||||
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||
consumed += segment.len();
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((k, v)) = trimmed.split_once(':') {
|
||||
let key = k.trim().to_ascii_lowercase();
|
||||
let mut val = v.trim().to_string();
|
||||
if val.len() >= 2 {
|
||||
let bytes = val.as_bytes();
|
||||
let first = bytes[0];
|
||||
let last = bytes[bytes.len() - 1];
|
||||
if (first == b'\"' && last == b'\"') || (first == b'\'' && last == b'\'') {
|
||||
val = val[1..val.len().saturating_sub(1)].to_string();
|
||||
}
|
||||
}
|
||||
match key.as_str() {
|
||||
"description" => desc = Some(val),
|
||||
"argument-hint" | "argument_hint" => hint = Some(val),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
consumed += segment.len();
|
||||
}
|
||||
|
||||
if !frontmatter_closed {
|
||||
// Unterminated frontmatter: treat input as-is.
|
||||
return (None, None, content.to_string());
|
||||
}
|
||||
|
||||
let body = if consumed >= content.len() {
|
||||
String::new()
|
||||
} else {
|
||||
content[consumed..].to_string()
|
||||
};
|
||||
(desc, hint, body)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_when_dir_missing() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let missing = tmp.path().join("nope");
|
||||
let found = discover_prompts_in(&missing).await;
|
||||
assert!(found.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn discovers_and_sorts_files() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let dir = tmp.path();
|
||||
fs::write(dir.join("b.md"), b"b").unwrap();
|
||||
fs::write(dir.join("a.md"), b"a").unwrap();
|
||||
fs::create_dir(dir.join("subdir")).unwrap();
|
||||
let found = discover_prompts_in(dir).await;
|
||||
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
|
||||
assert_eq!(names, vec!["a", "b"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn excludes_builtins() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let dir = tmp.path();
|
||||
fs::write(dir.join("init.md"), b"ignored").unwrap();
|
||||
fs::write(dir.join("foo.md"), b"ok").unwrap();
|
||||
let mut exclude = HashSet::new();
|
||||
exclude.insert("init".to_string());
|
||||
let found = discover_prompts_in_excluding(dir, &exclude).await;
|
||||
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
|
||||
assert_eq!(names, vec!["foo"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skips_non_utf8_files() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let dir = tmp.path();
|
||||
// Valid UTF-8 file
|
||||
fs::write(dir.join("good.md"), b"hello").unwrap();
|
||||
// Invalid UTF-8 content in .md file (e.g., lone 0xFF byte)
|
||||
fs::write(dir.join("bad.md"), vec![0xFF, 0xFE, b'\n']).unwrap();
|
||||
let found = discover_prompts_in(dir).await;
|
||||
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
|
||||
assert_eq!(names, vec!["good"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg(unix)]
|
||||
async fn discovers_symlinked_md_files() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let dir = tmp.path();
|
||||
|
||||
// Create a real file
|
||||
fs::write(dir.join("real.md"), b"real content").unwrap();
|
||||
|
||||
// Create a symlink to the real file
|
||||
std::os::unix::fs::symlink(dir.join("real.md"), dir.join("link.md")).unwrap();
|
||||
|
||||
let found = discover_prompts_in(dir).await;
|
||||
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
|
||||
|
||||
// Both real and link should be discovered, sorted alphabetically
|
||||
assert_eq!(names, vec!["link", "real"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parses_frontmatter_and_strips_from_body() {
|
||||
let tmp = tempdir().expect("create TempDir");
|
||||
let dir = tmp.path();
|
||||
let file = dir.join("withmeta.md");
|
||||
let text = "---\nname: ignored\ndescription: \"Quick review command\"\nargument-hint: \"[file] [priority]\"\n---\nActual body with $1 and $ARGUMENTS";
|
||||
fs::write(&file, text).unwrap();
|
||||
|
||||
let found = discover_prompts_in(dir).await;
|
||||
assert_eq!(found.len(), 1);
|
||||
let p = &found[0];
|
||||
assert_eq!(p.name, "withmeta");
|
||||
assert_eq!(p.description.as_deref(), Some("Quick review command"));
|
||||
assert_eq!(p.argument_hint.as_deref(), Some("[file] [priority]"));
|
||||
// Body should not include the frontmatter delimiters.
|
||||
assert_eq!(p.content, "Actual body with $1 and $ARGUMENTS");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_frontmatter_preserves_body_newlines() {
|
||||
let content = "---\r\ndescription: \"Line endings\"\r\nargument_hint: \"[arg]\"\r\n---\r\nFirst line\r\nSecond line\r\n";
|
||||
let (desc, hint, body) = parse_frontmatter(content);
|
||||
assert_eq!(desc.as_deref(), Some("Line endings"));
|
||||
assert_eq!(hint.as_deref(), Some("[arg]"));
|
||||
assert_eq!(body, "First line\r\nSecond line\r\n");
|
||||
}
|
||||
}
|
||||
375
llmx-rs/core/src/default_client.rs
Normal file
375
llmx-rs/core/src/default_client.rs
Normal file
@@ -0,0 +1,375 @@
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use http::Error as HttpError;
|
||||
use reqwest::IntoUrl;
|
||||
use reqwest::Method;
|
||||
use reqwest::Response;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// Set this to add a suffix to the User-Agent string.
|
||||
///
|
||||
/// It is not ideal that we're using a global singleton for this.
|
||||
/// This is primarily designed to differentiate MCP clients from each other.
|
||||
/// Because there can only be one MCP server per process, it should be safe for this to be a global static.
|
||||
/// However, future users of this should use this with caution as a result.
|
||||
/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a
|
||||
/// lot of wiring and it's easy to miss code paths by doing so.
|
||||
/// See https://github.com/openai/codex/pull/3388/files for an example of what that would look like.
|
||||
/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var
|
||||
/// or having to set data that they already specified in the mcp initialize request somewhere else.
|
||||
///
|
||||
/// A space is automatically added between the suffix and the rest of the User-Agent string.
|
||||
/// The full user agent string is returned from the mcp initialize response.
|
||||
/// Parenthesis will be added by Codex. This should only specify what goes inside of the parenthesis.
|
||||
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
|
||||
pub const DEFAULT_ORIGINATOR: &str = "codex_cli_rs";
|
||||
pub const CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "CODEX_INTERNAL_ORIGINATOR_OVERRIDE";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CodexHttpClient {
|
||||
inner: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CodexHttpClient {
|
||||
fn new(inner: reqwest::Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn get<U>(&self, url: U) -> CodexRequestBuilder
|
||||
where
|
||||
U: IntoUrl,
|
||||
{
|
||||
self.request(Method::GET, url)
|
||||
}
|
||||
|
||||
pub fn post<U>(&self, url: U) -> CodexRequestBuilder
|
||||
where
|
||||
U: IntoUrl,
|
||||
{
|
||||
self.request(Method::POST, url)
|
||||
}
|
||||
|
||||
pub fn request<U>(&self, method: Method, url: U) -> CodexRequestBuilder
|
||||
where
|
||||
U: IntoUrl,
|
||||
{
|
||||
let url_str = url.as_str().to_string();
|
||||
CodexRequestBuilder::new(self.inner.request(method.clone(), url), method, url_str)
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use = "requests are not sent unless `send` is awaited"]
|
||||
#[derive(Debug)]
|
||||
pub struct CodexRequestBuilder {
|
||||
builder: reqwest::RequestBuilder,
|
||||
method: Method,
|
||||
url: String,
|
||||
}
|
||||
|
||||
impl CodexRequestBuilder {
|
||||
fn new(builder: reqwest::RequestBuilder, method: Method, url: String) -> Self {
|
||||
Self {
|
||||
builder,
|
||||
method,
|
||||
url,
|
||||
}
|
||||
}
|
||||
|
||||
fn map(self, f: impl FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder) -> Self {
|
||||
Self {
|
||||
builder: f(self.builder),
|
||||
method: self.method,
|
||||
url: self.url,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn header<K, V>(self, key: K, value: V) -> Self
|
||||
where
|
||||
HeaderName: TryFrom<K>,
|
||||
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
|
||||
HeaderValue: TryFrom<V>,
|
||||
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
|
||||
{
|
||||
self.map(|builder| builder.header(key, value))
|
||||
}
|
||||
|
||||
pub fn bearer_auth<T>(self, token: T) -> Self
|
||||
where
|
||||
T: Display,
|
||||
{
|
||||
self.map(|builder| builder.bearer_auth(token))
|
||||
}
|
||||
|
||||
pub fn json<T>(self, value: &T) -> Self
|
||||
where
|
||||
T: ?Sized + Serialize,
|
||||
{
|
||||
self.map(|builder| builder.json(value))
|
||||
}
|
||||
|
||||
pub async fn send(self) -> Result<Response, reqwest::Error> {
|
||||
match self.builder.send().await {
|
||||
Ok(response) => {
|
||||
let request_ids = Self::extract_request_ids(&response);
|
||||
tracing::debug!(
|
||||
method = %self.method,
|
||||
url = %self.url,
|
||||
status = %response.status(),
|
||||
request_ids = ?request_ids,
|
||||
version = ?response.version(),
|
||||
"Request completed"
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
Err(error) => {
|
||||
let status = error.status();
|
||||
tracing::debug!(
|
||||
method = %self.method,
|
||||
url = %self.url,
|
||||
status = status.map(|s| s.as_u16()),
|
||||
error = %error,
|
||||
"Request failed"
|
||||
);
|
||||
Err(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_request_ids(response: &Response) -> HashMap<String, String> {
|
||||
["cf-ray", "x-request-id", "x-oai-request-id"]
|
||||
.iter()
|
||||
.filter_map(|&name| {
|
||||
let header_name = HeaderName::from_static(name);
|
||||
let value = response.headers().get(header_name)?;
|
||||
let value = value.to_str().ok()?.to_owned();
|
||||
Some((name.to_owned(), value))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Originator {
|
||||
pub value: String,
|
||||
pub header_value: HeaderValue,
|
||||
}
|
||||
static ORIGINATOR: OnceLock<Originator> = OnceLock::new();
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum SetOriginatorError {
|
||||
InvalidHeaderValue,
|
||||
AlreadyInitialized,
|
||||
}
|
||||
|
||||
fn get_originator_value(provided: Option<String>) -> Originator {
|
||||
let value = std::env::var(CODEX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
|
||||
.ok()
|
||||
.or(provided)
|
||||
.unwrap_or(DEFAULT_ORIGINATOR.to_string());
|
||||
|
||||
match HeaderValue::from_str(&value) {
|
||||
Ok(header_value) => Originator {
|
||||
value,
|
||||
header_value,
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::error!("Unable to turn originator override {value} into header value: {e}");
|
||||
Originator {
|
||||
value: DEFAULT_ORIGINATOR.to_string(),
|
||||
header_value: HeaderValue::from_static(DEFAULT_ORIGINATOR),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_default_originator(value: String) -> Result<(), SetOriginatorError> {
|
||||
let originator = get_originator_value(Some(value));
|
||||
ORIGINATOR
|
||||
.set(originator)
|
||||
.map_err(|_| SetOriginatorError::AlreadyInitialized)
|
||||
}
|
||||
|
||||
pub fn originator() -> &'static Originator {
|
||||
ORIGINATOR.get_or_init(|| get_originator_value(None))
|
||||
}
|
||||
|
||||
pub fn get_codex_user_agent() -> String {
|
||||
let build_version = env!("CARGO_PKG_VERSION");
|
||||
let os_info = os_info::get();
|
||||
let prefix = format!(
|
||||
"{}/{build_version} ({} {}; {}) {}",
|
||||
originator().value.as_str(),
|
||||
os_info.os_type(),
|
||||
os_info.version(),
|
||||
os_info.architecture().unwrap_or("unknown"),
|
||||
crate::terminal::user_agent()
|
||||
);
|
||||
let suffix = USER_AGENT_SUFFIX
|
||||
.lock()
|
||||
.ok()
|
||||
.and_then(|guard| guard.clone());
|
||||
let suffix = suffix
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map_or_else(String::new, |value| format!(" ({value})"));
|
||||
|
||||
let candidate = format!("{prefix}{suffix}");
|
||||
sanitize_user_agent(candidate, &prefix)
|
||||
}
|
||||
|
||||
/// Sanitize the user agent string.
|
||||
///
|
||||
/// Invalid characters are replaced with an underscore.
|
||||
///
|
||||
/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR.
|
||||
fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
|
||||
if HeaderValue::from_str(candidate.as_str()).is_ok() {
|
||||
return candidate;
|
||||
}
|
||||
|
||||
let sanitized: String = candidate
|
||||
.chars()
|
||||
.map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' })
|
||||
.collect();
|
||||
if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() {
|
||||
tracing::warn!(
|
||||
"Sanitized Codex user agent because provided suffix contained invalid header characters"
|
||||
);
|
||||
sanitized
|
||||
} else if HeaderValue::from_str(fallback).is_ok() {
|
||||
tracing::warn!(
|
||||
"Falling back to base Codex user agent because provided suffix could not be sanitized"
|
||||
);
|
||||
fallback.to_string()
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"Falling back to default Codex originator because base user agent string is invalid"
|
||||
);
|
||||
originator().value.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an HTTP client with default `originator` and `User-Agent` headers set.
|
||||
pub fn create_client() -> CodexHttpClient {
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("originator", originator().header_value.clone());
|
||||
let ua = get_codex_user_agent();
|
||||
|
||||
let mut builder = reqwest::Client::builder()
|
||||
// Set UA via dedicated helper to avoid header validation pitfalls
|
||||
.user_agent(ua)
|
||||
.default_headers(headers);
|
||||
if is_sandboxed() {
|
||||
builder = builder.no_proxy();
|
||||
}
|
||||
|
||||
let inner = builder.build().unwrap_or_else(|_| reqwest::Client::new());
|
||||
CodexHttpClient::new(inner)
|
||||
}
|
||||
|
||||
fn is_sandboxed() -> bool {
|
||||
std::env::var(CODEX_SANDBOX_ENV_VAR).as_deref() == Ok("seatbelt")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use core_test_support::skip_if_no_network;
|
||||
|
||||
#[test]
|
||||
fn test_get_codex_user_agent() {
|
||||
let user_agent = get_codex_user_agent();
|
||||
assert!(user_agent.starts_with("codex_cli_rs/"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_create_client_sets_default_headers() {
|
||||
skip_if_no_network!();
|
||||
|
||||
use wiremock::Mock;
|
||||
use wiremock::MockServer;
|
||||
use wiremock::ResponseTemplate;
|
||||
use wiremock::matchers::method;
|
||||
use wiremock::matchers::path;
|
||||
|
||||
let client = create_client();
|
||||
|
||||
// Spin up a local mock server and capture a request.
|
||||
let server = MockServer::start().await;
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/"))
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(&server)
|
||||
.await;
|
||||
|
||||
let resp = client
|
||||
.get(server.uri())
|
||||
.send()
|
||||
.await
|
||||
.expect("failed to send request");
|
||||
assert!(resp.status().is_success());
|
||||
|
||||
let requests = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch received requests");
|
||||
assert!(!requests.is_empty());
|
||||
let headers = &requests[0].headers;
|
||||
|
||||
// originator header is set to the provided value
|
||||
let originator_header = headers
|
||||
.get("originator")
|
||||
.expect("originator header missing");
|
||||
assert_eq!(originator_header.to_str().unwrap(), "codex_cli_rs");
|
||||
|
||||
// User-Agent matches the computed Codex UA for that originator
|
||||
let expected_ua = get_codex_user_agent();
|
||||
let ua_header = headers
|
||||
.get("user-agent")
|
||||
.expect("user-agent header missing");
|
||||
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\rsuffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_suffix_is_sanitized2() {
|
||||
let prefix = "codex_cli_rs/0.0.0";
|
||||
let suffix = "bad\0suffix";
|
||||
|
||||
assert_eq!(
|
||||
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
|
||||
"codex_cli_rs/0.0.0 (bad_suffix)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(target_os = "macos")]
|
||||
fn test_macos() {
|
||||
use regex_lite::Regex;
|
||||
let user_agent = get_codex_user_agent();
|
||||
let re = Regex::new(
|
||||
r"^codex_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(re.is_match(&user_agent));
|
||||
}
|
||||
}
|
||||
347
llmx-rs/core/src/environment_context.rs
Normal file
347
llmx-rs/core/src/environment_context.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use strum_macros::Display as DeriveDisplay;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::shell::Shell;
|
||||
use codex_protocol::config_types::SandboxMode;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_CLOSE_TAG;
|
||||
use codex_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
#[strum(serialize_all = "kebab-case")]
|
||||
pub enum NetworkAccess {
|
||||
Restricted,
|
||||
Enabled,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename = "environment_context", rename_all = "snake_case")]
|
||||
pub(crate) struct EnvironmentContext {
|
||||
pub cwd: Option<PathBuf>,
|
||||
pub approval_policy: Option<AskForApproval>,
|
||||
pub sandbox_mode: Option<SandboxMode>,
|
||||
pub network_access: Option<NetworkAccess>,
|
||||
pub writable_roots: Option<Vec<PathBuf>>,
|
||||
pub shell: Option<Shell>,
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
pub fn new(
|
||||
cwd: Option<PathBuf>,
|
||||
approval_policy: Option<AskForApproval>,
|
||||
sandbox_policy: Option<SandboxPolicy>,
|
||||
shell: Option<Shell>,
|
||||
) -> Self {
|
||||
Self {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_mode: match sandbox_policy {
|
||||
Some(SandboxPolicy::DangerFullAccess) => Some(SandboxMode::DangerFullAccess),
|
||||
Some(SandboxPolicy::ReadOnly) => Some(SandboxMode::ReadOnly),
|
||||
Some(SandboxPolicy::WorkspaceWrite { .. }) => Some(SandboxMode::WorkspaceWrite),
|
||||
None => None,
|
||||
},
|
||||
network_access: match sandbox_policy {
|
||||
Some(SandboxPolicy::DangerFullAccess) => Some(NetworkAccess::Enabled),
|
||||
Some(SandboxPolicy::ReadOnly) => Some(NetworkAccess::Restricted),
|
||||
Some(SandboxPolicy::WorkspaceWrite { network_access, .. }) => {
|
||||
if network_access {
|
||||
Some(NetworkAccess::Enabled)
|
||||
} else {
|
||||
Some(NetworkAccess::Restricted)
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
},
|
||||
writable_roots: match sandbox_policy {
|
||||
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
|
||||
if writable_roots.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(writable_roots)
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
},
|
||||
shell,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compares two environment contexts, ignoring the shell. Useful when
|
||||
/// comparing turn to turn, since the initial environment_context will
|
||||
/// include the shell, and then it is not configurable from turn to turn.
|
||||
pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool {
|
||||
let EnvironmentContext {
|
||||
cwd,
|
||||
approval_policy,
|
||||
sandbox_mode,
|
||||
network_access,
|
||||
writable_roots,
|
||||
// should compare all fields except shell
|
||||
shell: _,
|
||||
} = other;
|
||||
|
||||
self.cwd == *cwd
|
||||
&& self.approval_policy == *approval_policy
|
||||
&& self.sandbox_mode == *sandbox_mode
|
||||
&& self.network_access == *network_access
|
||||
&& self.writable_roots == *writable_roots
|
||||
}
|
||||
|
||||
pub fn diff(before: &TurnContext, after: &TurnContext) -> Self {
|
||||
let cwd = if before.cwd != after.cwd {
|
||||
Some(after.cwd.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let approval_policy = if before.approval_policy != after.approval_policy {
|
||||
Some(after.approval_policy)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let sandbox_policy = if before.sandbox_policy != after.sandbox_policy {
|
||||
Some(after.sandbox_policy.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&TurnContext> for EnvironmentContext {
|
||||
fn from(turn_context: &TurnContext) -> Self {
|
||||
Self::new(
|
||||
Some(turn_context.cwd.clone()),
|
||||
Some(turn_context.approval_policy),
|
||||
Some(turn_context.sandbox_policy.clone()),
|
||||
// Shell is not configurable from turn to turn
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl EnvironmentContext {
|
||||
/// Serializes the environment context to XML. Libraries like `quick-xml`
|
||||
/// require custom macros to handle Enums with newtypes, so we just do it
|
||||
/// manually, to keep things simple. Output looks like:
|
||||
///
|
||||
/// ```xml
|
||||
/// <environment_context>
|
||||
/// <cwd>...</cwd>
|
||||
/// <approval_policy>...</approval_policy>
|
||||
/// <sandbox_mode>...</sandbox_mode>
|
||||
/// <writable_roots>...</writable_roots>
|
||||
/// <network_access>...</network_access>
|
||||
/// <shell>...</shell>
|
||||
/// </environment_context>
|
||||
/// ```
|
||||
pub fn serialize_to_xml(self) -> String {
|
||||
let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()];
|
||||
if let Some(cwd) = self.cwd {
|
||||
lines.push(format!(" <cwd>{}</cwd>", cwd.to_string_lossy()));
|
||||
}
|
||||
if let Some(approval_policy) = self.approval_policy {
|
||||
lines.push(format!(
|
||||
" <approval_policy>{approval_policy}</approval_policy>"
|
||||
));
|
||||
}
|
||||
if let Some(sandbox_mode) = self.sandbox_mode {
|
||||
lines.push(format!(" <sandbox_mode>{sandbox_mode}</sandbox_mode>"));
|
||||
}
|
||||
if let Some(network_access) = self.network_access {
|
||||
lines.push(format!(
|
||||
" <network_access>{network_access}</network_access>"
|
||||
));
|
||||
}
|
||||
if let Some(writable_roots) = self.writable_roots {
|
||||
lines.push(" <writable_roots>".to_string());
|
||||
for writable_root in writable_roots {
|
||||
lines.push(format!(
|
||||
" <root>{}</root>",
|
||||
writable_root.to_string_lossy()
|
||||
));
|
||||
}
|
||||
lines.push(" </writable_roots>".to_string());
|
||||
}
|
||||
if let Some(shell) = self.shell
|
||||
&& let Some(shell_name) = shell.name()
|
||||
{
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
}
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EnvironmentContext> for ResponseItem {
|
||||
fn from(ec: EnvironmentContext) -> Self {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: ec.serialize_to_xml(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::shell::BashShell;
|
||||
use crate::shell::ZshShell;
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
|
||||
SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
|
||||
network_access,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_workspace_write_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<cwd>/repo</cwd>
|
||||
<approval_policy>on-request</approval_policy>
|
||||
<sandbox_mode>workspace-write</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
<writable_roots>
|
||||
<root>/repo</root>
|
||||
<root>/tmp</root>
|
||||
</writable_roots>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_read_only_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::Never),
|
||||
Some(SandboxPolicy::ReadOnly),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>never</approval_policy>
|
||||
<sandbox_mode>read-only</sandbox_mode>
|
||||
<network_access>restricted</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_full_access_environment_context() {
|
||||
let context = EnvironmentContext::new(
|
||||
None,
|
||||
Some(AskForApproval::OnFailure),
|
||||
Some(SandboxPolicy::DangerFullAccess),
|
||||
None,
|
||||
);
|
||||
|
||||
let expected = r#"<environment_context>
|
||||
<approval_policy>on-failure</approval_policy>
|
||||
<sandbox_mode>danger-full-access</sandbox_mode>
|
||||
<network_access>enabled</network_access>
|
||||
</environment_context>"#;
|
||||
|
||||
assert_eq!(context.serialize_to_xml(), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_approval_policy() {
|
||||
// Approval policy
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::Never),
|
||||
Some(workspace_write_policy(vec!["/repo"], true)),
|
||||
None,
|
||||
);
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_sandbox_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_read_only_policy()),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(SandboxPolicy::new_workspace_write_policy()),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_compares_workspace_write_policy() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
|
||||
None,
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(!context1.equals_except_shell(&context2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn equals_except_shell_ignores_shell() {
|
||||
let context1 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Bash(BashShell {
|
||||
shell_path: "/bin/bash".into(),
|
||||
bashrc_path: "/home/user/.bashrc".into(),
|
||||
})),
|
||||
);
|
||||
let context2 = EnvironmentContext::new(
|
||||
Some(PathBuf::from("/repo")),
|
||||
Some(AskForApproval::OnRequest),
|
||||
Some(workspace_write_policy(vec!["/repo"], false)),
|
||||
Some(Shell::Zsh(ZshShell {
|
||||
shell_path: "/bin/zsh".into(),
|
||||
zshrc_path: "/home/user/.zshrc".into(),
|
||||
})),
|
||||
);
|
||||
|
||||
assert!(context1.equals_except_shell(&context2));
|
||||
}
|
||||
}
|
||||
773
llmx-rs/core/src/error.rs
Normal file
773
llmx-rs/core/src/error.rs
Normal file
@@ -0,0 +1,773 @@
|
||||
use crate::codex::ProcessedResponseItem;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::token_data::KnownPlan;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::truncate::truncate_middle;
|
||||
use chrono::DateTime;
|
||||
use chrono::Datelike;
|
||||
use chrono::Local;
|
||||
use chrono::Utc;
|
||||
use codex_async_utils::CancelErr;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinError;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, CodexErr>;
|
||||
|
||||
/// Limit UI error messages to a reasonable size while keeping useful context.
|
||||
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 4 KiB
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum SandboxErr {
|
||||
/// Error from sandbox execution
|
||||
#[error(
|
||||
"sandbox denied exec error, exit code: {}, stdout: {}, stderr: {}",
|
||||
.output.exit_code, .output.stdout.text, .output.stderr.text
|
||||
)]
|
||||
Denied { output: Box<ExecToolCallOutput> },
|
||||
|
||||
/// Error from linux seccomp filter setup
|
||||
#[cfg(target_os = "linux")]
|
||||
#[error("seccomp setup error")]
|
||||
SeccompInstall(#[from] seccompiler::Error),
|
||||
|
||||
/// Error from linux seccomp backend
|
||||
#[cfg(target_os = "linux")]
|
||||
#[error("seccomp backend error")]
|
||||
SeccompBackend(#[from] seccompiler::BackendError),
|
||||
|
||||
/// Command timed out
|
||||
#[error("command timed out")]
|
||||
Timeout { output: Box<ExecToolCallOutput> },
|
||||
|
||||
/// Command was killed by a signal
|
||||
#[error("command was killed by a signal")]
|
||||
Signal(i32),
|
||||
|
||||
/// Error from linux landlock
|
||||
#[error("Landlock was not able to fully enforce all sandbox rules")]
|
||||
LandlockRestrict,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CodexErr {
|
||||
// todo(aibrahim): git rid of this error carrying the dangling artifacts
|
||||
#[error("turn aborted. Something went wrong? Hit `/feedback` to report the issue.")]
|
||||
TurnAborted {
|
||||
dangling_artifacts: Vec<ProcessedResponseItem>,
|
||||
},
|
||||
|
||||
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
|
||||
/// handshake has succeeded but **before** it finished emitting `response.completed`.
|
||||
///
|
||||
/// The Session loop treats this as a transient error and will automatically retry the turn.
|
||||
///
|
||||
/// Optionally includes the requested delay before retrying the turn.
|
||||
#[error("stream disconnected before completion: {0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
|
||||
#[error(
|
||||
"Codex ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying."
|
||||
)]
|
||||
ContextWindowExceeded,
|
||||
|
||||
#[error("no conversation with id: {0}")]
|
||||
ConversationNotFound(ConversationId),
|
||||
|
||||
#[error("session configured event was not the first event in the stream")]
|
||||
SessionConfiguredNotFirstEvent,
|
||||
|
||||
/// Returned by run_command_stream when the spawned child process timed out (10s).
|
||||
#[error("timeout waiting for child process to exit")]
|
||||
Timeout,
|
||||
|
||||
/// Returned by run_command_stream when the child could not be spawned (its stdout/stderr pipes
|
||||
/// could not be captured). Analogous to the previous `CodexError::Spawn` variant.
|
||||
#[error("spawn failed: child stdout/stderr not captured")]
|
||||
Spawn,
|
||||
|
||||
/// Returned by run_command_stream when the user pressed Ctrl‑C (SIGINT). Session uses this to
|
||||
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
|
||||
#[error("interrupted (Ctrl-C). Something went wrong? Hit `/feedback` to report the issue.")]
|
||||
Interrupted,
|
||||
|
||||
/// Unexpected HTTP status code.
|
||||
#[error("{0}")]
|
||||
UnexpectedStatus(UnexpectedResponseError),
|
||||
|
||||
#[error("{0}")]
|
||||
UsageLimitReached(UsageLimitReachedError),
|
||||
|
||||
#[error("{0}")]
|
||||
ResponseStreamFailed(ResponseStreamFailed),
|
||||
|
||||
#[error("{0}")]
|
||||
ConnectionFailed(ConnectionFailedError),
|
||||
|
||||
#[error("Quota exceeded. Check your plan and billing details.")]
|
||||
QuotaExceeded,
|
||||
|
||||
#[error(
|
||||
"To use Codex with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
|
||||
)]
|
||||
UsageNotIncluded,
|
||||
|
||||
#[error("We're currently experiencing high demand, which may cause temporary errors.")]
|
||||
InternalServerError,
|
||||
|
||||
/// Retry limit exceeded.
|
||||
#[error("{0}")]
|
||||
RetryLimit(RetryLimitReachedError),
|
||||
|
||||
/// Agent loop died unexpectedly
|
||||
#[error("internal error; agent loop died unexpectedly")]
|
||||
InternalAgentDied,
|
||||
|
||||
/// Sandbox error
|
||||
#[error("sandbox error: {0}")]
|
||||
Sandbox(#[from] SandboxErr),
|
||||
|
||||
#[error("codex-linux-sandbox was required but not provided")]
|
||||
LandlockSandboxExecutableNotProvided,
|
||||
|
||||
#[error("unsupported operation: {0}")]
|
||||
UnsupportedOperation(String),
|
||||
|
||||
#[error("{0}")]
|
||||
RefreshTokenFailed(RefreshTokenFailedError),
|
||||
|
||||
#[error("Fatal error: {0}")]
|
||||
Fatal(String),
|
||||
|
||||
// -----------------------------------------------------------------
|
||||
// Automatic conversions for common external error types
|
||||
// -----------------------------------------------------------------
|
||||
#[error(transparent)]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[error(transparent)]
|
||||
LandlockRuleset(#[from] landlock::RulesetError),
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[error(transparent)]
|
||||
LandlockPathFd(#[from] landlock::PathFdError),
|
||||
|
||||
#[error(transparent)]
|
||||
TokioJoin(#[from] JoinError),
|
||||
|
||||
#[error("{0}")]
|
||||
EnvVar(EnvVarError),
|
||||
}
|
||||
|
||||
impl From<CancelErr> for CodexErr {
|
||||
fn from(_: CancelErr) -> Self {
|
||||
CodexErr::TurnAborted {
|
||||
dangling_artifacts: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConnectionFailedError {
|
||||
pub source: reqwest::Error,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ConnectionFailedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Connection failed: {}", self.source)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ResponseStreamFailed {
|
||||
pub source: reqwest::Error,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ResponseStreamFailed {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"Error while reading the server response: {}{}",
|
||||
self.source,
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.map(|id| format!(", request id: {id}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Error)]
|
||||
#[error("{message}")]
|
||||
pub struct RefreshTokenFailedError {
|
||||
pub reason: RefreshTokenFailedReason,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl RefreshTokenFailedError {
|
||||
pub fn new(reason: RefreshTokenFailedReason, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RefreshTokenFailedReason {
|
||||
Expired,
|
||||
Exhausted,
|
||||
Revoked,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UnexpectedResponseError {
|
||||
pub status: StatusCode,
|
||||
pub body: String,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
const CLOUDFLARE_BLOCKED_MESSAGE: &str =
|
||||
"Access blocked by Cloudflare. This usually happens when connecting from a restricted region";
|
||||
|
||||
impl UnexpectedResponseError {
|
||||
fn friendly_message(&self) -> Option<String> {
|
||||
if self.status != StatusCode::FORBIDDEN {
|
||||
return None;
|
||||
}
|
||||
|
||||
if !self.body.contains("Cloudflare") || !self.body.contains("blocked") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut message = format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {})", self.status);
|
||||
if let Some(id) = &self.request_id {
|
||||
message.push_str(&format!(", request id: {id}"));
|
||||
}
|
||||
|
||||
Some(message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UnexpectedResponseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if let Some(friendly) = self.friendly_message() {
|
||||
write!(f, "{friendly}")
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"unexpected status {}: {}{}",
|
||||
self.status,
|
||||
self.body,
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.map(|id| format!(", request id: {id}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for UnexpectedResponseError {}
|
||||
#[derive(Debug)]
|
||||
pub struct RetryLimitReachedError {
|
||||
pub status: StatusCode,
|
||||
pub request_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RetryLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"exceeded retry limit, last status: {}{}",
|
||||
self.status,
|
||||
self.request_id
|
||||
.as_ref()
|
||||
.map(|id| format!(", request id: {id}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UsageLimitReachedError {
|
||||
pub(crate) plan_type: Option<PlanType>,
|
||||
pub(crate) resets_at: Option<DateTime<Utc>>,
|
||||
pub(crate) rate_limits: Option<RateLimitSnapshot>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UsageLimitReachedError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let message = match self.plan_type.as_ref() {
|
||||
Some(PlanType::Known(KnownPlan::Plus)) => format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/codex/settings/usage to purchase more credits{}",
|
||||
retry_suffix_after_or(self.resets_at.as_ref())
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
|
||||
format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin{}",
|
||||
retry_suffix_after_or(self.resets_at.as_ref())
|
||||
)
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Free)) => {
|
||||
"You've hit your usage limit. Upgrade to Plus to continue using Codex (https://openai.com/chatgpt/pricing)."
|
||||
.to_string()
|
||||
}
|
||||
Some(PlanType::Known(KnownPlan::Pro)) => format!(
|
||||
"You've hit your usage limit. Visit https://chatgpt.com/codex/settings/usage to purchase more credits{}",
|
||||
retry_suffix_after_or(self.resets_at.as_ref())
|
||||
),
|
||||
Some(PlanType::Known(KnownPlan::Enterprise))
|
||||
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_at.as_ref())
|
||||
),
|
||||
Some(PlanType::Unknown(_)) | None => format!(
|
||||
"You've hit your usage limit.{}",
|
||||
retry_suffix(self.resets_at.as_ref())
|
||||
),
|
||||
};
|
||||
|
||||
write!(f, "{message}")
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix(resets_at: Option<&DateTime<Utc>>) -> String {
|
||||
if let Some(resets_at) = resets_at {
|
||||
let formatted = format_retry_timestamp(resets_at);
|
||||
format!(" Try again at {formatted}.")
|
||||
} else {
|
||||
" Try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn retry_suffix_after_or(resets_at: Option<&DateTime<Utc>>) -> String {
|
||||
if let Some(resets_at) = resets_at {
|
||||
let formatted = format_retry_timestamp(resets_at);
|
||||
format!(" or try again at {formatted}.")
|
||||
} else {
|
||||
" or try again later.".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_retry_timestamp(resets_at: &DateTime<Utc>) -> String {
|
||||
let local_reset = resets_at.with_timezone(&Local);
|
||||
let local_now = now_for_retry().with_timezone(&Local);
|
||||
if local_reset.date_naive() == local_now.date_naive() {
|
||||
local_reset.format("%-I:%M %p").to_string()
|
||||
} else {
|
||||
let suffix = day_suffix(local_reset.day());
|
||||
local_reset
|
||||
.format(&format!("%b %-d{suffix}, %Y %-I:%M %p"))
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn day_suffix(day: u32) -> &'static str {
|
||||
match day {
|
||||
11..=13 => "th",
|
||||
_ => match day % 10 {
|
||||
1 => "st",
|
||||
2 => "nd", // codespell:ignore
|
||||
3 => "rd",
|
||||
_ => "th",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
thread_local! {
|
||||
static NOW_OVERRIDE: std::cell::RefCell<Option<DateTime<Utc>>> =
|
||||
const { std::cell::RefCell::new(None) };
|
||||
}
|
||||
|
||||
fn now_for_retry() -> DateTime<Utc> {
|
||||
#[cfg(test)]
|
||||
{
|
||||
if let Some(now) = NOW_OVERRIDE.with(|cell| *cell.borrow()) {
|
||||
return now;
|
||||
}
|
||||
}
|
||||
Utc::now()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EnvVarError {
|
||||
/// Name of the environment variable that is missing.
|
||||
pub var: String,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub instructions: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EnvVarError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Missing environment variable: `{}`.", self.var)?;
|
||||
if let Some(instructions) = &self.instructions {
|
||||
write!(f, " {instructions}")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl CodexErr {
|
||||
/// Minimal shim so that existing `e.downcast_ref::<CodexErr>()` checks continue to compile
|
||||
/// after replacing `anyhow::Error` in the return signature. This mirrors the behavior of
|
||||
/// `anyhow::Error::downcast_ref` but works directly on our concrete enum.
|
||||
pub fn downcast_ref<T: std::any::Any>(&self) -> Option<&T> {
|
||||
(self as &dyn std::any::Any).downcast_ref::<T>()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_error_message_ui(e: &CodexErr) -> String {
|
||||
let message = match e {
|
||||
CodexErr::Sandbox(SandboxErr::Denied { output }) => {
|
||||
let aggregated = output.aggregated_output.text.trim();
|
||||
if !aggregated.is_empty() {
|
||||
output.aggregated_output.text.clone()
|
||||
} else {
|
||||
let stderr = output.stderr.text.trim();
|
||||
let stdout = output.stdout.text.trim();
|
||||
match (stderr.is_empty(), stdout.is_empty()) {
|
||||
(false, false) => format!("{stderr}\n{stdout}"),
|
||||
(false, true) => output.stderr.text.clone(),
|
||||
(true, false) => output.stdout.text.clone(),
|
||||
(true, true) => format!(
|
||||
"command failed inside sandbox with exit code {}",
|
||||
output.exit_code
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
// Timeouts are not sandbox errors from a UX perspective; present them plainly
|
||||
CodexErr::Sandbox(SandboxErr::Timeout { output }) => {
|
||||
format!(
|
||||
"error: command timed out after {} ms",
|
||||
output.duration.as_millis()
|
||||
)
|
||||
}
|
||||
_ => e.to_string(),
|
||||
};
|
||||
|
||||
truncate_middle(&message, ERROR_MESSAGE_UI_MAX_BYTES).0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::exec::StreamOutput;
|
||||
use chrono::DateTime;
|
||||
use chrono::Duration as ChronoDuration;
|
||||
use chrono::TimeZone;
|
||||
use chrono::Utc;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
fn rate_limit_snapshot() -> RateLimitSnapshot {
|
||||
let primary_reset_at = Utc
|
||||
.with_ymd_and_hms(2024, 1, 1, 1, 0, 0)
|
||||
.unwrap()
|
||||
.timestamp();
|
||||
let secondary_reset_at = Utc
|
||||
.with_ymd_and_hms(2024, 1, 1, 2, 0, 0)
|
||||
.unwrap()
|
||||
.timestamp();
|
||||
RateLimitSnapshot {
|
||||
primary: Some(RateLimitWindow {
|
||||
used_percent: 50.0,
|
||||
window_minutes: Some(60),
|
||||
resets_at: Some(primary_reset_at),
|
||||
}),
|
||||
secondary: Some(RateLimitWindow {
|
||||
used_percent: 30.0,
|
||||
window_minutes: Some(120),
|
||||
resets_at: Some(secondary_reset_at),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_now_override<T>(now: DateTime<Utc>, f: impl FnOnce() -> T) -> T {
|
||||
NOW_OVERRIDE.with(|cell| {
|
||||
*cell.borrow_mut() = Some(now);
|
||||
let result = f();
|
||||
*cell.borrow_mut() = None;
|
||||
result
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_plus_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_at: None,
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_uses_aggregated_output_when_stderr_empty() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 77,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new("aggregate detail".to_string()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "aggregate detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_both_streams_when_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 9,
|
||||
stdout: StreamOutput::new("stdout detail".to_string()),
|
||||
stderr: StreamOutput::new("stderr detail".to_string()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(10),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_stdout_when_no_stderr() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 11,
|
||||
stdout: StreamOutput::new("stdout only".to_string()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(8),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(get_error_message_ui(&err), "stdout only");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_denied_reports_exit_code_when_no_output_available() {
|
||||
let output = ExecToolCallOutput {
|
||||
exit_code: 13,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(String::new()),
|
||||
aggregated_output: StreamOutput::new(String::new()),
|
||||
duration: Duration::from_millis(5),
|
||||
timed_out: false,
|
||||
};
|
||||
let err = CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(output),
|
||||
});
|
||||
assert_eq!(
|
||||
get_error_message_ui(&err),
|
||||
"command failed inside sandbox with exit code 13"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_free_plan() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Free)),
|
||||
resets_at: None,
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. Upgrade to Plus to continue using Codex (https://openai.com/chatgpt/pricing)."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_when_none() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: None,
|
||||
resets_at: None,
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. Try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_team_plan() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at = base + ChronoDuration::hours(1);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Team)),
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!(
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again at {expected_time}."
|
||||
);
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_business_plan_without_reset() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Business)),
|
||||
resets_at: None,
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_default_for_other_plans() {
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Enterprise)),
|
||||
resets_at: None,
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"You've hit your usage limit. Try again later."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_error_formats_pro_plan_with_reset() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at = base + ChronoDuration::hours(1);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!(
|
||||
"You've hit your usage limit. Visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}."
|
||||
);
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_minutes_when_available() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at = base + ChronoDuration::minutes(5);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: None,
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_status_cloudflare_html_is_simplified() {
|
||||
let err = UnexpectedResponseError {
|
||||
status: StatusCode::FORBIDDEN,
|
||||
body: "<html><body>Cloudflare error: Sorry, you have been blocked</body></html>"
|
||||
.to_string(),
|
||||
request_id: Some("ray-id".to_string()),
|
||||
};
|
||||
let status = StatusCode::FORBIDDEN.to_string();
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {status}), request id: ray-id")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_status_non_html_is_unchanged() {
|
||||
let err = UnexpectedResponseError {
|
||||
status: StatusCode::FORBIDDEN,
|
||||
body: "plain text error".to_string(),
|
||||
request_id: None,
|
||||
};
|
||||
let status = StatusCode::FORBIDDEN.to_string();
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
format!("unexpected status {status}: plain text error")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_hours_and_minutes() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at = base + ChronoDuration::hours(3) + ChronoDuration::minutes(32);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!(
|
||||
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/codex/settings/usage to purchase more credits or try again at {expected_time}."
|
||||
);
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_includes_days_hours_minutes() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at =
|
||||
base + ChronoDuration::days(2) + ChronoDuration::hours(3) + ChronoDuration::minutes(5);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: None,
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn usage_limit_reached_less_than_minute() {
|
||||
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
|
||||
let resets_at = base + ChronoDuration::seconds(30);
|
||||
with_now_override(base, move || {
|
||||
let expected_time = format_retry_timestamp(&resets_at);
|
||||
let err = UsageLimitReachedError {
|
||||
plan_type: None,
|
||||
resets_at: Some(resets_at),
|
||||
rate_limits: Some(rate_limit_snapshot()),
|
||||
};
|
||||
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
|
||||
assert_eq!(err.to_string(), expected);
|
||||
});
|
||||
}
|
||||
}
|
||||
323
llmx-rs/core/src/event_mapping.rs
Normal file
323
llmx-rs/core/src/event_mapping.rs
Normal file
@@ -0,0 +1,323 @@
|
||||
use codex_protocol::items::AgentMessageContent;
|
||||
use codex_protocol::items::AgentMessageItem;
|
||||
use codex_protocol::items::ReasoningItem;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::items::UserMessageItem;
|
||||
use codex_protocol::items::WebSearchItem;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tracing::warn;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::user_instructions::UserInstructions;
|
||||
use crate::user_shell_command::is_user_shell_command_text;
|
||||
|
||||
fn is_session_prefix(text: &str) -> bool {
|
||||
let trimmed = text.trim_start();
|
||||
let lowered = trimmed.to_ascii_lowercase();
|
||||
lowered.starts_with("<environment_context>")
|
||||
}
|
||||
|
||||
fn parse_user_message(message: &[ContentItem]) -> Option<UserMessageItem> {
|
||||
if UserInstructions::is_user_instructions(message) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut content: Vec<UserInput> = Vec::new();
|
||||
|
||||
for content_item in message.iter() {
|
||||
match content_item {
|
||||
ContentItem::InputText { text } => {
|
||||
if is_session_prefix(text) || is_user_shell_command_text(text) {
|
||||
return None;
|
||||
}
|
||||
content.push(UserInput::Text { text: text.clone() });
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
content.push(UserInput::Image {
|
||||
image_url: image_url.clone(),
|
||||
});
|
||||
}
|
||||
ContentItem::OutputText { text } => {
|
||||
if is_session_prefix(text) {
|
||||
return None;
|
||||
}
|
||||
warn!("Output text in user message: {}", text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(UserMessageItem::new(&content))
|
||||
}
|
||||
|
||||
fn parse_agent_message(id: Option<&String>, message: &[ContentItem]) -> AgentMessageItem {
|
||||
let mut content: Vec<AgentMessageContent> = Vec::new();
|
||||
for content_item in message.iter() {
|
||||
match content_item {
|
||||
ContentItem::OutputText { text } => {
|
||||
content.push(AgentMessageContent::Text { text: text.clone() });
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
"Unexpected content item in agent message: {:?}",
|
||||
content_item
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
let id = id.cloned().unwrap_or_else(|| Uuid::new_v4().to_string());
|
||||
AgentMessageItem { id, content }
|
||||
}
|
||||
|
||||
pub fn parse_turn_item(item: &ResponseItem) -> Option<TurnItem> {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, id } => match role.as_str() {
|
||||
"user" => parse_user_message(content).map(TurnItem::UserMessage),
|
||||
"assistant" => Some(TurnItem::AgentMessage(parse_agent_message(
|
||||
id.as_ref(),
|
||||
content,
|
||||
))),
|
||||
"system" => None,
|
||||
_ => None,
|
||||
},
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
let summary_text = summary
|
||||
.iter()
|
||||
.map(|entry| match entry {
|
||||
ReasoningItemReasoningSummary::SummaryText { text } => text.clone(),
|
||||
})
|
||||
.collect();
|
||||
let raw_content = content
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|entry| match entry {
|
||||
ReasoningItemContent::ReasoningText { text }
|
||||
| ReasoningItemContent::Text { text } => text,
|
||||
})
|
||||
.collect();
|
||||
Some(TurnItem::Reasoning(ReasoningItem {
|
||||
id: id.clone(),
|
||||
summary_text,
|
||||
raw_content,
|
||||
}))
|
||||
}
|
||||
ResponseItem::WebSearchCall {
|
||||
id,
|
||||
action: WebSearchAction::Search { query },
|
||||
..
|
||||
} => Some(TurnItem::WebSearch(WebSearchItem {
|
||||
id: id.clone().unwrap_or_default(),
|
||||
query: query.clone(),
|
||||
})),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::parse_turn_item;
|
||||
use codex_protocol::items::AgentMessageContent;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ReasoningItemReasoningSummary;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::WebSearchAction;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn parses_user_message_with_text_and_two_images() {
|
||||
let img1 = "https://example.com/one.png".to_string();
|
||||
let img2 = "https://example.com/two.jpg".to_string();
|
||||
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![
|
||||
ContentItem::InputText {
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img1.clone(),
|
||||
},
|
||||
ContentItem::InputImage {
|
||||
image_url: img2.clone(),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let turn_item = parse_turn_item(&item).expect("expected user message turn item");
|
||||
|
||||
match turn_item {
|
||||
TurnItem::UserMessage(user) => {
|
||||
let expected_content = vec![
|
||||
UserInput::Text {
|
||||
text: "Hello world".to_string(),
|
||||
},
|
||||
UserInput::Image { image_url: img1 },
|
||||
UserInput::Image { image_url: img2 },
|
||||
];
|
||||
assert_eq!(user.content, expected_content);
|
||||
}
|
||||
other => panic!("expected TurnItem::UserMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skips_user_instructions_and_env() {
|
||||
let items = vec![
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<user_instructions>test_text</user_instructions>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<environment_context>test_text</environment_context>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>".to_string(),
|
||||
}],
|
||||
},
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "<user_shell_command>echo 42</user_shell_command>".to_string(),
|
||||
}],
|
||||
},
|
||||
];
|
||||
|
||||
for item in items {
|
||||
let turn_item = parse_turn_item(&item);
|
||||
assert!(turn_item.is_none(), "expected none, got {turn_item:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_agent_message() {
|
||||
let item = ResponseItem::Message {
|
||||
id: Some("msg-1".to_string()),
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: "Hello from Codex".to_string(),
|
||||
}],
|
||||
};
|
||||
|
||||
let turn_item = parse_turn_item(&item).expect("expected agent message turn item");
|
||||
|
||||
match turn_item {
|
||||
TurnItem::AgentMessage(message) => {
|
||||
let Some(AgentMessageContent::Text { text }) = message.content.first() else {
|
||||
panic!("expected agent message text content");
|
||||
};
|
||||
assert_eq!(text, "Hello from Codex");
|
||||
}
|
||||
other => panic!("expected TurnItem::AgentMessage, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_reasoning_summary_and_raw_content() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: "reasoning_1".to_string(),
|
||||
summary: vec![
|
||||
ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "Step 1".to_string(),
|
||||
},
|
||||
ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "Step 2".to_string(),
|
||||
},
|
||||
],
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: "raw details".to_string(),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
|
||||
let turn_item = parse_turn_item(&item).expect("expected reasoning turn item");
|
||||
|
||||
match turn_item {
|
||||
TurnItem::Reasoning(reasoning) => {
|
||||
assert_eq!(
|
||||
reasoning.summary_text,
|
||||
vec!["Step 1".to_string(), "Step 2".to_string()]
|
||||
);
|
||||
assert_eq!(reasoning.raw_content, vec!["raw details".to_string()]);
|
||||
}
|
||||
other => panic!("expected TurnItem::Reasoning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_reasoning_including_raw_content() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: "reasoning_2".to_string(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "Summarized step".to_string(),
|
||||
}],
|
||||
content: Some(vec![
|
||||
ReasoningItemContent::ReasoningText {
|
||||
text: "raw step".to_string(),
|
||||
},
|
||||
ReasoningItemContent::Text {
|
||||
text: "final thought".to_string(),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
|
||||
let turn_item = parse_turn_item(&item).expect("expected reasoning turn item");
|
||||
|
||||
match turn_item {
|
||||
TurnItem::Reasoning(reasoning) => {
|
||||
assert_eq!(reasoning.summary_text, vec!["Summarized step".to_string()]);
|
||||
assert_eq!(
|
||||
reasoning.raw_content,
|
||||
vec!["raw step".to_string(), "final thought".to_string()]
|
||||
);
|
||||
}
|
||||
other => panic!("expected TurnItem::Reasoning, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_web_search_call() {
|
||||
let item = ResponseItem::WebSearchCall {
|
||||
id: Some("ws_1".to_string()),
|
||||
status: Some("completed".to_string()),
|
||||
action: WebSearchAction::Search {
|
||||
query: "weather".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let turn_item = parse_turn_item(&item).expect("expected web search turn item");
|
||||
|
||||
match turn_item {
|
||||
TurnItem::WebSearch(search) => {
|
||||
assert_eq!(search.id, "ws_1");
|
||||
assert_eq!(search.query, "weather");
|
||||
}
|
||||
other => panic!("expected TurnItem::WebSearch, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
777
llmx-rs/core/src/exec.rs
Normal file
777
llmx-rs/core/src/exec.rs
Normal file
@@ -0,0 +1,777 @@
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::process::ExitStatus;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_channel::Sender;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::process::Child;
|
||||
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::sandboxing::CommandSpec;
|
||||
use crate::sandboxing::ExecEnv;
|
||||
use crate::sandboxing::SandboxManager;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
|
||||
|
||||
// Hardcode these since it does not seem worth including the libc crate just
|
||||
// for these.
|
||||
const SIGKILL_CODE: i32 = 9;
|
||||
const TIMEOUT_CODE: i32 = 64;
|
||||
const EXIT_CODE_SIGNAL_BASE: i32 = 128; // conventional shell: 128 + signal
|
||||
const EXEC_TIMEOUT_EXIT_CODE: i32 = 124; // conventional timeout exit code
|
||||
|
||||
// I/O buffer sizing
|
||||
const READ_CHUNK_SIZE: usize = 8192; // bytes per read
|
||||
const AGGREGATE_BUFFER_INITIAL_CAPACITY: usize = 8 * 1024; // 8 KiB
|
||||
|
||||
/// Limit the number of ExecCommandOutputDelta events emitted per exec call.
|
||||
/// Aggregation still collects full output; only the live event stream is capped.
|
||||
pub(crate) const MAX_EXEC_OUTPUT_DELTAS_PER_CALL: usize = 10_000;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExecParams {
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub env: HashMap<String, String>,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
pub arg0: Option<String>,
|
||||
}
|
||||
|
||||
impl ExecParams {
|
||||
pub fn timeout_duration(&self) -> Duration {
|
||||
Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub enum SandboxType {
|
||||
None,
|
||||
|
||||
/// Only available on macOS.
|
||||
MacosSeatbelt,
|
||||
|
||||
/// Only available on Linux.
|
||||
LinuxSeccomp,
|
||||
|
||||
/// Only available on Windows.
|
||||
WindowsRestrictedToken,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct StdoutStream {
|
||||
pub sub_id: String,
|
||||
pub call_id: String,
|
||||
pub tx_event: Sender<Event>,
|
||||
}
|
||||
|
||||
pub async fn process_exec_tool_call(
|
||||
params: ExecParams,
|
||||
sandbox_type: SandboxType,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_cwd: &Path,
|
||||
codex_linux_sandbox_exe: &Option<PathBuf>,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions,
|
||||
justification,
|
||||
arg0: _,
|
||||
} = params;
|
||||
|
||||
let (program, args) = command.split_first().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"command args are empty",
|
||||
))
|
||||
})?;
|
||||
|
||||
let spec = CommandSpec {
|
||||
program: program.clone(),
|
||||
args: args.to_vec(),
|
||||
cwd,
|
||||
env,
|
||||
timeout_ms,
|
||||
with_escalated_permissions,
|
||||
justification,
|
||||
};
|
||||
|
||||
let manager = SandboxManager::new();
|
||||
let exec_env = manager
|
||||
.transform(
|
||||
&spec,
|
||||
sandbox_policy,
|
||||
sandbox_type,
|
||||
sandbox_cwd,
|
||||
codex_linux_sandbox_exe.as_ref(),
|
||||
)
|
||||
.map_err(CodexErr::from)?;
|
||||
|
||||
// Route through the sandboxing module for a single, unified execution path.
|
||||
crate::sandboxing::execute_env(&exec_env, sandbox_policy, stdout_stream).await
|
||||
}
|
||||
|
||||
pub(crate) async fn execute_exec_env(
|
||||
env: ExecEnv,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
let ExecEnv {
|
||||
command,
|
||||
cwd,
|
||||
env,
|
||||
timeout_ms,
|
||||
sandbox,
|
||||
with_escalated_permissions,
|
||||
justification,
|
||||
arg0,
|
||||
} = env;
|
||||
|
||||
let params = ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
timeout_ms,
|
||||
env,
|
||||
with_escalated_permissions,
|
||||
justification,
|
||||
arg0,
|
||||
};
|
||||
|
||||
let start = Instant::now();
|
||||
let raw_output_result = exec(params, sandbox, sandbox_policy, stdout_stream).await;
|
||||
let duration = start.elapsed();
|
||||
finalize_exec_result(raw_output_result, sandbox, duration)
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
async fn exec_windows_sandbox(
|
||||
params: ExecParams,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
use crate::config::find_codex_home;
|
||||
use codex_windows_sandbox::run_windows_sandbox_capture;
|
||||
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
env,
|
||||
timeout_ms,
|
||||
..
|
||||
} = params;
|
||||
|
||||
let policy_str = match sandbox_policy {
|
||||
SandboxPolicy::DangerFullAccess => "workspace-write",
|
||||
SandboxPolicy::ReadOnly => "read-only",
|
||||
SandboxPolicy::WorkspaceWrite { .. } => "workspace-write",
|
||||
};
|
||||
|
||||
let sandbox_cwd = cwd.clone();
|
||||
let logs_base_dir = find_codex_home().ok();
|
||||
let spawn_res = tokio::task::spawn_blocking(move || {
|
||||
run_windows_sandbox_capture(
|
||||
policy_str,
|
||||
&sandbox_cwd,
|
||||
command,
|
||||
&cwd,
|
||||
env,
|
||||
timeout_ms,
|
||||
logs_base_dir.as_deref(),
|
||||
)
|
||||
})
|
||||
.await;
|
||||
|
||||
let capture = match spawn_res {
|
||||
Ok(Ok(v)) => v,
|
||||
Ok(Err(err)) => {
|
||||
return Err(CodexErr::Io(io::Error::other(format!(
|
||||
"windows sandbox: {err}"
|
||||
))));
|
||||
}
|
||||
Err(join_err) => {
|
||||
return Err(CodexErr::Io(io::Error::other(format!(
|
||||
"windows sandbox join error: {join_err}"
|
||||
))));
|
||||
}
|
||||
};
|
||||
|
||||
let exit_status = synthetic_exit_status(capture.exit_code);
|
||||
let stdout = StreamOutput {
|
||||
text: capture.stdout,
|
||||
truncated_after_lines: None,
|
||||
};
|
||||
let stderr = StreamOutput {
|
||||
text: capture.stderr,
|
||||
truncated_after_lines: None,
|
||||
};
|
||||
// Best-effort aggregate: stdout then stderr
|
||||
let mut aggregated = Vec::with_capacity(stdout.text.len() + stderr.text.len());
|
||||
append_all(&mut aggregated, &stdout.text);
|
||||
append_all(&mut aggregated, &stderr.text);
|
||||
let aggregated_output = StreamOutput {
|
||||
text: aggregated,
|
||||
truncated_after_lines: None,
|
||||
};
|
||||
|
||||
Ok(RawExecToolCallOutput {
|
||||
exit_status,
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output,
|
||||
timed_out: capture.timed_out,
|
||||
})
|
||||
}
|
||||
|
||||
fn finalize_exec_result(
|
||||
raw_output_result: std::result::Result<RawExecToolCallOutput, CodexErr>,
|
||||
sandbox_type: SandboxType,
|
||||
duration: Duration,
|
||||
) -> Result<ExecToolCallOutput> {
|
||||
match raw_output_result {
|
||||
Ok(raw_output) => {
|
||||
#[allow(unused_mut)]
|
||||
let mut timed_out = raw_output.timed_out;
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
{
|
||||
if let Some(signal) = raw_output.exit_status.signal() {
|
||||
if signal == TIMEOUT_CODE {
|
||||
timed_out = true;
|
||||
} else {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Signal(signal)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut exit_code = raw_output.exit_status.code().unwrap_or(-1);
|
||||
if timed_out {
|
||||
exit_code = EXEC_TIMEOUT_EXIT_CODE;
|
||||
}
|
||||
|
||||
let stdout = raw_output.stdout.from_utf8_lossy();
|
||||
let stderr = raw_output.stderr.from_utf8_lossy();
|
||||
let aggregated_output = raw_output.aggregated_output.from_utf8_lossy();
|
||||
let exec_output = ExecToolCallOutput {
|
||||
exit_code,
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output,
|
||||
duration,
|
||||
timed_out,
|
||||
};
|
||||
|
||||
if timed_out {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Timeout {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
}
|
||||
|
||||
if is_likely_sandbox_denied(sandbox_type, &exec_output) {
|
||||
return Err(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output: Box::new(exec_output),
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(exec_output)
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("exec error: {err}");
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod errors {
|
||||
use super::CodexErr;
|
||||
use crate::sandboxing::SandboxTransformError;
|
||||
|
||||
impl From<SandboxTransformError> for CodexErr {
|
||||
fn from(err: SandboxTransformError) -> Self {
|
||||
match err {
|
||||
SandboxTransformError::MissingLinuxSandboxExecutable => {
|
||||
CodexErr::LandlockSandboxExecutableNotProvided
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
SandboxTransformError::SeatbeltUnavailable => CodexErr::UnsupportedOperation(
|
||||
"seatbelt sandbox is only available on macOS".to_string(),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't have a fully deterministic way to tell if our command failed
|
||||
/// because of the sandbox - a command in the user's zshrc file might hit an
|
||||
/// error, but the command itself might fail or succeed for other reasons.
|
||||
/// For now, we conservatively check for well known command failure exit codes and
|
||||
/// also look for common sandbox denial keywords in the command output.
|
||||
pub(crate) fn is_likely_sandbox_denied(
|
||||
sandbox_type: SandboxType,
|
||||
exec_output: &ExecToolCallOutput,
|
||||
) -> bool {
|
||||
if sandbox_type == SandboxType::None || exec_output.exit_code == 0 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Quick rejects: well-known non-sandbox shell exit codes
|
||||
// 2: misuse of shell builtins
|
||||
// 126: permission denied
|
||||
// 127: command not found
|
||||
const SANDBOX_DENIED_KEYWORDS: [&str; 7] = [
|
||||
"operation not permitted",
|
||||
"permission denied",
|
||||
"read-only file system",
|
||||
"seccomp",
|
||||
"sandbox",
|
||||
"landlock",
|
||||
"failed to write file",
|
||||
];
|
||||
|
||||
let has_sandbox_keyword = [
|
||||
&exec_output.stderr.text,
|
||||
&exec_output.stdout.text,
|
||||
&exec_output.aggregated_output.text,
|
||||
]
|
||||
.into_iter()
|
||||
.any(|section| {
|
||||
let lower = section.to_lowercase();
|
||||
SANDBOX_DENIED_KEYWORDS
|
||||
.iter()
|
||||
.any(|needle| lower.contains(needle))
|
||||
});
|
||||
|
||||
if has_sandbox_keyword {
|
||||
return true;
|
||||
}
|
||||
|
||||
const QUICK_REJECT_EXIT_CODES: [i32; 3] = [2, 126, 127];
|
||||
if QUICK_REJECT_EXIT_CODES.contains(&exec_output.exit_code) {
|
||||
return false;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
const SIGSYS_CODE: i32 = libc::SIGSYS;
|
||||
if sandbox_type == SandboxType::LinuxSeccomp
|
||||
&& exec_output.exit_code == EXIT_CODE_SIGNAL_BASE + SIGSYS_CODE
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamOutput<T: Clone> {
|
||||
pub text: T,
|
||||
pub truncated_after_lines: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RawExecToolCallOutput {
|
||||
pub exit_status: ExitStatus,
|
||||
pub stdout: StreamOutput<Vec<u8>>,
|
||||
pub stderr: StreamOutput<Vec<u8>>,
|
||||
pub aggregated_output: StreamOutput<Vec<u8>>,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
impl StreamOutput<String> {
|
||||
pub fn new(text: String) -> Self {
|
||||
Self {
|
||||
text,
|
||||
truncated_after_lines: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamOutput<Vec<u8>> {
|
||||
pub fn from_utf8_lossy(&self) -> StreamOutput<String> {
|
||||
StreamOutput {
|
||||
text: String::from_utf8_lossy(&self.text).to_string(),
|
||||
truncated_after_lines: self.truncated_after_lines,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn append_all(dst: &mut Vec<u8>, src: &[u8]) {
|
||||
dst.extend_from_slice(src);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExecToolCallOutput {
|
||||
pub exit_code: i32,
|
||||
pub stdout: StreamOutput<String>,
|
||||
pub stderr: StreamOutput<String>,
|
||||
pub aggregated_output: StreamOutput<String>,
|
||||
pub duration: Duration,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
#[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
|
||||
async fn exec(
|
||||
params: ExecParams,
|
||||
sandbox: SandboxType,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
#[cfg(target_os = "windows")]
|
||||
if sandbox == SandboxType::WindowsRestrictedToken {
|
||||
return exec_windows_sandbox(params, sandbox_policy).await;
|
||||
}
|
||||
let timeout = params.timeout_duration();
|
||||
let ExecParams {
|
||||
command,
|
||||
cwd,
|
||||
env,
|
||||
arg0,
|
||||
..
|
||||
} = params;
|
||||
|
||||
let (program, args) = command.split_first().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"command args are empty",
|
||||
))
|
||||
})?;
|
||||
let arg0_ref = arg0.as_deref();
|
||||
let child = spawn_child_async(
|
||||
PathBuf::from(program),
|
||||
args.into(),
|
||||
arg0_ref,
|
||||
cwd,
|
||||
sandbox_policy,
|
||||
StdioPolicy::RedirectForShellTool,
|
||||
env,
|
||||
)
|
||||
.await?;
|
||||
consume_truncated_output(child, timeout, stdout_stream).await
|
||||
}
|
||||
|
||||
/// Consumes the output of a child process, truncating it so it is suitable for
|
||||
/// use as the output of a `shell` tool call. Also enforces specified timeout.
|
||||
async fn consume_truncated_output(
|
||||
mut child: Child,
|
||||
timeout: Duration,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> Result<RawExecToolCallOutput> {
|
||||
// Both stdout and stderr were configured with `Stdio::piped()`
|
||||
// above, therefore `take()` should normally return `Some`. If it doesn't
|
||||
// we treat it as an exceptional I/O error
|
||||
|
||||
let stdout_reader = child.stdout.take().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::other(
|
||||
"stdout pipe was unexpectedly not available",
|
||||
))
|
||||
})?;
|
||||
let stderr_reader = child.stderr.take().ok_or_else(|| {
|
||||
CodexErr::Io(io::Error::other(
|
||||
"stderr pipe was unexpectedly not available",
|
||||
))
|
||||
})?;
|
||||
|
||||
let (agg_tx, agg_rx) = async_channel::unbounded::<Vec<u8>>();
|
||||
|
||||
let stdout_handle = tokio::spawn(read_capped(
|
||||
BufReader::new(stdout_reader),
|
||||
stdout_stream.clone(),
|
||||
false,
|
||||
Some(agg_tx.clone()),
|
||||
));
|
||||
let stderr_handle = tokio::spawn(read_capped(
|
||||
BufReader::new(stderr_reader),
|
||||
stdout_stream.clone(),
|
||||
true,
|
||||
Some(agg_tx.clone()),
|
||||
));
|
||||
|
||||
let (exit_status, timed_out) = tokio::select! {
|
||||
result = tokio::time::timeout(timeout, child.wait()) => {
|
||||
match result {
|
||||
Ok(status_result) => {
|
||||
let exit_status = status_result?;
|
||||
(exit_status, false)
|
||||
}
|
||||
Err(_) => {
|
||||
// timeout
|
||||
kill_child_process_group(&mut child)?;
|
||||
child.start_kill()?;
|
||||
// Debatable whether `child.wait().await` should be called here.
|
||||
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
kill_child_process_group(&mut child)?;
|
||||
child.start_kill()?;
|
||||
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + SIGKILL_CODE), false)
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = stdout_handle.await??;
|
||||
let stderr = stderr_handle.await??;
|
||||
|
||||
drop(agg_tx);
|
||||
|
||||
let mut combined_buf = Vec::with_capacity(AGGREGATE_BUFFER_INITIAL_CAPACITY);
|
||||
while let Ok(chunk) = agg_rx.recv().await {
|
||||
append_all(&mut combined_buf, &chunk);
|
||||
}
|
||||
let aggregated_output = StreamOutput {
|
||||
text: combined_buf,
|
||||
truncated_after_lines: None,
|
||||
};
|
||||
|
||||
Ok(RawExecToolCallOutput {
|
||||
exit_status,
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output,
|
||||
timed_out,
|
||||
})
|
||||
}
|
||||
|
||||
async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
|
||||
mut reader: R,
|
||||
stream: Option<StdoutStream>,
|
||||
is_stderr: bool,
|
||||
aggregate_tx: Option<Sender<Vec<u8>>>,
|
||||
) -> io::Result<StreamOutput<Vec<u8>>> {
|
||||
let mut buf = Vec::with_capacity(AGGREGATE_BUFFER_INITIAL_CAPACITY);
|
||||
let mut tmp = [0u8; READ_CHUNK_SIZE];
|
||||
let mut emitted_deltas: usize = 0;
|
||||
|
||||
// No caps: append all bytes
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut tmp).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(stream) = &stream
|
||||
&& emitted_deltas < MAX_EXEC_OUTPUT_DELTAS_PER_CALL
|
||||
{
|
||||
let chunk = tmp[..n].to_vec();
|
||||
let msg = EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
|
||||
call_id: stream.call_id.clone(),
|
||||
stream: if is_stderr {
|
||||
ExecOutputStream::Stderr
|
||||
} else {
|
||||
ExecOutputStream::Stdout
|
||||
},
|
||||
chunk,
|
||||
});
|
||||
let event = Event {
|
||||
id: stream.sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
#[allow(clippy::let_unit_value)]
|
||||
let _ = stream.tx_event.send(event).await;
|
||||
emitted_deltas += 1;
|
||||
}
|
||||
|
||||
if let Some(tx) = &aggregate_tx {
|
||||
let _ = tx.send(tmp[..n].to_vec()).await;
|
||||
}
|
||||
|
||||
append_all(&mut buf, &tmp[..n]);
|
||||
// Continue reading to EOF to avoid back-pressure
|
||||
}
|
||||
|
||||
Ok(StreamOutput {
|
||||
text: buf,
|
||||
truncated_after_lines: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn synthetic_exit_status(code: i32) -> ExitStatus {
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
std::process::ExitStatus::from_raw(code)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
fn synthetic_exit_status(code: i32) -> ExitStatus {
|
||||
use std::os::windows::process::ExitStatusExt;
|
||||
// On Windows the raw status is a u32. Use a direct cast to avoid
|
||||
// panicking on negative i32 values produced by prior narrowing casts.
|
||||
std::process::ExitStatus::from_raw(code as u32)
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn kill_child_process_group(child: &mut Child) -> io::Result<()> {
|
||||
use std::io::ErrorKind;
|
||||
|
||||
if let Some(pid) = child.id() {
|
||||
let pid = pid as libc::pid_t;
|
||||
let pgid = unsafe { libc::getpgid(pid) };
|
||||
if pgid == -1 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
return Err(err);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
|
||||
if result == -1 {
|
||||
let err = std::io::Error::last_os_error();
|
||||
if err.kind() != ErrorKind::NotFound {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn kill_child_process_group(_: &mut Child) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
fn make_exec_output(
|
||||
exit_code: i32,
|
||||
stdout: &str,
|
||||
stderr: &str,
|
||||
aggregated: &str,
|
||||
) -> ExecToolCallOutput {
|
||||
ExecToolCallOutput {
|
||||
exit_code,
|
||||
stdout: StreamOutput::new(stdout.to_string()),
|
||||
stderr: StreamOutput::new(stderr.to_string()),
|
||||
aggregated_output: StreamOutput::new(aggregated.to_string()),
|
||||
duration: Duration::from_millis(1),
|
||||
timed_out: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_requires_keywords() {
|
||||
let output = make_exec_output(1, "", "", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_identifies_keyword_in_stderr() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_respects_quick_reject_exit_codes() {
|
||||
let output = make_exec_output(127, "", "command not found", "");
|
||||
assert!(!is_likely_sandbox_denied(
|
||||
SandboxType::LinuxSeccomp,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_ignores_non_sandbox_mode() {
|
||||
let output = make_exec_output(1, "", "Operation not permitted", "");
|
||||
assert!(!is_likely_sandbox_denied(SandboxType::None, &output));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_detection_uses_aggregated_output() {
|
||||
let output = make_exec_output(
|
||||
101,
|
||||
"",
|
||||
"",
|
||||
"cargo failed: Read-only file system when writing target",
|
||||
);
|
||||
assert!(is_likely_sandbox_denied(
|
||||
SandboxType::MacosSeatbelt,
|
||||
&output
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn sandbox_detection_flags_sigsys_exit_code() {
|
||||
let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS;
|
||||
let output = make_exec_output(exit_code, "", "", "");
|
||||
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[tokio::test]
|
||||
async fn kill_child_process_group_kills_grandchildren_on_timeout() -> Result<()> {
|
||||
let command = vec![
|
||||
"/bin/bash".to_string(),
|
||||
"-c".to_string(),
|
||||
"sleep 60 & echo $!; sleep 60".to_string(),
|
||||
];
|
||||
let env: HashMap<String, String> = std::env::vars().collect();
|
||||
let params = ExecParams {
|
||||
command,
|
||||
cwd: std::env::current_dir()?,
|
||||
timeout_ms: Some(500),
|
||||
env,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
};
|
||||
|
||||
let output = exec(params, SandboxType::None, &SandboxPolicy::ReadOnly, None).await?;
|
||||
assert!(output.timed_out);
|
||||
|
||||
let stdout = output.stdout.from_utf8_lossy().text;
|
||||
let pid_line = stdout.lines().next().unwrap_or("").trim();
|
||||
let pid: i32 = pid_line.parse().map_err(|error| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("Failed to parse pid from stdout '{pid_line}': {error}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut killed = false;
|
||||
for _ in 0..20 {
|
||||
// Use kill(pid, 0) to check if the process is alive.
|
||||
if unsafe { libc::kill(pid, 0) } == -1
|
||||
&& let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error()
|
||||
{
|
||||
killed = true;
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
assert!(killed, "grandchild process with pid {pid} is still alive");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
194
llmx-rs/core/src/exec_env.rs
Normal file
194
llmx-rs/core/src/exec_env.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use crate::config::types::EnvironmentVariablePattern;
|
||||
use crate::config::types::ShellEnvironmentPolicy;
|
||||
use crate::config::types::ShellEnvironmentPolicyInherit;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Construct an environment map based on the rules in the specified policy. The
|
||||
/// resulting map can be passed directly to `Command::envs()` after calling
|
||||
/// `env_clear()` to ensure no unintended variables are leaked to the spawned
|
||||
/// process.
|
||||
///
|
||||
/// The derivation follows the algorithm documented in the struct-level comment
|
||||
/// for [`ShellEnvironmentPolicy`].
|
||||
pub fn create_env(policy: &ShellEnvironmentPolicy) -> HashMap<String, String> {
|
||||
populate_env(std::env::vars(), policy)
|
||||
}
|
||||
|
||||
fn populate_env<I>(vars: I, policy: &ShellEnvironmentPolicy) -> HashMap<String, String>
|
||||
where
|
||||
I: IntoIterator<Item = (String, String)>,
|
||||
{
|
||||
// Step 1 – determine the starting set of variables based on the
|
||||
// `inherit` strategy.
|
||||
let mut env_map: HashMap<String, String> = match policy.inherit {
|
||||
ShellEnvironmentPolicyInherit::All => vars.into_iter().collect(),
|
||||
ShellEnvironmentPolicyInherit::None => HashMap::new(),
|
||||
ShellEnvironmentPolicyInherit::Core => {
|
||||
const CORE_VARS: &[&str] = &[
|
||||
"HOME", "LOGNAME", "PATH", "SHELL", "USER", "USERNAME", "TMPDIR", "TEMP", "TMP",
|
||||
];
|
||||
let allow: HashSet<&str> = CORE_VARS.iter().copied().collect();
|
||||
vars.into_iter()
|
||||
.filter(|(k, _)| allow.contains(k.as_str()))
|
||||
.collect()
|
||||
}
|
||||
};
|
||||
|
||||
// Internal helper – does `name` match **any** pattern in `patterns`?
|
||||
let matches_any = |name: &str, patterns: &[EnvironmentVariablePattern]| -> bool {
|
||||
patterns.iter().any(|pattern| pattern.matches(name))
|
||||
};
|
||||
|
||||
// Step 2 – Apply the default exclude if not disabled.
|
||||
if !policy.ignore_default_excludes {
|
||||
let default_excludes = vec![
|
||||
EnvironmentVariablePattern::new_case_insensitive("*KEY*"),
|
||||
EnvironmentVariablePattern::new_case_insensitive("*SECRET*"),
|
||||
EnvironmentVariablePattern::new_case_insensitive("*TOKEN*"),
|
||||
];
|
||||
env_map.retain(|k, _| !matches_any(k, &default_excludes));
|
||||
}
|
||||
|
||||
// Step 3 – Apply custom excludes.
|
||||
if !policy.exclude.is_empty() {
|
||||
env_map.retain(|k, _| !matches_any(k, &policy.exclude));
|
||||
}
|
||||
|
||||
// Step 4 – Apply user-provided overrides.
|
||||
for (key, val) in &policy.r#set {
|
||||
env_map.insert(key.clone(), val.clone());
|
||||
}
|
||||
|
||||
// Step 5 – If include_only is non-empty, keep *only* the matching vars.
|
||||
if !policy.include_only.is_empty() {
|
||||
env_map.retain(|k, _| matches_any(k, &policy.include_only));
|
||||
}
|
||||
|
||||
env_map
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::types::ShellEnvironmentPolicyInherit;
|
||||
use maplit::hashmap;
|
||||
|
||||
fn make_vars(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_core_inherit_and_default_excludes() {
|
||||
let vars = make_vars(&[
|
||||
("PATH", "/usr/bin"),
|
||||
("HOME", "/home/user"),
|
||||
("API_KEY", "secret"),
|
||||
("SECRET_TOKEN", "t"),
|
||||
]);
|
||||
|
||||
let policy = ShellEnvironmentPolicy::default(); // inherit Core, default excludes on
|
||||
let result = populate_env(vars, &policy);
|
||||
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"PATH".to_string() => "/usr/bin".to_string(),
|
||||
"HOME".to_string() => "/home/user".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_include_only() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]);
|
||||
|
||||
let policy = ShellEnvironmentPolicy {
|
||||
// skip default excludes so nothing is removed prematurely
|
||||
ignore_default_excludes: true,
|
||||
include_only: vec![EnvironmentVariablePattern::new_case_insensitive("*PATH")],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = populate_env(vars, &policy);
|
||||
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"PATH".to_string() => "/usr/bin".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_overrides() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin")]);
|
||||
|
||||
let mut policy = ShellEnvironmentPolicy {
|
||||
ignore_default_excludes: true,
|
||||
..Default::default()
|
||||
};
|
||||
policy.r#set.insert("NEW_VAR".to_string(), "42".to_string());
|
||||
|
||||
let result = populate_env(vars, &policy);
|
||||
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"PATH".to_string() => "/usr/bin".to_string(),
|
||||
"NEW_VAR".to_string() => "42".to_string(),
|
||||
};
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inherit_all() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]);
|
||||
|
||||
let policy = ShellEnvironmentPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::All,
|
||||
ignore_default_excludes: true, // keep everything
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = populate_env(vars.clone(), &policy);
|
||||
let expected: HashMap<String, String> = vars.into_iter().collect();
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inherit_all_with_default_excludes() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin"), ("API_KEY", "secret")]);
|
||||
|
||||
let policy = ShellEnvironmentPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::All,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = populate_env(vars, &policy);
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"PATH".to_string() => "/usr/bin".to_string(),
|
||||
};
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inherit_none() {
|
||||
let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]);
|
||||
|
||||
let mut policy = ShellEnvironmentPolicy {
|
||||
inherit: ShellEnvironmentPolicyInherit::None,
|
||||
ignore_default_excludes: true,
|
||||
..Default::default()
|
||||
};
|
||||
policy
|
||||
.r#set
|
||||
.insert("ONLY_VAR".to_string(), "yes".to_string());
|
||||
|
||||
let result = populate_env(vars, &policy);
|
||||
let expected: HashMap<String, String> = hashmap! {
|
||||
"ONLY_VAR".to_string() => "yes".to_string(),
|
||||
};
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
295
llmx-rs/core/src/features.rs
Normal file
295
llmx-rs/core/src/features.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
//! Centralized feature flags and metadata.
|
||||
//!
|
||||
//! This module defines a small set of toggles that gate experimental and
|
||||
//! optional behavior across the codebase. Instead of wiring individual
|
||||
//! booleans through multiple types, call sites consult a single `Features`
|
||||
//! container attached to `Config`.
|
||||
|
||||
use crate::config::ConfigToml;
|
||||
use crate::config::profile::ConfigProfile;
|
||||
use serde::Deserialize;
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::BTreeSet;
|
||||
|
||||
mod legacy;
|
||||
pub(crate) use legacy::LegacyFeatureToggles;
|
||||
|
||||
/// High-level lifecycle stage for a feature.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Stage {
|
||||
Experimental,
|
||||
Beta,
|
||||
Stable,
|
||||
Deprecated,
|
||||
Removed,
|
||||
}
|
||||
|
||||
/// Unique features toggled via configuration.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum Feature {
|
||||
/// Use the single unified PTY-backed exec tool.
|
||||
UnifiedExec,
|
||||
/// Enable experimental RMCP features such as OAuth login.
|
||||
RmcpClient,
|
||||
/// Include the freeform apply_patch tool.
|
||||
ApplyPatchFreeform,
|
||||
/// Include the view_image tool.
|
||||
ViewImageTool,
|
||||
/// Allow the model to request web searches.
|
||||
WebSearchRequest,
|
||||
/// Enable the model-based risk assessments for sandboxed commands.
|
||||
SandboxCommandAssessment,
|
||||
/// Create a ghost commit at each turn.
|
||||
GhostCommit,
|
||||
/// Enable Windows sandbox (restricted token) on Windows.
|
||||
WindowsSandbox,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
pub fn key(self) -> &'static str {
|
||||
self.info().key
|
||||
}
|
||||
|
||||
pub fn stage(self) -> Stage {
|
||||
self.info().stage
|
||||
}
|
||||
|
||||
pub fn default_enabled(self) -> bool {
|
||||
self.info().default_enabled
|
||||
}
|
||||
|
||||
fn info(self) -> &'static FeatureSpec {
|
||||
FEATURES
|
||||
.iter()
|
||||
.find(|spec| spec.id == self)
|
||||
.unwrap_or_else(|| unreachable!("missing FeatureSpec for {:?}", self))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub struct LegacyFeatureUsage {
|
||||
pub alias: String,
|
||||
pub feature: Feature,
|
||||
}
|
||||
|
||||
/// Holds the effective set of enabled features.
|
||||
#[derive(Debug, Clone, Default, PartialEq)]
|
||||
pub struct Features {
|
||||
enabled: BTreeSet<Feature>,
|
||||
legacy_usages: BTreeSet<LegacyFeatureUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FeatureOverrides {
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub web_search_request: Option<bool>,
|
||||
pub experimental_sandbox_command_assessment: Option<bool>,
|
||||
}
|
||||
|
||||
impl FeatureOverrides {
|
||||
fn apply(self, features: &mut Features) {
|
||||
LegacyFeatureToggles {
|
||||
include_apply_patch_tool: self.include_apply_patch_tool,
|
||||
tools_web_search: self.web_search_request,
|
||||
..Default::default()
|
||||
}
|
||||
.apply(features);
|
||||
}
|
||||
}
|
||||
|
||||
impl Features {
|
||||
/// Starts with built-in defaults.
|
||||
pub fn with_defaults() -> Self {
|
||||
let mut set = BTreeSet::new();
|
||||
for spec in FEATURES {
|
||||
if spec.default_enabled {
|
||||
set.insert(spec.id);
|
||||
}
|
||||
}
|
||||
Self {
|
||||
enabled: set,
|
||||
legacy_usages: BTreeSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enabled(&self, f: Feature) -> bool {
|
||||
self.enabled.contains(&f)
|
||||
}
|
||||
|
||||
pub fn enable(&mut self, f: Feature) -> &mut Self {
|
||||
self.enabled.insert(f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn disable(&mut self, f: Feature) -> &mut Self {
|
||||
self.enabled.remove(&f);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn record_legacy_usage_force(&mut self, alias: &str, feature: Feature) {
|
||||
self.legacy_usages.insert(LegacyFeatureUsage {
|
||||
alias: alias.to_string(),
|
||||
feature,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn record_legacy_usage(&mut self, alias: &str, feature: Feature) {
|
||||
if alias == feature.key() {
|
||||
return;
|
||||
}
|
||||
self.record_legacy_usage_force(alias, feature);
|
||||
}
|
||||
|
||||
pub fn legacy_feature_usages(&self) -> impl Iterator<Item = (&str, Feature)> + '_ {
|
||||
self.legacy_usages
|
||||
.iter()
|
||||
.map(|usage| (usage.alias.as_str(), usage.feature))
|
||||
}
|
||||
|
||||
/// Apply a table of key -> bool toggles (e.g. from TOML).
|
||||
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
|
||||
for (k, v) in m {
|
||||
match feature_for_key(k) {
|
||||
Some(feat) => {
|
||||
if k != feat.key() {
|
||||
self.record_legacy_usage(k.as_str(), feat);
|
||||
}
|
||||
if *v {
|
||||
self.enable(feat);
|
||||
} else {
|
||||
self.disable(feat);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
tracing::warn!("unknown feature key in config: {k}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_config(
|
||||
cfg: &ConfigToml,
|
||||
config_profile: &ConfigProfile,
|
||||
overrides: FeatureOverrides,
|
||||
) -> Self {
|
||||
let mut features = Features::with_defaults();
|
||||
|
||||
let base_legacy = LegacyFeatureToggles {
|
||||
experimental_sandbox_command_assessment: cfg.experimental_sandbox_command_assessment,
|
||||
experimental_use_freeform_apply_patch: cfg.experimental_use_freeform_apply_patch,
|
||||
experimental_use_unified_exec_tool: cfg.experimental_use_unified_exec_tool,
|
||||
experimental_use_rmcp_client: cfg.experimental_use_rmcp_client,
|
||||
tools_web_search: cfg.tools.as_ref().and_then(|t| t.web_search),
|
||||
tools_view_image: cfg.tools.as_ref().and_then(|t| t.view_image),
|
||||
..Default::default()
|
||||
};
|
||||
base_legacy.apply(&mut features);
|
||||
|
||||
if let Some(base_features) = cfg.features.as_ref() {
|
||||
features.apply_map(&base_features.entries);
|
||||
}
|
||||
|
||||
let profile_legacy = LegacyFeatureToggles {
|
||||
include_apply_patch_tool: config_profile.include_apply_patch_tool,
|
||||
experimental_sandbox_command_assessment: config_profile
|
||||
.experimental_sandbox_command_assessment,
|
||||
experimental_use_freeform_apply_patch: config_profile
|
||||
.experimental_use_freeform_apply_patch,
|
||||
|
||||
experimental_use_unified_exec_tool: config_profile.experimental_use_unified_exec_tool,
|
||||
experimental_use_rmcp_client: config_profile.experimental_use_rmcp_client,
|
||||
tools_web_search: config_profile.tools_web_search,
|
||||
tools_view_image: config_profile.tools_view_image,
|
||||
};
|
||||
profile_legacy.apply(&mut features);
|
||||
if let Some(profile_features) = config_profile.features.as_ref() {
|
||||
features.apply_map(&profile_features.entries);
|
||||
}
|
||||
|
||||
overrides.apply(&mut features);
|
||||
|
||||
features
|
||||
}
|
||||
}
|
||||
|
||||
/// Keys accepted in `[features]` tables.
|
||||
fn feature_for_key(key: &str) -> Option<Feature> {
|
||||
for spec in FEATURES {
|
||||
if spec.key == key {
|
||||
return Some(spec.id);
|
||||
}
|
||||
}
|
||||
legacy::feature_for_key(key)
|
||||
}
|
||||
|
||||
/// Returns `true` if the provided string matches a known feature toggle key.
|
||||
pub fn is_known_feature_key(key: &str) -> bool {
|
||||
feature_for_key(key).is_some()
|
||||
}
|
||||
|
||||
/// Deserializable features table for TOML.
|
||||
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
|
||||
pub struct FeaturesToml {
|
||||
#[serde(flatten)]
|
||||
pub entries: BTreeMap<String, bool>,
|
||||
}
|
||||
|
||||
/// Single, easy-to-read registry of all feature definitions.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FeatureSpec {
|
||||
pub id: Feature,
|
||||
pub key: &'static str,
|
||||
pub stage: Stage,
|
||||
pub default_enabled: bool,
|
||||
}
|
||||
|
||||
pub const FEATURES: &[FeatureSpec] = &[
|
||||
FeatureSpec {
|
||||
id: Feature::UnifiedExec,
|
||||
key: "unified_exec",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::RmcpClient,
|
||||
key: "rmcp_client",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ApplyPatchFreeform,
|
||||
key: "apply_patch_freeform",
|
||||
stage: Stage::Beta,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ViewImageTool,
|
||||
key: "view_image_tool",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::WebSearchRequest,
|
||||
key: "web_search_request",
|
||||
stage: Stage::Stable,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::SandboxCommandAssessment,
|
||||
key: "experimental_sandbox_command_assessment",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::GhostCommit,
|
||||
key: "ghost_commit",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: true,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::WindowsSandbox,
|
||||
key: "enable_experimental_windows_sandbox",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
137
llmx-rs/core/src/features/legacy.rs
Normal file
137
llmx-rs/core/src/features/legacy.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use super::Feature;
|
||||
use super::Features;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct Alias {
|
||||
legacy_key: &'static str,
|
||||
feature: Feature,
|
||||
}
|
||||
|
||||
const ALIASES: &[Alias] = &[
|
||||
Alias {
|
||||
legacy_key: "experimental_sandbox_command_assessment",
|
||||
feature: Feature::SandboxCommandAssessment,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_unified_exec_tool",
|
||||
feature: Feature::UnifiedExec,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_rmcp_client",
|
||||
feature: Feature::RmcpClient,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "experimental_use_freeform_apply_patch",
|
||||
feature: Feature::ApplyPatchFreeform,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "include_apply_patch_tool",
|
||||
feature: Feature::ApplyPatchFreeform,
|
||||
},
|
||||
Alias {
|
||||
legacy_key: "web_search",
|
||||
feature: Feature::WebSearchRequest,
|
||||
},
|
||||
];
|
||||
|
||||
pub(crate) fn feature_for_key(key: &str) -> Option<Feature> {
|
||||
ALIASES
|
||||
.iter()
|
||||
.find(|alias| alias.legacy_key == key)
|
||||
.map(|alias| {
|
||||
log_alias(alias.legacy_key, alias.feature);
|
||||
alias.feature
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LegacyFeatureToggles {
|
||||
pub include_apply_patch_tool: Option<bool>,
|
||||
pub experimental_sandbox_command_assessment: Option<bool>,
|
||||
pub experimental_use_freeform_apply_patch: Option<bool>,
|
||||
pub experimental_use_unified_exec_tool: Option<bool>,
|
||||
pub experimental_use_rmcp_client: Option<bool>,
|
||||
pub tools_web_search: Option<bool>,
|
||||
pub tools_view_image: Option<bool>,
|
||||
}
|
||||
|
||||
impl LegacyFeatureToggles {
|
||||
pub fn apply(self, features: &mut Features) {
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ApplyPatchFreeform,
|
||||
self.include_apply_patch_tool,
|
||||
"include_apply_patch_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::SandboxCommandAssessment,
|
||||
self.experimental_sandbox_command_assessment,
|
||||
"experimental_sandbox_command_assessment",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ApplyPatchFreeform,
|
||||
self.experimental_use_freeform_apply_patch,
|
||||
"experimental_use_freeform_apply_patch",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::UnifiedExec,
|
||||
self.experimental_use_unified_exec_tool,
|
||||
"experimental_use_unified_exec_tool",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::RmcpClient,
|
||||
self.experimental_use_rmcp_client,
|
||||
"experimental_use_rmcp_client",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::WebSearchRequest,
|
||||
self.tools_web_search,
|
||||
"tools.web_search",
|
||||
);
|
||||
set_if_some(
|
||||
features,
|
||||
Feature::ViewImageTool,
|
||||
self.tools_view_image,
|
||||
"tools.view_image",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_if_some(
|
||||
features: &mut Features,
|
||||
feature: Feature,
|
||||
maybe_value: Option<bool>,
|
||||
alias_key: &'static str,
|
||||
) {
|
||||
if let Some(enabled) = maybe_value {
|
||||
set_feature(features, feature, enabled);
|
||||
log_alias(alias_key, feature);
|
||||
features.record_legacy_usage(alias_key, feature);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_feature(features: &mut Features, feature: Feature, enabled: bool) {
|
||||
if enabled {
|
||||
features.enable(feature);
|
||||
} else {
|
||||
features.disable(feature);
|
||||
}
|
||||
}
|
||||
|
||||
fn log_alias(alias: &str, feature: Feature) {
|
||||
let canonical = feature.key();
|
||||
if alias == canonical {
|
||||
return;
|
||||
}
|
||||
info!(
|
||||
%alias,
|
||||
canonical,
|
||||
"legacy feature toggle detected; prefer `[features].{canonical}`"
|
||||
);
|
||||
}
|
||||
6
llmx-rs/core/src/flags.rs
Normal file
6
llmx-rs/core/src/flags.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
use env_flags::env_flags;
|
||||
|
||||
env_flags! {
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
}
|
||||
14
llmx-rs/core/src/function_tool.rs
Normal file
14
llmx-rs/core/src/function_tool.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error, PartialEq)]
|
||||
pub enum FunctionCallError {
|
||||
#[error("{0}")]
|
||||
RespondToModel(String),
|
||||
#[error("{0}")]
|
||||
#[allow(dead_code)] // TODO(jif) fix in a follow-up PR
|
||||
Denied(String),
|
||||
#[error("LocalShellCall without call_id or id")]
|
||||
MissingLocalShellCallId,
|
||||
#[error("Fatal error: {0}")]
|
||||
Fatal(String),
|
||||
}
|
||||
1124
llmx-rs/core/src/git_info.rs
Normal file
1124
llmx-rs/core/src/git_info.rs
Normal file
File diff suppressed because it is too large
Load Diff
72
llmx-rs/core/src/landlock.rs
Normal file
72
llmx-rs/core/src/landlock.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::process::Child;
|
||||
|
||||
/// Spawn a shell tool command under the Linux Landlock+seccomp sandbox helper
|
||||
/// (codex-linux-sandbox).
|
||||
///
|
||||
/// Unlike macOS Seatbelt where we directly embed the policy text, the Linux
|
||||
/// helper accepts a list of `--sandbox-permission`/`-s` flags mirroring the
|
||||
/// public CLI. We convert the internal [`SandboxPolicy`] representation into
|
||||
/// the equivalent CLI options.
|
||||
pub async fn spawn_command_under_linux_sandbox<P>(
|
||||
codex_linux_sandbox_exe: P,
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
{
|
||||
let args = create_linux_sandbox_command_args(command, sandbox_policy, sandbox_policy_cwd);
|
||||
let arg0 = Some("codex-linux-sandbox");
|
||||
spawn_child_async(
|
||||
codex_linux_sandbox_exe.as_ref().to_path_buf(),
|
||||
args,
|
||||
arg0,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
stdio_policy,
|
||||
env,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Converts the sandbox policy into the CLI invocation for `codex-linux-sandbox`.
|
||||
pub(crate) fn create_linux_sandbox_command_args(
|
||||
command: Vec<String>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
#[expect(clippy::expect_used)]
|
||||
let sandbox_policy_cwd = sandbox_policy_cwd
|
||||
.to_str()
|
||||
.expect("cwd must be valid UTF-8")
|
||||
.to_string();
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
let sandbox_policy_json =
|
||||
serde_json::to_string(sandbox_policy).expect("Failed to serialize SandboxPolicy to JSON");
|
||||
|
||||
let mut linux_cmd: Vec<String> = vec![
|
||||
"--sandbox-policy-cwd".to_string(),
|
||||
sandbox_policy_cwd,
|
||||
"--sandbox-policy".to_string(),
|
||||
sandbox_policy_json,
|
||||
// Separator so that command arguments starting with `-` are not parsed as
|
||||
// options of the helper itself.
|
||||
"--".to_string(),
|
||||
];
|
||||
|
||||
// Append the original tool command.
|
||||
linux_cmd.extend(command);
|
||||
|
||||
linux_cmd
|
||||
}
|
||||
111
llmx-rs/core/src/lib.rs
Normal file
111
llmx-rs/core/src/lib.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
//! Root of the `codex-core` library.
|
||||
|
||||
// Prevent accidental direct writes to stdout/stderr in library code. All
|
||||
// user-visible output must go through the appropriate abstraction (e.g.,
|
||||
// the TUI or the tracing stack).
|
||||
#![deny(clippy::print_stdout, clippy::print_stderr)]
|
||||
|
||||
mod apply_patch;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
mod codex_conversation;
|
||||
pub use codex_conversation::CodexConversation;
|
||||
mod codex_delegate;
|
||||
mod command_safety;
|
||||
pub mod config;
|
||||
pub mod config_loader;
|
||||
mod context_manager;
|
||||
pub mod custom_prompts;
|
||||
mod environment_context;
|
||||
pub mod error;
|
||||
pub mod exec;
|
||||
pub mod exec_env;
|
||||
pub mod features;
|
||||
mod flags;
|
||||
pub mod git_info;
|
||||
pub mod landlock;
|
||||
pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod response_processing;
|
||||
pub mod sandboxing;
|
||||
pub mod token_data;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub mod review_format;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
// Re-export common auth types for workspace consumers
|
||||
pub use auth::AuthManager;
|
||||
pub use auth::CodexAuth;
|
||||
pub mod default_client;
|
||||
pub mod model_family;
|
||||
mod openai_model_info;
|
||||
pub mod project_doc;
|
||||
mod rollout;
|
||||
pub(crate) mod safety;
|
||||
pub mod seatbelt;
|
||||
pub mod shell;
|
||||
pub mod spawn;
|
||||
pub mod terminal;
|
||||
mod tools;
|
||||
pub mod turn_diff_tracker;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
||||
pub use rollout::RolloutRecorder;
|
||||
pub use rollout::SESSIONS_SUBDIR;
|
||||
pub use rollout::SessionMeta;
|
||||
pub use rollout::find_conversation_path_by_id_str;
|
||||
pub use rollout::list::ConversationItem;
|
||||
pub use rollout::list::ConversationsPage;
|
||||
pub use rollout::list::Cursor;
|
||||
pub use rollout::list::parse_cursor;
|
||||
pub use rollout::list::read_head_for_summary;
|
||||
mod function_tool;
|
||||
mod state;
|
||||
mod tasks;
|
||||
mod user_notification;
|
||||
mod user_shell_command;
|
||||
pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use command_safety::is_safe_command;
|
||||
pub use safety::get_platform_sandbox;
|
||||
pub use safety::set_windows_sandbox_enabled;
|
||||
// Re-export the protocol types from the standalone `codex-protocol` crate so existing
|
||||
// `codex_core::protocol::...` references continue to work across the workspace.
|
||||
pub use codex_protocol::protocol;
|
||||
// Re-export protocol config enums to ensure call sites can use the same types
|
||||
// as those in the protocol crate when constructing protocol messages.
|
||||
pub use codex_protocol::config_types as protocol_config_types;
|
||||
|
||||
pub use client::ModelClient;
|
||||
pub use client_common::Prompt;
|
||||
pub use client_common::REVIEW_PROMPT;
|
||||
pub use client_common::ResponseEvent;
|
||||
pub use client_common::ResponseStream;
|
||||
pub use codex_protocol::models::ContentItem;
|
||||
pub use codex_protocol::models::LocalShellAction;
|
||||
pub use codex_protocol::models::LocalShellExecAction;
|
||||
pub use codex_protocol::models::LocalShellStatus;
|
||||
pub use codex_protocol::models::ResponseItem;
|
||||
pub use compact::content_items_to_text;
|
||||
pub use event_mapping::parse_turn_item;
|
||||
pub mod compact;
|
||||
pub mod otel_init;
|
||||
72
llmx-rs/core/src/mcp/auth.rs
Normal file
72
llmx-rs/core/src/mcp/auth.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_protocol::protocol::McpAuthStatus;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::determine_streamable_http_auth_status;
|
||||
use futures::future::join_all;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpAuthStatusEntry {
|
||||
pub config: McpServerConfig,
|
||||
pub auth_status: McpAuthStatus,
|
||||
}
|
||||
|
||||
pub async fn compute_auth_statuses<'a, I>(
|
||||
servers: I,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> HashMap<String, McpAuthStatusEntry>
|
||||
where
|
||||
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
|
||||
{
|
||||
let futures = servers.into_iter().map(|(name, config)| {
|
||||
let name = name.clone();
|
||||
let config = config.clone();
|
||||
async move {
|
||||
let auth_status = match compute_auth_status(&name, &config, store_mode).await {
|
||||
Ok(status) => status,
|
||||
Err(error) => {
|
||||
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
|
||||
McpAuthStatus::Unsupported
|
||||
}
|
||||
};
|
||||
let entry = McpAuthStatusEntry {
|
||||
config,
|
||||
auth_status,
|
||||
};
|
||||
(name, entry)
|
||||
}
|
||||
});
|
||||
|
||||
join_all(futures).await.into_iter().collect()
|
||||
}
|
||||
|
||||
async fn compute_auth_status(
|
||||
server_name: &str,
|
||||
config: &McpServerConfig,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<McpAuthStatus> {
|
||||
match &config.transport {
|
||||
McpServerTransportConfig::Stdio { .. } => Ok(McpAuthStatus::Unsupported),
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
bearer_token_env_var,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
} => {
|
||||
determine_streamable_http_auth_status(
|
||||
server_name,
|
||||
url,
|
||||
bearer_token_env_var.as_deref(),
|
||||
http_headers.clone(),
|
||||
env_http_headers.clone(),
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
1
llmx-rs/core/src/mcp/mod.rs
Normal file
1
llmx-rs/core/src/mcp/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth;
|
||||
822
llmx-rs/core/src/mcp_connection_manager.rs
Normal file
822
llmx-rs/core/src/mcp_connection_manager.rs
Normal file
@@ -0,0 +1,822 @@
|
||||
//! Connection manager for Model Context Protocol (MCP) servers.
|
||||
//!
|
||||
//! The [`McpConnectionManager`] owns one [`codex_rmcp_client::RmcpClient`] per
|
||||
//! configured server (keyed by the *server name*). It offers convenience
|
||||
//! helpers to query the available tools across *all* servers and returns them
|
||||
//! in a single aggregated map using the fully-qualified tool name
|
||||
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::ffi::OsString;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use mcp_types::ClientCapabilities;
|
||||
use mcp_types::Implementation;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::Tool;
|
||||
|
||||
use serde_json::json;
|
||||
use sha1::Digest;
|
||||
use sha1::Sha1;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::config::types::McpServerConfig;
|
||||
use crate::config::types::McpServerTransportConfig;
|
||||
|
||||
/// Delimiter used to separate the server name from the tool name in a fully
|
||||
/// qualified tool name.
|
||||
///
|
||||
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
|
||||
/// choose a delimiter from this character set.
|
||||
const MCP_TOOL_NAME_DELIMITER: &str = "__";
|
||||
const MAX_TOOL_NAME_LENGTH: usize = 64;
|
||||
|
||||
/// Default timeout for initializing MCP server & initially listing tools.
|
||||
pub const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
|
||||
/// Default timeout for individual tool calls.
|
||||
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
/// Map that holds a startup error for every MCP server that could **not** be
|
||||
/// spawned successfully.
|
||||
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
|
||||
|
||||
fn qualify_tools(tools: Vec<ToolInfo>) -> HashMap<String, ToolInfo> {
|
||||
let mut used_names = HashSet::new();
|
||||
let mut qualified_tools = HashMap::new();
|
||||
for tool in tools {
|
||||
let mut qualified_name = format!(
|
||||
"mcp{}{}{}{}",
|
||||
MCP_TOOL_NAME_DELIMITER, tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
|
||||
);
|
||||
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
|
||||
let mut hasher = Sha1::new();
|
||||
hasher.update(qualified_name.as_bytes());
|
||||
let sha1 = hasher.finalize();
|
||||
let sha1_str = format!("{sha1:x}");
|
||||
|
||||
// Truncate to make room for the hash suffix
|
||||
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
|
||||
|
||||
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
|
||||
}
|
||||
|
||||
if used_names.contains(&qualified_name) {
|
||||
warn!("skipping duplicated tool {}", qualified_name);
|
||||
continue;
|
||||
}
|
||||
|
||||
used_names.insert(qualified_name.clone());
|
||||
qualified_tools.insert(qualified_name, tool);
|
||||
}
|
||||
|
||||
qualified_tools
|
||||
}
|
||||
|
||||
struct ToolInfo {
|
||||
server_name: String,
|
||||
tool_name: String,
|
||||
tool: Tool,
|
||||
}
|
||||
|
||||
struct ManagedClient {
|
||||
client: Arc<RmcpClient>,
|
||||
startup_timeout: Duration,
|
||||
tool_timeout: Option<Duration>,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`RmcpClient`] instances.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct McpConnectionManager {
|
||||
/// Server-name -> client instance.
|
||||
///
|
||||
/// The server name originates from the keys of the `mcp_servers` map in
|
||||
/// the user configuration.
|
||||
clients: HashMap<String, ManagedClient>,
|
||||
|
||||
/// Fully qualified tool name -> tool instance.
|
||||
tools: HashMap<String, ToolInfo>,
|
||||
|
||||
/// Server-name -> configured tool filters.
|
||||
tool_filters: HashMap<String, ToolFilter>,
|
||||
}
|
||||
|
||||
impl McpConnectionManager {
|
||||
/// Spawn a [`RmcpClient`] for each configured server.
|
||||
///
|
||||
/// * `mcp_servers` – Map loaded from the user configuration where *keys*
|
||||
/// are human-readable server identifiers and *values* are the spawn
|
||||
/// instructions.
|
||||
///
|
||||
/// Servers that fail to start are reported in `ClientStartErrors`: the
|
||||
/// user should be informed about these errors.
|
||||
pub async fn new(
|
||||
mcp_servers: HashMap<String, McpServerConfig>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
) -> Result<(Self, ClientStartErrors)> {
|
||||
// Early exit if no servers are configured.
|
||||
if mcp_servers.is_empty() {
|
||||
return Ok((Self::default(), ClientStartErrors::default()));
|
||||
}
|
||||
|
||||
// Launch all configured servers concurrently.
|
||||
let mut join_set = JoinSet::new();
|
||||
let mut errors = ClientStartErrors::new();
|
||||
let mut tool_filters: HashMap<String, ToolFilter> = HashMap::new();
|
||||
|
||||
for (server_name, cfg) in mcp_servers {
|
||||
// Validate server name before spawning
|
||||
if !is_valid_mcp_server_name(&server_name) {
|
||||
let error = anyhow::anyhow!(
|
||||
"invalid server name '{server_name}': must match pattern ^[a-zA-Z0-9_-]+$"
|
||||
);
|
||||
errors.insert(server_name, error);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !cfg.enabled {
|
||||
tool_filters.insert(server_name, ToolFilter::from_config(&cfg));
|
||||
continue;
|
||||
}
|
||||
|
||||
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
|
||||
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
|
||||
tool_filters.insert(server_name.clone(), ToolFilter::from_config(&cfg));
|
||||
|
||||
let resolved_bearer_token = match &cfg.transport {
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
bearer_token_env_var,
|
||||
..
|
||||
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
|
||||
_ => Ok(None),
|
||||
};
|
||||
|
||||
join_set.spawn(async move {
|
||||
let McpServerConfig { transport, .. } = cfg;
|
||||
let params = mcp_types::InitializeRequestParams {
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
|
||||
// indicates this should be an empty object.
|
||||
elicitation: Some(json!({})),
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-mcp-client".to_owned(),
|
||||
version: env!("CARGO_PKG_VERSION").to_owned(),
|
||||
title: Some("Codex".into()),
|
||||
// This field is used by Codex when it is an MCP
|
||||
// server: it should not be used when Codex is
|
||||
// an MCP client.
|
||||
user_agent: None,
|
||||
},
|
||||
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
|
||||
};
|
||||
|
||||
let resolved_bearer_token = resolved_bearer_token.unwrap_or_default();
|
||||
let client_result = match transport {
|
||||
McpServerTransportConfig::Stdio {
|
||||
command,
|
||||
args,
|
||||
env,
|
||||
env_vars,
|
||||
cwd,
|
||||
} => {
|
||||
let command_os: OsString = command.into();
|
||||
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
|
||||
match RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd)
|
||||
.await
|
||||
{
|
||||
Ok(client) => {
|
||||
let client = Arc::new(client);
|
||||
client
|
||||
.initialize(params.clone(), Some(startup_timeout))
|
||||
.await
|
||||
.map(|_| client)
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
McpServerTransportConfig::StreamableHttp {
|
||||
url,
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
..
|
||||
} => {
|
||||
match RmcpClient::new_streamable_http_client(
|
||||
&server_name,
|
||||
&url,
|
||||
resolved_bearer_token.clone(),
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(client) => {
|
||||
let client = Arc::new(client);
|
||||
client
|
||||
.initialize(params.clone(), Some(startup_timeout))
|
||||
.await
|
||||
.map(|_| client)
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
(server_name, tool_timeout),
|
||||
client_result.map(|client| (client, startup_timeout)),
|
||||
)
|
||||
});
|
||||
}
|
||||
|
||||
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
let ((server_name, tool_timeout), client_res) = match res {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
warn!("Task panic when starting MCP server: {e:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match client_res {
|
||||
Ok((client, startup_timeout)) => {
|
||||
clients.insert(
|
||||
server_name,
|
||||
ManagedClient {
|
||||
client,
|
||||
startup_timeout,
|
||||
tool_timeout: Some(tool_timeout),
|
||||
},
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
errors.insert(server_name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let all_tools = match list_all_tools(&clients).await {
|
||||
Ok(tools) => tools,
|
||||
Err(e) => {
|
||||
warn!("Failed to list tools from some MCP servers: {e:#}");
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
let filtered_tools = filter_tools(all_tools, &tool_filters);
|
||||
let tools = qualify_tools(filtered_tools);
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
clients,
|
||||
tools,
|
||||
tool_filters,
|
||||
},
|
||||
errors,
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all tools. Each key is the
|
||||
/// fully-qualified name for the tool.
|
||||
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
|
||||
self.tools
|
||||
.iter()
|
||||
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resources. Each key is the
|
||||
/// server name and the value is a vector of resources.
|
||||
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<Resource> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor.as_ref().map(|next| ListResourcesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resources(params, timeout).await {
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resources);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!("resources/list returned duplicate cursor")),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<Resource>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(resources))) => {
|
||||
aggregated.insert(server_name, resources);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!("Failed to list resources for MCP server '{server_name}': {err:#}");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resources for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Returns a single map that contains all resource templates. Each key is the
|
||||
/// server name and the value is a vector of resource templates.
|
||||
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for (server_name, managed_client) in &self.clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let timeout = managed_client.tool_timeout;
|
||||
|
||||
join_set.spawn(async move {
|
||||
let mut collected: Vec<ResourceTemplate> = Vec::new();
|
||||
let mut cursor: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let params = cursor
|
||||
.as_ref()
|
||||
.map(|next| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(next.clone()),
|
||||
});
|
||||
let response = match client_clone.list_resource_templates(params, timeout).await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => return (server_name_cloned, Err(err)),
|
||||
};
|
||||
|
||||
collected.extend(response.resource_templates);
|
||||
|
||||
match response.next_cursor {
|
||||
Some(next) => {
|
||||
if cursor.as_ref() == Some(&next) {
|
||||
return (
|
||||
server_name_cloned,
|
||||
Err(anyhow!(
|
||||
"resources/templates/list returned duplicate cursor"
|
||||
)),
|
||||
);
|
||||
}
|
||||
cursor = Some(next);
|
||||
}
|
||||
None => return (server_name_cloned, Ok(collected)),
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: HashMap<String, Vec<ResourceTemplate>> = HashMap::new();
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok((server_name, Ok(templates))) => {
|
||||
aggregated.insert(server_name, templates);
|
||||
}
|
||||
Ok((server_name, Err(err))) => {
|
||||
warn!(
|
||||
"Failed to list resource templates for MCP server '{server_name}': {err:#}"
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("Task panic when listing resource templates for MCP server: {err:#}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
aggregated
|
||||
}
|
||||
|
||||
/// Invoke the tool indicated by the (server, tool) pair.
|
||||
pub async fn call_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
) -> Result<mcp_types::CallToolResult> {
|
||||
if let Some(filter) = self.tool_filters.get(server)
|
||||
&& !filter.allows(tool)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"tool '{tool}' is disabled for MCP server '{server}'"
|
||||
));
|
||||
}
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = &managed.client;
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.call_tool(tool.to_string(), arguments, timeout)
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
|
||||
}
|
||||
|
||||
/// List resources from the specified server.
|
||||
pub async fn list_resources(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourcesRequestParams>,
|
||||
) -> Result<ListResourcesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resources(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// List resource templates from the specified server.
|
||||
pub async fn list_resource_templates(
|
||||
&self,
|
||||
server: &str,
|
||||
params: Option<ListResourceTemplatesRequestParams>,
|
||||
) -> Result<ListResourceTemplatesResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
|
||||
client
|
||||
.list_resource_templates(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/templates/list failed for `{server}`"))
|
||||
}
|
||||
|
||||
/// Read a resource from the specified server.
|
||||
pub async fn read_resource(
|
||||
&self,
|
||||
server: &str,
|
||||
params: ReadResourceRequestParams,
|
||||
) -> Result<ReadResourceResult> {
|
||||
let managed = self
|
||||
.clients
|
||||
.get(server)
|
||||
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
|
||||
let client = managed.client.clone();
|
||||
let timeout = managed.tool_timeout;
|
||||
let uri = params.uri.clone();
|
||||
|
||||
client
|
||||
.read_resource(params, timeout)
|
||||
.await
|
||||
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
|
||||
}
|
||||
|
||||
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
|
||||
self.tools
|
||||
.get(tool_name)
|
||||
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool is allowed to be used if both are true:
|
||||
/// 1. enabled is None (no allowlist is set) or the tool is explicitly enabled.
|
||||
/// 2. The tool is not explicitly disabled.
|
||||
#[derive(Default, Clone)]
|
||||
struct ToolFilter {
|
||||
enabled: Option<HashSet<String>>,
|
||||
disabled: HashSet<String>,
|
||||
}
|
||||
|
||||
impl ToolFilter {
|
||||
fn from_config(cfg: &McpServerConfig) -> Self {
|
||||
let enabled = cfg
|
||||
.enabled_tools
|
||||
.as_ref()
|
||||
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>());
|
||||
let disabled = cfg
|
||||
.disabled_tools
|
||||
.as_ref()
|
||||
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>())
|
||||
.unwrap_or_default();
|
||||
|
||||
Self { enabled, disabled }
|
||||
}
|
||||
|
||||
fn allows(&self, tool_name: &str) -> bool {
|
||||
if let Some(enabled) = &self.enabled
|
||||
&& !enabled.contains(tool_name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
!self.disabled.contains(tool_name)
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_tools(tools: Vec<ToolInfo>, filters: &HashMap<String, ToolFilter>) -> Vec<ToolInfo> {
|
||||
tools
|
||||
.into_iter()
|
||||
.filter(|tool| {
|
||||
filters
|
||||
.get(&tool.server_name)
|
||||
.is_none_or(|filter| filter.allows(&tool.tool_name))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn resolve_bearer_token(
|
||||
server_name: &str,
|
||||
bearer_token_env_var: Option<&str>,
|
||||
) -> Result<Option<String>> {
|
||||
let Some(env_var) = bearer_token_env_var else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
match env::var(env_var) {
|
||||
Ok(value) => {
|
||||
if value.is_empty() {
|
||||
Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is empty"
|
||||
))
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
}
|
||||
}
|
||||
Err(env::VarError::NotPresent) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' is not set"
|
||||
)),
|
||||
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
|
||||
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Query every server for its available tools and return a single map that
|
||||
/// contains all tools. Each key is the fully-qualified name for the tool.
|
||||
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn one task per server so we can query them concurrently. This
|
||||
// keeps the overall latency roughly at the slowest server instead of
|
||||
// the cumulative latency.
|
||||
for (server_name, managed_client) in clients {
|
||||
let server_name_cloned = server_name.clone();
|
||||
let client_clone = managed_client.client.clone();
|
||||
let startup_timeout = managed_client.startup_timeout;
|
||||
join_set.spawn(async move {
|
||||
let res = client_clone.list_tools(None, Some(startup_timeout)).await;
|
||||
(server_name_cloned, res)
|
||||
});
|
||||
}
|
||||
|
||||
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
let (server_name, list_result) = if let Ok(result) = join_res {
|
||||
result
|
||||
} else {
|
||||
warn!("Task panic when listing tools for MCP server: {join_res:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
let list_result = if let Ok(result) = list_result {
|
||||
result
|
||||
} else {
|
||||
warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}");
|
||||
continue;
|
||||
};
|
||||
|
||||
for tool in list_result.tools {
|
||||
let tool_info = ToolInfo {
|
||||
server_name: server_name.clone(),
|
||||
tool_name: tool.name.clone(),
|
||||
tool,
|
||||
};
|
||||
aggregated.push(tool_info);
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"aggregated {} tools from {} servers",
|
||||
aggregated.len(),
|
||||
clients.len()
|
||||
);
|
||||
|
||||
Ok(aggregated)
|
||||
}
|
||||
|
||||
fn is_valid_mcp_server_name(server_name: &str) -> bool {
|
||||
!server_name.is_empty()
|
||||
&& server_name
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ToolInputSchema;
|
||||
use std::collections::HashSet;
|
||||
|
||||
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
|
||||
ToolInfo {
|
||||
server_name: server_name.to_string(),
|
||||
tool_name: tool_name.to_string(),
|
||||
tool: Tool {
|
||||
annotations: None,
|
||||
description: Some(format!("Test tool: {tool_name}")),
|
||||
input_schema: ToolInputSchema {
|
||||
properties: None,
|
||||
required: None,
|
||||
r#type: "object".to_string(),
|
||||
},
|
||||
name: tool_name.to_string(),
|
||||
output_schema: None,
|
||||
title: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_short_non_duplicated_names() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "tool1"),
|
||||
create_test_tool("server1", "tool2"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
assert!(qualified_tools.contains_key("mcp__server1__tool1"));
|
||||
assert!(qualified_tools.contains_key("mcp__server1__tool2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_duplicated_names_skipped() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
create_test_tool("server1", "duplicate_tool"),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
// Only the first tool should remain, the second is skipped
|
||||
assert_eq!(qualified_tools.len(), 1);
|
||||
assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_qualify_tools_long_names_same_server() {
|
||||
let server_name = "my_server";
|
||||
|
||||
let tools = vec![
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
create_test_tool(
|
||||
server_name,
|
||||
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
|
||||
),
|
||||
];
|
||||
|
||||
let qualified_tools = qualify_tools(tools);
|
||||
|
||||
assert_eq!(qualified_tools.len(), 2);
|
||||
|
||||
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
|
||||
keys.sort();
|
||||
|
||||
assert_eq!(keys[0].len(), 64);
|
||||
assert_eq!(
|
||||
keys[0],
|
||||
"mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9"
|
||||
);
|
||||
|
||||
assert_eq!(keys[1].len(), 64);
|
||||
assert_eq!(
|
||||
keys[1],
|
||||
"mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_filter_allows_by_default() {
|
||||
let filter = ToolFilter::default();
|
||||
|
||||
assert!(filter.allows("any"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_filter_applies_enabled_list() {
|
||||
let filter = ToolFilter {
|
||||
enabled: Some(HashSet::from(["allowed".to_string()])),
|
||||
disabled: HashSet::new(),
|
||||
};
|
||||
|
||||
assert!(filter.allows("allowed"));
|
||||
assert!(!filter.allows("denied"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_filter_applies_disabled_list() {
|
||||
let filter = ToolFilter {
|
||||
enabled: None,
|
||||
disabled: HashSet::from(["blocked".to_string()]),
|
||||
};
|
||||
|
||||
assert!(!filter.allows("blocked"));
|
||||
assert!(filter.allows("open"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_filter_applies_enabled_then_disabled() {
|
||||
let filter = ToolFilter {
|
||||
enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])),
|
||||
disabled: HashSet::from(["remove".to_string()]),
|
||||
};
|
||||
|
||||
assert!(filter.allows("keep"));
|
||||
assert!(!filter.allows("remove"));
|
||||
assert!(!filter.allows("unknown"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_tools_applies_per_server_filters() {
|
||||
let tools = vec![
|
||||
create_test_tool("server1", "tool_a"),
|
||||
create_test_tool("server1", "tool_b"),
|
||||
create_test_tool("server2", "tool_a"),
|
||||
];
|
||||
let mut filters = HashMap::new();
|
||||
filters.insert(
|
||||
"server1".to_string(),
|
||||
ToolFilter {
|
||||
enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])),
|
||||
disabled: HashSet::from(["tool_b".to_string()]),
|
||||
},
|
||||
);
|
||||
filters.insert(
|
||||
"server2".to_string(),
|
||||
ToolFilter {
|
||||
enabled: None,
|
||||
disabled: HashSet::from(["tool_a".to_string()]),
|
||||
},
|
||||
);
|
||||
|
||||
let filtered = filter_tools(tools, &filters);
|
||||
|
||||
assert_eq!(filtered.len(), 1);
|
||||
assert_eq!(filtered[0].server_name, "server1");
|
||||
assert_eq!(filtered[0].tool_name, "tool_a");
|
||||
}
|
||||
}
|
||||
80
llmx-rs/core/src/mcp_tool_call.rs
Normal file
80
llmx-rs/core/src/mcp_tool_call.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use tracing::error;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
use crate::protocol::McpToolCallEndEvent;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
|
||||
/// Handles the specified tool call dispatches the appropriate
|
||||
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
|
||||
pub(crate) async fn handle_mcp_tool_call(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
call_id: String,
|
||||
server: String,
|
||||
tool_name: String,
|
||||
arguments: String,
|
||||
) -> ResponseInputItem {
|
||||
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
|
||||
// is not.
|
||||
let arguments_value = if arguments.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
match serde_json::from_str::<serde_json::Value>(&arguments) {
|
||||
Ok(value) => Some(value),
|
||||
Err(e) => {
|
||||
error!("failed to parse tool call arguments: {e}");
|
||||
return ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: format!("err: {e}"),
|
||||
success: Some(false),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone(),
|
||||
tool: tool_name.clone(),
|
||||
arguments: arguments_value.clone(),
|
||||
};
|
||||
|
||||
let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
invocation: invocation.clone(),
|
||||
});
|
||||
notify_mcp_tool_call_event(sess, turn_context, tool_call_begin_event).await;
|
||||
|
||||
let start = Instant::now();
|
||||
// Perform the tool call.
|
||||
let result = sess
|
||||
.call_tool(&server, &tool_name, arguments_value.clone())
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"));
|
||||
if let Err(e) = &result {
|
||||
tracing::warn!("MCP tool call error: {e:?}");
|
||||
}
|
||||
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
invocation,
|
||||
duration: start.elapsed(),
|
||||
result: result.clone(),
|
||||
});
|
||||
|
||||
notify_mcp_tool_call_event(sess, turn_context, tool_call_end_event.clone()).await;
|
||||
|
||||
ResponseInputItem::McpToolCallOutput { call_id, result }
|
||||
}
|
||||
|
||||
async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) {
|
||||
sess.send_event(turn_context, event).await;
|
||||
}
|
||||
286
llmx-rs/core/src/message_history.rs
Normal file
286
llmx-rs/core/src/message_history.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
//! Persistence layer for the global, append-only *message history* file.
|
||||
//!
|
||||
//! The history is stored at `~/.codex/history.jsonl` with **one JSON object per
|
||||
//! line** so that it can be efficiently appended to and parsed with standard
|
||||
//! JSON-Lines tooling. Each record has the following schema:
|
||||
//!
|
||||
//! ````text
|
||||
//! {"conversation_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
|
||||
//! ````
|
||||
//!
|
||||
//! To minimise the chance of interleaved writes when multiple processes are
|
||||
//! appending concurrently, callers should *prepare the full line* (record +
|
||||
//! trailing `\n`) and write it with a **single `write(2)` system call** while
|
||||
//! the file descriptor is opened with the `O_APPEND` flag. POSIX guarantees
|
||||
//! that writes up to `PIPE_BUF` bytes are atomic in that case.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Result;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use std::time::Duration;
|
||||
use tokio::fs;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::config::types::HistoryPersistence;
|
||||
|
||||
use codex_protocol::ConversationId;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
#[cfg(unix)]
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
/// Filename that stores the message history inside `~/.codex`.
|
||||
const HISTORY_FILENAME: &str = "history.jsonl";
|
||||
|
||||
const MAX_RETRIES: usize = 10;
|
||||
const RETRY_SLEEP: Duration = Duration::from_millis(100);
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct HistoryEntry {
|
||||
pub session_id: String,
|
||||
pub ts: u64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
fn history_filepath(config: &Config) -> PathBuf {
|
||||
let mut path = config.codex_home.clone();
|
||||
path.push(HISTORY_FILENAME);
|
||||
path
|
||||
}
|
||||
|
||||
/// Append a `text` entry associated with `conversation_id` to the history file. Uses
|
||||
/// advisory file locking to ensure that concurrent writes do not interleave,
|
||||
/// which entails a small amount of blocking I/O internally.
|
||||
pub(crate) async fn append_entry(
|
||||
text: &str,
|
||||
conversation_id: &ConversationId,
|
||||
config: &Config,
|
||||
) -> Result<()> {
|
||||
match config.history.persistence {
|
||||
HistoryPersistence::SaveAll => {
|
||||
// Save everything: proceed.
|
||||
}
|
||||
HistoryPersistence::None => {
|
||||
// No history persistence requested.
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: check `text` for sensitive patterns
|
||||
|
||||
// Resolve `~/.codex/history.jsonl` and ensure the parent directory exists.
|
||||
let path = history_filepath(config);
|
||||
if let Some(parent) = path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
// Compute timestamp (seconds since the Unix epoch).
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map_err(|e| std::io::Error::other(format!("system clock before Unix epoch: {e}")))?
|
||||
.as_secs();
|
||||
|
||||
// Construct the JSON line first so we can write it in a single syscall.
|
||||
let entry = HistoryEntry {
|
||||
session_id: conversation_id.to_string(),
|
||||
ts,
|
||||
text: text.to_string(),
|
||||
};
|
||||
let mut line = serde_json::to_string(&entry)
|
||||
.map_err(|e| std::io::Error::other(format!("failed to serialise history entry: {e}")))?;
|
||||
line.push('\n');
|
||||
|
||||
// Open in append-only mode.
|
||||
let mut options = OpenOptions::new();
|
||||
options.append(true).read(true).create(true);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
options.mode(0o600);
|
||||
}
|
||||
|
||||
let mut history_file = options.open(&path)?;
|
||||
|
||||
// Ensure permissions.
|
||||
ensure_owner_only_permissions(&history_file).await?;
|
||||
|
||||
// Perform a blocking write under an advisory write lock using std::fs.
|
||||
tokio::task::spawn_blocking(move || -> Result<()> {
|
||||
// Retry a few times to avoid indefinite blocking when contended.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
match history_file.try_lock() {
|
||||
Ok(()) => {
|
||||
// While holding the exclusive lock, write the full line.
|
||||
history_file.write_all(line.as_bytes())?;
|
||||
history_file.flush()?;
|
||||
return Ok(());
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::WouldBlock,
|
||||
"could not acquire exclusive lock on history file after multiple attempts",
|
||||
))
|
||||
})
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Asynchronously fetch the history file's *identifier* (inode on Unix) and
|
||||
/// the current number of entries by counting newline characters.
|
||||
pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
|
||||
let path = history_filepath(config);
|
||||
|
||||
#[cfg(unix)]
|
||||
let log_id = {
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
// Obtain metadata (async) to get the identifier.
|
||||
let meta = match fs::metadata(&path).await {
|
||||
Ok(m) => m,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return (0, 0),
|
||||
Err(_) => return (0, 0),
|
||||
};
|
||||
meta.ino()
|
||||
};
|
||||
#[cfg(not(unix))]
|
||||
let log_id = 0u64;
|
||||
|
||||
// Open the file.
|
||||
let mut file = match fs::File::open(&path).await {
|
||||
Ok(f) => f,
|
||||
Err(_) => return (log_id, 0),
|
||||
};
|
||||
|
||||
// Count newline bytes.
|
||||
let mut buf = [0u8; 8192];
|
||||
let mut count = 0usize;
|
||||
loop {
|
||||
match file.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => {
|
||||
count += buf[..n].iter().filter(|&&b| b == b'\n').count();
|
||||
}
|
||||
Err(_) => return (log_id, 0),
|
||||
}
|
||||
}
|
||||
|
||||
(log_id, count)
|
||||
}
|
||||
|
||||
/// Given a `log_id` (on Unix this is the file's inode number) and a zero-based
|
||||
/// `offset`, return the corresponding `HistoryEntry` if the identifier matches
|
||||
/// the current history file **and** the requested offset exists. Any I/O or
|
||||
/// parsing errors are logged and result in `None`.
|
||||
///
|
||||
/// Note this function is not async because it uses a sync advisory file
|
||||
/// locking API.
|
||||
#[cfg(unix)]
|
||||
pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<HistoryEntry> {
|
||||
use std::io::BufRead;
|
||||
use std::io::BufReader;
|
||||
use std::os::unix::fs::MetadataExt;
|
||||
|
||||
let path = history_filepath(config);
|
||||
let file: File = match OpenOptions::new().read(true).open(&path) {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to open history file");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
let metadata = match file.metadata() {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to stat history file");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if metadata.ino() != log_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Open & lock file for reading using a shared lock.
|
||||
// Retry a few times to avoid indefinite blocking.
|
||||
for _ in 0..MAX_RETRIES {
|
||||
let lock_result = file.try_lock_shared();
|
||||
|
||||
match lock_result {
|
||||
Ok(()) => {
|
||||
let reader = BufReader::new(&file);
|
||||
for (idx, line_res) in reader.lines().enumerate() {
|
||||
let line = match line_res {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to read line from history file");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
if idx == offset {
|
||||
match serde_json::from_str::<HistoryEntry>(&line) {
|
||||
Ok(entry) => return Some(entry),
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to parse history entry");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not found at requested offset.
|
||||
return None;
|
||||
}
|
||||
Err(std::fs::TryLockError::WouldBlock) => {
|
||||
std::thread::sleep(RETRY_SLEEP);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Fallback stub for non-Unix systems: currently always returns `None`.
|
||||
#[cfg(not(unix))]
|
||||
pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<HistoryEntry> {
|
||||
let _ = (log_id, offset, config);
|
||||
None
|
||||
}
|
||||
|
||||
/// On Unix systems ensure the file permissions are `0o600` (rw-------). If the
|
||||
/// permissions cannot be changed the error is propagated to the caller.
|
||||
#[cfg(unix)]
|
||||
async fn ensure_owner_only_permissions(file: &File) -> Result<()> {
|
||||
let metadata = file.metadata()?;
|
||||
let current_mode = metadata.permissions().mode() & 0o777;
|
||||
if current_mode != 0o600 {
|
||||
let mut perms = metadata.permissions();
|
||||
perms.set_mode(0o600);
|
||||
let perms_clone = perms.clone();
|
||||
let file_clone = file.try_clone()?;
|
||||
tokio::task::spawn_blocking(move || file_clone.set_permissions(perms_clone)).await??;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
async fn ensure_owner_only_permissions(_file: &File) -> Result<()> {
|
||||
// For now, on non-Unix, simply succeed.
|
||||
Ok(())
|
||||
}
|
||||
193
llmx-rs/core/src/model_family.rs
Normal file
193
llmx-rs/core/src/model_family.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
use crate::config::types::ReasoningSummaryFormat;
|
||||
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
|
||||
use crate::tools::spec::ConfigShellToolType;
|
||||
|
||||
/// The `instructions` field in the payload sent to a model should always start
|
||||
/// with this content.
|
||||
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
|
||||
const GPT_5_CODEX_INSTRUCTIONS: &str = include_str!("../gpt_5_codex_prompt.md");
|
||||
|
||||
/// A model family is a group of models that share certain characteristics.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct ModelFamily {
|
||||
/// The full model slug used to derive this model family, e.g.
|
||||
/// "gpt-4.1-2025-04-14".
|
||||
pub slug: String,
|
||||
|
||||
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
|
||||
/// with [`crate::openai_model_info::get_model_info`].
|
||||
pub family: String,
|
||||
|
||||
/// True if the model needs additional instructions on how to use the
|
||||
/// "virtual" `apply_patch` CLI.
|
||||
pub needs_special_apply_patch_instructions: bool,
|
||||
|
||||
// Whether the `reasoning` field can be set when making a request to this
|
||||
// model family. Note it has `effort` and `summary` subfields (though
|
||||
// `summary` is optional).
|
||||
pub supports_reasoning_summaries: bool,
|
||||
|
||||
// Define if we need a special handling of reasoning summary
|
||||
pub reasoning_summary_format: ReasoningSummaryFormat,
|
||||
|
||||
/// Whether this model supports parallel tool calls when using the
|
||||
/// Responses API.
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
|
||||
/// Present if the model performs better when `apply_patch` is provided as
|
||||
/// a tool call instead of just a bash command
|
||||
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
|
||||
|
||||
// Instructions to use for querying the model
|
||||
pub base_instructions: String,
|
||||
|
||||
/// Names of beta tools that should be exposed to this model family.
|
||||
pub experimental_supported_tools: Vec<String>,
|
||||
|
||||
/// Percentage of the context window considered usable for inputs, after
|
||||
/// reserving headroom for system prompts, tool overhead, and model output.
|
||||
/// This is applied when computing the effective context window seen by
|
||||
/// consumers.
|
||||
pub effective_context_window_percent: i64,
|
||||
|
||||
/// If the model family supports setting the verbosity level when using Responses API.
|
||||
pub support_verbosity: bool,
|
||||
|
||||
/// Preferred shell tool type for this model family when features do not override it.
|
||||
pub shell_type: ConfigShellToolType,
|
||||
}
|
||||
|
||||
macro_rules! model_family {
|
||||
(
|
||||
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?
|
||||
) => {{
|
||||
// defaults
|
||||
#[allow(unused_mut)]
|
||||
let mut mf = ModelFamily {
|
||||
slug: $slug.to_string(),
|
||||
family: $family.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
effective_context_window_percent: 95,
|
||||
support_verbosity: false,
|
||||
shell_type: ConfigShellToolType::Default,
|
||||
};
|
||||
// apply overrides
|
||||
$(
|
||||
mf.$key = $value;
|
||||
)*
|
||||
Some(mf)
|
||||
}};
|
||||
}
|
||||
|
||||
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
|
||||
/// does not match any known model family.
|
||||
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
|
||||
if slug.starts_with("o3") {
|
||||
model_family!(
|
||||
slug, "o3",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("o4-mini") {
|
||||
model_family!(
|
||||
slug, "o4-mini",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("codex-mini-latest") {
|
||||
model_family!(
|
||||
slug, "codex-mini-latest",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
shell_type: ConfigShellToolType::Local,
|
||||
)
|
||||
} else if slug.starts_with("gpt-4.1") {
|
||||
model_family!(
|
||||
slug, "gpt-4.1",
|
||||
needs_special_apply_patch_instructions: true,
|
||||
)
|
||||
} else if slug.starts_with("gpt-oss") || slug.starts_with("openai/gpt-oss") {
|
||||
model_family!(slug, "gpt-oss", apply_patch_tool_type: Some(ApplyPatchToolType::Function))
|
||||
} else if slug.starts_with("gpt-4o") {
|
||||
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("gpt-3.5") {
|
||||
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
|
||||
} else if slug.starts_with("porcupine") {
|
||||
model_family!(slug, "porcupine", shell_type: ConfigShellToolType::UnifiedExec)
|
||||
} else if slug.starts_with("test-gpt-5-codex") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: vec![
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"read_file".to_string(),
|
||||
"test_sync_tool".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
support_verbosity: true,
|
||||
)
|
||||
|
||||
// Internal models.
|
||||
} else if slug.starts_with("codex-exp-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
experimental_supported_tools: vec![
|
||||
"grep_files".to_string(),
|
||||
"list_dir".to_string(),
|
||||
"read_file".to_string(),
|
||||
],
|
||||
supports_parallel_tool_calls: true,
|
||||
support_verbosity: true,
|
||||
)
|
||||
|
||||
// Production models.
|
||||
} else if slug.starts_with("gpt-5-codex") || slug.starts_with("codex-") {
|
||||
model_family!(
|
||||
slug, slug,
|
||||
supports_reasoning_summaries: true,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
|
||||
base_instructions: GPT_5_CODEX_INSTRUCTIONS.to_string(),
|
||||
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
|
||||
support_verbosity: false,
|
||||
)
|
||||
} else if slug.starts_with("gpt-5") {
|
||||
model_family!(
|
||||
slug, "gpt-5",
|
||||
supports_reasoning_summaries: true,
|
||||
needs_special_apply_patch_instructions: true,
|
||||
support_verbosity: true,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn derive_default_model_family(model: &str) -> ModelFamily {
|
||||
ModelFamily {
|
||||
slug: model.to_string(),
|
||||
family: model.to_string(),
|
||||
needs_special_apply_patch_instructions: false,
|
||||
supports_reasoning_summaries: false,
|
||||
reasoning_summary_format: ReasoningSummaryFormat::None,
|
||||
supports_parallel_tool_calls: false,
|
||||
apply_patch_tool_type: None,
|
||||
base_instructions: BASE_INSTRUCTIONS.to_string(),
|
||||
experimental_supported_tools: Vec::new(),
|
||||
effective_context_window_percent: 95,
|
||||
support_verbosity: false,
|
||||
shell_type: ConfigShellToolType::Default,
|
||||
}
|
||||
}
|
||||
532
llmx-rs/core/src/model_provider_info.rs
Normal file
532
llmx-rs/core/src/model_provider_info.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
//! Registry of model providers supported by Codex.
|
||||
//!
|
||||
//! Providers can be defined in two places:
|
||||
//! 1. Built-in defaults compiled into the binary so Codex works out-of-the-box.
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use crate::CodexAuth;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::default_client::CodexRequestBuilder;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
/// Hard cap for user-configured `stream_max_retries`.
|
||||
const MAX_STREAM_MAX_RETRIES: u64 = 100;
|
||||
/// Hard cap for user-configured `request_max_retries`.
|
||||
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
/// itself (and a handful of others) additionally expose the more modern
|
||||
/// *Responses* API. The two protocols use different request/response shapes
|
||||
/// and *cannot* be auto-detected at runtime, therefore each provider entry
|
||||
/// must declare which one it expects.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
}
|
||||
|
||||
/// Serializable representation of a provider definition.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
|
||||
pub struct ModelProviderInfo {
|
||||
/// Friendly display name.
|
||||
pub name: String,
|
||||
/// Base URL for the provider's OpenAI-compatible API.
|
||||
pub base_url: Option<String>,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: Option<String>,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub env_key_instructions: Option<String>,
|
||||
|
||||
/// Value to use with `Authorization: Bearer <token>` header. Use of this
|
||||
/// config is discouraged in favor of `env_key` for security reasons, but
|
||||
/// this may be necessary when using this programmatically.
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
|
||||
/// Optional query parameters to append to the base URL.
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
|
||||
/// Additional HTTP headers to include in requests to this provider where
|
||||
/// the (key, value) pairs are the header name and value.
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Optional HTTP headers to include in requests to this provider where the
|
||||
/// (key, value) pairs are the header name and _environment variable_ whose
|
||||
/// value should be used. If the environment variable is not set, or the
|
||||
/// value is empty, the header will not be included in the request.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Maximum number of times to retry a failed HTTP request to this provider.
|
||||
pub request_max_retries: Option<u64>,
|
||||
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<u64>,
|
||||
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Construct a `POST` RequestBuilder for the given URL using the provided
|
||||
/// [`CodexHttpClient`] applying:
|
||||
/// • provider-specific headers (static + env based)
|
||||
/// • Bearer auth header when an API key is available.
|
||||
/// • Auth token for OAuth.
|
||||
///
|
||||
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
|
||||
/// one produced by [`ModelProviderInfo::api_key`].
|
||||
pub async fn create_request_builder<'a>(
|
||||
&'a self,
|
||||
client: &'a CodexHttpClient,
|
||||
auth: &Option<CodexAuth>,
|
||||
) -> crate::error::Result<CodexRequestBuilder> {
|
||||
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
|
||||
Some(CodexAuth::from_api_key(secret_key))
|
||||
} else {
|
||||
match self.api_key() {
|
||||
Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
|
||||
Ok(None) => auth.clone(),
|
||||
Err(err) => {
|
||||
if auth.is_some() {
|
||||
auth.clone()
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.get_full_url(&effective_auth);
|
||||
|
||||
let mut builder = client.post(url);
|
||||
|
||||
if let Some(auth) = effective_auth.as_ref() {
|
||||
builder = builder.bearer_auth(auth.get_token().await?);
|
||||
}
|
||||
|
||||
Ok(self.apply_http_headers(builder))
|
||||
}
|
||||
|
||||
fn get_query_string(&self) -> String {
|
||||
self.query_params
|
||||
.as_ref()
|
||||
.map_or_else(String::new, |params| {
|
||||
let full_params = params
|
||||
.iter()
|
||||
.map(|(k, v)| format!("{k}={v}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
format!("?{full_params}")
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
|
||||
let default_base_url = if matches!(
|
||||
auth,
|
||||
Some(CodexAuth {
|
||||
mode: AuthMode::ChatGPT,
|
||||
..
|
||||
})
|
||||
) {
|
||||
"https://chatgpt.com/backend-api/codex"
|
||||
} else {
|
||||
"https://api.openai.com/v1"
|
||||
};
|
||||
let query_string = self.get_query_string();
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or(default_base_url.to_string());
|
||||
|
||||
match self.wire_api {
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire_api != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
|
||||
if self.name.eq_ignore_ascii_case("azure") {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.base_url
|
||||
.as_ref()
|
||||
.map(|base| matches_azure_responses_base_url(base))
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||
/// onto an existing [`CodexRequestBuilder`] and return the updated
|
||||
/// builder.
|
||||
fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder {
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
builder = builder.header(k, v);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(env_headers) = &self.env_http_headers {
|
||||
for (header, env_var) in env_headers {
|
||||
if let Ok(val) = std::env::var(env_var)
|
||||
&& !val.trim().is_empty()
|
||||
{
|
||||
builder = builder.header(header, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
/// If `env_key` is Some, returns the API key for this provider if present
|
||||
/// (and non-empty) in the environment. If `env_key` is required but
|
||||
/// cannot be found, returns an error.
|
||||
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => {
|
||||
let env_value = std::env::var(env_key);
|
||||
env_value
|
||||
.and_then(|v| {
|
||||
if v.trim().is_empty() {
|
||||
Err(VarError::NotPresent)
|
||||
} else {
|
||||
Ok(Some(v))
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
crate::error::CodexErr::EnvVar(EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})
|
||||
})
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Effective maximum number of request retries for this provider.
|
||||
pub fn request_max_retries(&self) -> u64 {
|
||||
self.request_max_retries
|
||||
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
|
||||
.min(MAX_REQUEST_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective maximum number of stream reconnection attempts for this provider.
|
||||
pub fn stream_max_retries(&self) -> u64 {
|
||||
self.stream_max_retries
|
||||
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
|
||||
.min(MAX_STREAM_MAX_RETRIES)
|
||||
}
|
||||
|
||||
/// Effective idle timeout for streaming responses.
|
||||
pub fn stream_idle_timeout(&self) -> Duration {
|
||||
self.stream_idle_timeout_ms
|
||||
.map(Duration::from_millis)
|
||||
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||
|
||||
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
|
||||
|
||||
/// Built-in default provider list.
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
|
||||
// We do not want to be in the business of adjucating which third-party
|
||||
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
||||
// open source ("oss") providers by default. Users are encouraged to add to
|
||||
// `model_providers` in config.toml to add their own providers.
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
// Allow users to override the default OpenAI endpoint by
|
||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||
// Codex at a proxy, mock server, or Azure-style deployment
|
||||
// without requiring a full TOML override for the built-in
|
||||
// OpenAI provider.
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: Some(
|
||||
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
env_http_headers: Some(
|
||||
[
|
||||
(
|
||||
"OpenAI-Organization".to_string(),
|
||||
"OPENAI_ORGANIZATION".to_string(),
|
||||
),
|
||||
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
|
||||
]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
),
|
||||
// Use global defaults for retry/timeout unless overridden in config.toml.
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: true,
|
||||
},
|
||||
),
|
||||
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
|
||||
]
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||
// These CODEX_OSS_ environment variables are experimental: we may
|
||||
// switch to reading values from config.toml instead.
|
||||
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
{
|
||||
Some(url) => url,
|
||||
None => format!(
|
||||
"http://localhost:{port}/v1",
|
||||
port = std::env::var("CODEX_OSS_PORT")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(DEFAULT_OLLAMA_PORT)
|
||||
),
|
||||
};
|
||||
|
||||
create_oss_provider_with_base_url(&codex_oss_base_url)
|
||||
}
|
||||
|
||||
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "gpt-oss".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
let base = base_url.to_ascii_lowercase();
|
||||
const AZURE_MARKERS: [&str; 5] = [
|
||||
"openai.azure.",
|
||||
"cognitiveservices.azure.",
|
||||
"aoai.azure.",
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
];
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: Some("http://localhost:11434/v1".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_azure_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Azure"
|
||||
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||
env_key = "AZURE_OPENAI_API_KEY"
|
||||
query_params = { api-version = "2025-04-01-preview" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||
env_key: Some("AZURE_OPENAI_API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_example_model_provider_toml() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Example"
|
||||
base_url = "https://example.com"
|
||||
env_key = "API_KEY"
|
||||
http_headers = { "X-Example-Header" = "example-value" }
|
||||
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Example".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: Some("API_KEY".into()),
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Header".to_string() => "example-value".to_string(),
|
||||
}),
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
|
||||
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
|
||||
assert_eq!(expected_provider, provider);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detects_azure_responses_base_urls() {
|
||||
fn provider_for(base_url: &str) -> ModelProviderInfo {
|
||||
ModelProviderInfo {
|
||||
name: "test".into(),
|
||||
base_url: Some(base_url.into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
}
|
||||
}
|
||||
|
||||
let positive_cases = [
|
||||
"https://foo.openai.azure.com/openai",
|
||||
"https://foo.openai.azure.us/openai/deployments/bar",
|
||||
"https://foo.cognitiveservices.azure.cn/openai",
|
||||
"https://foo.aoai.azure.com/openai",
|
||||
"https://foo.openai.azure-api.net/openai",
|
||||
"https://foo.z01.azurefd.net/",
|
||||
];
|
||||
for base_url in positive_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} to be detected as Azure"
|
||||
);
|
||||
}
|
||||
|
||||
let named_provider = ModelProviderInfo {
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://example.com".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Responses,
|
||||
query_params: None,
|
||||
http_headers: None,
|
||||
env_http_headers: None,
|
||||
request_max_retries: None,
|
||||
stream_max_retries: None,
|
||||
stream_idle_timeout_ms: None,
|
||||
requires_openai_auth: false,
|
||||
};
|
||||
assert!(named_provider.is_azure_responses_endpoint());
|
||||
|
||||
let negative_cases = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://example.com/openai",
|
||||
"https://myproxy.azurewebsites.net/openai",
|
||||
];
|
||||
for base_url in negative_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
!provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} not to be detected as Azure"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
87
llmx-rs/core/src/openai_model_info.rs
Normal file
87
llmx-rs/core/src/openai_model_info.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use crate::model_family::ModelFamily;
|
||||
|
||||
// Shared constants for commonly used window/token sizes.
|
||||
pub(crate) const CONTEXT_WINDOW_272K: i64 = 272_000;
|
||||
pub(crate) const MAX_OUTPUT_TOKENS_128K: i64 = 128_000;
|
||||
|
||||
/// Metadata about a model, particularly OpenAI models.
|
||||
/// We may want to consider including details like the pricing for
|
||||
/// input tokens, output tokens, etc., though users will need to be able to
|
||||
/// override this in config.toml, as this information can get out of date.
|
||||
/// Though this would help present more accurate pricing information in the UI.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ModelInfo {
|
||||
/// Size of the context window in tokens. This is the maximum size of the input context.
|
||||
pub(crate) context_window: i64,
|
||||
|
||||
/// Maximum number of output tokens that can be generated for the model.
|
||||
pub(crate) max_output_tokens: i64,
|
||||
|
||||
/// Token threshold where we should automatically compact conversation history. This considers
|
||||
/// input tokens + output tokens of this turn.
|
||||
pub(crate) auto_compact_token_limit: Option<i64>,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
const fn new(context_window: i64, max_output_tokens: i64) -> Self {
|
||||
Self {
|
||||
context_window,
|
||||
max_output_tokens,
|
||||
auto_compact_token_limit: Some(Self::default_auto_compact_limit(context_window)),
|
||||
}
|
||||
}
|
||||
|
||||
const fn default_auto_compact_limit(context_window: i64) -> i64 {
|
||||
(context_window * 9) / 10
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
|
||||
let slug = model_family.slug.as_str();
|
||||
match slug {
|
||||
// OSS models have a 128k shared token pool.
|
||||
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
|
||||
// https://openai.com/index/gpt-oss-model-card/
|
||||
"gpt-oss-20b" => Some(ModelInfo::new(96_000, 32_000)),
|
||||
"gpt-oss-120b" => Some(ModelInfo::new(96_000, 32_000)),
|
||||
// https://platform.openai.com/docs/models/o3
|
||||
"o3" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// https://platform.openai.com/docs/models/o4-mini
|
||||
"o4-mini" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// https://platform.openai.com/docs/models/codex-mini-latest
|
||||
"codex-mini-latest" => Some(ModelInfo::new(200_000, 100_000)),
|
||||
|
||||
// As of Jun 25, 2025, gpt-4.1 defaults to gpt-4.1-2025-04-14.
|
||||
// https://platform.openai.com/docs/models/gpt-4.1
|
||||
"gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo::new(1_047_576, 32_768)),
|
||||
|
||||
// As of Jun 25, 2025, gpt-4o defaults to gpt-4o-2024-08-06.
|
||||
// https://platform.openai.com/docs/models/gpt-4o
|
||||
"gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo::new(128_000, 16_384)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-05-13
|
||||
"gpt-4o-2024-05-13" => Some(ModelInfo::new(128_000, 4_096)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-11-20
|
||||
"gpt-4o-2024-11-20" => Some(ModelInfo::new(128_000, 16_384)),
|
||||
|
||||
// https://platform.openai.com/docs/models/gpt-3.5-turbo
|
||||
"gpt-3.5-turbo" => Some(ModelInfo::new(16_385, 4_096)),
|
||||
|
||||
_ if slug.starts_with("gpt-5-codex") => {
|
||||
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
|
||||
}
|
||||
|
||||
_ if slug.starts_with("gpt-5") => {
|
||||
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
|
||||
}
|
||||
|
||||
_ if slug.starts_with("codex-") => {
|
||||
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
61
llmx-rs/core/src/otel_init.rs
Normal file
61
llmx-rs/core/src/otel_init.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::types::OtelExporterKind as Kind;
|
||||
use crate::config::types::OtelHttpProtocol as Protocol;
|
||||
use crate::default_client::originator;
|
||||
use codex_otel::config::OtelExporter;
|
||||
use codex_otel::config::OtelHttpProtocol;
|
||||
use codex_otel::config::OtelSettings;
|
||||
use codex_otel::otel_provider::OtelProvider;
|
||||
use std::error::Error;
|
||||
|
||||
/// Build an OpenTelemetry provider from the app Config.
|
||||
///
|
||||
/// Returns `None` when OTEL export is disabled.
|
||||
pub fn build_provider(
|
||||
config: &Config,
|
||||
service_version: &str,
|
||||
) -> Result<Option<OtelProvider>, Box<dyn Error>> {
|
||||
let exporter = match &config.otel.exporter {
|
||||
Kind::None => OtelExporter::None,
|
||||
Kind::OtlpHttp {
|
||||
endpoint,
|
||||
headers,
|
||||
protocol,
|
||||
} => {
|
||||
let protocol = match protocol {
|
||||
Protocol::Json => OtelHttpProtocol::Json,
|
||||
Protocol::Binary => OtelHttpProtocol::Binary,
|
||||
};
|
||||
|
||||
OtelExporter::OtlpHttp {
|
||||
endpoint: endpoint.clone(),
|
||||
headers: headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
protocol,
|
||||
}
|
||||
}
|
||||
Kind::OtlpGrpc { endpoint, headers } => OtelExporter::OtlpGrpc {
|
||||
endpoint: endpoint.clone(),
|
||||
headers: headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
},
|
||||
};
|
||||
|
||||
OtelProvider::from(&OtelSettings {
|
||||
service_name: originator().value.to_owned(),
|
||||
service_version: service_version.to_string(),
|
||||
codex_home: config.codex_home.clone(),
|
||||
environment: config.otel.environment.to_string(),
|
||||
exporter,
|
||||
})
|
||||
}
|
||||
|
||||
/// Filter predicate for exporting only Codex-owned events via OTEL.
|
||||
/// Keeps events that originated from codex_otel module
|
||||
pub fn codex_export_filter(meta: &tracing::Metadata<'_>) -> bool {
|
||||
meta.target().starts_with("codex_otel")
|
||||
}
|
||||
1614
llmx-rs/core/src/parse_command.rs
Normal file
1614
llmx-rs/core/src/parse_command.rs
Normal file
File diff suppressed because it is too large
Load Diff
450
llmx-rs/core/src/project_doc.rs
Normal file
450
llmx-rs/core/src/project_doc.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
//! Project-level documentation discovery.
|
||||
//!
|
||||
//! Project-level documentation is primarily stored in files named `AGENTS.md`.
|
||||
//! Additional fallback filenames can be configured via `project_doc_fallback_filenames`.
|
||||
//! We include the concatenation of all files found along the path from the
|
||||
//! repository root to the current working directory as follows:
|
||||
//!
|
||||
//! 1. Determine the Git repository root by walking upwards from the current
|
||||
//! working directory until a `.git` directory or file is found. If no Git
|
||||
//! root is found, only the current working directory is considered.
|
||||
//! 2. Collect every `AGENTS.md` found from the repository root down to the
|
||||
//! current working directory (inclusive) and concatenate their contents in
|
||||
//! that order.
|
||||
//! 3. We do **not** walk past the Git root.
|
||||
|
||||
use crate::config::Config;
|
||||
use dunce::canonicalize as normalize_path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tracing::error;
|
||||
|
||||
/// Default filename scanned for project-level docs.
|
||||
pub const DEFAULT_PROJECT_DOC_FILENAME: &str = "AGENTS.md";
|
||||
/// Preferred local override for project-level docs.
|
||||
pub const LOCAL_PROJECT_DOC_FILENAME: &str = "AGENTS.override.md";
|
||||
|
||||
/// When both `Config::instructions` and the project doc are present, they will
|
||||
/// be concatenated with the following separator.
|
||||
const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
|
||||
|
||||
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
|
||||
/// string of instructions.
|
||||
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
|
||||
match read_project_docs(config).await {
|
||||
Ok(Some(project_doc)) => match &config.user_instructions {
|
||||
Some(original_instructions) => Some(format!(
|
||||
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
|
||||
)),
|
||||
None => Some(project_doc),
|
||||
},
|
||||
Ok(None) => config.user_instructions.clone(),
|
||||
Err(e) => {
|
||||
error!("error trying to find project doc: {e:#}");
|
||||
config.user_instructions.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to locate and load the project documentation.
|
||||
///
|
||||
/// On success returns `Ok(Some(contents))` where `contents` is the
|
||||
/// concatenation of all discovered docs. If no documentation file is found the
|
||||
/// function returns `Ok(None)`. Unexpected I/O failures bubble up as `Err` so
|
||||
/// callers can decide how to handle them.
|
||||
pub async fn read_project_docs(config: &Config) -> std::io::Result<Option<String>> {
|
||||
let max_total = config.project_doc_max_bytes;
|
||||
|
||||
if max_total == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let paths = discover_project_doc_paths(config)?;
|
||||
if paths.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut remaining: u64 = max_total as u64;
|
||||
let mut parts: Vec<String> = Vec::new();
|
||||
|
||||
for p in paths {
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
let file = match tokio::fs::File::open(&p).await {
|
||||
Ok(f) => f,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let size = file.metadata().await?.len();
|
||||
let mut reader = tokio::io::BufReader::new(file).take(remaining);
|
||||
let mut data: Vec<u8> = Vec::new();
|
||||
reader.read_to_end(&mut data).await?;
|
||||
|
||||
if size > remaining {
|
||||
tracing::warn!(
|
||||
"Project doc `{}` exceeds remaining budget ({} bytes) - truncating.",
|
||||
p.display(),
|
||||
remaining,
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8_lossy(&data).to_string();
|
||||
if !text.trim().is_empty() {
|
||||
parts.push(text);
|
||||
remaining = remaining.saturating_sub(data.len() as u64);
|
||||
}
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(parts.join("\n\n")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Discover the list of AGENTS.md files using the same search rules as
|
||||
/// `read_project_docs`, but return the file paths instead of concatenated
|
||||
/// contents. The list is ordered from repository root to the current working
|
||||
/// directory (inclusive). Symlinks are allowed. When `project_doc_max_bytes`
|
||||
/// is zero, returns an empty list.
|
||||
pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBuf>> {
|
||||
let mut dir = config.cwd.clone();
|
||||
if let Ok(canon) = normalize_path(&dir) {
|
||||
dir = canon;
|
||||
}
|
||||
|
||||
// Build chain from cwd upwards and detect git root.
|
||||
let mut chain: Vec<PathBuf> = vec![dir.clone()];
|
||||
let mut git_root: Option<PathBuf> = None;
|
||||
let mut cursor = dir;
|
||||
while let Some(parent) = cursor.parent() {
|
||||
let git_marker = cursor.join(".git");
|
||||
let git_exists = match std::fs::metadata(&git_marker) {
|
||||
Ok(_) => true,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => false,
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
if git_exists {
|
||||
git_root = Some(cursor.clone());
|
||||
break;
|
||||
}
|
||||
|
||||
chain.push(parent.to_path_buf());
|
||||
cursor = parent.to_path_buf();
|
||||
}
|
||||
|
||||
let search_dirs: Vec<PathBuf> = if let Some(root) = git_root {
|
||||
let mut dirs: Vec<PathBuf> = Vec::new();
|
||||
let mut saw_root = false;
|
||||
for p in chain.iter().rev() {
|
||||
if !saw_root {
|
||||
if p == &root {
|
||||
saw_root = true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
dirs.push(p.clone());
|
||||
}
|
||||
dirs
|
||||
} else {
|
||||
vec![config.cwd.clone()]
|
||||
};
|
||||
|
||||
let mut found: Vec<PathBuf> = Vec::new();
|
||||
let candidate_filenames = candidate_filenames(config);
|
||||
for d in search_dirs {
|
||||
for name in &candidate_filenames {
|
||||
let candidate = d.join(name);
|
||||
match std::fs::symlink_metadata(&candidate) {
|
||||
Ok(md) => {
|
||||
let ft = md.file_type();
|
||||
// Allow regular files and symlinks; opening will later fail for dangling links.
|
||||
if ft.is_file() || ft.is_symlink() {
|
||||
found.push(candidate);
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue,
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(found)
|
||||
}
|
||||
|
||||
fn candidate_filenames<'a>(config: &'a Config) -> Vec<&'a str> {
|
||||
let mut names: Vec<&'a str> =
|
||||
Vec::with_capacity(2 + config.project_doc_fallback_filenames.len());
|
||||
names.push(LOCAL_PROJECT_DOC_FILENAME);
|
||||
names.push(DEFAULT_PROJECT_DOC_FILENAME);
|
||||
for candidate in &config.project_doc_fallback_filenames {
|
||||
let candidate = candidate.as_str();
|
||||
if candidate.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if !names.contains(&candidate) {
|
||||
names.push(candidate);
|
||||
}
|
||||
}
|
||||
names
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ConfigOverrides;
|
||||
use crate::config::ConfigToml;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Helper that returns a `Config` pointing at `root` and using `limit` as
|
||||
/// the maximum number of bytes to embed from AGENTS.md. The caller can
|
||||
/// optionally specify a custom `instructions` string – when `None` the
|
||||
/// value is cleared to mimic a scenario where no system instructions have
|
||||
/// been configured.
|
||||
fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> Config {
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = Config::load_from_base_config_with_overrides(
|
||||
ConfigToml::default(),
|
||||
ConfigOverrides::default(),
|
||||
codex_home.path().to_path_buf(),
|
||||
)
|
||||
.expect("defaults for test should always succeed");
|
||||
|
||||
config.cwd = root.path().to_path_buf();
|
||||
config.project_doc_max_bytes = limit;
|
||||
|
||||
config.user_instructions = instructions.map(ToOwned::to_owned);
|
||||
config
|
||||
}
|
||||
|
||||
fn make_config_with_fallback(
|
||||
root: &TempDir,
|
||||
limit: usize,
|
||||
instructions: Option<&str>,
|
||||
fallbacks: &[&str],
|
||||
) -> Config {
|
||||
let mut config = make_config(root, limit, instructions);
|
||||
config.project_doc_fallback_filenames = fallbacks
|
||||
.iter()
|
||||
.map(std::string::ToString::to_string)
|
||||
.collect();
|
||||
config
|
||||
}
|
||||
|
||||
/// AGENTS.md missing – should yield `None`.
|
||||
#[tokio::test]
|
||||
async fn no_doc_file_returns_none() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, 4096, None)).await;
|
||||
assert!(
|
||||
res.is_none(),
|
||||
"Expected None when AGENTS.md is absent and no system instructions provided"
|
||||
);
|
||||
assert!(res.is_none(), "Expected None when AGENTS.md is absent");
|
||||
}
|
||||
|
||||
/// Small file within the byte-limit is returned unmodified.
|
||||
#[tokio::test]
|
||||
async fn doc_smaller_than_limit_is_returned() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join("AGENTS.md"), "hello world").unwrap();
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, 4096, None))
|
||||
.await
|
||||
.expect("doc expected");
|
||||
|
||||
assert_eq!(
|
||||
res, "hello world",
|
||||
"The document should be returned verbatim when it is smaller than the limit and there are no existing instructions"
|
||||
);
|
||||
}
|
||||
|
||||
/// Oversize file is truncated to `project_doc_max_bytes`.
|
||||
#[tokio::test]
|
||||
async fn doc_larger_than_limit_is_truncated() {
|
||||
const LIMIT: usize = 1024;
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
let huge = "A".repeat(LIMIT * 2); // 2 KiB
|
||||
fs::write(tmp.path().join("AGENTS.md"), &huge).unwrap();
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, LIMIT, None))
|
||||
.await
|
||||
.expect("doc expected");
|
||||
|
||||
assert_eq!(res.len(), LIMIT, "doc should be truncated to LIMIT bytes");
|
||||
assert_eq!(res, huge[..LIMIT]);
|
||||
}
|
||||
|
||||
/// When `cwd` is nested inside a repo, the search should locate AGENTS.md
|
||||
/// placed at the repository root (identified by `.git`).
|
||||
#[tokio::test]
|
||||
async fn finds_doc_in_repo_root() {
|
||||
let repo = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
// Simulate a git repository. Note .git can be a file or a directory.
|
||||
std::fs::write(
|
||||
repo.path().join(".git"),
|
||||
"gitdir: /path/to/actual/git/dir\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Put the doc at the repo root.
|
||||
fs::write(repo.path().join("AGENTS.md"), "root level doc").unwrap();
|
||||
|
||||
// Now create a nested working directory: repo/workspace/crate_a
|
||||
let nested = repo.path().join("workspace/crate_a");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
|
||||
// Build config pointing at the nested dir.
|
||||
let mut cfg = make_config(&repo, 4096, None);
|
||||
cfg.cwd = nested;
|
||||
|
||||
let res = get_user_instructions(&cfg).await.expect("doc expected");
|
||||
assert_eq!(res, "root level doc");
|
||||
}
|
||||
|
||||
/// Explicitly setting the byte-limit to zero disables project docs.
|
||||
#[tokio::test]
|
||||
async fn zero_byte_limit_disables_docs() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join("AGENTS.md"), "something").unwrap();
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, 0, None)).await;
|
||||
assert!(
|
||||
res.is_none(),
|
||||
"With limit 0 the function should return None"
|
||||
);
|
||||
}
|
||||
|
||||
/// When both system instructions *and* a project doc are present the two
|
||||
/// should be concatenated with the separator.
|
||||
#[tokio::test]
|
||||
async fn merges_existing_instructions_with_project_doc() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join("AGENTS.md"), "proj doc").unwrap();
|
||||
|
||||
const INSTRUCTIONS: &str = "base instructions";
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, 4096, Some(INSTRUCTIONS)))
|
||||
.await
|
||||
.expect("should produce a combined instruction string");
|
||||
|
||||
let expected = format!("{INSTRUCTIONS}{PROJECT_DOC_SEPARATOR}{}", "proj doc");
|
||||
|
||||
assert_eq!(res, expected);
|
||||
}
|
||||
|
||||
/// If there are existing system instructions but the project doc is
|
||||
/// missing we expect the original instructions to be returned unchanged.
|
||||
#[tokio::test]
|
||||
async fn keeps_existing_instructions_when_doc_missing() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
const INSTRUCTIONS: &str = "some instructions";
|
||||
|
||||
let res = get_user_instructions(&make_config(&tmp, 4096, Some(INSTRUCTIONS))).await;
|
||||
|
||||
assert_eq!(res, Some(INSTRUCTIONS.to_string()));
|
||||
}
|
||||
|
||||
/// When both the repository root and the working directory contain
|
||||
/// AGENTS.md files, their contents are concatenated from root to cwd.
|
||||
#[tokio::test]
|
||||
async fn concatenates_root_and_cwd_docs() {
|
||||
let repo = tempfile::tempdir().expect("tempdir");
|
||||
|
||||
// Simulate a git repository.
|
||||
std::fs::write(
|
||||
repo.path().join(".git"),
|
||||
"gitdir: /path/to/actual/git/dir\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Repo root doc.
|
||||
fs::write(repo.path().join("AGENTS.md"), "root doc").unwrap();
|
||||
|
||||
// Nested working directory with its own doc.
|
||||
let nested = repo.path().join("workspace/crate_a");
|
||||
std::fs::create_dir_all(&nested).unwrap();
|
||||
fs::write(nested.join("AGENTS.md"), "crate doc").unwrap();
|
||||
|
||||
let mut cfg = make_config(&repo, 4096, None);
|
||||
cfg.cwd = nested;
|
||||
|
||||
let res = get_user_instructions(&cfg).await.expect("doc expected");
|
||||
assert_eq!(res, "root doc\n\ncrate doc");
|
||||
}
|
||||
|
||||
/// AGENTS.override.md is preferred over AGENTS.md when both are present.
|
||||
#[tokio::test]
|
||||
async fn agents_local_md_preferred() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap();
|
||||
fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap();
|
||||
|
||||
let cfg = make_config(&tmp, 4096, None);
|
||||
|
||||
let res = get_user_instructions(&cfg)
|
||||
.await
|
||||
.expect("local doc expected");
|
||||
|
||||
assert_eq!(res, "local");
|
||||
|
||||
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
|
||||
assert_eq!(discovery.len(), 1);
|
||||
assert_eq!(
|
||||
discovery[0].file_name().unwrap().to_string_lossy(),
|
||||
LOCAL_PROJECT_DOC_FILENAME
|
||||
);
|
||||
}
|
||||
|
||||
/// When AGENTS.md is absent but a configured fallback exists, the fallback is used.
|
||||
#[tokio::test]
|
||||
async fn uses_configured_fallback_when_agents_missing() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join("EXAMPLE.md"), "example instructions").unwrap();
|
||||
|
||||
let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md"]);
|
||||
|
||||
let res = get_user_instructions(&cfg)
|
||||
.await
|
||||
.expect("fallback doc expected");
|
||||
|
||||
assert_eq!(res, "example instructions");
|
||||
}
|
||||
|
||||
/// AGENTS.md remains preferred when both AGENTS.md and fallbacks are present.
|
||||
#[tokio::test]
|
||||
async fn agents_md_preferred_over_fallbacks() {
|
||||
let tmp = tempfile::tempdir().expect("tempdir");
|
||||
fs::write(tmp.path().join("AGENTS.md"), "primary").unwrap();
|
||||
fs::write(tmp.path().join("EXAMPLE.md"), "secondary").unwrap();
|
||||
|
||||
let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md", ".example.md"]);
|
||||
|
||||
let res = get_user_instructions(&cfg)
|
||||
.await
|
||||
.expect("AGENTS.md should win");
|
||||
|
||||
assert_eq!(res, "primary");
|
||||
|
||||
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
|
||||
assert_eq!(discovery.len(), 1);
|
||||
assert!(
|
||||
discovery[0]
|
||||
.file_name()
|
||||
.unwrap()
|
||||
.to_string_lossy()
|
||||
.eq(DEFAULT_PROJECT_DOC_FILENAME)
|
||||
);
|
||||
}
|
||||
}
|
||||
104
llmx-rs/core/src/response_processing.rs
Normal file
104
llmx-rs/core/src/response_processing.rs
Normal file
@@ -0,0 +1,104 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use tracing::warn;
|
||||
|
||||
/// Process streamed `ResponseItem`s from the model into the pair of:
|
||||
/// - items we should record in conversation history; and
|
||||
/// - `ResponseInputItem`s to send back to the model on the next turn.
|
||||
pub(crate) async fn process_items(
|
||||
processed_items: Vec<crate::codex::ProcessedResponseItem>,
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
) -> (Vec<ResponseInputItem>, Vec<ResponseItem>) {
|
||||
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
|
||||
let mut responses = Vec::<ResponseInputItem>::new();
|
||||
for processed_response_item in processed_items {
|
||||
let crate::codex::ProcessedResponseItem { item, response } = processed_response_item;
|
||||
match (&item, &response) {
|
||||
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
|
||||
// If the model returned a message, we need to record it.
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
}
|
||||
(
|
||||
ResponseItem::LocalShellCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::CustomToolCall { .. },
|
||||
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output: output.clone(),
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::FunctionCall { .. },
|
||||
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(item);
|
||||
let output = match result {
|
||||
Ok(call_tool_result) => FunctionCallOutputPayload::from(call_tool_result),
|
||||
Err(err) => FunctionCallOutputPayload {
|
||||
content: err.clone(),
|
||||
success: Some(false),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
|
||||
call_id: call_id.clone(),
|
||||
output,
|
||||
});
|
||||
}
|
||||
(
|
||||
ResponseItem::Reasoning {
|
||||
id,
|
||||
summary,
|
||||
content,
|
||||
encrypted_content,
|
||||
},
|
||||
None,
|
||||
) => {
|
||||
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
|
||||
id: id.clone(),
|
||||
summary: summary.clone(),
|
||||
content: content.clone(),
|
||||
encrypted_content: encrypted_content.clone(),
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
warn!("Unexpected response item: {item:?} with response: {response:?}");
|
||||
}
|
||||
};
|
||||
if let Some(response) = response {
|
||||
responses.push(response);
|
||||
}
|
||||
}
|
||||
|
||||
// Only attempt to take the lock if there is something to record.
|
||||
if !items_to_record_in_conversation_history.is_empty() {
|
||||
sess.record_conversation_items(turn_context, &items_to_record_in_conversation_history)
|
||||
.await;
|
||||
}
|
||||
(responses, items_to_record_in_conversation_history)
|
||||
}
|
||||
55
llmx-rs/core/src/review_format.rs
Normal file
55
llmx-rs/core/src/review_format.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use crate::protocol::ReviewFinding;
|
||||
|
||||
// Note: We keep this module UI-agnostic. It returns plain strings that
|
||||
// higher layers (e.g., TUI) may style as needed.
|
||||
|
||||
fn format_location(item: &ReviewFinding) -> String {
|
||||
let path = item.code_location.absolute_file_path.display();
|
||||
let start = item.code_location.line_range.start;
|
||||
let end = item.code_location.line_range.end;
|
||||
format!("{path}:{start}-{end}")
|
||||
}
|
||||
|
||||
/// Format a full review findings block as plain text lines.
|
||||
///
|
||||
/// - When `selection` is `Some`, each item line includes a checkbox marker:
|
||||
/// "[x]" for selected items and "[ ]" for unselected. Missing indices
|
||||
/// default to selected.
|
||||
/// - When `selection` is `None`, the marker is omitted and a simple bullet is
|
||||
/// rendered ("- Title — path:start-end").
|
||||
pub fn format_review_findings_block(
|
||||
findings: &[ReviewFinding],
|
||||
selection: Option<&[bool]>,
|
||||
) -> String {
|
||||
let mut lines: Vec<String> = Vec::new();
|
||||
lines.push(String::new());
|
||||
|
||||
// Header
|
||||
if findings.len() > 1 {
|
||||
lines.push("Full review comments:".to_string());
|
||||
} else {
|
||||
lines.push("Review comment:".to_string());
|
||||
}
|
||||
|
||||
for (idx, item) in findings.iter().enumerate() {
|
||||
lines.push(String::new());
|
||||
|
||||
let title = &item.title;
|
||||
let location = format_location(item);
|
||||
|
||||
if let Some(flags) = selection {
|
||||
// Default to selected if index is out of bounds.
|
||||
let checked = flags.get(idx).copied().unwrap_or(true);
|
||||
let marker = if checked { "[x]" } else { "[ ]" };
|
||||
lines.push(format!("- {marker} {title} — {location}"));
|
||||
} else {
|
||||
lines.push(format!("- {title} — {location}"));
|
||||
}
|
||||
|
||||
for body_line in item.body.lines() {
|
||||
lines.push(format!(" {body_line}"));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
591
llmx-rs/core/src/rollout/list.rs
Normal file
591
llmx-rs/core/src/rollout/list.rs
Normal file
@@ -0,0 +1,591 @@
|
||||
use std::cmp::Reverse;
|
||||
use std::io::{self};
|
||||
use std::num::NonZero;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use time::OffsetDateTime;
|
||||
use time::PrimitiveDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use crate::protocol::EventMsg;
|
||||
use codex_file_search as file_search;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
|
||||
/// Returned page of conversation summaries.
|
||||
#[derive(Debug, Default, PartialEq)]
|
||||
pub struct ConversationsPage {
|
||||
/// Conversation summaries ordered newest first.
|
||||
pub items: Vec<ConversationItem>,
|
||||
/// Opaque pagination token to resume after the last item, or `None` if end.
|
||||
pub next_cursor: Option<Cursor>,
|
||||
/// Total number of files touched while scanning this request.
|
||||
pub num_scanned_files: usize,
|
||||
/// True if a hard scan cap was hit; consider resuming with `next_cursor`.
|
||||
pub reached_scan_cap: bool,
|
||||
}
|
||||
|
||||
/// Summary information for a conversation rollout file.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ConversationItem {
|
||||
/// Absolute path to the rollout file.
|
||||
pub path: PathBuf,
|
||||
/// First up to `HEAD_RECORD_LIMIT` JSONL records parsed as JSON (includes meta line).
|
||||
pub head: Vec<serde_json::Value>,
|
||||
/// Last up to `TAIL_RECORD_LIMIT` JSONL response records parsed as JSON.
|
||||
pub tail: Vec<serde_json::Value>,
|
||||
/// RFC3339 timestamp string for when the session was created, if available.
|
||||
pub created_at: Option<String>,
|
||||
/// RFC3339 timestamp string for the most recent response in the tail, if available.
|
||||
pub updated_at: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct HeadTailSummary {
|
||||
head: Vec<serde_json::Value>,
|
||||
tail: Vec<serde_json::Value>,
|
||||
saw_session_meta: bool,
|
||||
saw_user_event: bool,
|
||||
source: Option<SessionSource>,
|
||||
model_provider: Option<String>,
|
||||
created_at: Option<String>,
|
||||
updated_at: Option<String>,
|
||||
}
|
||||
|
||||
/// Hard cap to bound worst‑case work per request.
|
||||
const MAX_SCAN_FILES: usize = 10000;
|
||||
const HEAD_RECORD_LIMIT: usize = 10;
|
||||
const TAIL_RECORD_LIMIT: usize = 10;
|
||||
|
||||
/// Pagination cursor identifying a file by timestamp and UUID.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Cursor {
|
||||
ts: OffsetDateTime,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl Cursor {
|
||||
fn new(ts: OffsetDateTime, id: Uuid) -> Self {
|
||||
Self { ts, id }
|
||||
}
|
||||
}
|
||||
|
||||
impl serde::Serialize for Cursor {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let ts_str = self
|
||||
.ts
|
||||
.format(&format_description!(
|
||||
"[year]-[month]-[day]T[hour]-[minute]-[second]"
|
||||
))
|
||||
.map_err(|e| serde::ser::Error::custom(format!("format error: {e}")))?;
|
||||
serializer.serialize_str(&format!("{ts_str}|{}", self.id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> serde::Deserialize<'de> for Cursor {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
parse_cursor(&s).ok_or_else(|| serde::de::Error::custom("invalid cursor"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieve recorded conversation file paths with token pagination. The returned `next_cursor`
|
||||
/// can be supplied on the next call to resume after the last returned item, resilient to
|
||||
/// concurrent new sessions being appended. Ordering is stable by timestamp desc, then UUID desc.
|
||||
pub(crate) async fn get_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
allowed_sources: &[SessionSource],
|
||||
model_providers: Option<&[String]>,
|
||||
default_provider: &str,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut root = codex_home.to_path_buf();
|
||||
root.push(SESSIONS_SUBDIR);
|
||||
|
||||
if !root.exists() {
|
||||
return Ok(ConversationsPage {
|
||||
items: Vec::new(),
|
||||
next_cursor: None,
|
||||
num_scanned_files: 0,
|
||||
reached_scan_cap: false,
|
||||
});
|
||||
}
|
||||
|
||||
let anchor = cursor.cloned();
|
||||
|
||||
let provider_matcher =
|
||||
model_providers.and_then(|filters| ProviderMatcher::new(filters, default_provider));
|
||||
|
||||
let result = traverse_directories_for_paths(
|
||||
root.clone(),
|
||||
page_size,
|
||||
anchor,
|
||||
allowed_sources,
|
||||
provider_matcher.as_ref(),
|
||||
)
|
||||
.await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Load the full contents of a single conversation session file at `path`.
|
||||
/// Returns the entire file contents as a String.
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn get_conversation(path: &Path) -> io::Result<String> {
|
||||
tokio::fs::read_to_string(path).await
|
||||
}
|
||||
|
||||
/// Load conversation file paths from disk using directory traversal.
|
||||
///
|
||||
/// Directory layout: `~/.codex/sessions/YYYY/MM/DD/rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl`
|
||||
/// Returned newest (latest) first.
|
||||
async fn traverse_directories_for_paths(
|
||||
root: PathBuf,
|
||||
page_size: usize,
|
||||
anchor: Option<Cursor>,
|
||||
allowed_sources: &[SessionSource],
|
||||
provider_matcher: Option<&ProviderMatcher<'_>>,
|
||||
) -> io::Result<ConversationsPage> {
|
||||
let mut items: Vec<ConversationItem> = Vec::with_capacity(page_size);
|
||||
let mut scanned_files = 0usize;
|
||||
let mut anchor_passed = anchor.is_none();
|
||||
let (anchor_ts, anchor_id) = match anchor {
|
||||
Some(c) => (c.ts, c.id),
|
||||
None => (OffsetDateTime::UNIX_EPOCH, Uuid::nil()),
|
||||
};
|
||||
let mut more_matches_available = false;
|
||||
|
||||
let year_dirs = collect_dirs_desc(&root, |s| s.parse::<u16>().ok()).await?;
|
||||
|
||||
'outer: for (_year, year_path) in year_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break;
|
||||
}
|
||||
let month_dirs = collect_dirs_desc(year_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_month, month_path) in month_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let day_dirs = collect_dirs_desc(month_path, |s| s.parse::<u8>().ok()).await?;
|
||||
for (_day, day_path) in day_dirs.iter() {
|
||||
if scanned_files >= MAX_SCAN_FILES {
|
||||
break 'outer;
|
||||
}
|
||||
let mut day_files = collect_files(day_path, |name_str, path| {
|
||||
if !name_str.starts_with("rollout-") || !name_str.ends_with(".jsonl") {
|
||||
return None;
|
||||
}
|
||||
|
||||
parse_timestamp_uuid_from_filename(name_str)
|
||||
.map(|(ts, id)| (ts, id, name_str.to_string(), path.to_path_buf()))
|
||||
})
|
||||
.await?;
|
||||
// Stable ordering within the same second: (timestamp desc, uuid desc)
|
||||
day_files.sort_by_key(|(ts, sid, _name_str, _path)| (Reverse(*ts), Reverse(*sid)));
|
||||
for (ts, sid, _name_str, path) in day_files.into_iter() {
|
||||
scanned_files += 1;
|
||||
if scanned_files >= MAX_SCAN_FILES && items.len() >= page_size {
|
||||
more_matches_available = true;
|
||||
break 'outer;
|
||||
}
|
||||
if !anchor_passed {
|
||||
if ts < anchor_ts || (ts == anchor_ts && sid < anchor_id) {
|
||||
anchor_passed = true;
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if items.len() == page_size {
|
||||
more_matches_available = true;
|
||||
break 'outer;
|
||||
}
|
||||
// Read head and simultaneously detect message events within the same
|
||||
// first N JSONL records to avoid a second file read.
|
||||
let summary = read_head_and_tail(&path, HEAD_RECORD_LIMIT, TAIL_RECORD_LIMIT)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if !allowed_sources.is_empty()
|
||||
&& !summary
|
||||
.source
|
||||
.is_some_and(|source| allowed_sources.iter().any(|s| s == &source))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if let Some(matcher) = provider_matcher
|
||||
&& !matcher.matches(summary.model_provider.as_deref())
|
||||
{
|
||||
continue;
|
||||
}
|
||||
// Apply filters: must have session meta and at least one user message event
|
||||
if summary.saw_session_meta && summary.saw_user_event {
|
||||
let HeadTailSummary {
|
||||
head,
|
||||
tail,
|
||||
created_at,
|
||||
mut updated_at,
|
||||
..
|
||||
} = summary;
|
||||
updated_at = updated_at.or_else(|| created_at.clone());
|
||||
items.push(ConversationItem {
|
||||
path,
|
||||
head,
|
||||
tail,
|
||||
created_at,
|
||||
updated_at,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let reached_scan_cap = scanned_files >= MAX_SCAN_FILES;
|
||||
if reached_scan_cap && !items.is_empty() {
|
||||
more_matches_available = true;
|
||||
}
|
||||
|
||||
let next = if more_matches_available {
|
||||
build_next_cursor(&items)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConversationsPage {
|
||||
items,
|
||||
next_cursor: next,
|
||||
num_scanned_files: scanned_files,
|
||||
reached_scan_cap,
|
||||
})
|
||||
}
|
||||
|
||||
/// Pagination cursor token format: "<file_ts>|<uuid>" where `file_ts` matches the
|
||||
/// filename timestamp portion (YYYY-MM-DDThh-mm-ss) used in rollout filenames.
|
||||
/// The cursor orders files by timestamp desc, then UUID desc.
|
||||
pub fn parse_cursor(token: &str) -> Option<Cursor> {
|
||||
let (file_ts, uuid_str) = token.split_once('|')?;
|
||||
|
||||
let Ok(uuid) = Uuid::parse_str(uuid_str) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
|
||||
|
||||
Some(Cursor::new(ts, uuid))
|
||||
}
|
||||
|
||||
fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
|
||||
let last = items.last()?;
|
||||
let file_name = last.path.file_name()?.to_string_lossy();
|
||||
let (ts, id) = parse_timestamp_uuid_from_filename(&file_name)?;
|
||||
Some(Cursor::new(ts, id))
|
||||
}
|
||||
|
||||
/// Collects immediate subdirectories of `parent`, parses their (string) names with `parse`,
|
||||
/// and returns them sorted descending by the parsed key.
|
||||
async fn collect_dirs_desc<T, F>(parent: &Path, parse: F) -> io::Result<Vec<(T, PathBuf)>>
|
||||
where
|
||||
T: Ord + Copy,
|
||||
F: Fn(&str) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut vec: Vec<(T, PathBuf)> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_dir())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s)
|
||||
{
|
||||
vec.push((v, entry.path()));
|
||||
}
|
||||
}
|
||||
vec.sort_by_key(|(v, _)| Reverse(*v));
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
/// Collects files in a directory and parses them with `parse`.
|
||||
async fn collect_files<T, F>(parent: &Path, parse: F) -> io::Result<Vec<T>>
|
||||
where
|
||||
F: Fn(&str, &Path) -> Option<T>,
|
||||
{
|
||||
let mut dir = tokio::fs::read_dir(parent).await?;
|
||||
let mut collected: Vec<T> = Vec::new();
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
if entry
|
||||
.file_type()
|
||||
.await
|
||||
.map(|ft| ft.is_file())
|
||||
.unwrap_or(false)
|
||||
&& let Some(s) = entry.file_name().to_str()
|
||||
&& let Some(v) = parse(s, &entry.path())
|
||||
{
|
||||
collected.push(v);
|
||||
}
|
||||
}
|
||||
Ok(collected)
|
||||
}
|
||||
|
||||
fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> {
|
||||
// Expected: rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl
|
||||
let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?;
|
||||
|
||||
// Scan from the right for a '-' such that the suffix parses as a UUID.
|
||||
let (sep_idx, uuid) = core
|
||||
.match_indices('-')
|
||||
.rev()
|
||||
.find_map(|(i, _)| Uuid::parse_str(&core[i + 1..]).ok().map(|u| (i, u)))?;
|
||||
|
||||
let ts_str = &core[..sep_idx];
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let ts = PrimitiveDateTime::parse(ts_str, format).ok()?.assume_utc();
|
||||
Some((ts, uuid))
|
||||
}
|
||||
|
||||
struct ProviderMatcher<'a> {
|
||||
filters: &'a [String],
|
||||
matches_default_provider: bool,
|
||||
}
|
||||
|
||||
impl<'a> ProviderMatcher<'a> {
|
||||
fn new(filters: &'a [String], default_provider: &'a str) -> Option<Self> {
|
||||
if filters.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let matches_default_provider = filters.iter().any(|provider| provider == default_provider);
|
||||
Some(Self {
|
||||
filters,
|
||||
matches_default_provider,
|
||||
})
|
||||
}
|
||||
|
||||
fn matches(&self, session_provider: Option<&str>) -> bool {
|
||||
match session_provider {
|
||||
Some(provider) => self.filters.iter().any(|candidate| candidate == provider),
|
||||
None => self.matches_default_provider,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_head_and_tail(
|
||||
path: &Path,
|
||||
head_limit: usize,
|
||||
tail_limit: usize,
|
||||
) -> io::Result<HeadTailSummary> {
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
|
||||
let file = tokio::fs::File::open(path).await?;
|
||||
let reader = tokio::io::BufReader::new(file);
|
||||
let mut lines = reader.lines();
|
||||
let mut summary = HeadTailSummary::default();
|
||||
|
||||
while summary.head.len() < head_limit {
|
||||
let line_opt = lines.next_line().await?;
|
||||
let Some(line) = line_opt else { break };
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parsed: Result<RolloutLine, _> = serde_json::from_str(trimmed);
|
||||
let Ok(rollout_line) = parsed else { continue };
|
||||
|
||||
match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
summary.source = Some(session_meta_line.meta.source.clone());
|
||||
summary.model_provider = session_meta_line.meta.model_provider.clone();
|
||||
summary.created_at = summary
|
||||
.created_at
|
||||
.clone()
|
||||
.or_else(|| Some(rollout_line.timestamp.clone()));
|
||||
if let Ok(val) = serde_json::to_value(session_meta_line) {
|
||||
summary.head.push(val);
|
||||
summary.saw_session_meta = true;
|
||||
}
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
summary.created_at = summary
|
||||
.created_at
|
||||
.clone()
|
||||
.or_else(|| Some(rollout_line.timestamp.clone()));
|
||||
if let Ok(val) = serde_json::to_value(item) {
|
||||
summary.head.push(val);
|
||||
}
|
||||
}
|
||||
RolloutItem::TurnContext(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::Compacted(_) => {
|
||||
// Not included in `head`; skip.
|
||||
}
|
||||
RolloutItem::EventMsg(ev) => {
|
||||
if matches!(ev, EventMsg::UserMessage(_)) {
|
||||
summary.saw_user_event = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tail_limit != 0 {
|
||||
let (tail, updated_at) = read_tail_records(path, tail_limit).await?;
|
||||
summary.tail = tail;
|
||||
summary.updated_at = updated_at;
|
||||
}
|
||||
Ok(summary)
|
||||
}
|
||||
|
||||
/// Read up to `HEAD_RECORD_LIMIT` records from the start of the rollout file at `path`.
|
||||
/// This should be enough to produce a summary including the session meta line.
|
||||
pub async fn read_head_for_summary(path: &Path) -> io::Result<Vec<serde_json::Value>> {
|
||||
let summary = read_head_and_tail(path, HEAD_RECORD_LIMIT, 0).await?;
|
||||
Ok(summary.head)
|
||||
}
|
||||
|
||||
async fn read_tail_records(
|
||||
path: &Path,
|
||||
max_records: usize,
|
||||
) -> io::Result<(Vec<serde_json::Value>, Option<String>)> {
|
||||
use std::io::SeekFrom;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncSeekExt;
|
||||
|
||||
if max_records == 0 {
|
||||
return Ok((Vec::new(), None));
|
||||
}
|
||||
|
||||
const CHUNK_SIZE: usize = 8192;
|
||||
|
||||
let mut file = tokio::fs::File::open(path).await?;
|
||||
let mut pos = file.seek(SeekFrom::End(0)).await?;
|
||||
if pos == 0 {
|
||||
return Ok((Vec::new(), None));
|
||||
}
|
||||
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
let mut latest_timestamp: Option<String> = None;
|
||||
|
||||
loop {
|
||||
let slice_start = match (pos > 0, buffer.iter().position(|&b| b == b'\n')) {
|
||||
(true, Some(idx)) => idx + 1,
|
||||
_ => 0,
|
||||
};
|
||||
let (tail, newest_ts) = collect_last_response_values(&buffer[slice_start..], max_records);
|
||||
if latest_timestamp.is_none() {
|
||||
latest_timestamp = newest_ts.clone();
|
||||
}
|
||||
if tail.len() >= max_records || pos == 0 {
|
||||
return Ok((tail, latest_timestamp.or(newest_ts)));
|
||||
}
|
||||
|
||||
let read_size = CHUNK_SIZE.min(pos as usize);
|
||||
if read_size == 0 {
|
||||
return Ok((tail, latest_timestamp.or(newest_ts)));
|
||||
}
|
||||
pos -= read_size as u64;
|
||||
file.seek(SeekFrom::Start(pos)).await?;
|
||||
let mut chunk = vec![0; read_size];
|
||||
file.read_exact(&mut chunk).await?;
|
||||
chunk.extend_from_slice(&buffer);
|
||||
buffer = chunk;
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_last_response_values(
|
||||
buffer: &[u8],
|
||||
max_records: usize,
|
||||
) -> (Vec<serde_json::Value>, Option<String>) {
|
||||
use std::borrow::Cow;
|
||||
|
||||
if buffer.is_empty() || max_records == 0 {
|
||||
return (Vec::new(), None);
|
||||
}
|
||||
|
||||
let text: Cow<'_, str> = String::from_utf8_lossy(buffer);
|
||||
let mut collected_rev: Vec<serde_json::Value> = Vec::new();
|
||||
let mut latest_timestamp: Option<String> = None;
|
||||
for line in text.lines().rev() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let parsed: serde_json::Result<RolloutLine> = serde_json::from_str(trimmed);
|
||||
let Ok(rollout_line) = parsed else { continue };
|
||||
let RolloutLine { timestamp, item } = rollout_line;
|
||||
if let RolloutItem::ResponseItem(item) = item
|
||||
&& let Ok(val) = serde_json::to_value(&item)
|
||||
{
|
||||
if latest_timestamp.is_none() {
|
||||
latest_timestamp = Some(timestamp.clone());
|
||||
}
|
||||
collected_rev.push(val);
|
||||
if collected_rev.len() == max_records {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
collected_rev.reverse();
|
||||
(collected_rev, latest_timestamp)
|
||||
}
|
||||
|
||||
/// Locate a recorded conversation rollout file by its UUID string using the existing
|
||||
/// paginated listing implementation. Returns `Ok(Some(path))` if found, `Ok(None)` if not present
|
||||
/// or the id is invalid.
|
||||
pub async fn find_conversation_path_by_id_str(
|
||||
codex_home: &Path,
|
||||
id_str: &str,
|
||||
) -> io::Result<Option<PathBuf>> {
|
||||
// Validate UUID format early.
|
||||
if Uuid::parse_str(id_str).is_err() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut root = codex_home.to_path_buf();
|
||||
root.push(SESSIONS_SUBDIR);
|
||||
if !root.exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
// This is safe because we know the values are valid.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let limit = NonZero::new(1).unwrap();
|
||||
// This is safe because we know the values are valid.
|
||||
#[allow(clippy::unwrap_used)]
|
||||
let threads = NonZero::new(2).unwrap();
|
||||
let cancel = Arc::new(AtomicBool::new(false));
|
||||
let exclude: Vec<String> = Vec::new();
|
||||
let compute_indices = false;
|
||||
|
||||
let results = file_search::run(
|
||||
id_str,
|
||||
limit,
|
||||
&root,
|
||||
exclude,
|
||||
threads,
|
||||
cancel,
|
||||
compute_indices,
|
||||
false,
|
||||
)
|
||||
.map_err(|e| io::Error::other(format!("file search failed: {e}")))?;
|
||||
|
||||
Ok(results
|
||||
.matches
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|m| root.join(m.path)))
|
||||
}
|
||||
20
llmx-rs/core/src/rollout/mod.rs
Normal file
20
llmx-rs/core/src/rollout/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
//! Rollout module: persistence and discovery of session rollout files.
|
||||
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
|
||||
pub const SESSIONS_SUBDIR: &str = "sessions";
|
||||
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
|
||||
pub const INTERACTIVE_SESSION_SOURCES: &[SessionSource] =
|
||||
&[SessionSource::Cli, SessionSource::VSCode];
|
||||
|
||||
pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use codex_protocol::protocol::SessionMeta;
|
||||
pub use list::find_conversation_path_by_id_str;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
86
llmx-rs/core/src/rollout/policy.rs
Normal file
86
llmx-rs/core/src/rollout/policy.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::RolloutItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
/// Whether a rollout `item` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
|
||||
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
|
||||
// Persist Codex executive markers so we can analyze flows (e.g., compaction, API turns).
|
||||
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a `ResponseItem` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
|
||||
match item {
|
||||
ResponseItem::Message { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::LocalShellCall { .. }
|
||||
| ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::FunctionCallOutput { .. }
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. } => true,
|
||||
ResponseItem::Other => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether an `EventMsg` should be persisted in rollout files.
|
||||
#[inline]
|
||||
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
|
||||
match ev {
|
||||
EventMsg::UserMessage(_)
|
||||
| EventMsg::AgentMessage(_)
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningRawContent(_)
|
||||
| EventMsg::TokenCount(_)
|
||||
| EventMsg::EnteredReviewMode(_)
|
||||
| EventMsg::ExitedReviewMode(_)
|
||||
| EventMsg::UndoCompleted(_)
|
||||
| EventMsg::TurnAborted(_) => true,
|
||||
EventMsg::Error(_)
|
||||
| EventMsg::Warning(_)
|
||||
| EventMsg::TaskStarted(_)
|
||||
| EventMsg::TaskComplete(_)
|
||||
| EventMsg::AgentMessageDelta(_)
|
||||
| EventMsg::AgentReasoningDelta(_)
|
||||
| EventMsg::AgentReasoningRawContentDelta(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::RawResponseItem(_)
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::ExecCommandOutputDelta(_)
|
||||
| EventMsg::ExecCommandEnd(_)
|
||||
| EventMsg::ExecApprovalRequest(_)
|
||||
| EventMsg::ApplyPatchApprovalRequest(_)
|
||||
| EventMsg::BackgroundEvent(_)
|
||||
| EventMsg::StreamError(_)
|
||||
| EventMsg::PatchApplyBegin(_)
|
||||
| EventMsg::PatchApplyEnd(_)
|
||||
| EventMsg::TurnDiff(_)
|
||||
| EventMsg::GetHistoryEntryResponse(_)
|
||||
| EventMsg::UndoStarted(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListCustomPromptsResponse(_)
|
||||
| EventMsg::PlanUpdate(_)
|
||||
| EventMsg::ShutdownComplete
|
||||
| EventMsg::ViewImageToolCall(_)
|
||||
| EventMsg::DeprecationNotice(_)
|
||||
| EventMsg::ItemStarted(_)
|
||||
| EventMsg::ItemCompleted(_)
|
||||
| EventMsg::AgentMessageContentDelta(_)
|
||||
| EventMsg::ReasoningContentDelta(_)
|
||||
| EventMsg::ReasoningRawContentDelta(_) => false,
|
||||
}
|
||||
}
|
||||
424
llmx-rs/core/src/rollout/recorder.rs
Normal file
424
llmx-rs/core/src/rollout/recorder.rs
Normal file
@@ -0,0 +1,424 @@
|
||||
//! Persist Codex session rollouts (.jsonl) so sessions can be replayed or inspected later.
|
||||
|
||||
use std::fs::File;
|
||||
use std::fs::{self};
|
||||
use std::io::Error as IoError;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::ConversationId;
|
||||
use serde_json::Value;
|
||||
use time::OffsetDateTime;
|
||||
use time::format_description::FormatItem;
|
||||
use time::macros::format_description;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::sync::mpsc::{self};
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use super::SESSIONS_SUBDIR;
|
||||
use super::list::ConversationsPage;
|
||||
use super::list::Cursor;
|
||||
use super::list::get_conversations;
|
||||
use super::policy::is_persisted_response_item;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::originator;
|
||||
use crate::git_info::collect_git_info;
|
||||
use codex_protocol::protocol::InitialHistory;
|
||||
use codex_protocol::protocol::ResumedHistory;
|
||||
use codex_protocol::protocol::RolloutItem;
|
||||
use codex_protocol::protocol::RolloutLine;
|
||||
use codex_protocol::protocol::SessionMeta;
|
||||
use codex_protocol::protocol::SessionMetaLine;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
|
||||
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
|
||||
/// every update.
|
||||
///
|
||||
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
|
||||
///
|
||||
/// ```ignore
|
||||
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct RolloutRecorder {
|
||||
tx: Sender<RolloutCmd>,
|
||||
pub(crate) rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum RolloutRecorderParams {
|
||||
Create {
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
source: SessionSource,
|
||||
},
|
||||
Resume {
|
||||
path: PathBuf,
|
||||
},
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<RolloutItem>),
|
||||
/// Ensure all prior writes are processed; respond when flushed.
|
||||
Flush {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
Shutdown {
|
||||
ack: oneshot::Sender<()>,
|
||||
},
|
||||
}
|
||||
|
||||
impl RolloutRecorderParams {
|
||||
pub fn new(
|
||||
conversation_id: ConversationId,
|
||||
instructions: Option<String>,
|
||||
source: SessionSource,
|
||||
) -> Self {
|
||||
Self::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resume(path: PathBuf) -> Self {
|
||||
Self::Resume { path }
|
||||
}
|
||||
}
|
||||
|
||||
impl RolloutRecorder {
|
||||
/// List conversations (rollout files) under the provided Codex home directory.
|
||||
pub async fn list_conversations(
|
||||
codex_home: &Path,
|
||||
page_size: usize,
|
||||
cursor: Option<&Cursor>,
|
||||
allowed_sources: &[SessionSource],
|
||||
model_providers: Option<&[String]>,
|
||||
default_provider: &str,
|
||||
) -> std::io::Result<ConversationsPage> {
|
||||
get_conversations(
|
||||
codex_home,
|
||||
page_size,
|
||||
cursor,
|
||||
allowed_sources,
|
||||
model_providers,
|
||||
default_provider,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
|
||||
/// cannot be created or the rollout file cannot be opened we return the
|
||||
/// error so the caller can decide whether to disable persistence.
|
||||
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
|
||||
let (file, rollout_path, meta) = match params {
|
||||
RolloutRecorderParams::Create {
|
||||
conversation_id,
|
||||
instructions,
|
||||
source,
|
||||
} => {
|
||||
let LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id: session_id,
|
||||
timestamp,
|
||||
} = create_log_file(config, conversation_id)?;
|
||||
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = timestamp
|
||||
.to_offset(time::UtcOffset::UTC)
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
(
|
||||
tokio::fs::File::from_std(file),
|
||||
path,
|
||||
Some(SessionMeta {
|
||||
id: session_id,
|
||||
timestamp,
|
||||
cwd: config.cwd.clone(),
|
||||
originator: originator().value.clone(),
|
||||
cli_version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
instructions,
|
||||
source,
|
||||
model_provider: Some(config.model_provider_id.clone()),
|
||||
}),
|
||||
)
|
||||
}
|
||||
RolloutRecorderParams::Resume { path } => (
|
||||
tokio::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.open(&path)
|
||||
.await?,
|
||||
path,
|
||||
None,
|
||||
),
|
||||
};
|
||||
|
||||
// Clone the cwd for the spawned task to collect git info asynchronously
|
||||
let cwd = config.cwd.clone();
|
||||
|
||||
// A reasonably-sized bounded channel. If the buffer fills up the send
|
||||
// future will yield, which is fine – we only need to ensure we do not
|
||||
// perform *blocking* I/O on the caller's thread.
|
||||
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
|
||||
|
||||
// Spawn a Tokio task that owns the file handle and performs async
|
||||
// writes. Using `tokio::fs::File` keeps everything on the async I/O
|
||||
// driver instead of blocking the runtime.
|
||||
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
|
||||
|
||||
Ok(Self { tx, rollout_path })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
// "fully qualified MCP tool calls," so we could consider
|
||||
// reformatting them in that case.
|
||||
if is_persisted_response_item(item) {
|
||||
filtered.push(item.clone());
|
||||
}
|
||||
}
|
||||
if filtered.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
/// Flush all queued writes and wait until they are committed by the writer task.
|
||||
pub async fn flush(&self) -> std::io::Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.tx
|
||||
.send(RolloutCmd::Flush { ack: tx })
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
|
||||
rx.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
|
||||
}
|
||||
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
info!("Resuming rollout from {path:?}");
|
||||
let text = tokio::fs::read_to_string(path).await?;
|
||||
if text.trim().is_empty() {
|
||||
return Err(IoError::other("empty session file"));
|
||||
}
|
||||
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
let mut conversation_id: Option<ConversationId> = None;
|
||||
for line in text.lines() {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
let v: Value = match serde_json::from_str(line) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("failed to parse line as JSON: {line:?}, error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the rollout line structure
|
||||
match serde_json::from_value::<RolloutLine>(v.clone()) {
|
||||
Ok(rollout_line) => match rollout_line.item {
|
||||
RolloutItem::SessionMeta(session_meta_line) => {
|
||||
// Use the FIRST SessionMeta encountered in the file as the canonical
|
||||
// conversation id and main session information. Keep all items intact.
|
||||
if conversation_id.is_none() {
|
||||
conversation_id = Some(session_meta_line.meta.id);
|
||||
}
|
||||
items.push(RolloutItem::SessionMeta(session_meta_line));
|
||||
}
|
||||
RolloutItem::ResponseItem(item) => {
|
||||
items.push(RolloutItem::ResponseItem(item));
|
||||
}
|
||||
RolloutItem::Compacted(item) => {
|
||||
items.push(RolloutItem::Compacted(item));
|
||||
}
|
||||
RolloutItem::TurnContext(item) => {
|
||||
items.push(RolloutItem::TurnContext(item));
|
||||
}
|
||||
RolloutItem::EventMsg(_ev) => {
|
||||
items.push(RolloutItem::EventMsg(_ev));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("failed to parse rollout line: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(
|
||||
"Resumed rollout with {} items, conversation ID: {:?}",
|
||||
items.len(),
|
||||
conversation_id
|
||||
);
|
||||
let conversation_id = conversation_id
|
||||
.ok_or_else(|| IoError::other("failed to parse conversation ID from rollout file"))?;
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(InitialHistory::New);
|
||||
}
|
||||
|
||||
info!("Resumed rollout successfully from {path:?}");
|
||||
Ok(InitialHistory::Resumed(ResumedHistory {
|
||||
conversation_id,
|
||||
history: items,
|
||||
rollout_path: path.to_path_buf(),
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn shutdown(&self) -> std::io::Result<()> {
|
||||
let (tx_done, rx_done) = oneshot::channel();
|
||||
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
|
||||
Ok(_) => rx_done
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
|
||||
Err(e) => {
|
||||
warn!("failed to send rollout shutdown command: {e}");
|
||||
Err(IoError::other(format!(
|
||||
"failed to send rollout shutdown command: {e}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LogFileInfo {
|
||||
/// Opened file handle to the rollout file.
|
||||
file: File,
|
||||
|
||||
/// Full path to the rollout file.
|
||||
path: PathBuf,
|
||||
|
||||
/// Session ID (also embedded in filename).
|
||||
conversation_id: ConversationId,
|
||||
|
||||
/// Timestamp for the start of the session.
|
||||
timestamp: OffsetDateTime,
|
||||
}
|
||||
|
||||
fn create_log_file(
|
||||
config: &Config,
|
||||
conversation_id: ConversationId,
|
||||
) -> std::io::Result<LogFileInfo> {
|
||||
// Resolve ~/.codex/sessions/YYYY/MM/DD and create it if missing.
|
||||
let timestamp = OffsetDateTime::now_local()
|
||||
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
|
||||
let mut dir = config.codex_home.clone();
|
||||
dir.push(SESSIONS_SUBDIR);
|
||||
dir.push(timestamp.year().to_string());
|
||||
dir.push(format!("{:02}", u8::from(timestamp.month())));
|
||||
dir.push(format!("{:02}", timestamp.day()));
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
|
||||
// compatibility with filesystems that do not allow colons in filenames.
|
||||
let format: &[FormatItem] =
|
||||
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
|
||||
let date_str = timestamp
|
||||
.format(format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
|
||||
|
||||
let path = dir.join(filename);
|
||||
let file = std::fs::OpenOptions::new()
|
||||
.append(true)
|
||||
.create(true)
|
||||
.open(&path)?;
|
||||
|
||||
Ok(LogFileInfo {
|
||||
file,
|
||||
path,
|
||||
conversation_id,
|
||||
timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
async fn rollout_writer(
|
||||
file: tokio::fs::File,
|
||||
mut rx: mpsc::Receiver<RolloutCmd>,
|
||||
mut meta: Option<SessionMeta>,
|
||||
cwd: std::path::PathBuf,
|
||||
) -> std::io::Result<()> {
|
||||
let mut writer = JsonlWriter { file };
|
||||
|
||||
// If we have a meta, collect git info asynchronously and write meta first
|
||||
if let Some(session_meta) = meta.take() {
|
||||
let git_info = collect_git_info(&cwd).await;
|
||||
let session_meta_line = SessionMetaLine {
|
||||
meta: session_meta,
|
||||
git: git_info,
|
||||
};
|
||||
|
||||
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
|
||||
writer
|
||||
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_rollout_item(item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::Flush { ack } => {
|
||||
// Ensure underlying file is flushed and then ack.
|
||||
if let Err(e) = writer.file.flush().await {
|
||||
let _ = ack.send(());
|
||||
return Err(e);
|
||||
}
|
||||
let _ = ack.send(());
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct JsonlWriter {
|
||||
file: tokio::fs::File,
|
||||
}
|
||||
|
||||
impl JsonlWriter {
|
||||
async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> {
|
||||
let timestamp_format: &[FormatItem] = format_description!(
|
||||
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
|
||||
);
|
||||
let timestamp = OffsetDateTime::now_utc()
|
||||
.format(timestamp_format)
|
||||
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
|
||||
|
||||
let line = RolloutLine {
|
||||
timestamp,
|
||||
item: rollout_item,
|
||||
};
|
||||
self.write_line(&line).await
|
||||
}
|
||||
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
|
||||
let mut json = serde_json::to_string(item)?;
|
||||
json.push('\n');
|
||||
self.file.write_all(json.as_bytes()).await?;
|
||||
self.file.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
1131
llmx-rs/core/src/rollout/tests.rs
Normal file
1131
llmx-rs/core/src/rollout/tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
245
llmx-rs/core/src/safety.rs
Normal file
245
llmx-rs/core/src/safety.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use std::path::Component;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
use codex_apply_patch::ApplyPatchFileChange;
|
||||
|
||||
use crate::exec::SandboxType;
|
||||
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::sync::atomic::AtomicBool;
|
||||
#[cfg(target_os = "windows")]
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
static WINDOWS_SANDBOX_ENABLED: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub fn set_windows_sandbox_enabled(enabled: bool) {
|
||||
WINDOWS_SANDBOX_ENABLED.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
#[allow(dead_code)]
|
||||
pub fn set_windows_sandbox_enabled(_enabled: bool) {}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum SafetyCheck {
|
||||
AutoApprove {
|
||||
sandbox_type: SandboxType,
|
||||
user_explicitly_approved: bool,
|
||||
},
|
||||
AskUser,
|
||||
Reject {
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn assess_patch_safety(
|
||||
action: &ApplyPatchAction,
|
||||
policy: AskForApproval,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: &Path,
|
||||
) -> SafetyCheck {
|
||||
if action.is_empty() {
|
||||
return SafetyCheck::Reject {
|
||||
reason: "empty patch".to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
match policy {
|
||||
AskForApproval::OnFailure | AskForApproval::Never | AskForApproval::OnRequest => {
|
||||
// Continue to see if this can be auto-approved.
|
||||
}
|
||||
// TODO(ragona): I'm not sure this is actually correct? I believe in this case
|
||||
// we want to continue to the writable paths check before asking the user.
|
||||
AskForApproval::UnlessTrusted => {
|
||||
return SafetyCheck::AskUser;
|
||||
}
|
||||
}
|
||||
|
||||
// Even though the patch appears to be constrained to writable paths, it is
|
||||
// possible that paths in the patch are hard links to files outside the
|
||||
// writable roots, so we should still run `apply_patch` in a sandbox in that case.
|
||||
if is_write_patch_constrained_to_writable_paths(action, sandbox_policy, cwd)
|
||||
|| policy == AskForApproval::OnFailure
|
||||
{
|
||||
if matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) {
|
||||
// DangerFullAccess is intended to bypass sandboxing entirely.
|
||||
SafetyCheck::AutoApprove {
|
||||
sandbox_type: SandboxType::None,
|
||||
user_explicitly_approved: false,
|
||||
}
|
||||
} else {
|
||||
// Only auto‑approve when we can actually enforce a sandbox. Otherwise
|
||||
// fall back to asking the user because the patch may touch arbitrary
|
||||
// paths outside the project.
|
||||
match get_platform_sandbox() {
|
||||
Some(sandbox_type) => SafetyCheck::AutoApprove {
|
||||
sandbox_type,
|
||||
user_explicitly_approved: false,
|
||||
},
|
||||
None => SafetyCheck::AskUser,
|
||||
}
|
||||
}
|
||||
} else if policy == AskForApproval::Never {
|
||||
SafetyCheck::Reject {
|
||||
reason: "writing outside of the project; rejected by user approval settings"
|
||||
.to_string(),
|
||||
}
|
||||
} else {
|
||||
SafetyCheck::AskUser
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_platform_sandbox() -> Option<SandboxType> {
|
||||
if cfg!(target_os = "macos") {
|
||||
Some(SandboxType::MacosSeatbelt)
|
||||
} else if cfg!(target_os = "linux") {
|
||||
Some(SandboxType::LinuxSeccomp)
|
||||
} else if cfg!(target_os = "windows") {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
if WINDOWS_SANDBOX_ENABLED.load(Ordering::Relaxed) {
|
||||
return Some(SandboxType::WindowsRestrictedToken);
|
||||
}
|
||||
}
|
||||
None
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn is_write_patch_constrained_to_writable_paths(
|
||||
action: &ApplyPatchAction,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: &Path,
|
||||
) -> bool {
|
||||
// Early‑exit if there are no declared writable roots.
|
||||
let writable_roots = match sandbox_policy {
|
||||
SandboxPolicy::ReadOnly => {
|
||||
return false;
|
||||
}
|
||||
SandboxPolicy::DangerFullAccess => {
|
||||
return true;
|
||||
}
|
||||
SandboxPolicy::WorkspaceWrite { .. } => sandbox_policy.get_writable_roots_with_cwd(cwd),
|
||||
};
|
||||
|
||||
// Normalize a path by removing `.` and resolving `..` without touching the
|
||||
// filesystem (works even if the file does not exist).
|
||||
fn normalize(path: &Path) -> Option<PathBuf> {
|
||||
let mut out = PathBuf::new();
|
||||
for comp in path.components() {
|
||||
match comp {
|
||||
Component::ParentDir => {
|
||||
out.pop();
|
||||
}
|
||||
Component::CurDir => { /* skip */ }
|
||||
other => out.push(other.as_os_str()),
|
||||
}
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
|
||||
// Determine whether `path` is inside **any** writable root. Both `path`
|
||||
// and roots are converted to absolute, normalized forms before the
|
||||
// prefix check.
|
||||
let is_path_writable = |p: &PathBuf| {
|
||||
let abs = if p.is_absolute() {
|
||||
p.clone()
|
||||
} else {
|
||||
cwd.join(p)
|
||||
};
|
||||
let abs = match normalize(&abs) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
writable_roots
|
||||
.iter()
|
||||
.any(|writable_root| writable_root.is_path_writable(&abs))
|
||||
};
|
||||
|
||||
for (path, change) in action.changes() {
|
||||
match change {
|
||||
ApplyPatchFileChange::Add { .. } | ApplyPatchFileChange::Delete { .. } => {
|
||||
if !is_path_writable(path) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ApplyPatchFileChange::Update { move_path, .. } => {
|
||||
if !is_path_writable(path) {
|
||||
return false;
|
||||
}
|
||||
if let Some(dest) = move_path
|
||||
&& !is_path_writable(dest)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_writable_roots_constraint() {
|
||||
// Use a temporary directory as our workspace to avoid touching
|
||||
// the real current working directory.
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cwd = tmp.path().to_path_buf();
|
||||
let parent = cwd.parent().unwrap().to_path_buf();
|
||||
|
||||
// Helper to build a single‑entry patch that adds a file at `p`.
|
||||
let make_add_change = |p: PathBuf| ApplyPatchAction::new_add_for_test(&p, "".to_string());
|
||||
|
||||
let add_inside = make_add_change(cwd.join("inner.txt"));
|
||||
let add_outside = make_add_change(parent.join("outside.txt"));
|
||||
|
||||
// Policy limited to the workspace only; exclude system temp roots so
|
||||
// only `cwd` is writable by default.
|
||||
let policy_workspace_only = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
};
|
||||
|
||||
assert!(is_write_patch_constrained_to_writable_paths(
|
||||
&add_inside,
|
||||
&policy_workspace_only,
|
||||
&cwd,
|
||||
));
|
||||
|
||||
assert!(!is_write_patch_constrained_to_writable_paths(
|
||||
&add_outside,
|
||||
&policy_workspace_only,
|
||||
&cwd,
|
||||
));
|
||||
|
||||
// With the parent dir explicitly added as a writable root, the
|
||||
// outside write should be permitted.
|
||||
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![parent],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
};
|
||||
assert!(is_write_patch_constrained_to_writable_paths(
|
||||
&add_outside,
|
||||
&policy_with_parent,
|
||||
&cwd,
|
||||
));
|
||||
}
|
||||
}
|
||||
260
llmx-rs/core/src/sandboxing/assessment.rs
Normal file
260
llmx-rs/core/src/sandboxing/assessment.rs
Normal file
@@ -0,0 +1,260 @@
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use askama::Template;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SandboxCommandAssessment;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use futures::StreamExt;
|
||||
use serde_json::json;
|
||||
use tokio::time::timeout;
|
||||
use tracing::warn;
|
||||
|
||||
const SANDBOX_ASSESSMENT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "sandboxing/assessment_prompt.md", escape = "none")]
|
||||
struct SandboxAssessmentPromptTemplate<'a> {
|
||||
platform: &'a str,
|
||||
sandbox_policy: &'a str,
|
||||
filesystem_roots: Option<&'a str>,
|
||||
working_directory: &'a str,
|
||||
command_argv: &'a str,
|
||||
command_joined: &'a str,
|
||||
sandbox_failure_message: Option<&'a str>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn assess_command(
|
||||
config: Arc<Config>,
|
||||
provider: ModelProviderInfo,
|
||||
auth_manager: Arc<AuthManager>,
|
||||
parent_otel: &OtelEventManager,
|
||||
conversation_id: ConversationId,
|
||||
session_source: SessionSource,
|
||||
call_id: &str,
|
||||
command: &[String],
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
cwd: &Path,
|
||||
failure_message: Option<&str>,
|
||||
) -> Option<SandboxCommandAssessment> {
|
||||
if !config.experimental_sandbox_command_assessment || command.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let command_json = serde_json::to_string(command).unwrap_or_else(|_| "[]".to_string());
|
||||
let command_joined =
|
||||
shlex::try_join(command.iter().map(String::as_str)).unwrap_or_else(|_| command.join(" "));
|
||||
let failure = failure_message
|
||||
.map(str::trim)
|
||||
.filter(|msg| !msg.is_empty())
|
||||
.map(str::to_string);
|
||||
|
||||
let cwd_str = cwd.to_string_lossy().to_string();
|
||||
let sandbox_summary = summarize_sandbox_policy(sandbox_policy);
|
||||
let mut roots = sandbox_roots_for_prompt(sandbox_policy, cwd);
|
||||
roots.sort();
|
||||
roots.dedup();
|
||||
|
||||
let platform = std::env::consts::OS;
|
||||
let roots_formatted = roots.iter().map(|root| root.to_string_lossy().to_string());
|
||||
let filesystem_roots = match roots_formatted.collect::<Vec<_>>() {
|
||||
collected if collected.is_empty() => None,
|
||||
collected => Some(collected.join(", ")),
|
||||
};
|
||||
|
||||
let prompt_template = SandboxAssessmentPromptTemplate {
|
||||
platform,
|
||||
sandbox_policy: sandbox_summary.as_str(),
|
||||
filesystem_roots: filesystem_roots.as_deref(),
|
||||
working_directory: cwd_str.as_str(),
|
||||
command_argv: command_json.as_str(),
|
||||
command_joined: command_joined.as_str(),
|
||||
sandbox_failure_message: failure.as_deref(),
|
||||
};
|
||||
let rendered_prompt = match prompt_template.render() {
|
||||
Ok(rendered) => rendered,
|
||||
Err(err) => {
|
||||
warn!("failed to render sandbox assessment prompt: {err}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let (system_prompt_section, user_prompt_section) = match rendered_prompt.split_once("\n---\n") {
|
||||
Some(split) => split,
|
||||
None => {
|
||||
warn!("rendered sandbox assessment prompt missing separator");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let system_prompt = system_prompt_section
|
||||
.strip_prefix("System Prompt:\n")
|
||||
.unwrap_or(system_prompt_section)
|
||||
.trim()
|
||||
.to_string();
|
||||
let user_prompt = user_prompt_section
|
||||
.strip_prefix("User Prompt:\n")
|
||||
.unwrap_or(user_prompt_section)
|
||||
.trim()
|
||||
.to_string();
|
||||
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: user_prompt }],
|
||||
}],
|
||||
tools: Vec::new(),
|
||||
parallel_tool_calls: false,
|
||||
base_instructions_override: Some(system_prompt),
|
||||
output_schema: Some(sandbox_assessment_schema()),
|
||||
};
|
||||
|
||||
let child_otel =
|
||||
parent_otel.with_model(config.model.as_str(), config.model_family.slug.as_str());
|
||||
|
||||
let client = ModelClient::new(
|
||||
Arc::clone(&config),
|
||||
Some(auth_manager),
|
||||
child_otel,
|
||||
provider,
|
||||
config.model_reasoning_effort,
|
||||
config.model_reasoning_summary,
|
||||
conversation_id,
|
||||
session_source,
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let assessment_result = timeout(SANDBOX_ASSESSMENT_TIMEOUT, async move {
|
||||
let mut stream = client.stream(&prompt).await?;
|
||||
let mut last_json: Option<String> = None;
|
||||
while let Some(event) = stream.next().await {
|
||||
match event {
|
||||
Ok(ResponseEvent::OutputItemDone(item)) => {
|
||||
if let Some(text) = response_item_text(&item) {
|
||||
last_json = Some(text);
|
||||
}
|
||||
}
|
||||
Ok(ResponseEvent::RateLimits(_)) => {}
|
||||
Ok(ResponseEvent::Completed { .. }) => break,
|
||||
Ok(_) => continue,
|
||||
Err(err) => return Err(err),
|
||||
}
|
||||
}
|
||||
Ok(last_json)
|
||||
})
|
||||
.await;
|
||||
let duration = start.elapsed();
|
||||
parent_otel.sandbox_assessment_latency(call_id, duration);
|
||||
|
||||
match assessment_result {
|
||||
Ok(Ok(Some(raw))) => match serde_json::from_str::<SandboxCommandAssessment>(raw.trim()) {
|
||||
Ok(assessment) => {
|
||||
parent_otel.sandbox_assessment(
|
||||
call_id,
|
||||
"success",
|
||||
Some(assessment.risk_level),
|
||||
duration,
|
||||
);
|
||||
return Some(assessment);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("failed to parse sandbox assessment JSON: {err}");
|
||||
parent_otel.sandbox_assessment(call_id, "parse_error", None, duration);
|
||||
}
|
||||
},
|
||||
Ok(Ok(None)) => {
|
||||
warn!("sandbox assessment response did not include any message");
|
||||
parent_otel.sandbox_assessment(call_id, "no_output", None, duration);
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!("sandbox assessment failed: {err}");
|
||||
parent_otel.sandbox_assessment(call_id, "model_error", None, duration);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("sandbox assessment timed out");
|
||||
parent_otel.sandbox_assessment(call_id, "timeout", None, duration);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn summarize_sandbox_policy(policy: &SandboxPolicy) -> String {
|
||||
match policy {
|
||||
SandboxPolicy::DangerFullAccess => "danger-full-access".to_string(),
|
||||
SandboxPolicy::ReadOnly => "read-only".to_string(),
|
||||
SandboxPolicy::WorkspaceWrite { network_access, .. } => {
|
||||
let network = if *network_access {
|
||||
"network"
|
||||
} else {
|
||||
"no-network"
|
||||
};
|
||||
format!("workspace-write (network_access={network})")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sandbox_roots_for_prompt(policy: &SandboxPolicy, cwd: &Path) -> Vec<PathBuf> {
|
||||
let mut roots = vec![cwd.to_path_buf()];
|
||||
if let SandboxPolicy::WorkspaceWrite { writable_roots, .. } = policy {
|
||||
roots.extend(writable_roots.iter().cloned());
|
||||
}
|
||||
roots
|
||||
}
|
||||
|
||||
fn sandbox_assessment_schema() -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"required": ["description", "risk_level"],
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 500
|
||||
},
|
||||
"risk_level": {
|
||||
"type": "string",
|
||||
"enum": ["low", "medium", "high"]
|
||||
},
|
||||
},
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
fn response_item_text(item: &ResponseItem) -> Option<String> {
|
||||
match item {
|
||||
ResponseItem::Message { content, .. } => {
|
||||
let mut buffers: Vec<&str> = Vec::new();
|
||||
for segment in content {
|
||||
match segment {
|
||||
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
|
||||
if !text.is_empty() {
|
||||
buffers.push(text);
|
||||
}
|
||||
}
|
||||
ContentItem::InputImage { .. } => {}
|
||||
}
|
||||
}
|
||||
if buffers.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(buffers.join("\n"))
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { output, .. } => Some(output.content.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
178
llmx-rs/core/src/sandboxing/mod.rs
Normal file
178
llmx-rs/core/src/sandboxing/mod.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
/*
|
||||
Module: sandboxing
|
||||
|
||||
Build platform wrappers and produce ExecEnv for execution. Owns low‑level
|
||||
sandbox placement and transformation of portable CommandSpec into a
|
||||
ready‑to‑spawn environment.
|
||||
*/
|
||||
|
||||
pub mod assessment;
|
||||
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::exec::execute_exec_env;
|
||||
use crate::landlock::create_linux_sandbox_command_args;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::seatbelt::create_seatbelt_command_args;
|
||||
#[cfg(target_os = "macos")]
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use crate::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use crate::tools::sandboxing::SandboxablePreference;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CommandSpec {
|
||||
pub program: String,
|
||||
pub args: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
pub env: HashMap<String, String>,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExecEnv {
|
||||
pub command: Vec<String>,
|
||||
pub cwd: PathBuf,
|
||||
pub env: HashMap<String, String>,
|
||||
pub timeout_ms: Option<u64>,
|
||||
pub sandbox: SandboxType,
|
||||
pub with_escalated_permissions: Option<bool>,
|
||||
pub justification: Option<String>,
|
||||
pub arg0: Option<String>,
|
||||
}
|
||||
|
||||
pub enum SandboxPreference {
|
||||
Auto,
|
||||
Require,
|
||||
Forbid,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum SandboxTransformError {
|
||||
#[error("missing codex-linux-sandbox executable path")]
|
||||
MissingLinuxSandboxExecutable,
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[error("seatbelt sandbox is only available on macOS")]
|
||||
SeatbeltUnavailable,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct SandboxManager;
|
||||
|
||||
impl SandboxManager {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub(crate) fn select_initial(
|
||||
&self,
|
||||
policy: &SandboxPolicy,
|
||||
pref: SandboxablePreference,
|
||||
) -> SandboxType {
|
||||
match pref {
|
||||
SandboxablePreference::Forbid => SandboxType::None,
|
||||
SandboxablePreference::Require => {
|
||||
// Require a platform sandbox when available; on Windows this
|
||||
// respects the enable_experimental_windows_sandbox feature.
|
||||
crate::safety::get_platform_sandbox().unwrap_or(SandboxType::None)
|
||||
}
|
||||
SandboxablePreference::Auto => match policy {
|
||||
SandboxPolicy::DangerFullAccess => SandboxType::None,
|
||||
_ => crate::safety::get_platform_sandbox().unwrap_or(SandboxType::None),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn transform(
|
||||
&self,
|
||||
spec: &CommandSpec,
|
||||
policy: &SandboxPolicy,
|
||||
sandbox: SandboxType,
|
||||
sandbox_policy_cwd: &Path,
|
||||
codex_linux_sandbox_exe: Option<&PathBuf>,
|
||||
) -> Result<ExecEnv, SandboxTransformError> {
|
||||
let mut env = spec.env.clone();
|
||||
if !policy.has_full_network_access() {
|
||||
env.insert(
|
||||
CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR.to_string(),
|
||||
"1".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut command = Vec::with_capacity(1 + spec.args.len());
|
||||
command.push(spec.program.clone());
|
||||
command.extend(spec.args.iter().cloned());
|
||||
|
||||
let (command, sandbox_env, arg0_override) = match sandbox {
|
||||
SandboxType::None => (command, HashMap::new(), None),
|
||||
#[cfg(target_os = "macos")]
|
||||
SandboxType::MacosSeatbelt => {
|
||||
let mut seatbelt_env = HashMap::new();
|
||||
seatbelt_env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
let mut args =
|
||||
create_seatbelt_command_args(command.clone(), policy, sandbox_policy_cwd);
|
||||
let mut full_command = Vec::with_capacity(1 + args.len());
|
||||
full_command.push(MACOS_PATH_TO_SEATBELT_EXECUTABLE.to_string());
|
||||
full_command.append(&mut args);
|
||||
(full_command, seatbelt_env, None)
|
||||
}
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
SandboxType::MacosSeatbelt => return Err(SandboxTransformError::SeatbeltUnavailable),
|
||||
SandboxType::LinuxSeccomp => {
|
||||
let exe = codex_linux_sandbox_exe
|
||||
.ok_or(SandboxTransformError::MissingLinuxSandboxExecutable)?;
|
||||
let mut args =
|
||||
create_linux_sandbox_command_args(command.clone(), policy, sandbox_policy_cwd);
|
||||
let mut full_command = Vec::with_capacity(1 + args.len());
|
||||
full_command.push(exe.to_string_lossy().to_string());
|
||||
full_command.append(&mut args);
|
||||
(
|
||||
full_command,
|
||||
HashMap::new(),
|
||||
Some("codex-linux-sandbox".to_string()),
|
||||
)
|
||||
}
|
||||
// On Windows, the restricted token sandbox executes in-process via the
|
||||
// codex-windows-sandbox crate. We leave the command unchanged here and
|
||||
// branch during execution based on the sandbox type.
|
||||
#[cfg(target_os = "windows")]
|
||||
SandboxType::WindowsRestrictedToken => (command, HashMap::new(), None),
|
||||
// When building for non-Windows targets, this variant is never constructed.
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
SandboxType::WindowsRestrictedToken => (command, HashMap::new(), None),
|
||||
};
|
||||
|
||||
env.extend(sandbox_env);
|
||||
|
||||
Ok(ExecEnv {
|
||||
command,
|
||||
cwd: spec.cwd.clone(),
|
||||
env,
|
||||
timeout_ms: spec.timeout_ms,
|
||||
sandbox,
|
||||
with_escalated_permissions: spec.with_escalated_permissions,
|
||||
justification: spec.justification.clone(),
|
||||
arg0: arg0_override,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn denied(&self, sandbox: SandboxType, out: &ExecToolCallOutput) -> bool {
|
||||
crate::exec::is_likely_sandbox_denied(sandbox, out)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn execute_env(
|
||||
env: &ExecEnv,
|
||||
policy: &SandboxPolicy,
|
||||
stdout_stream: Option<StdoutStream>,
|
||||
) -> crate::error::Result<ExecToolCallOutput> {
|
||||
execute_exec_env(env.clone(), policy, stdout_stream).await
|
||||
}
|
||||
368
llmx-rs/core/src/seatbelt.rs
Normal file
368
llmx-rs/core/src/seatbelt.rs
Normal file
@@ -0,0 +1,368 @@
|
||||
#![cfg(target_os = "macos")]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::CStr;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tokio::process::Child;
|
||||
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::spawn::CODEX_SANDBOX_ENV_VAR;
|
||||
use crate::spawn::StdioPolicy;
|
||||
use crate::spawn::spawn_child_async;
|
||||
|
||||
const MACOS_SEATBELT_BASE_POLICY: &str = include_str!("seatbelt_base_policy.sbpl");
|
||||
const MACOS_SEATBELT_NETWORK_POLICY: &str = include_str!("seatbelt_network_policy.sbpl");
|
||||
|
||||
/// When working with `sandbox-exec`, only consider `sandbox-exec` in `/usr/bin`
|
||||
/// to defend against an attacker trying to inject a malicious version on the
|
||||
/// PATH. If /usr/bin/sandbox-exec has been tampered with, then the attacker
|
||||
/// already has root access.
|
||||
pub(crate) const MACOS_PATH_TO_SEATBELT_EXECUTABLE: &str = "/usr/bin/sandbox-exec";
|
||||
|
||||
pub async fn spawn_command_under_seatbelt(
|
||||
command: Vec<String>,
|
||||
command_cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
stdio_policy: StdioPolicy,
|
||||
mut env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
let args = create_seatbelt_command_args(command, sandbox_policy, sandbox_policy_cwd);
|
||||
let arg0 = None;
|
||||
env.insert(CODEX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
|
||||
spawn_child_async(
|
||||
PathBuf::from(MACOS_PATH_TO_SEATBELT_EXECUTABLE),
|
||||
args,
|
||||
arg0,
|
||||
command_cwd,
|
||||
sandbox_policy,
|
||||
stdio_policy,
|
||||
env,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) fn create_seatbelt_command_args(
|
||||
command: Vec<String>,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
sandbox_policy_cwd: &Path,
|
||||
) -> Vec<String> {
|
||||
let (file_write_policy, file_write_dir_params) = {
|
||||
if sandbox_policy.has_full_disk_write_access() {
|
||||
// Allegedly, this is more permissive than `(allow file-write*)`.
|
||||
(
|
||||
r#"(allow file-write* (regex #"^/"))"#.to_string(),
|
||||
Vec::new(),
|
||||
)
|
||||
} else {
|
||||
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(sandbox_policy_cwd);
|
||||
|
||||
let mut writable_folder_policies: Vec<String> = Vec::new();
|
||||
let mut file_write_params = Vec::new();
|
||||
|
||||
for (index, wr) in writable_roots.iter().enumerate() {
|
||||
// Canonicalize to avoid mismatches like /var vs /private/var on macOS.
|
||||
let canonical_root = wr.root.canonicalize().unwrap_or_else(|_| wr.root.clone());
|
||||
let root_param = format!("WRITABLE_ROOT_{index}");
|
||||
file_write_params.push((root_param.clone(), canonical_root));
|
||||
|
||||
if wr.read_only_subpaths.is_empty() {
|
||||
writable_folder_policies.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
} else {
|
||||
// Add parameters for each read-only subpath and generate
|
||||
// the `(require-not ...)` clauses.
|
||||
let mut require_parts: Vec<String> = Vec::new();
|
||||
require_parts.push(format!("(subpath (param \"{root_param}\"))"));
|
||||
for (subpath_index, ro) in wr.read_only_subpaths.iter().enumerate() {
|
||||
let canonical_ro = ro.canonicalize().unwrap_or_else(|_| ro.clone());
|
||||
let ro_param = format!("WRITABLE_ROOT_{index}_RO_{subpath_index}");
|
||||
require_parts
|
||||
.push(format!("(require-not (subpath (param \"{ro_param}\")))"));
|
||||
file_write_params.push((ro_param, canonical_ro));
|
||||
}
|
||||
let policy_component = format!("(require-all {} )", require_parts.join(" "));
|
||||
writable_folder_policies.push(policy_component);
|
||||
}
|
||||
}
|
||||
|
||||
if writable_folder_policies.is_empty() {
|
||||
("".to_string(), Vec::new())
|
||||
} else {
|
||||
let file_write_policy = format!(
|
||||
"(allow file-write*\n{}\n)",
|
||||
writable_folder_policies.join(" ")
|
||||
);
|
||||
(file_write_policy, file_write_params)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let file_read_policy = if sandbox_policy.has_full_disk_read_access() {
|
||||
"; allow read-only file operations\n(allow file-read*)"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// TODO(mbolin): apply_patch calls must also honor the SandboxPolicy.
|
||||
let network_policy = if sandbox_policy.has_full_network_access() {
|
||||
MACOS_SEATBELT_NETWORK_POLICY
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let full_policy = format!(
|
||||
"{MACOS_SEATBELT_BASE_POLICY}\n{file_read_policy}\n{file_write_policy}\n{network_policy}"
|
||||
);
|
||||
|
||||
let dir_params = [file_write_dir_params, macos_dir_params()].concat();
|
||||
|
||||
let mut seatbelt_args: Vec<String> = vec!["-p".to_string(), full_policy];
|
||||
let definition_args = dir_params
|
||||
.into_iter()
|
||||
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy()));
|
||||
seatbelt_args.extend(definition_args);
|
||||
seatbelt_args.push("--".to_string());
|
||||
seatbelt_args.extend(command);
|
||||
seatbelt_args
|
||||
}
|
||||
|
||||
/// Wraps libc::confstr to return a String.
|
||||
fn confstr(name: libc::c_int) -> Option<String> {
|
||||
let mut buf = vec![0_i8; (libc::PATH_MAX as usize) + 1];
|
||||
let len = unsafe { libc::confstr(name, buf.as_mut_ptr(), buf.len()) };
|
||||
if len == 0 {
|
||||
return None;
|
||||
}
|
||||
// confstr guarantees NUL-termination when len > 0.
|
||||
let cstr = unsafe { CStr::from_ptr(buf.as_ptr()) };
|
||||
cstr.to_str().ok().map(ToString::to_string)
|
||||
}
|
||||
|
||||
/// Wraps confstr to return a canonicalized PathBuf.
|
||||
fn confstr_path(name: libc::c_int) -> Option<PathBuf> {
|
||||
let s = confstr(name)?;
|
||||
let path = PathBuf::from(s);
|
||||
path.canonicalize().ok().or(Some(path))
|
||||
}
|
||||
|
||||
fn macos_dir_params() -> Vec<(String, PathBuf)> {
|
||||
if let Some(p) = confstr_path(libc::_CS_DARWIN_USER_CACHE_DIR) {
|
||||
return vec![("DARWIN_USER_CACHE_DIR".to_string(), p)];
|
||||
}
|
||||
vec![]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::MACOS_SEATBELT_BASE_POLICY;
|
||||
use super::create_seatbelt_command_args;
|
||||
use super::macos_dir_params;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_with_read_only_git_subpath() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
} = populate_tmpdir(tmp.path());
|
||||
let cwd = tmp.path().join("cwd");
|
||||
|
||||
// Build a policy that only includes the two test roots as writable and
|
||||
// does not automatically include defaults TMPDIR or /tmp.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![root_with_git, root_without_git],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: true,
|
||||
exclude_slash_tmp: true,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
&cwd,
|
||||
);
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")) (subpath (param "WRITABLE_ROOT_2"))
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
root_without_git_canon.to_string_lossy()
|
||||
),
|
||||
format!("-DWRITABLE_ROOT_2={}", cwd.to_string_lossy()),
|
||||
];
|
||||
|
||||
expected_args.extend(
|
||||
macos_dir_params()
|
||||
.into_iter()
|
||||
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
|
||||
);
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
|
||||
assert_eq!(expected_args, args);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_seatbelt_args_for_cwd_as_git_repo() {
|
||||
// Create a temporary workspace with two writable roots: one containing
|
||||
// a top-level .git directory and one without it.
|
||||
let tmp = TempDir::new().expect("tempdir");
|
||||
let PopulatedTmp {
|
||||
root_with_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
..
|
||||
} = populate_tmpdir(tmp.path());
|
||||
|
||||
// Build a policy that does not specify any writable_roots, but does
|
||||
// use the default ones (cwd and TMPDIR) and verifies the `.git` check
|
||||
// is done properly for cwd.
|
||||
let policy = SandboxPolicy::WorkspaceWrite {
|
||||
writable_roots: vec![],
|
||||
network_access: false,
|
||||
exclude_tmpdir_env_var: false,
|
||||
exclude_slash_tmp: false,
|
||||
};
|
||||
|
||||
let args = create_seatbelt_command_args(
|
||||
vec!["/bin/echo".to_string(), "hello".to_string()],
|
||||
&policy,
|
||||
root_with_git.as_path(),
|
||||
);
|
||||
|
||||
let tmpdir_env_var = std::env::var("TMPDIR")
|
||||
.ok()
|
||||
.map(PathBuf::from)
|
||||
.and_then(|p| p.canonicalize().ok())
|
||||
.map(|p| p.to_string_lossy().to_string());
|
||||
|
||||
let tempdir_policy_entry = if tmpdir_env_var.is_some() {
|
||||
r#" (subpath (param "WRITABLE_ROOT_2"))"#
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// Build the expected policy text using a raw string for readability.
|
||||
// Note that the policy includes:
|
||||
// - the base policy,
|
||||
// - read-only access to the filesystem,
|
||||
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
|
||||
let expected_policy = format!(
|
||||
r#"{MACOS_SEATBELT_BASE_POLICY}
|
||||
; allow read-only file operations
|
||||
(allow file-read*)
|
||||
(allow file-write*
|
||||
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry}
|
||||
)
|
||||
"#,
|
||||
);
|
||||
|
||||
let mut expected_args = vec![
|
||||
"-p".to_string(),
|
||||
expected_policy,
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0={}",
|
||||
root_with_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_0_RO_0={}",
|
||||
root_with_git_git_canon.to_string_lossy()
|
||||
),
|
||||
format!(
|
||||
"-DWRITABLE_ROOT_1={}",
|
||||
PathBuf::from("/tmp")
|
||||
.canonicalize()
|
||||
.expect("canonicalize /tmp")
|
||||
.to_string_lossy()
|
||||
),
|
||||
];
|
||||
|
||||
if let Some(p) = tmpdir_env_var {
|
||||
expected_args.push(format!("-DWRITABLE_ROOT_2={p}"));
|
||||
}
|
||||
|
||||
expected_args.extend(
|
||||
macos_dir_params()
|
||||
.into_iter()
|
||||
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
|
||||
);
|
||||
|
||||
expected_args.extend(vec![
|
||||
"--".to_string(),
|
||||
"/bin/echo".to_string(),
|
||||
"hello".to_string(),
|
||||
]);
|
||||
|
||||
assert_eq!(expected_args, args);
|
||||
}
|
||||
|
||||
struct PopulatedTmp {
|
||||
root_with_git: PathBuf,
|
||||
root_without_git: PathBuf,
|
||||
root_with_git_canon: PathBuf,
|
||||
root_with_git_git_canon: PathBuf,
|
||||
root_without_git_canon: PathBuf,
|
||||
}
|
||||
|
||||
fn populate_tmpdir(tmp: &Path) -> PopulatedTmp {
|
||||
let root_with_git = tmp.join("with_git");
|
||||
let root_without_git = tmp.join("no_git");
|
||||
fs::create_dir_all(&root_with_git).expect("create with_git");
|
||||
fs::create_dir_all(&root_without_git).expect("create no_git");
|
||||
fs::create_dir_all(root_with_git.join(".git")).expect("create .git");
|
||||
|
||||
// Ensure we have canonical paths for -D parameter matching.
|
||||
let root_with_git_canon = root_with_git.canonicalize().expect("canonicalize with_git");
|
||||
let root_with_git_git_canon = root_with_git_canon.join(".git");
|
||||
let root_without_git_canon = root_without_git
|
||||
.canonicalize()
|
||||
.expect("canonicalize no_git");
|
||||
PopulatedTmp {
|
||||
root_with_git,
|
||||
root_without_git,
|
||||
root_with_git_canon,
|
||||
root_with_git_git_canon,
|
||||
root_without_git_canon,
|
||||
}
|
||||
}
|
||||
}
|
||||
95
llmx-rs/core/src/seatbelt_base_policy.sbpl
Normal file
95
llmx-rs/core/src/seatbelt_base_policy.sbpl
Normal file
@@ -0,0 +1,95 @@
|
||||
(version 1)
|
||||
|
||||
; inspired by Chrome's sandbox policy:
|
||||
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/common.sb;l=273-319;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
|
||||
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/renderer.sb;l=64;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
|
||||
|
||||
; start with closed-by-default
|
||||
(deny default)
|
||||
|
||||
; child processes inherit the policy of their parent
|
||||
(allow process-exec)
|
||||
(allow process-fork)
|
||||
(allow signal (target same-sandbox))
|
||||
|
||||
; Allow cf prefs to work.
|
||||
(allow user-preference-read)
|
||||
|
||||
; process-info
|
||||
(allow process-info* (target same-sandbox))
|
||||
|
||||
(allow file-write-data
|
||||
(require-all
|
||||
(path "/dev/null")
|
||||
(vnode-type CHARACTER-DEVICE)))
|
||||
|
||||
; sysctls permitted.
|
||||
(allow sysctl-read
|
||||
(sysctl-name "hw.activecpu")
|
||||
(sysctl-name "hw.busfrequency_compat")
|
||||
(sysctl-name "hw.byteorder")
|
||||
(sysctl-name "hw.cacheconfig")
|
||||
(sysctl-name "hw.cachelinesize_compat")
|
||||
(sysctl-name "hw.cpufamily")
|
||||
(sysctl-name "hw.cpufrequency_compat")
|
||||
(sysctl-name "hw.cputype")
|
||||
(sysctl-name "hw.l1dcachesize_compat")
|
||||
(sysctl-name "hw.l1icachesize_compat")
|
||||
(sysctl-name "hw.l2cachesize_compat")
|
||||
(sysctl-name "hw.l3cachesize_compat")
|
||||
(sysctl-name "hw.logicalcpu_max")
|
||||
(sysctl-name "hw.machine")
|
||||
(sysctl-name "hw.memsize")
|
||||
(sysctl-name "hw.ncpu")
|
||||
(sysctl-name "hw.nperflevels")
|
||||
; Chrome locks these CPU feature detection down a bit more tightly,
|
||||
; but mostly for fingerprinting concerns which isn't an issue for codex.
|
||||
(sysctl-name-prefix "hw.optional.arm.")
|
||||
(sysctl-name-prefix "hw.optional.armv8_")
|
||||
(sysctl-name "hw.packages")
|
||||
(sysctl-name "hw.pagesize_compat")
|
||||
(sysctl-name "hw.pagesize")
|
||||
(sysctl-name "hw.physicalcpu")
|
||||
(sysctl-name "hw.physicalcpu_max")
|
||||
(sysctl-name "hw.tbfrequency_compat")
|
||||
(sysctl-name "hw.vectorunit")
|
||||
(sysctl-name "kern.hostname")
|
||||
(sysctl-name "kern.maxfilesperproc")
|
||||
(sysctl-name "kern.maxproc")
|
||||
(sysctl-name "kern.osproductversion")
|
||||
(sysctl-name "kern.osrelease")
|
||||
(sysctl-name "kern.ostype")
|
||||
(sysctl-name "kern.osvariant_status")
|
||||
(sysctl-name "kern.osversion")
|
||||
(sysctl-name "kern.secure_kernel")
|
||||
(sysctl-name "kern.usrstack64")
|
||||
(sysctl-name "kern.version")
|
||||
(sysctl-name "sysctl.proc_cputype")
|
||||
(sysctl-name "vm.loadavg")
|
||||
(sysctl-name-prefix "hw.perflevel")
|
||||
(sysctl-name-prefix "kern.proc.pgrp.")
|
||||
(sysctl-name-prefix "kern.proc.pid.")
|
||||
(sysctl-name-prefix "net.routetable.")
|
||||
)
|
||||
|
||||
; Allow Java to set CPU type grade when required
|
||||
(allow sysctl-write
|
||||
(sysctl-name "kern.grade_cputype"))
|
||||
|
||||
; IOKit
|
||||
(allow iokit-open
|
||||
(iokit-registry-entry-class "RootDomainUserClient")
|
||||
)
|
||||
|
||||
; needed to look up user info, see https://crbug.com/792228
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.system.opendirectoryd.libinfo")
|
||||
)
|
||||
|
||||
; Added on top of Chrome profile
|
||||
; Needed for python multiprocessing on MacOS for the SemLock
|
||||
(allow ipc-posix-sem)
|
||||
|
||||
(allow mach-lookup
|
||||
(global-name "com.apple.PowerManagement.control")
|
||||
)
|
||||
30
llmx-rs/core/src/seatbelt_network_policy.sbpl
Normal file
30
llmx-rs/core/src/seatbelt_network_policy.sbpl
Normal file
@@ -0,0 +1,30 @@
|
||||
; when network access is enabled, these policies are added after those in seatbelt_base_policy.sbpl
|
||||
; Ref https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/network.sb;drc=f8f264d5e4e7509c913f4c60c2639d15905a07e4
|
||||
|
||||
(allow network-outbound)
|
||||
(allow network-inbound)
|
||||
(allow system-socket)
|
||||
|
||||
(allow mach-lookup
|
||||
; Used to look up the _CS_DARWIN_USER_CACHE_DIR in the sandbox.
|
||||
(global-name "com.apple.bsd.dirhelper")
|
||||
(global-name "com.apple.system.opendirectoryd.membership")
|
||||
|
||||
; Communicate with the security server for TLS certificate information.
|
||||
(global-name "com.apple.SecurityServer")
|
||||
(global-name "com.apple.networkd")
|
||||
(global-name "com.apple.ocspd")
|
||||
(global-name "com.apple.trustd.agent")
|
||||
|
||||
; Read network configuration.
|
||||
(global-name "com.apple.SystemConfiguration.DNSConfiguration")
|
||||
(global-name "com.apple.SystemConfiguration.configd")
|
||||
)
|
||||
|
||||
(allow sysctl-read
|
||||
(sysctl-name-regex #"^net.routetable")
|
||||
)
|
||||
|
||||
(allow file-write*
|
||||
(subpath (param "DARWIN_USER_CACHE_DIR"))
|
||||
)
|
||||
434
llmx-rs/core/src/shell.rs
Normal file
434
llmx-rs/core/src/shell.rs
Normal file
@@ -0,0 +1,434 @@
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct ZshShell {
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) zshrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct BashShell {
|
||||
pub(crate) shell_path: String,
|
||||
pub(crate) bashrc_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct PowerShellConfig {
|
||||
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
|
||||
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub enum Shell {
|
||||
Zsh(ZshShell),
|
||||
Bash(BashShell),
|
||||
PowerShell(PowerShellConfig),
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl Shell {
|
||||
pub fn name(&self) -> Option<String> {
|
||||
match self {
|
||||
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().to_string()),
|
||||
Shell::PowerShell(ps) => Some(ps.exe.clone()),
|
||||
Shell::Unknown => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn detect_default_user_shell() -> Shell {
|
||||
use libc::getpwuid;
|
||||
use libc::getuid;
|
||||
use std::ffi::CStr;
|
||||
|
||||
unsafe {
|
||||
let uid = getuid();
|
||||
let pw = getpwuid(uid);
|
||||
|
||||
if !pw.is_null() {
|
||||
let shell_path = CStr::from_ptr((*pw).pw_shell)
|
||||
.to_string_lossy()
|
||||
.into_owned();
|
||||
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
|
||||
|
||||
if shell_path.ends_with("/zsh") {
|
||||
return Shell::Zsh(ZshShell {
|
||||
shell_path,
|
||||
zshrc_path: format!("{home_path}/.zshrc"),
|
||||
});
|
||||
}
|
||||
|
||||
if shell_path.ends_with("/bash") {
|
||||
return Shell::Bash(BashShell {
|
||||
shell_path,
|
||||
bashrc_path: format!("{home_path}/.bashrc"),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
detect_default_user_shell()
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
use tokio::process::Command;
|
||||
|
||||
// Prefer PowerShell 7+ (`pwsh`) if available, otherwise fall back to Windows PowerShell.
|
||||
let has_pwsh = Command::new("pwsh")
|
||||
.arg("-NoLogo")
|
||||
.arg("-NoProfile")
|
||||
.arg("-Command")
|
||||
.arg("$PSVersionTable.PSVersion.Major")
|
||||
.output()
|
||||
.await
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false);
|
||||
let bash_exe = if Command::new("bash.exe")
|
||||
.arg("--version")
|
||||
.stdin(std::process::Stdio::null())
|
||||
.output()
|
||||
.await
|
||||
.ok()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
which::which("bash.exe").ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if has_pwsh {
|
||||
Shell::PowerShell(PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: bash_exe,
|
||||
})
|
||||
} else {
|
||||
Shell::PowerShell(PowerShellConfig {
|
||||
exe: "powershell.exe".to_string(),
|
||||
bash_exe_fallback: bash_exe,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(all(not(target_os = "windows"), not(unix)))]
|
||||
pub async fn default_user_shell() -> Shell {
|
||||
Shell::Unknown
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(unix)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Command;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_current_shell_detects_zsh() {
|
||||
let shell = Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg("echo $SHELL")
|
||||
.output()
|
||||
.unwrap();
|
||||
|
||||
let home = std::env::var("HOME").unwrap();
|
||||
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
|
||||
if shell_path.ends_with("/zsh") {
|
||||
assert_eq!(
|
||||
default_user_shell().await,
|
||||
Shell::Zsh(ZshShell {
|
||||
shell_path: shell_path.to_string(),
|
||||
zshrc_path: format!("{home}/.zshrc",),
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_bash_escaping_and_execution() {
|
||||
let shell_path = "/bin/bash";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source BASHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let bashrc_path = temp_home.path().join(".bashrc");
|
||||
std::fs::write(
|
||||
&bashrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let command = expected_cmd
|
||||
.iter()
|
||||
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: command.clone(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
temp_home.path(),
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout.text, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "macos")]
|
||||
mod macos_tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_run_with_profile_escaping_and_execution() {
|
||||
let shell_path = "/bin/zsh";
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["myecho"],
|
||||
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
|
||||
Some("It works!\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-c", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
(
|
||||
vec!["bash", "-lc", "echo 'single' \"double\""],
|
||||
vec![
|
||||
shell_path,
|
||||
"-lc",
|
||||
"source ZSHRC_PATH && (echo 'single' \"double\")",
|
||||
],
|
||||
Some("single double\n"),
|
||||
),
|
||||
];
|
||||
for (input, expected_cmd, expected_output) in cases {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::process_exec_tool_call;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
let temp_home = tempfile::tempdir().unwrap();
|
||||
let zshrc_path = temp_home.path().join(".zshrc");
|
||||
std::fs::write(
|
||||
&zshrc_path,
|
||||
r#"
|
||||
set -x
|
||||
function myecho {
|
||||
echo 'It works!'
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let command = expected_cmd
|
||||
.iter()
|
||||
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output = process_exec_tool_call(
|
||||
ExecParams {
|
||||
command: command.clone(),
|
||||
cwd: PathBuf::from(temp_home.path()),
|
||||
timeout_ms: None,
|
||||
env: HashMap::from([(
|
||||
"HOME".to_string(),
|
||||
temp_home.path().to_str().unwrap().to_string(),
|
||||
)]),
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
},
|
||||
SandboxType::None,
|
||||
&SandboxPolicy::DangerFullAccess,
|
||||
temp_home.path(),
|
||||
&None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
|
||||
if let Some(expected) = expected_output {
|
||||
assert_eq!(
|
||||
output.stdout.text, expected,
|
||||
"input: {input:?} output: {output:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(target_os = "windows")]
|
||||
mod tests_windows {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_default_shell_invocation_powershell() {
|
||||
use std::path::PathBuf;
|
||||
|
||||
let cases = vec![
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: None,
|
||||
},
|
||||
vec!["bash", "-lc", "echo hello"],
|
||||
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "powershell.exe".to_string(),
|
||||
bash_exe_fallback: None,
|
||||
},
|
||||
vec!["bash", "-lc", "echo hello"],
|
||||
vec!["powershell.exe", "-NoProfile", "-Command", "echo hello"],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
|
||||
},
|
||||
vec!["bash", "-lc", "echo hello"],
|
||||
vec!["bash.exe", "-lc", "echo hello"],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
|
||||
},
|
||||
vec![
|
||||
"bash",
|
||||
"-lc",
|
||||
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
|
||||
],
|
||||
vec![
|
||||
"bash.exe",
|
||||
"-lc",
|
||||
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
|
||||
],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
|
||||
},
|
||||
vec!["echo", "hello"],
|
||||
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "pwsh.exe".to_string(),
|
||||
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
|
||||
},
|
||||
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
|
||||
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
|
||||
),
|
||||
(
|
||||
PowerShellConfig {
|
||||
exe: "powershell.exe".to_string(),
|
||||
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
|
||||
},
|
||||
vec![
|
||||
"codex-mcp-server.exe",
|
||||
"--codex-run-as-apply-patch",
|
||||
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
|
||||
],
|
||||
vec![
|
||||
"codex-mcp-server.exe",
|
||||
"--codex-run-as-apply-patch",
|
||||
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
|
||||
],
|
||||
),
|
||||
];
|
||||
|
||||
for (config, input, expected_cmd) in cases {
|
||||
let command = expected_cmd
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// These tests assert the final command for each scenario now that the helper
|
||||
// has been removed. The inputs remain to document the original coverage.
|
||||
let expected = expected_cmd
|
||||
.iter()
|
||||
.map(|s| (*s).to_string())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(command, expected, "input: {input:?} config: {config:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
117
llmx-rs/core/src/spawn.rs
Normal file
117
llmx-rs/core/src/spawn.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::protocol::SandboxPolicy;
|
||||
|
||||
/// Experimental environment variable that will be set to some non-empty value
|
||||
/// if both of the following are true:
|
||||
///
|
||||
/// 1. The process was spawned by Codex as part of a shell tool call.
|
||||
/// 2. SandboxPolicy.has_full_network_access() was false for the tool call.
|
||||
///
|
||||
/// We may try to have just one environment variable for all sandboxing
|
||||
/// attributes, so this may change in the future.
|
||||
pub const CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR: &str = "CODEX_SANDBOX_NETWORK_DISABLED";
|
||||
|
||||
/// Should be set when the process is spawned under a sandbox. Currently, the
|
||||
/// value is "seatbelt" for macOS, but it may change in the future to
|
||||
/// accommodate sandboxing configuration and other sandboxing mechanisms.
|
||||
pub const CODEX_SANDBOX_ENV_VAR: &str = "CODEX_SANDBOX";
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StdioPolicy {
|
||||
RedirectForShellTool,
|
||||
Inherit,
|
||||
}
|
||||
|
||||
/// Spawns the appropriate child process for the ExecParams and SandboxPolicy,
|
||||
/// ensuring the args and environment variables used to create the `Command`
|
||||
/// (and `Child`) honor the configuration.
|
||||
///
|
||||
/// For now, we take `SandboxPolicy` as a parameter to spawn_child() because
|
||||
/// we need to determine whether to set the
|
||||
/// `CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR` environment variable.
|
||||
pub(crate) async fn spawn_child_async(
|
||||
program: PathBuf,
|
||||
args: Vec<String>,
|
||||
#[cfg_attr(not(unix), allow(unused_variables))] arg0: Option<&str>,
|
||||
cwd: PathBuf,
|
||||
sandbox_policy: &SandboxPolicy,
|
||||
stdio_policy: StdioPolicy,
|
||||
env: HashMap<String, String>,
|
||||
) -> std::io::Result<Child> {
|
||||
trace!(
|
||||
"spawn_child_async: {program:?} {args:?} {arg0:?} {cwd:?} {sandbox_policy:?} {stdio_policy:?} {env:?}"
|
||||
);
|
||||
|
||||
let mut cmd = Command::new(&program);
|
||||
#[cfg(unix)]
|
||||
cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from));
|
||||
cmd.args(args);
|
||||
cmd.current_dir(cwd);
|
||||
cmd.env_clear();
|
||||
cmd.envs(env);
|
||||
|
||||
if !sandbox_policy.has_full_network_access() {
|
||||
cmd.env(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR, "1");
|
||||
}
|
||||
|
||||
// If this Codex process dies (including being killed via SIGKILL), we want
|
||||
// any child processes that were spawned as part of a `"shell"` tool call
|
||||
// to also be terminated.
|
||||
|
||||
#[cfg(unix)]
|
||||
unsafe {
|
||||
#[cfg(target_os = "linux")]
|
||||
let parent_pid = libc::getpid();
|
||||
cmd.pre_exec(move || {
|
||||
if libc::setpgid(0, 0) == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
// This relies on prctl(2), so it only works on Linux.
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
// This prctl call effectively requests, "deliver SIGTERM when my
|
||||
// current parent dies."
|
||||
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
|
||||
return Err(std::io::Error::last_os_error());
|
||||
}
|
||||
|
||||
// Though if there was a race condition and this pre_exec() block is
|
||||
// run _after_ the parent (i.e., the Codex process) has already
|
||||
// exited, then parent will be the closest configured "subreaper"
|
||||
// ancestor process, or PID 1 (init). If the Codex process has exited
|
||||
// already, so should the child process.
|
||||
if libc::getppid() != parent_pid {
|
||||
libc::raise(libc::SIGTERM);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
}
|
||||
|
||||
match stdio_policy {
|
||||
StdioPolicy::RedirectForShellTool => {
|
||||
// Do not create a file descriptor for stdin because otherwise some
|
||||
// commands may hang forever waiting for input. For example, ripgrep has
|
||||
// a heuristic where it may try to read from stdin as explained here:
|
||||
// https://github.com/BurntSushi/ripgrep/blob/e2362d4d5185d02fa857bf381e7bd52e66fafc73/crates/core/flags/hiargs.rs#L1101-L1103
|
||||
cmd.stdin(Stdio::null());
|
||||
|
||||
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
|
||||
}
|
||||
StdioPolicy::Inherit => {
|
||||
// Inherit stdin, stdout, and stderr from the parent process.
|
||||
cmd.stdin(Stdio::inherit())
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit());
|
||||
}
|
||||
}
|
||||
|
||||
cmd.kill_on_drop(true).spawn()
|
||||
}
|
||||
9
llmx-rs/core/src/state/mod.rs
Normal file
9
llmx-rs/core/src/state/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod service;
|
||||
mod session;
|
||||
mod turn;
|
||||
|
||||
pub(crate) use service::SessionServices;
|
||||
pub(crate) use session::SessionState;
|
||||
pub(crate) use turn::ActiveTurn;
|
||||
pub(crate) use turn::RunningTask;
|
||||
pub(crate) use turn::TaskKind;
|
||||
22
llmx-rs/core/src/state/service.rs
Normal file
22
llmx-rs/core/src/state/service.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::RolloutRecorder;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::user_notification::UserNotifier;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub(crate) struct SessionServices {
|
||||
pub(crate) mcp_connection_manager: McpConnectionManager,
|
||||
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
|
||||
pub(crate) notifier: UserNotifier,
|
||||
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
|
||||
pub(crate) user_shell: crate::shell::Shell,
|
||||
pub(crate) show_raw_agent_reasoning: bool,
|
||||
pub(crate) auth_manager: Arc<AuthManager>,
|
||||
pub(crate) otel_event_manager: OtelEventManager,
|
||||
pub(crate) tool_approvals: Mutex<ApprovalStore>,
|
||||
}
|
||||
71
llmx-rs/core/src/state/session.rs
Normal file
71
llmx-rs/core/src/state/session.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
//! Session-wide mutable state.
|
||||
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::codex::SessionConfiguration;
|
||||
use crate::context_manager::ContextManager;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
|
||||
/// Persistent, session-scoped state previously stored directly on `Session`.
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) session_configuration: SessionConfiguration,
|
||||
pub(crate) history: ContextManager,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Create a new session state mirroring previous `State::default()` semantics.
|
||||
pub(crate) fn new(session_configuration: SessionConfiguration) -> Self {
|
||||
Self {
|
||||
session_configuration,
|
||||
history: ContextManager::new(),
|
||||
latest_rate_limits: None,
|
||||
}
|
||||
}
|
||||
|
||||
// History helpers
|
||||
pub(crate) fn record_items<I>(&mut self, items: I)
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: std::ops::Deref<Target = ResponseItem>,
|
||||
{
|
||||
self.history.record_items(items)
|
||||
}
|
||||
|
||||
pub(crate) fn clone_history(&self) -> ContextManager {
|
||||
self.history.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
|
||||
self.history.replace(items);
|
||||
}
|
||||
|
||||
// Token/rate limit helpers
|
||||
pub(crate) fn update_token_info_from_usage(
|
||||
&mut self,
|
||||
usage: &TokenUsage,
|
||||
model_context_window: Option<i64>,
|
||||
) {
|
||||
self.history.update_token_info(usage, model_context_window);
|
||||
}
|
||||
|
||||
pub(crate) fn token_info(&self) -> Option<TokenUsageInfo> {
|
||||
self.history.token_info()
|
||||
}
|
||||
|
||||
pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
|
||||
self.latest_rate_limits = Some(snapshot);
|
||||
}
|
||||
|
||||
pub(crate) fn token_info_and_rate_limits(
|
||||
&self,
|
||||
) -> (Option<TokenUsageInfo>, Option<RateLimitSnapshot>) {
|
||||
(self.token_info(), self.latest_rate_limits.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
|
||||
self.history.set_token_usage_full(context_window);
|
||||
}
|
||||
}
|
||||
115
llmx-rs/core/src/state/turn.rs
Normal file
115
llmx-rs/core/src/state/turn.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
//! Turn-scoped state and active turn metadata scaffolding.
|
||||
|
||||
use indexmap::IndexMap;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::ReviewDecision;
|
||||
use crate::tasks::SessionTask;
|
||||
|
||||
/// Metadata about the currently running turn.
|
||||
pub(crate) struct ActiveTurn {
|
||||
pub(crate) tasks: IndexMap<String, RunningTask>,
|
||||
pub(crate) turn_state: Arc<Mutex<TurnState>>,
|
||||
}
|
||||
|
||||
impl Default for ActiveTurn {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tasks: IndexMap::new(),
|
||||
turn_state: Arc::new(Mutex::new(TurnState::default())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub(crate) enum TaskKind {
|
||||
Regular,
|
||||
Review,
|
||||
Compact,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct RunningTask {
|
||||
pub(crate) done: Arc<Notify>,
|
||||
pub(crate) kind: TaskKind,
|
||||
pub(crate) task: Arc<dyn SessionTask>,
|
||||
pub(crate) cancellation_token: CancellationToken,
|
||||
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||
pub(crate) turn_context: Arc<TurnContext>,
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
pub(crate) fn add_task(&mut self, task: RunningTask) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
self.tasks.insert(sub_id, task);
|
||||
}
|
||||
|
||||
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
|
||||
self.tasks.swap_remove(sub_id);
|
||||
self.tasks.is_empty()
|
||||
}
|
||||
|
||||
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
|
||||
self.tasks.drain(..).map(|(_, task)| task).collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state for a single turn.
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TurnState {
|
||||
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
|
||||
pending_input: Vec<ResponseInputItem>,
|
||||
}
|
||||
|
||||
impl TurnState {
|
||||
pub(crate) fn insert_pending_approval(
|
||||
&mut self,
|
||||
key: String,
|
||||
tx: oneshot::Sender<ReviewDecision>,
|
||||
) -> Option<oneshot::Sender<ReviewDecision>> {
|
||||
self.pending_approvals.insert(key, tx)
|
||||
}
|
||||
|
||||
pub(crate) fn remove_pending_approval(
|
||||
&mut self,
|
||||
key: &str,
|
||||
) -> Option<oneshot::Sender<ReviewDecision>> {
|
||||
self.pending_approvals.remove(key)
|
||||
}
|
||||
|
||||
pub(crate) fn clear_pending(&mut self) {
|
||||
self.pending_approvals.clear();
|
||||
self.pending_input.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) {
|
||||
self.pending_input.push(input);
|
||||
}
|
||||
|
||||
pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
|
||||
if self.pending_input.is_empty() {
|
||||
Vec::with_capacity(0)
|
||||
} else {
|
||||
let mut ret = Vec::new();
|
||||
std::mem::swap(&mut ret, &mut self.pending_input);
|
||||
ret
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ActiveTurn {
|
||||
/// Clear any pending approvals and input buffered for the current turn.
|
||||
pub(crate) async fn clear_pending(&self) {
|
||||
let mut ts = self.turn_state.lock().await;
|
||||
ts.clear_pending();
|
||||
}
|
||||
}
|
||||
32
llmx-rs/core/src/tasks/compact.rs
Normal file
32
llmx-rs/core/src/tasks/compact.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::compact;
|
||||
use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct CompactTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for CompactTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Compact
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
_cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
compact::run_compact_task(session.clone_session(), ctx, input).await
|
||||
}
|
||||
}
|
||||
110
llmx-rs/core/src/tasks/ghost_snapshot.rs
Normal file
110
llmx-rs/core/src/tasks/ghost_snapshot.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use crate::codex::TurnContext;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use async_trait::async_trait;
|
||||
use codex_git::CreateGhostCommitOptions;
|
||||
use codex_git::GitToolingError;
|
||||
use codex_git::create_ghost_commit;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use codex_utils_readiness::Readiness;
|
||||
use codex_utils_readiness::Token;
|
||||
use std::sync::Arc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct GhostSnapshotTask {
|
||||
token: Token,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for GhostSnapshotTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
tokio::task::spawn(async move {
|
||||
let token = self.token;
|
||||
let ctx_for_task = Arc::clone(&ctx);
|
||||
let cancelled = tokio::select! {
|
||||
_ = cancellation_token.cancelled() => true,
|
||||
_ = async {
|
||||
let repo_path = ctx_for_task.cwd.clone();
|
||||
// Required to run in a dedicated blocking pool.
|
||||
match tokio::task::spawn_blocking(move || {
|
||||
let options = CreateGhostCommitOptions::new(&repo_path);
|
||||
create_ghost_commit(&options)
|
||||
})
|
||||
.await
|
||||
{
|
||||
Ok(Ok(ghost_commit)) => {
|
||||
info!("ghost snapshot blocking task finished");
|
||||
session
|
||||
.session
|
||||
.record_conversation_items(&ctx, &[ResponseItem::GhostSnapshot {
|
||||
ghost_commit: ghost_commit.clone(),
|
||||
}])
|
||||
.await;
|
||||
info!("ghost commit captured: {}", ghost_commit.id());
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"failed to capture ghost snapshot: {err}"
|
||||
);
|
||||
let message = match err {
|
||||
GitToolingError::NotAGitRepository { .. } => {
|
||||
"Snapshots disabled: current directory is not a Git repository."
|
||||
.to_string()
|
||||
}
|
||||
_ => format!("Snapshots disabled after ghost snapshot error: {err}."),
|
||||
};
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
sub_id = ctx_for_task.sub_id.as_str(),
|
||||
"ghost snapshot task panicked: {err}"
|
||||
);
|
||||
let message =
|
||||
format!("Snapshots disabled after ghost snapshot panic: {err}.");
|
||||
session
|
||||
.session
|
||||
.notify_background_event(&ctx_for_task, message)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
} => false,
|
||||
};
|
||||
|
||||
if cancelled {
|
||||
info!("ghost snapshot task cancelled");
|
||||
}
|
||||
|
||||
match ctx.tool_call_gate.mark_ready(token).await {
|
||||
Ok(true) => info!("ghost snapshot gate marked ready"),
|
||||
Ok(false) => warn!("ghost snapshot gate already ready"),
|
||||
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
|
||||
}
|
||||
});
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl GhostSnapshotTask {
|
||||
pub(crate) fn new(token: Token) -> Self {
|
||||
Self { token }
|
||||
}
|
||||
}
|
||||
225
llmx-rs/core/src/tasks/mod.rs
Normal file
225
llmx-rs/core/src/tasks/mod.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
mod compact;
|
||||
mod ghost_snapshot;
|
||||
mod regular;
|
||||
mod review;
|
||||
mod undo;
|
||||
mod user_shell;
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio::select;
|
||||
use tokio::sync::Notify;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
use tracing::trace;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::TaskCompleteEvent;
|
||||
use crate::protocol::TurnAbortReason;
|
||||
use crate::protocol::TurnAbortedEvent;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::RunningTask;
|
||||
use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
pub(crate) use compact::CompactTask;
|
||||
pub(crate) use ghost_snapshot::GhostSnapshotTask;
|
||||
pub(crate) use regular::RegularTask;
|
||||
pub(crate) use review::ReviewTask;
|
||||
pub(crate) use undo::UndoTask;
|
||||
pub(crate) use user_shell::UserShellCommandTask;
|
||||
|
||||
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
|
||||
|
||||
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct SessionTaskContext {
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
impl SessionTaskContext {
|
||||
pub(crate) fn new(session: Arc<Session>) -> Self {
|
||||
Self { session }
|
||||
}
|
||||
|
||||
pub(crate) fn clone_session(&self) -> Arc<Session> {
|
||||
Arc::clone(&self.session)
|
||||
}
|
||||
|
||||
pub(crate) fn auth_manager(&self) -> Arc<AuthManager> {
|
||||
Arc::clone(&self.session.services.auth_manager)
|
||||
}
|
||||
}
|
||||
|
||||
/// Async task that drives a [`Session`] turn.
|
||||
///
|
||||
/// Implementations encapsulate a specific Codex workflow (regular chat,
|
||||
/// reviews, ghost snapshots, etc.). Each task instance is owned by a
|
||||
/// [`Session`] and executed on a background Tokio task. The trait is
|
||||
/// intentionally small: implementers identify themselves via
|
||||
/// [`SessionTask::kind`], perform their work in [`SessionTask::run`], and may
|
||||
/// release resources in [`SessionTask::abort`].
|
||||
#[async_trait]
|
||||
pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
/// Describes the type of work the task performs so the session can
|
||||
/// surface it in telemetry and UI.
|
||||
fn kind(&self) -> TaskKind;
|
||||
|
||||
/// Executes the task until completion or cancellation.
|
||||
///
|
||||
/// Implementations typically stream protocol events using `session` and
|
||||
/// `ctx`, returning an optional final agent message when finished. The
|
||||
/// provided `cancellation_token` is cancelled when the session requests an
|
||||
/// abort; implementers should watch for it and terminate quickly once it
|
||||
/// fires. Returning [`Some`] yields a final message that
|
||||
/// [`Session::on_task_finished`] will emit to the client.
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String>;
|
||||
|
||||
/// Gives the task a chance to perform cleanup after an abort.
|
||||
///
|
||||
/// The default implementation is a no-op; override this if additional
|
||||
/// teardown or notifications are required once
|
||||
/// [`Session::abort_all_tasks`] cancels the task.
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
let _ = (session, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
impl Session {
|
||||
pub async fn spawn_task<T: SessionTask>(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
task: T,
|
||||
) {
|
||||
self.abort_all_tasks(TurnAbortReason::Replaced).await;
|
||||
|
||||
let task: Arc<dyn SessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let done = Arc::new(Notify::new());
|
||||
|
||||
let done_clone = Arc::clone(&done);
|
||||
let handle = {
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
let ctx = Arc::clone(&turn_context);
|
||||
let task_for_run = Arc::clone(&task);
|
||||
let task_cancellation_token = cancellation_token.child_token();
|
||||
tokio::spawn(async move {
|
||||
let ctx_for_finish = Arc::clone(&ctx);
|
||||
let last_agent_message = task_for_run
|
||||
.run(
|
||||
Arc::clone(&session_ctx),
|
||||
ctx,
|
||||
input,
|
||||
task_cancellation_token.child_token(),
|
||||
)
|
||||
.await;
|
||||
session_ctx.clone_session().flush_rollout().await;
|
||||
if !task_cancellation_token.is_cancelled() {
|
||||
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
|
||||
let sess = session_ctx.clone_session();
|
||||
sess.on_task_finished(ctx_for_finish, last_agent_message)
|
||||
.await;
|
||||
}
|
||||
done_clone.notify_waiters();
|
||||
})
|
||||
};
|
||||
|
||||
let running_task = RunningTask {
|
||||
done,
|
||||
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||
kind: task_kind,
|
||||
task,
|
||||
cancellation_token,
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
};
|
||||
self.register_new_active_task(running_task).await;
|
||||
}
|
||||
|
||||
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
|
||||
for task in self.take_all_running_tasks().await {
|
||||
self.handle_task_abort(task, reason.clone()).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn on_task_finished(
|
||||
self: &Arc<Self>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
last_agent_message: Option<String>,
|
||||
) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
if let Some(at) = active.as_mut()
|
||||
&& at.remove_task(&turn_context.sub_id)
|
||||
{
|
||||
*active = None;
|
||||
}
|
||||
drop(active);
|
||||
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
|
||||
self.send_event(turn_context.as_ref(), event).await;
|
||||
}
|
||||
|
||||
async fn register_new_active_task(&self, task: RunningTask) {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let mut turn = ActiveTurn::default();
|
||||
turn.add_task(task);
|
||||
*active = Some(turn);
|
||||
}
|
||||
|
||||
async fn take_all_running_tasks(&self) -> Vec<RunningTask> {
|
||||
let mut active = self.active_turn.lock().await;
|
||||
match active.take() {
|
||||
Some(mut at) => {
|
||||
at.clear_pending().await;
|
||||
|
||||
at.drain_tasks()
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
|
||||
let sub_id = task.turn_context.sub_id.clone();
|
||||
if task.cancellation_token.is_cancelled() {
|
||||
return;
|
||||
}
|
||||
|
||||
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
|
||||
task.cancellation_token.cancel();
|
||||
let session_task = task.task;
|
||||
|
||||
select! {
|
||||
_ = task.done.notified() => {
|
||||
},
|
||||
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
|
||||
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
|
||||
}
|
||||
}
|
||||
|
||||
task.handle.abort();
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task
|
||||
.abort(session_ctx, Arc::clone(&task.turn_context))
|
||||
.await;
|
||||
|
||||
let event = EventMsg::TurnAborted(TurnAbortedEvent { reason });
|
||||
self.send_event(task.turn_context.as_ref(), event).await;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
33
llmx-rs/core/src/tasks/regular.rs
Normal file
33
llmx-rs/core/src/tasks/regular.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex::run_task;
|
||||
use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct RegularTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for RegularTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
run_task(sess, ctx, input, cancellation_token).await
|
||||
}
|
||||
}
|
||||
210
llmx-rs/core/src/tasks/review.rs
Normal file
210
llmx-rs/core/src/tasks/review.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::AgentMessageContentDeltaEvent;
|
||||
use codex_protocol::protocol::AgentMessageDeltaEvent;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ExitedReviewModeEvent;
|
||||
use codex_protocol::protocol::ItemCompletedEvent;
|
||||
use codex_protocol::protocol::ReviewOutputEvent;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::codex_delegate::run_codex_conversation_one_shot;
|
||||
use crate::review_format::format_review_findings_block;
|
||||
use crate::state::TaskKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone, Copy, Default)]
|
||||
pub(crate) struct ReviewTask;
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for ReviewTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Review
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
// Start sub-codex conversation and get the receiver for events.
|
||||
let output = match start_review_conversation(
|
||||
session.clone(),
|
||||
ctx.clone(),
|
||||
input,
|
||||
cancellation_token.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Some(receiver) => process_review_events(session.clone(), ctx.clone(), receiver).await,
|
||||
None => None,
|
||||
};
|
||||
if !cancellation_token.is_cancelled() {
|
||||
exit_review_mode(session.clone_session(), output.clone(), ctx.clone()).await;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
|
||||
exit_review_mode(session.clone_session(), None, ctx).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_review_conversation(
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<async_channel::Receiver<Event>> {
|
||||
let config = ctx.client.config();
|
||||
let mut sub_agent_config = config.as_ref().clone();
|
||||
// Run with only reviewer rubric — drop outer user_instructions
|
||||
sub_agent_config.user_instructions = None;
|
||||
// Avoid loading project docs; reviewer only needs findings
|
||||
sub_agent_config.project_doc_max_bytes = 0;
|
||||
// Carry over review-only feature restrictions so the delegate cannot
|
||||
// re-enable blocked tools (web search, view image).
|
||||
sub_agent_config
|
||||
.features
|
||||
.disable(crate::features::Feature::WebSearchRequest)
|
||||
.disable(crate::features::Feature::ViewImageTool);
|
||||
|
||||
// Set explicit review rubric for the sub-agent
|
||||
sub_agent_config.base_instructions = Some(crate::REVIEW_PROMPT.to_string());
|
||||
(run_codex_conversation_one_shot(
|
||||
sub_agent_config,
|
||||
session.auth_manager(),
|
||||
input,
|
||||
session.clone_session(),
|
||||
ctx.clone(),
|
||||
cancellation_token,
|
||||
None,
|
||||
)
|
||||
.await)
|
||||
.ok()
|
||||
.map(|io| io.rx_event)
|
||||
}
|
||||
|
||||
async fn process_review_events(
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
receiver: async_channel::Receiver<Event>,
|
||||
) -> Option<ReviewOutputEvent> {
|
||||
let mut prev_agent_message: Option<Event> = None;
|
||||
while let Ok(event) = receiver.recv().await {
|
||||
match event.clone().msg {
|
||||
EventMsg::AgentMessage(_) => {
|
||||
if let Some(prev) = prev_agent_message.take() {
|
||||
session
|
||||
.clone_session()
|
||||
.send_event(ctx.as_ref(), prev.msg)
|
||||
.await;
|
||||
}
|
||||
prev_agent_message = Some(event);
|
||||
}
|
||||
// Suppress ItemCompleted only for assistant messages: forwarding it
|
||||
// would trigger legacy AgentMessage via as_legacy_events(), which this
|
||||
// review flow intentionally hides in favor of structured output.
|
||||
EventMsg::ItemCompleted(ItemCompletedEvent {
|
||||
item: TurnItem::AgentMessage(_),
|
||||
..
|
||||
})
|
||||
| EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { .. })
|
||||
| EventMsg::AgentMessageContentDelta(AgentMessageContentDeltaEvent { .. }) => {}
|
||||
EventMsg::TaskComplete(task_complete) => {
|
||||
// Parse review output from the last agent message (if present).
|
||||
let out = task_complete
|
||||
.last_agent_message
|
||||
.as_deref()
|
||||
.map(parse_review_output_event);
|
||||
return out;
|
||||
}
|
||||
EventMsg::TurnAborted(_) => {
|
||||
// Cancellation or abort: consumer will finalize with None.
|
||||
return None;
|
||||
}
|
||||
other => {
|
||||
session
|
||||
.clone_session()
|
||||
.send_event(ctx.as_ref(), other)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Channel closed without TaskComplete: treat as interrupted.
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse a ReviewOutputEvent from a text blob returned by the reviewer model.
|
||||
/// If the text is valid JSON matching ReviewOutputEvent, deserialize it.
|
||||
/// Otherwise, attempt to extract the first JSON object substring and parse it.
|
||||
/// If parsing still fails, return a structured fallback carrying the plain text
|
||||
/// in `overall_explanation`.
|
||||
fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
|
||||
if let Ok(ev) = serde_json::from_str::<ReviewOutputEvent>(text) {
|
||||
return ev;
|
||||
}
|
||||
if let (Some(start), Some(end)) = (text.find('{'), text.rfind('}'))
|
||||
&& start < end
|
||||
&& let Some(slice) = text.get(start..=end)
|
||||
&& let Ok(ev) = serde_json::from_str::<ReviewOutputEvent>(slice)
|
||||
{
|
||||
return ev;
|
||||
}
|
||||
ReviewOutputEvent {
|
||||
overall_explanation: text.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
|
||||
/// and records a developer message with the review output.
|
||||
pub(crate) async fn exit_review_mode(
|
||||
session: Arc<Session>,
|
||||
review_output: Option<ReviewOutputEvent>,
|
||||
ctx: Arc<TurnContext>,
|
||||
) {
|
||||
let user_message = if let Some(out) = review_output.clone() {
|
||||
let mut findings_str = String::new();
|
||||
let text = out.overall_explanation.trim();
|
||||
if !text.is_empty() {
|
||||
findings_str.push_str(text);
|
||||
}
|
||||
if !out.findings.is_empty() {
|
||||
let block = format_review_findings_block(&out.findings, None);
|
||||
findings_str.push_str(&format!("\n{block}"));
|
||||
}
|
||||
crate::client_common::REVIEW_EXIT_SUCCESS_TMPL.replace("{results}", &findings_str)
|
||||
} else {
|
||||
crate::client_common::REVIEW_EXIT_INTERRUPTED_TMPL.to_string()
|
||||
};
|
||||
|
||||
session
|
||||
.record_conversation_items(
|
||||
&ctx,
|
||||
&[ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: user_message }],
|
||||
}],
|
||||
)
|
||||
.await;
|
||||
session
|
||||
.send_event(
|
||||
ctx.as_ref(),
|
||||
EventMsg::ExitedReviewMode(ExitedReviewModeEvent { review_output }),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
117
llmx-rs/core/src/tasks/undo.rs
Normal file
117
llmx-rs/core/src/tasks/undo.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::UndoCompletedEvent;
|
||||
use crate::protocol::UndoStartedEvent;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tasks::SessionTask;
|
||||
use crate::tasks::SessionTaskContext;
|
||||
use async_trait::async_trait;
|
||||
use codex_git::restore_ghost_commit;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
pub(crate) struct UndoTask;
|
||||
|
||||
impl UndoTask {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for UndoTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
ctx: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let sess = session.clone_session();
|
||||
sess.send_event(
|
||||
ctx.as_ref(),
|
||||
EventMsg::UndoStarted(UndoStartedEvent {
|
||||
message: Some("Undo in progress...".to_string()),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
if cancellation_token.is_cancelled() {
|
||||
sess.send_event(
|
||||
ctx.as_ref(),
|
||||
EventMsg::UndoCompleted(UndoCompletedEvent {
|
||||
success: false,
|
||||
message: Some("Undo cancelled.".to_string()),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut history = sess.clone_history().await;
|
||||
let mut items = history.get_history();
|
||||
let mut completed = UndoCompletedEvent {
|
||||
success: false,
|
||||
message: None,
|
||||
};
|
||||
|
||||
let Some((idx, ghost_commit)) =
|
||||
items
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find_map(|(idx, item)| match item {
|
||||
ResponseItem::GhostSnapshot { ghost_commit } => {
|
||||
Some((idx, ghost_commit.clone()))
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
else {
|
||||
completed.message = Some("No ghost snapshot available to undo.".to_string());
|
||||
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
|
||||
.await;
|
||||
return None;
|
||||
};
|
||||
|
||||
let commit_id = ghost_commit.id().to_string();
|
||||
let repo_path = ctx.cwd.clone();
|
||||
let restore_result =
|
||||
tokio::task::spawn_blocking(move || restore_ghost_commit(&repo_path, &ghost_commit))
|
||||
.await;
|
||||
|
||||
match restore_result {
|
||||
Ok(Ok(())) => {
|
||||
items.remove(idx);
|
||||
sess.replace_history(items).await;
|
||||
let short_id: String = commit_id.chars().take(7).collect();
|
||||
info!(commit_id = commit_id, "Undo restored ghost snapshot");
|
||||
completed.success = true;
|
||||
completed.message = Some(format!("Undo restored snapshot {short_id}."));
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
let message = format!("Failed to restore snapshot {commit_id}: {err}");
|
||||
warn!("{message}");
|
||||
completed.message = Some(message);
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Failed to restore snapshot {commit_id}: {err}");
|
||||
error!("{message}");
|
||||
completed.message = Some(message);
|
||||
}
|
||||
}
|
||||
|
||||
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
|
||||
.await;
|
||||
None
|
||||
}
|
||||
}
|
||||
211
llmx-rs/core/src/tasks/user_shell.rs
Normal file
211
llmx-rs/core/src/tasks/user_shell.rs
Normal file
@@ -0,0 +1,211 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_async_utils::CancelErr;
|
||||
use codex_async_utils::OrCancelExt;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::exec::SandboxType;
|
||||
use crate::exec::StdoutStream;
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::exec::execute_exec_env;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandBeginEvent;
|
||||
use crate::protocol::ExecCommandEndEvent;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use crate::protocol::TaskStartedEvent;
|
||||
use crate::sandboxing::ExecEnv;
|
||||
use crate::state::TaskKind;
|
||||
use crate::tools::format_exec_output_str;
|
||||
use crate::user_shell_command::user_shell_command_record_item;
|
||||
|
||||
use super::SessionTask;
|
||||
use super::SessionTaskContext;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct UserShellCommandTask {
|
||||
command: String,
|
||||
}
|
||||
|
||||
impl UserShellCommandTask {
|
||||
pub(crate) fn new(command: String) -> Self {
|
||||
Self { command }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SessionTask for UserShellCommandTask {
|
||||
fn kind(&self) -> TaskKind {
|
||||
TaskKind::Regular
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
_input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
let event = EventMsg::TaskStarted(TaskStartedEvent {
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
});
|
||||
let session = session.clone_session();
|
||||
session.send_event(turn_context.as_ref(), event).await;
|
||||
|
||||
// Execute the user's script under their default shell when known; this
|
||||
// allows commands that use shell features (pipes, &&, redirects, etc.).
|
||||
// We do not source rc files or otherwise reformat the script.
|
||||
let shell_invocation = match session.user_shell() {
|
||||
crate::shell::Shell::Zsh(zsh) => vec![
|
||||
zsh.shell_path.clone(),
|
||||
"-lc".to_string(),
|
||||
self.command.clone(),
|
||||
],
|
||||
crate::shell::Shell::Bash(bash) => vec![
|
||||
bash.shell_path.clone(),
|
||||
"-lc".to_string(),
|
||||
self.command.clone(),
|
||||
],
|
||||
crate::shell::Shell::PowerShell(ps) => vec![
|
||||
ps.exe.clone(),
|
||||
"-NoProfile".to_string(),
|
||||
"-Command".to_string(),
|
||||
self.command.clone(),
|
||||
],
|
||||
crate::shell::Shell::Unknown => {
|
||||
shlex::split(&self.command).unwrap_or_else(|| vec![self.command.clone()])
|
||||
}
|
||||
};
|
||||
|
||||
let call_id = Uuid::new_v4().to_string();
|
||||
let raw_command = self.command.clone();
|
||||
|
||||
let parsed_cmd = parse_command(&shell_invocation);
|
||||
session
|
||||
.send_event(
|
||||
turn_context.as_ref(),
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
command: shell_invocation.clone(),
|
||||
cwd: turn_context.cwd.clone(),
|
||||
parsed_cmd,
|
||||
is_user_shell_command: true,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let exec_env = ExecEnv {
|
||||
command: shell_invocation,
|
||||
cwd: turn_context.cwd.clone(),
|
||||
env: create_env(&turn_context.shell_environment_policy),
|
||||
timeout_ms: None,
|
||||
sandbox: SandboxType::None,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
arg0: None,
|
||||
};
|
||||
|
||||
let stdout_stream = Some(StdoutStream {
|
||||
sub_id: turn_context.sub_id.clone(),
|
||||
call_id: call_id.clone(),
|
||||
tx_event: session.get_tx_event(),
|
||||
});
|
||||
|
||||
let sandbox_policy = SandboxPolicy::DangerFullAccess;
|
||||
let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await;
|
||||
|
||||
match exec_result {
|
||||
Err(CancelErr::Cancelled) => {
|
||||
let aborted_message = "command aborted by user".to_string();
|
||||
let exec_output = ExecToolCallOutput {
|
||||
exit_code: -1,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(aborted_message.clone()),
|
||||
aggregated_output: StreamOutput::new(aborted_message.clone()),
|
||||
duration: Duration::ZERO,
|
||||
timed_out: false,
|
||||
};
|
||||
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
|
||||
session
|
||||
.record_conversation_items(turn_context.as_ref(), &output_items)
|
||||
.await;
|
||||
session
|
||||
.send_event(
|
||||
turn_context.as_ref(),
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout: String::new(),
|
||||
stderr: aborted_message.clone(),
|
||||
aggregated_output: aborted_message.clone(),
|
||||
exit_code: -1,
|
||||
duration: Duration::ZERO,
|
||||
formatted_output: aborted_message,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Ok(Ok(output)) => {
|
||||
session
|
||||
.send_event(
|
||||
turn_context.as_ref(),
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: call_id.clone(),
|
||||
stdout: output.stdout.text.clone(),
|
||||
stderr: output.stderr.text.clone(),
|
||||
aggregated_output: output.aggregated_output.text.clone(),
|
||||
exit_code: output.exit_code,
|
||||
duration: output.duration,
|
||||
formatted_output: format_exec_output_str(&output),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let output_items = [user_shell_command_record_item(&raw_command, &output)];
|
||||
session
|
||||
.record_conversation_items(turn_context.as_ref(), &output_items)
|
||||
.await;
|
||||
}
|
||||
Ok(Err(err)) => {
|
||||
error!("user shell command failed: {err:?}");
|
||||
let message = format!("execution error: {err:?}");
|
||||
let exec_output = ExecToolCallOutput {
|
||||
exit_code: -1,
|
||||
stdout: StreamOutput::new(String::new()),
|
||||
stderr: StreamOutput::new(message.clone()),
|
||||
aggregated_output: StreamOutput::new(message.clone()),
|
||||
duration: Duration::ZERO,
|
||||
timed_out: false,
|
||||
};
|
||||
session
|
||||
.send_event(
|
||||
turn_context.as_ref(),
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id,
|
||||
stdout: exec_output.stdout.text.clone(),
|
||||
stderr: exec_output.stderr.text.clone(),
|
||||
aggregated_output: exec_output.aggregated_output.text.clone(),
|
||||
exit_code: exec_output.exit_code,
|
||||
duration: exec_output.duration,
|
||||
formatted_output: format_exec_output_str(&exec_output),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
|
||||
session
|
||||
.record_conversation_items(turn_context.as_ref(), &output_items)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
72
llmx-rs/core/src/terminal.rs
Normal file
72
llmx-rs/core/src/terminal.rs
Normal file
@@ -0,0 +1,72 @@
|
||||
use std::sync::OnceLock;
|
||||
|
||||
static TERMINAL: OnceLock<String> = OnceLock::new();
|
||||
|
||||
pub fn user_agent() -> String {
|
||||
TERMINAL.get_or_init(detect_terminal).to_string()
|
||||
}
|
||||
|
||||
/// Sanitize a header value to be used in a User-Agent string.
|
||||
///
|
||||
/// This function replaces any characters that are not allowed in a User-Agent string with an underscore.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `value` - The value to sanitize.
|
||||
fn is_valid_header_value_char(c: char) -> bool {
|
||||
c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/'
|
||||
}
|
||||
|
||||
fn sanitize_header_value(value: String) -> String {
|
||||
value.replace(|c| !is_valid_header_value_char(c), "_")
|
||||
}
|
||||
|
||||
fn detect_terminal() -> String {
|
||||
sanitize_header_value(
|
||||
if let Ok(tp) = std::env::var("TERM_PROGRAM")
|
||||
&& !tp.trim().is_empty()
|
||||
{
|
||||
let ver = std::env::var("TERM_PROGRAM_VERSION").ok();
|
||||
match ver {
|
||||
Some(v) if !v.trim().is_empty() => format!("{tp}/{v}"),
|
||||
_ => tp,
|
||||
}
|
||||
} else if let Ok(v) = std::env::var("WEZTERM_VERSION") {
|
||||
if !v.trim().is_empty() {
|
||||
format!("WezTerm/{v}")
|
||||
} else {
|
||||
"WezTerm".to_string()
|
||||
}
|
||||
} else if std::env::var("KITTY_WINDOW_ID").is_ok()
|
||||
|| std::env::var("TERM")
|
||||
.map(|t| t.contains("kitty"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"kitty".to_string()
|
||||
} else if std::env::var("ALACRITTY_SOCKET").is_ok()
|
||||
|| std::env::var("TERM")
|
||||
.map(|t| t == "alacritty")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"Alacritty".to_string()
|
||||
} else if let Ok(v) = std::env::var("KONSOLE_VERSION") {
|
||||
if !v.trim().is_empty() {
|
||||
format!("Konsole/{v}")
|
||||
} else {
|
||||
"Konsole".to_string()
|
||||
}
|
||||
} else if std::env::var("GNOME_TERMINAL_SCREEN").is_ok() {
|
||||
return "gnome-terminal".to_string();
|
||||
} else if let Ok(v) = std::env::var("VTE_VERSION") {
|
||||
if !v.trim().is_empty() {
|
||||
format!("VTE/{v}")
|
||||
} else {
|
||||
"VTE".to_string()
|
||||
}
|
||||
} else if std::env::var("WT_SESSION").is_ok() {
|
||||
return "WindowsTerminal".to_string();
|
||||
} else {
|
||||
std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string())
|
||||
},
|
||||
)
|
||||
}
|
||||
195
llmx-rs/core/src/token_data.rs
Normal file
195
llmx-rs/core/src/token_data.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
use base64::Engine;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
|
||||
pub struct TokenData {
|
||||
/// Flat info parsed from the JWT in auth.json.
|
||||
#[serde(
|
||||
deserialize_with = "deserialize_id_token",
|
||||
serialize_with = "serialize_id_token"
|
||||
)]
|
||||
pub id_token: IdTokenInfo,
|
||||
|
||||
/// This is a JWT.
|
||||
pub access_token: String,
|
||||
|
||||
pub refresh_token: String,
|
||||
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Flat subset of useful claims in id_token from auth.json.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
pub struct IdTokenInfo {
|
||||
pub email: Option<String>,
|
||||
/// The ChatGPT subscription plan type
|
||||
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
|
||||
/// (Note: values may vary by backend.)
|
||||
pub(crate) chatgpt_plan_type: Option<PlanType>,
|
||||
/// Organization/workspace identifier associated with the token, if present.
|
||||
pub chatgpt_account_id: Option<String>,
|
||||
pub raw_jwt: String,
|
||||
}
|
||||
|
||||
impl IdTokenInfo {
|
||||
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
|
||||
self.chatgpt_plan_type.as_ref().map(|t| match t {
|
||||
PlanType::Known(plan) => format!("{plan:?}"),
|
||||
PlanType::Unknown(s) => s.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub(crate) enum PlanType {
|
||||
Known(KnownPlan),
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum KnownPlan {
|
||||
Free,
|
||||
Plus,
|
||||
Pro,
|
||||
Team,
|
||||
Business,
|
||||
Enterprise,
|
||||
Edu,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct IdClaims {
|
||||
#[serde(default)]
|
||||
email: Option<String>,
|
||||
#[serde(rename = "https://api.openai.com/auth", default)]
|
||||
auth: Option<AuthClaims>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct AuthClaims {
|
||||
#[serde(default)]
|
||||
chatgpt_plan_type: Option<PlanType>,
|
||||
#[serde(default)]
|
||||
chatgpt_account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum IdTokenInfoError {
|
||||
#[error("invalid ID token format")]
|
||||
InvalidFormat,
|
||||
#[error(transparent)]
|
||||
Base64(#[from] base64::DecodeError),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
}
|
||||
|
||||
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
|
||||
// JWT format: header.payload.signature
|
||||
let mut parts = id_token.split('.');
|
||||
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
|
||||
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
|
||||
_ => return Err(IdTokenInfoError::InvalidFormat),
|
||||
};
|
||||
|
||||
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
|
||||
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
|
||||
|
||||
match claims.auth {
|
||||
Some(auth) => Ok(IdTokenInfo {
|
||||
email: claims.email,
|
||||
raw_jwt: id_token.to_string(),
|
||||
chatgpt_plan_type: auth.chatgpt_plan_type,
|
||||
chatgpt_account_id: auth.chatgpt_account_id,
|
||||
}),
|
||||
None => Ok(IdTokenInfo {
|
||||
email: claims.email,
|
||||
raw_jwt: id_token.to_string(),
|
||||
chatgpt_plan_type: None,
|
||||
chatgpt_account_id: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
parse_id_token(&s).map_err(serde::de::Error::custom)
|
||||
}
|
||||
|
||||
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
serializer.serialize_str(&id_token.raw_jwt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde::Serialize;
|
||||
|
||||
#[test]
|
||||
fn id_token_info_parses_email_and_plan() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({
|
||||
"email": "user@example.com",
|
||||
"https://api.openai.com/auth": {
|
||||
"chatgpt_plan_type": "pro"
|
||||
}
|
||||
});
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert_eq!(info.email.as_deref(), Some("user@example.com"));
|
||||
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn id_token_info_handles_missing_fields() {
|
||||
#[derive(Serialize)]
|
||||
struct Header {
|
||||
alg: &'static str,
|
||||
typ: &'static str,
|
||||
}
|
||||
let header = Header {
|
||||
alg: "none",
|
||||
typ: "JWT",
|
||||
};
|
||||
let payload = serde_json::json!({ "sub": "123" });
|
||||
|
||||
fn b64url_no_pad(bytes: &[u8]) -> String {
|
||||
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
|
||||
}
|
||||
|
||||
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
|
||||
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
|
||||
let signature_b64 = b64url_no_pad(b"sig");
|
||||
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
|
||||
|
||||
let info = parse_id_token(&fake_jwt).expect("should parse");
|
||||
assert!(info.email.is_none());
|
||||
assert!(info.get_chatgpt_plan_type().is_none());
|
||||
}
|
||||
}
|
||||
268
llmx-rs/core/src/tools/context.rs
Normal file
268
llmx-rs/core/src/tools/context.rs
Normal file
@@ -0,0 +1,268 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::tools::TELEMETRY_PREVIEW_MAX_BYTES;
|
||||
use crate::tools::TELEMETRY_PREVIEW_MAX_LINES;
|
||||
use crate::tools::TELEMETRY_PREVIEW_TRUNCATION_NOTICE;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use codex_protocol::protocol::FileChange;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use mcp_types::CallToolResult;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolInvocation {
|
||||
pub session: Arc<Session>,
|
||||
pub turn: Arc<TurnContext>,
|
||||
pub tracker: SharedTurnDiffTracker,
|
||||
pub call_id: String,
|
||||
pub tool_name: String,
|
||||
pub payload: ToolPayload,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ToolPayload {
|
||||
Function {
|
||||
arguments: String,
|
||||
},
|
||||
Custom {
|
||||
input: String,
|
||||
},
|
||||
LocalShell {
|
||||
params: ShellToolCallParams,
|
||||
},
|
||||
UnifiedExec {
|
||||
arguments: String,
|
||||
},
|
||||
Mcp {
|
||||
server: String,
|
||||
tool: String,
|
||||
raw_arguments: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolPayload {
|
||||
pub fn log_payload(&self) -> Cow<'_, str> {
|
||||
match self {
|
||||
ToolPayload::Function { arguments } => Cow::Borrowed(arguments),
|
||||
ToolPayload::Custom { input } => Cow::Borrowed(input),
|
||||
ToolPayload::LocalShell { params } => Cow::Owned(params.command.join(" ")),
|
||||
ToolPayload::UnifiedExec { arguments } => Cow::Borrowed(arguments),
|
||||
ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum ToolOutput {
|
||||
Function {
|
||||
// Plain text representation of the tool output.
|
||||
content: String,
|
||||
// Some tool calls such as MCP calls may return structured content that can get parsed into an array of polymorphic content items.
|
||||
content_items: Option<Vec<FunctionCallOutputContentItem>>,
|
||||
success: Option<bool>,
|
||||
},
|
||||
Mcp {
|
||||
result: Result<CallToolResult, String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolOutput {
|
||||
pub fn log_preview(&self) -> String {
|
||||
match self {
|
||||
ToolOutput::Function { content, .. } => telemetry_preview(content),
|
||||
ToolOutput::Mcp { result } => format!("{result:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn success_for_logging(&self) -> bool {
|
||||
match self {
|
||||
ToolOutput::Function { success, .. } => success.unwrap_or(true),
|
||||
ToolOutput::Mcp { result } => result.is_ok(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
|
||||
match self {
|
||||
ToolOutput::Function {
|
||||
content,
|
||||
content_items,
|
||||
success,
|
||||
} => {
|
||||
if matches!(payload, ToolPayload::Custom { .. }) {
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: content,
|
||||
}
|
||||
} else {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content,
|
||||
content_items,
|
||||
success,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call_id.to_string(),
|
||||
result,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn telemetry_preview(content: &str) -> String {
|
||||
let truncated_slice = take_bytes_at_char_boundary(content, TELEMETRY_PREVIEW_MAX_BYTES);
|
||||
let truncated_by_bytes = truncated_slice.len() < content.len();
|
||||
|
||||
let mut preview = String::new();
|
||||
let mut lines_iter = truncated_slice.lines();
|
||||
for idx in 0..TELEMETRY_PREVIEW_MAX_LINES {
|
||||
match lines_iter.next() {
|
||||
Some(line) => {
|
||||
if idx > 0 {
|
||||
preview.push('\n');
|
||||
}
|
||||
preview.push_str(line);
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
let truncated_by_lines = lines_iter.next().is_some();
|
||||
|
||||
if !truncated_by_bytes && !truncated_by_lines {
|
||||
return content.to_string();
|
||||
}
|
||||
|
||||
if preview.len() < truncated_slice.len()
|
||||
&& truncated_slice
|
||||
.as_bytes()
|
||||
.get(preview.len())
|
||||
.is_some_and(|byte| *byte == b'\n')
|
||||
{
|
||||
preview.push('\n');
|
||||
}
|
||||
|
||||
if !preview.is_empty() && !preview.ends_with('\n') {
|
||||
preview.push('\n');
|
||||
}
|
||||
preview.push_str(TELEMETRY_PREVIEW_TRUNCATION_NOTICE);
|
||||
|
||||
preview
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[test]
|
||||
fn custom_tool_calls_should_roundtrip_as_custom_outputs() {
|
||||
let payload = ToolPayload::Custom {
|
||||
input: "patch".to_string(),
|
||||
};
|
||||
let response = ToolOutput::Function {
|
||||
content: "patched".to_string(),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
}
|
||||
.into_response("call-42", &payload);
|
||||
|
||||
match response {
|
||||
ResponseInputItem::CustomToolCallOutput { call_id, output } => {
|
||||
assert_eq!(call_id, "call-42");
|
||||
assert_eq!(output, "patched");
|
||||
}
|
||||
other => panic!("expected CustomToolCallOutput, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn function_payloads_remain_function_outputs() {
|
||||
let payload = ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
};
|
||||
let response = ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
}
|
||||
.into_response("fn-1", &payload);
|
||||
|
||||
match response {
|
||||
ResponseInputItem::FunctionCallOutput { call_id, output } => {
|
||||
assert_eq!(call_id, "fn-1");
|
||||
assert_eq!(output.content, "ok");
|
||||
assert!(output.content_items.is_none());
|
||||
assert_eq!(output.success, Some(true));
|
||||
}
|
||||
other => panic!("expected FunctionCallOutput, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_returns_original_within_limits() {
|
||||
let content = "short output";
|
||||
assert_eq!(telemetry_preview(content), content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_truncates_by_bytes() {
|
||||
let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8);
|
||||
let preview = telemetry_preview(&content);
|
||||
|
||||
assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||
assert!(
|
||||
preview.len()
|
||||
<= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn telemetry_preview_truncates_by_lines() {
|
||||
let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5))
|
||||
.map(|idx| format!("line {idx}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let preview = telemetry_preview(&content);
|
||||
let lines: Vec<&str> = preview.lines().collect();
|
||||
|
||||
assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1);
|
||||
assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct ExecCommandContext {
|
||||
pub(crate) turn: Arc<TurnContext>,
|
||||
pub(crate) call_id: String,
|
||||
pub(crate) command_for_display: Vec<String>,
|
||||
pub(crate) cwd: PathBuf,
|
||||
pub(crate) apply_patch: Option<ApplyPatchCommandContext>,
|
||||
pub(crate) tool_name: String,
|
||||
pub(crate) otel_event_manager: OtelEventManager,
|
||||
// TODO(abhisek-oai): Find a better way to track this.
|
||||
// https://github.com/openai/codex/pull/2471/files#r2470352242
|
||||
pub(crate) is_user_shell_command: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub(crate) struct ApplyPatchCommandContext {
|
||||
pub(crate) user_explicitly_approved_this_action: bool,
|
||||
pub(crate) changes: HashMap<PathBuf, FileChange>,
|
||||
}
|
||||
369
llmx-rs/core/src/tools/events.rs
Normal file
369
llmx-rs/core/src/tools/events.rs
Normal file
@@ -0,0 +1,369 @@
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandBeginEvent;
|
||||
use crate::protocol::ExecCommandEndEvent;
|
||||
use crate::protocol::FileChange;
|
||||
use crate::protocol::PatchApplyBeginEvent;
|
||||
use crate::protocol::PatchApplyEndEvent;
|
||||
use crate::protocol::TurnDiffEvent;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::format_exec_output_str;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) struct ToolEventCtx<'a> {
|
||||
pub session: &'a Session,
|
||||
pub turn: &'a TurnContext,
|
||||
pub call_id: &'a str,
|
||||
pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||
}
|
||||
|
||||
impl<'a> ToolEventCtx<'a> {
|
||||
pub fn new(
|
||||
session: &'a Session,
|
||||
turn: &'a TurnContext,
|
||||
call_id: &'a str,
|
||||
turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
|
||||
) -> Self {
|
||||
Self {
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
turn_diff_tracker,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum ToolEventStage {
|
||||
Begin,
|
||||
Success(ExecToolCallOutput),
|
||||
Failure(ToolEventFailure),
|
||||
}
|
||||
|
||||
pub(crate) enum ToolEventFailure {
|
||||
Output(ExecToolCallOutput),
|
||||
Message(String),
|
||||
}
|
||||
|
||||
pub(crate) async fn emit_exec_command_begin(
|
||||
ctx: ToolEventCtx<'_>,
|
||||
command: &[String],
|
||||
cwd: &Path,
|
||||
is_user_shell_command: bool,
|
||||
) {
|
||||
ctx.session
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
command: command.to_vec(),
|
||||
cwd: cwd.to_path_buf(),
|
||||
parsed_cmd: parse_command(command),
|
||||
is_user_shell_command,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
// Concrete, allocation-free emitter: avoid trait objects and boxed futures.
|
||||
pub(crate) enum ToolEmitter {
|
||||
Shell {
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
is_user_shell_command: bool,
|
||||
},
|
||||
ApplyPatch {
|
||||
changes: HashMap<PathBuf, FileChange>,
|
||||
auto_approved: bool,
|
||||
},
|
||||
UnifiedExec {
|
||||
command: String,
|
||||
cwd: PathBuf,
|
||||
// True for `exec_command` and false for `write_stdin`.
|
||||
#[allow(dead_code)]
|
||||
is_startup_command: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolEmitter {
|
||||
pub fn shell(command: Vec<String>, cwd: PathBuf, is_user_shell_command: bool) -> Self {
|
||||
Self::Shell {
|
||||
command,
|
||||
cwd,
|
||||
is_user_shell_command,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_patch(changes: HashMap<PathBuf, FileChange>, auto_approved: bool) -> Self {
|
||||
Self::ApplyPatch {
|
||||
changes,
|
||||
auto_approved,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unified_exec(command: String, cwd: PathBuf, is_startup_command: bool) -> Self {
|
||||
Self::UnifiedExec {
|
||||
command,
|
||||
cwd,
|
||||
is_startup_command,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn emit(&self, ctx: ToolEventCtx<'_>, stage: ToolEventStage) {
|
||||
match (self, stage) {
|
||||
(
|
||||
Self::Shell {
|
||||
command,
|
||||
cwd,
|
||||
is_user_shell_command,
|
||||
},
|
||||
ToolEventStage::Begin,
|
||||
) => {
|
||||
emit_exec_command_begin(ctx, command, cwd.as_path(), *is_user_shell_command).await;
|
||||
}
|
||||
(Self::Shell { .. }, ToolEventStage::Success(output)) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.aggregated_output.text.clone(),
|
||||
output.exit_code,
|
||||
output.duration,
|
||||
format_exec_output_str(&output),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Self::Shell { .. }, ToolEventStage::Failure(ToolEventFailure::Output(output))) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.aggregated_output.text.clone(),
|
||||
output.exit_code,
|
||||
output.duration,
|
||||
format_exec_output_str(&output),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Self::Shell { .. }, ToolEventStage::Failure(ToolEventFailure::Message(message))) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
String::new(),
|
||||
(*message).to_string(),
|
||||
(*message).to_string(),
|
||||
-1,
|
||||
Duration::ZERO,
|
||||
message.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
(
|
||||
Self::ApplyPatch {
|
||||
changes,
|
||||
auto_approved,
|
||||
},
|
||||
ToolEventStage::Begin,
|
||||
) => {
|
||||
if let Some(tracker) = ctx.turn_diff_tracker {
|
||||
let mut guard = tracker.lock().await;
|
||||
guard.on_patch_begin(changes);
|
||||
}
|
||||
ctx.session
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
auto_approved: *auto_approved,
|
||||
changes: changes.clone(),
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => {
|
||||
emit_patch_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.exit_code == 0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(
|
||||
Self::ApplyPatch { .. },
|
||||
ToolEventStage::Failure(ToolEventFailure::Output(output)),
|
||||
) => {
|
||||
emit_patch_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.exit_code == 0,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(
|
||||
Self::ApplyPatch { .. },
|
||||
ToolEventStage::Failure(ToolEventFailure::Message(message)),
|
||||
) => {
|
||||
emit_patch_end(ctx, String::new(), (*message).to_string(), false).await;
|
||||
}
|
||||
(Self::UnifiedExec { command, cwd, .. }, ToolEventStage::Begin) => {
|
||||
emit_exec_command_begin(ctx, &[command.to_string()], cwd.as_path(), false).await;
|
||||
}
|
||||
(Self::UnifiedExec { .. }, ToolEventStage::Success(output)) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.aggregated_output.text.clone(),
|
||||
output.exit_code,
|
||||
output.duration,
|
||||
format_exec_output_str(&output),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(
|
||||
Self::UnifiedExec { .. },
|
||||
ToolEventStage::Failure(ToolEventFailure::Output(output)),
|
||||
) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
output.stdout.text.clone(),
|
||||
output.stderr.text.clone(),
|
||||
output.aggregated_output.text.clone(),
|
||||
output.exit_code,
|
||||
output.duration,
|
||||
format_exec_output_str(&output),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(
|
||||
Self::UnifiedExec { .. },
|
||||
ToolEventStage::Failure(ToolEventFailure::Message(message)),
|
||||
) => {
|
||||
emit_exec_end(
|
||||
ctx,
|
||||
String::new(),
|
||||
(*message).to_string(),
|
||||
(*message).to_string(),
|
||||
-1,
|
||||
Duration::ZERO,
|
||||
message.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn begin(&self, ctx: ToolEventCtx<'_>) {
|
||||
self.emit(ctx, ToolEventStage::Begin).await;
|
||||
}
|
||||
|
||||
pub async fn finish(
|
||||
&self,
|
||||
ctx: ToolEventCtx<'_>,
|
||||
out: Result<ExecToolCallOutput, ToolError>,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let (event, result) = match out {
|
||||
Ok(output) => {
|
||||
let content = super::format_exec_output_for_model(&output);
|
||||
let exit_code = output.exit_code;
|
||||
let event = ToolEventStage::Success(output);
|
||||
let result = if exit_code == 0 {
|
||||
Ok(content)
|
||||
} else {
|
||||
Err(FunctionCallError::RespondToModel(content))
|
||||
};
|
||||
(event, result)
|
||||
}
|
||||
Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Timeout { output })))
|
||||
| Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied { output }))) => {
|
||||
let response = super::format_exec_output_for_model(&output);
|
||||
let event = ToolEventStage::Failure(ToolEventFailure::Output(*output));
|
||||
let result = Err(FunctionCallError::RespondToModel(response));
|
||||
(event, result)
|
||||
}
|
||||
Err(ToolError::Codex(err)) => {
|
||||
let message = format!("execution error: {err:?}");
|
||||
let event = ToolEventStage::Failure(ToolEventFailure::Message(message.clone()));
|
||||
let result = Err(FunctionCallError::RespondToModel(message));
|
||||
(event, result)
|
||||
}
|
||||
Err(ToolError::Rejected(msg)) => {
|
||||
// Normalize common rejection messages for exec tools so tests and
|
||||
// users see a clear, consistent phrase.
|
||||
let normalized = if msg == "rejected by user" {
|
||||
"exec command rejected by user".to_string()
|
||||
} else {
|
||||
msg
|
||||
};
|
||||
let event = ToolEventStage::Failure(ToolEventFailure::Message(normalized.clone()));
|
||||
let result = Err(FunctionCallError::RespondToModel(normalized));
|
||||
(event, result)
|
||||
}
|
||||
};
|
||||
self.emit(ctx, event).await;
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_exec_end(
|
||||
ctx: ToolEventCtx<'_>,
|
||||
stdout: String,
|
||||
stderr: String,
|
||||
aggregated_output: String,
|
||||
exit_code: i32,
|
||||
duration: Duration,
|
||||
formatted_output: String,
|
||||
) {
|
||||
ctx.session
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
aggregated_output,
|
||||
exit_code,
|
||||
duration,
|
||||
formatted_output,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) {
|
||||
ctx.session
|
||||
.send_event(
|
||||
ctx.turn,
|
||||
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
|
||||
call_id: ctx.call_id.to_string(),
|
||||
stdout,
|
||||
stderr,
|
||||
success,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
if let Some(tracker) = ctx.turn_diff_tracker {
|
||||
let unified_diff = {
|
||||
let mut guard = tracker.lock().await;
|
||||
guard.get_unified_diff()
|
||||
};
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
ctx.session
|
||||
.send_event(ctx.turn, EventMsg::TurnDiff(TurnDiffEvent { unified_diff }))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
265
llmx-rs/core/src/tools/handlers/apply_patch.rs
Normal file
265
llmx-rs/core/src/tools/handlers/apply_patch.rs
Normal file
@@ -0,0 +1,265 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::apply_patch;
|
||||
use crate::apply_patch::InternalApplyPatchInvocation;
|
||||
use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::events::ToolEmitter;
|
||||
use crate::tools::events::ToolEventCtx;
|
||||
use crate::tools::orchestrator::ToolOrchestrator;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
|
||||
use crate::tools::sandboxing::ToolCtx;
|
||||
use crate::tools::spec::ApplyPatchToolArgs;
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
pub struct ApplyPatchHandler;
|
||||
|
||||
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ApplyPatchHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::Function { .. } | ToolPayload::Custom { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
} = invocation;
|
||||
|
||||
let patch_input = match payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
args.input
|
||||
}
|
||||
ToolPayload::Custom { input } => input,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"apply_patch handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
// Re-parse and verify the patch so we can compute changes and approval.
|
||||
// Avoid building temporary ExecParams/command vectors; derive directly from inputs.
|
||||
let cwd = turn.cwd.clone();
|
||||
let command = vec!["apply_patch".to_string(), patch_input.clone()];
|
||||
match codex_apply_patch::maybe_parse_apply_patch_verified(&command, &cwd) {
|
||||
codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(session.as_ref(), turn.as_ref(), &call_id, changes)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => {
|
||||
let content = item?;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply) => {
|
||||
let emitter = ToolEmitter::apply_patch(
|
||||
convert_apply_patch_to_protocol(&apply.action),
|
||||
!apply.user_explicitly_approved_this_action,
|
||||
);
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
Some(&tracker),
|
||||
);
|
||||
emitter.begin(event_ctx).await;
|
||||
|
||||
let req = ApplyPatchRequest {
|
||||
patch: apply.action.patch.clone(),
|
||||
cwd: apply.action.cwd.clone(),
|
||||
timeout_ms: None,
|
||||
user_explicitly_approved: apply.user_explicitly_approved_this_action,
|
||||
codex_exe: turn.codex_linux_sandbox_exe.clone(),
|
||||
};
|
||||
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ApplyPatchRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: session.as_ref(),
|
||||
turn: turn.as_ref(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
let out = orchestrator
|
||||
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
|
||||
.await;
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
Some(&tracker),
|
||||
);
|
||||
let content = emitter.finish(event_ctx, out).await?;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
|
||||
Err(FunctionCallError::RespondToModel(format!(
|
||||
"apply_patch verification failed: {parse_error}"
|
||||
)))
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => {
|
||||
tracing::trace!("Failed to parse apply_patch input, {error:?}");
|
||||
Err(FunctionCallError::RespondToModel(
|
||||
"apply_patch handler received invalid patch input".to_string(),
|
||||
))
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => {
|
||||
Err(FunctionCallError::RespondToModel(
|
||||
"apply_patch handler received non-apply_patch input".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ApplyPatchToolType {
|
||||
Freeform,
|
||||
Function,
|
||||
}
|
||||
|
||||
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
|
||||
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
|
||||
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
|
||||
ToolSpec::Freeform(FreeformTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: "Use the `apply_patch` tool to edit files. This is a FREEFORM tool, so do not wrap the patch in JSON.".to_string(),
|
||||
format: FreeformToolFormat {
|
||||
r#type: "grammar".to_string(),
|
||||
syntax: "lark".to_string(),
|
||||
definition: APPLY_PATCH_LARK_GRAMMAR.to_string(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models
|
||||
pub(crate) fn create_apply_patch_json_tool() -> ToolSpec {
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"input".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some(r#"The entire contents of the apply_patch command"#.to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "apply_patch".to_string(),
|
||||
description: r#"Use the `apply_patch` tool to edit files.
|
||||
Your patch language is a stripped‑down, file‑oriented diff format designed to be easy to parse and safe to apply. You can think of it as a high‑level envelope:
|
||||
|
||||
*** Begin Patch
|
||||
[ one or more file sections ]
|
||||
*** End Patch
|
||||
|
||||
Within that envelope, you get a sequence of file operations.
|
||||
You MUST include a header to specify the action you are taking.
|
||||
Each operation starts with one of three headers:
|
||||
|
||||
*** Add File: <path> - create a new file. Every following line is a + line (the initial contents).
|
||||
*** Delete File: <path> - remove an existing file. Nothing follows.
|
||||
*** Update File: <path> - patch an existing file in place (optionally with a rename).
|
||||
|
||||
May be immediately followed by *** Move to: <new path> if you want to rename the file.
|
||||
Then one or more “hunks”, each introduced by @@ (optionally followed by a hunk header).
|
||||
Within a hunk each line starts with:
|
||||
|
||||
For instructions on [context_before] and [context_after]:
|
||||
- By default, show 3 lines of code immediately above and 3 lines immediately below each change. If a change is within 3 lines of a previous change, do NOT duplicate the first change’s [context_after] lines in the second change’s [context_before] lines.
|
||||
- If 3 lines of context is insufficient to uniquely identify the snippet of code within the file, use the @@ operator to indicate the class or function to which the snippet belongs. For instance, we might have:
|
||||
@@ class BaseClass
|
||||
[3 lines of pre-context]
|
||||
- [old_code]
|
||||
+ [new_code]
|
||||
[3 lines of post-context]
|
||||
|
||||
- If a code block is repeated so many times in a class or function such that even a single `@@` statement and 3 lines of context cannot uniquely identify the snippet of code, you can use multiple `@@` statements to jump to the right context. For instance:
|
||||
|
||||
@@ class BaseClass
|
||||
@@ def method():
|
||||
[3 lines of pre-context]
|
||||
- [old_code]
|
||||
+ [new_code]
|
||||
[3 lines of post-context]
|
||||
|
||||
The full grammar definition is below:
|
||||
Patch := Begin { FileOp } End
|
||||
Begin := "*** Begin Patch" NEWLINE
|
||||
End := "*** End Patch" NEWLINE
|
||||
FileOp := AddFile | DeleteFile | UpdateFile
|
||||
AddFile := "*** Add File: " path NEWLINE { "+" line NEWLINE }
|
||||
DeleteFile := "*** Delete File: " path NEWLINE
|
||||
UpdateFile := "*** Update File: " path NEWLINE [ MoveTo ] { Hunk }
|
||||
MoveTo := "*** Move to: " newPath NEWLINE
|
||||
Hunk := "@@" [ header ] NEWLINE { HunkLine } [ "*** End of File" NEWLINE ]
|
||||
HunkLine := (" " | "-" | "+") text NEWLINE
|
||||
|
||||
A full patch can combine several operations:
|
||||
|
||||
*** Begin Patch
|
||||
*** Add File: hello.txt
|
||||
+Hello world
|
||||
*** Update File: src/app.py
|
||||
*** Move to: src/main.py
|
||||
@@ def greet():
|
||||
-print("Hi")
|
||||
+print("Hello, world!")
|
||||
*** Delete File: obsolete.txt
|
||||
*** End Patch
|
||||
|
||||
It is important to remember:
|
||||
|
||||
- You must include a header with your intended action (Add/Delete/Update)
|
||||
- You must prefix new lines with `+` even when creating a new file
|
||||
- File references can only be relative, NEVER ABSOLUTE.
|
||||
"#
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["input".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
}
|
||||
274
llmx-rs/core/src/tools/handlers/grep_files.rs
Normal file
274
llmx-rs/core/src/tools/handlers/grep_files.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct GrepFilesHandler;
|
||||
|
||||
const DEFAULT_LIMIT: usize = 100;
|
||||
const MAX_LIMIT: usize = 2000;
|
||||
const COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
fn default_limit() -> usize {
|
||||
DEFAULT_LIMIT
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct GrepFilesArgs {
|
||||
pattern: String,
|
||||
#[serde(default)]
|
||||
include: Option<String>,
|
||||
#[serde(default)]
|
||||
path: Option<String>,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for GrepFilesHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, turn, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"grep_files handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let pattern = args.pattern.trim();
|
||||
if pattern.is_empty() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"pattern must not be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let limit = args.limit.min(MAX_LIMIT);
|
||||
let search_path = turn.resolve_path(args.path.clone());
|
||||
|
||||
verify_path_exists(&search_path).await?;
|
||||
|
||||
let include = args.include.as_deref().map(str::trim).and_then(|val| {
|
||||
if val.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(val.to_string())
|
||||
}
|
||||
});
|
||||
|
||||
let search_results =
|
||||
run_rg_search(pattern, include.as_deref(), &search_path, limit, &turn.cwd).await?;
|
||||
|
||||
if search_results.is_empty() {
|
||||
Ok(ToolOutput::Function {
|
||||
content: "No matches found.".to_string(),
|
||||
content_items: None,
|
||||
success: Some(false),
|
||||
})
|
||||
} else {
|
||||
Ok(ToolOutput::Function {
|
||||
content: search_results.join("\n"),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn verify_path_exists(path: &Path) -> Result<(), FunctionCallError> {
|
||||
tokio::fs::metadata(path).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("unable to access `{}`: {err}", path.display()))
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_rg_search(
|
||||
pattern: &str,
|
||||
include: Option<&str>,
|
||||
search_path: &Path,
|
||||
limit: usize,
|
||||
cwd: &Path,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut command = Command::new("rg");
|
||||
command
|
||||
.current_dir(cwd)
|
||||
.arg("--files-with-matches")
|
||||
.arg("--sortr=modified")
|
||||
.arg("--regexp")
|
||||
.arg(pattern)
|
||||
.arg("--no-messages");
|
||||
|
||||
if let Some(glob) = include {
|
||||
command.arg("--glob").arg(glob);
|
||||
}
|
||||
|
||||
command.arg("--").arg(search_path);
|
||||
|
||||
let output = timeout(COMMAND_TIMEOUT, command.output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("rg timed out after 30 seconds".to_string())
|
||||
})?
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to launch rg: {err}. Ensure ripgrep is installed and on PATH."
|
||||
))
|
||||
})?;
|
||||
|
||||
match output.status.code() {
|
||||
Some(0) => Ok(parse_results(&output.stdout, limit)),
|
||||
Some(1) => Ok(Vec::new()),
|
||||
_ => {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
Err(FunctionCallError::RespondToModel(format!(
|
||||
"rg failed: {stderr}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_results(stdout: &[u8], limit: usize) -> Vec<String> {
|
||||
let mut results = Vec::new();
|
||||
for line in stdout.split(|byte| *byte == b'\n') {
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if let Ok(text) = std::str::from_utf8(line) {
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
results.push(text.to_string());
|
||||
if results.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::process::Command as StdCommand;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn parses_basic_results() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n";
|
||||
let parsed = parse_results(stdout, 10);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_truncates_after_limit() {
|
||||
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n";
|
||||
let parsed = parse_results(stdout, 2);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_returns_results() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
std::fs::write(dir.join("other.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results.iter().any(|path| path.ends_with("match_one.txt")));
|
||||
assert!(results.iter().any(|path| path.ends_with("match_two.txt")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_with_glob_filter() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap();
|
||||
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?;
|
||||
assert_eq!(results.len(), 1);
|
||||
assert!(results.iter().all(|path| path.ends_with("match_one.rs")));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_respects_limit() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "alpha one").unwrap();
|
||||
std::fs::write(dir.join("two.txt"), "alpha two").unwrap();
|
||||
std::fs::write(dir.join("three.txt"), "alpha three").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 2, dir).await?;
|
||||
assert_eq!(results.len(), 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_search_handles_no_matches() -> anyhow::Result<()> {
|
||||
if !rg_available() {
|
||||
return Ok(());
|
||||
}
|
||||
let temp = tempdir().expect("create temp dir");
|
||||
let dir = temp.path();
|
||||
std::fs::write(dir.join("one.txt"), "omega").unwrap();
|
||||
|
||||
let results = run_rg_search("alpha", None, dir, 5, dir).await?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn rg_available() -> bool {
|
||||
StdCommand::new("rg")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
477
llmx-rs/core/src/tools/handlers/list_dir.rs
Normal file
477
llmx-rs/core/src/tools/handlers/list_dir.rs
Normal file
@@ -0,0 +1,477 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::ffi::OsStr;
|
||||
use std::fs::FileType;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ListDirHandler;
|
||||
|
||||
const MAX_ENTRY_LENGTH: usize = 500;
|
||||
const INDENTATION_SPACES: usize = 2;
|
||||
|
||||
fn default_offset() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
25
|
||||
}
|
||||
|
||||
fn default_depth() -> usize {
|
||||
2
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ListDirArgs {
|
||||
dir_path: String,
|
||||
#[serde(default = "default_offset")]
|
||||
offset: usize,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
#[serde(default = "default_depth")]
|
||||
depth: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ListDirHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"list_dir handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ListDirArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let ListDirArgs {
|
||||
dir_path,
|
||||
offset,
|
||||
limit,
|
||||
depth,
|
||||
} = args;
|
||||
|
||||
if offset == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset must be a 1-indexed entry number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if depth == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"depth must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(&dir_path);
|
||||
if !path.is_absolute() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"dir_path must be an absolute path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(&path, offset, limit, depth).await?;
|
||||
let mut output = Vec::with_capacity(entries.len() + 1);
|
||||
output.push(format!("Absolute path: {}", path.display()));
|
||||
output.extend(entries);
|
||||
Ok(ToolOutput::Function {
|
||||
content: output.join("\n"),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_dir_slice(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
depth: usize,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let mut entries = Vec::new();
|
||||
collect_entries(path, Path::new(""), depth, &mut entries).await?;
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let start_index = offset - 1;
|
||||
if start_index >= entries.len() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset exceeds directory entry count".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let remaining_entries = entries.len() - start_index;
|
||||
let capped_limit = limit.min(remaining_entries);
|
||||
let end_index = start_index + capped_limit;
|
||||
let mut selected_entries = entries[start_index..end_index].to_vec();
|
||||
selected_entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
|
||||
let mut formatted = Vec::with_capacity(selected_entries.len());
|
||||
|
||||
for entry in &selected_entries {
|
||||
formatted.push(format_entry_line(entry));
|
||||
}
|
||||
|
||||
if end_index < entries.len() {
|
||||
formatted.push(format!("More than {capped_limit} entries found"));
|
||||
}
|
||||
|
||||
Ok(formatted)
|
||||
}
|
||||
|
||||
async fn collect_entries(
|
||||
dir_path: &Path,
|
||||
relative_prefix: &Path,
|
||||
depth: usize,
|
||||
entries: &mut Vec<DirEntry>,
|
||||
) -> Result<(), FunctionCallError> {
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back((dir_path.to_path_buf(), relative_prefix.to_path_buf(), depth));
|
||||
|
||||
while let Some((current_dir, prefix, remaining_depth)) = queue.pop_front() {
|
||||
let mut read_dir = fs::read_dir(¤t_dir).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})?;
|
||||
|
||||
let mut dir_entries = Vec::new();
|
||||
|
||||
while let Some(entry) = read_dir.next_entry().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
|
||||
})? {
|
||||
let file_type = entry.file_type().await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to inspect entry: {err}"))
|
||||
})?;
|
||||
|
||||
let file_name = entry.file_name();
|
||||
let relative_path = if prefix.as_os_str().is_empty() {
|
||||
PathBuf::from(&file_name)
|
||||
} else {
|
||||
prefix.join(&file_name)
|
||||
};
|
||||
|
||||
let display_name = format_entry_component(&file_name);
|
||||
let display_depth = prefix.components().count();
|
||||
let sort_key = format_entry_name(&relative_path);
|
||||
let kind = DirEntryKind::from(&file_type);
|
||||
dir_entries.push((
|
||||
entry.path(),
|
||||
relative_path,
|
||||
kind,
|
||||
DirEntry {
|
||||
name: sort_key,
|
||||
display_name,
|
||||
depth: display_depth,
|
||||
kind,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
dir_entries.sort_unstable_by(|a, b| a.3.name.cmp(&b.3.name));
|
||||
|
||||
for (entry_path, relative_path, kind, dir_entry) in dir_entries {
|
||||
if kind == DirEntryKind::Directory && remaining_depth > 1 {
|
||||
queue.push_back((entry_path, relative_path, remaining_depth - 1));
|
||||
}
|
||||
entries.push(dir_entry);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_entry_name(path: &Path) -> String {
|
||||
let normalized = path.to_string_lossy().replace("\\", "/");
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_component(name: &OsStr) -> String {
|
||||
let normalized = name.to_string_lossy();
|
||||
if normalized.len() > MAX_ENTRY_LENGTH {
|
||||
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
|
||||
} else {
|
||||
normalized.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_entry_line(entry: &DirEntry) -> String {
|
||||
let indent = " ".repeat(entry.depth * INDENTATION_SPACES);
|
||||
let mut name = entry.display_name.clone();
|
||||
match entry.kind {
|
||||
DirEntryKind::Directory => name.push('/'),
|
||||
DirEntryKind::Symlink => name.push('@'),
|
||||
DirEntryKind::Other => name.push('?'),
|
||||
DirEntryKind::File => {}
|
||||
}
|
||||
format!("{indent}{name}")
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct DirEntry {
|
||||
name: String,
|
||||
display_name: String,
|
||||
depth: usize,
|
||||
kind: DirEntryKind,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
enum DirEntryKind {
|
||||
Directory,
|
||||
File,
|
||||
Symlink,
|
||||
Other,
|
||||
}
|
||||
|
||||
impl From<&FileType> for DirEntryKind {
|
||||
fn from(file_type: &FileType) -> Self {
|
||||
if file_type.is_symlink() {
|
||||
DirEntryKind::Symlink
|
||||
} else if file_type.is_dir() {
|
||||
DirEntryKind::Directory
|
||||
} else if file_type.is_file() {
|
||||
DirEntryKind::File
|
||||
} else {
|
||||
DirEntryKind::Other
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn lists_directory_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
let sub_dir = dir_path.join("nested");
|
||||
tokio::fs::create_dir(&sub_dir)
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let deeper_dir = sub_dir.join("deeper");
|
||||
tokio::fs::create_dir(&deeper_dir)
|
||||
.await
|
||||
.expect("create deeper dir");
|
||||
|
||||
tokio::fs::write(dir_path.join("entry.txt"), b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
tokio::fs::write(sub_dir.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write child");
|
||||
tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild")
|
||||
.await
|
||||
.expect("write grandchild");
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::symlink;
|
||||
let link_path = dir_path.join("link");
|
||||
symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 20, 3)
|
||||
.await
|
||||
.expect("list directory");
|
||||
|
||||
#[cfg(unix)]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"link@".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let expected = vec![
|
||||
"entry.txt".to_string(),
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
];
|
||||
|
||||
assert_eq!(entries, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_offset_exceeds_entries() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::create_dir(dir_path.join("nested"))
|
||||
.await
|
||||
.expect("create sub dir");
|
||||
|
||||
let err = list_dir_slice(dir_path, 10, 1, 2)
|
||||
.await
|
||||
.expect_err("offset exceeds entries");
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_depth_parameter() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await.expect("create nested");
|
||||
tokio::fs::create_dir(&deeper).await.expect("create deeper");
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root")
|
||||
.await
|
||||
.expect("write root");
|
||||
tokio::fs::write(nested.join("child.txt"), b"child")
|
||||
.await
|
||||
.expect("write nested");
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep")
|
||||
.await
|
||||
.expect("write deeper");
|
||||
|
||||
let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1)
|
||||
.await
|
||||
.expect("list depth 1");
|
||||
assert_eq!(
|
||||
entries_depth_one,
|
||||
vec!["nested/".to_string(), "root.txt".to_string(),]
|
||||
);
|
||||
|
||||
let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2)
|
||||
.await
|
||||
.expect("list depth 2");
|
||||
assert_eq!(
|
||||
entries_depth_two,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3)
|
||||
.await
|
||||
.expect("list depth 3");
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
" deeper/".to_string(),
|
||||
" grandchild.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handles_large_limit_without_overflow() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
tokio::fs::write(dir_path.join("alpha.txt"), b"alpha")
|
||||
.await
|
||||
.expect("write alpha");
|
||||
tokio::fs::write(dir_path.join("beta.txt"), b"beta")
|
||||
.await
|
||||
.expect("write beta");
|
||||
tokio::fs::write(dir_path.join("gamma.txt"), b"gamma")
|
||||
.await
|
||||
.expect("write gamma");
|
||||
|
||||
let entries = list_dir_slice(dir_path, 2, usize::MAX, 1)
|
||||
.await
|
||||
.expect("list without overflow");
|
||||
assert_eq!(
|
||||
entries,
|
||||
vec!["beta.txt".to_string(), "gamma.txt".to_string(),]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indicates_truncated_results() {
|
||||
let temp = tempdir().expect("create tempdir");
|
||||
let dir_path = temp.path();
|
||||
|
||||
for idx in 0..40 {
|
||||
let file = dir_path.join(format!("file_{idx:02}.txt"));
|
||||
tokio::fs::write(file, b"content")
|
||||
.await
|
||||
.expect("write file");
|
||||
}
|
||||
|
||||
let entries = list_dir_slice(dir_path, 1, 25, 1)
|
||||
.await
|
||||
.expect("list directory");
|
||||
assert_eq!(entries.len(), 26);
|
||||
assert_eq!(
|
||||
entries.last(),
|
||||
Some(&"More than 25 entries found".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bfs_truncation() -> anyhow::Result<()> {
|
||||
let temp = tempdir()?;
|
||||
let dir_path = temp.path();
|
||||
let nested = dir_path.join("nested");
|
||||
let deeper = nested.join("deeper");
|
||||
tokio::fs::create_dir(&nested).await?;
|
||||
tokio::fs::create_dir(&deeper).await?;
|
||||
tokio::fs::write(dir_path.join("root.txt"), b"root").await?;
|
||||
tokio::fs::write(nested.join("child.txt"), b"child").await?;
|
||||
tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?;
|
||||
|
||||
let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?;
|
||||
assert_eq!(
|
||||
entries_depth_three,
|
||||
vec![
|
||||
"nested/".to_string(),
|
||||
" child.txt".to_string(),
|
||||
"root.txt".to_string(),
|
||||
"More than 3 entries found".to_string()
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
75
llmx-rs/core/src/tools/handlers/mcp.rs
Normal file
75
llmx-rs/core/src/tools/handlers/mcp.rs
Normal file
@@ -0,0 +1,75 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::mcp_tool_call::handle_mcp_tool_call;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct McpHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for McpHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Mcp
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let payload = match payload {
|
||||
ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
raw_arguments,
|
||||
} => (server, tool, raw_arguments),
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"mcp handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let (server, tool, raw_arguments) = payload;
|
||||
let arguments_str = raw_arguments;
|
||||
|
||||
let response = handle_mcp_tool_call(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
call_id.clone(),
|
||||
server,
|
||||
tool,
|
||||
arguments_str,
|
||||
)
|
||||
.await;
|
||||
|
||||
match response {
|
||||
codex_protocol::models::ResponseInputItem::McpToolCallOutput { result, .. } => {
|
||||
Ok(ToolOutput::Mcp { result })
|
||||
}
|
||||
codex_protocol::models::ResponseInputItem::FunctionCallOutput { output, .. } => {
|
||||
let codex_protocol::models::FunctionCallOutputPayload {
|
||||
content,
|
||||
content_items,
|
||||
success,
|
||||
} = output;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items,
|
||||
success,
|
||||
})
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(
|
||||
"mcp handler received unexpected response variant".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
789
llmx-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
789
llmx-rs/core/src/tools/handlers/mcp_resource.rs
Normal file
@@ -0,0 +1,789 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mcp_types::CallToolResult;
|
||||
use mcp_types::ContentBlock;
|
||||
use mcp_types::ListResourceTemplatesRequestParams;
|
||||
use mcp_types::ListResourceTemplatesResult;
|
||||
use mcp_types::ListResourcesRequestParams;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ReadResourceRequestParams;
|
||||
use mcp_types::ReadResourceResult;
|
||||
use mcp_types::Resource;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use mcp_types::TextContent;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::McpInvocation;
|
||||
use crate::protocol::McpToolCallBeginEvent;
|
||||
use crate::protocol::McpToolCallEndEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct McpResourceHandler;
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourcesArgs {
|
||||
/// Lists all resources from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
struct ListResourceTemplatesArgs {
|
||||
/// Lists all resource templates from all servers if not specified.
|
||||
#[serde(default)]
|
||||
server: Option<String>,
|
||||
#[serde(default)]
|
||||
cursor: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ReadResourceArgs {
|
||||
server: String,
|
||||
uri: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
resource: Resource,
|
||||
}
|
||||
|
||||
impl ResourceWithServer {
|
||||
fn new(server: String, resource: Resource) -> Self {
|
||||
Self { server, resource }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourceTemplateWithServer {
|
||||
server: String,
|
||||
#[serde(flatten)]
|
||||
template: ResourceTemplate,
|
||||
}
|
||||
|
||||
impl ResourceTemplateWithServer {
|
||||
fn new(server: String, template: ResourceTemplate) -> Self {
|
||||
Self { server, template }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourcesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resources: Vec<ResourceWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourcesPayload {
|
||||
fn from_single_server(server: String, result: ListResourcesResult) -> Self {
|
||||
let resources = result
|
||||
.resources
|
||||
.into_iter()
|
||||
.map(|resource| ResourceWithServer::new(server.clone(), resource))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resources,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(resources_by_server: HashMap<String, Vec<Resource>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<Resource>)> = resources_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resources = Vec::new();
|
||||
for (server, server_resources) in entries {
|
||||
for resource in server_resources {
|
||||
resources.push(ResourceWithServer::new(server.clone(), resource));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resources,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct ListResourceTemplatesPayload {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
server: Option<String>,
|
||||
resource_templates: Vec<ResourceTemplateWithServer>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
impl ListResourceTemplatesPayload {
|
||||
fn from_single_server(server: String, result: ListResourceTemplatesResult) -> Self {
|
||||
let resource_templates = result
|
||||
.resource_templates
|
||||
.into_iter()
|
||||
.map(|template| ResourceTemplateWithServer::new(server.clone(), template))
|
||||
.collect();
|
||||
Self {
|
||||
server: Some(server),
|
||||
resource_templates,
|
||||
next_cursor: result.next_cursor,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_all_servers(templates_by_server: HashMap<String, Vec<ResourceTemplate>>) -> Self {
|
||||
let mut entries: Vec<(String, Vec<ResourceTemplate>)> =
|
||||
templates_by_server.into_iter().collect();
|
||||
entries.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
|
||||
let mut resource_templates = Vec::new();
|
||||
for (server, server_templates) in entries {
|
||||
for template in server_templates {
|
||||
resource_templates.push(ResourceTemplateWithServer::new(server.clone(), template));
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
server: None,
|
||||
resource_templates,
|
||||
next_cursor: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ReadResourcePayload {
|
||||
server: String,
|
||||
uri: String,
|
||||
#[serde(flatten)]
|
||||
result: ReadResourceResult,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for McpResourceHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"mcp_resource handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let arguments_value = parse_arguments(arguments.as_str())?;
|
||||
|
||||
match tool_name.as_str() {
|
||||
"list_mcp_resources" => {
|
||||
handle_list_resources(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"list_mcp_resource_templates" => {
|
||||
handle_list_resource_templates(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
call_id.clone(),
|
||||
arguments_value.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
"read_mcp_resource" => {
|
||||
handle_read_resource(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
call_id,
|
||||
arguments_value,
|
||||
)
|
||||
.await
|
||||
}
|
||||
other => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported MCP resource tool: {other}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resources(
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourcesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resources".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor.clone().map(|value| ListResourcesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resources(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}"))
|
||||
})?;
|
||||
Ok(ListResourcesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let resources = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resources()
|
||||
.await;
|
||||
Ok(ListResourcesPayload::from_all_servers(resources))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = &output
|
||||
else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_list_resource_templates(
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?;
|
||||
let ListResourceTemplatesArgs { server, cursor } = args;
|
||||
let server = normalize_optional_string(server);
|
||||
let cursor = normalize_optional_string(cursor);
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone().unwrap_or_else(|| "codex".to_string()),
|
||||
tool: "list_mcp_resource_templates".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
|
||||
if let Some(server_name) = server.clone() {
|
||||
let params = cursor
|
||||
.clone()
|
||||
.map(|value| ListResourceTemplatesRequestParams {
|
||||
cursor: Some(value),
|
||||
});
|
||||
let result = session
|
||||
.list_resource_templates(&server_name, params)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"resources/templates/list failed: {err:#}"
|
||||
))
|
||||
})?;
|
||||
Ok(ListResourceTemplatesPayload::from_single_server(
|
||||
server_name,
|
||||
result,
|
||||
))
|
||||
} else {
|
||||
if cursor.is_some() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"cursor can only be used when a server is specified".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let templates = session
|
||||
.services
|
||||
.mcp_connection_manager
|
||||
.list_all_resource_templates()
|
||||
.await;
|
||||
Ok(ListResourceTemplatesPayload::from_all_servers(templates))
|
||||
}
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = &output
|
||||
else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_read_resource(
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
arguments: Option<Value>,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
let args: ReadResourceArgs = parse_args(arguments.clone())?;
|
||||
let ReadResourceArgs { server, uri } = args;
|
||||
let server = normalize_required_string("server", server)?;
|
||||
let uri = normalize_required_string("uri", uri)?;
|
||||
|
||||
let invocation = McpInvocation {
|
||||
server: server.clone(),
|
||||
tool: "read_mcp_resource".to_string(),
|
||||
arguments: arguments.clone(),
|
||||
};
|
||||
|
||||
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
|
||||
let start = Instant::now();
|
||||
|
||||
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
|
||||
let result = session
|
||||
.read_resource(&server, ReadResourceRequestParams { uri: uri.clone() })
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}"))
|
||||
})?;
|
||||
|
||||
Ok(ReadResourcePayload {
|
||||
server,
|
||||
uri,
|
||||
result,
|
||||
})
|
||||
}
|
||||
.await;
|
||||
|
||||
match payload_result {
|
||||
Ok(payload) => match serialize_function_output(payload) {
|
||||
Ok(output) => {
|
||||
let ToolOutput::Function {
|
||||
content, success, ..
|
||||
} = &output
|
||||
else {
|
||||
unreachable!("MCP resource handler should return function output");
|
||||
};
|
||||
let duration = start.elapsed();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Ok(call_tool_result_from_content(content, *success)),
|
||||
)
|
||||
.await;
|
||||
Ok(output)
|
||||
}
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
let duration = start.elapsed();
|
||||
let message = err.to_string();
|
||||
emit_tool_call_end(
|
||||
&session,
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
invocation,
|
||||
duration,
|
||||
Err(message.clone()),
|
||||
)
|
||||
.await;
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallToolResult {
|
||||
CallToolResult {
|
||||
content: vec![ContentBlock::TextContent(TextContent {
|
||||
annotations: None,
|
||||
text: content.to_string(),
|
||||
r#type: "text".to_string(),
|
||||
})],
|
||||
is_error: success.map(|value| !value),
|
||||
structured_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_tool_call_begin(
|
||||
session: &Arc<Session>,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
) {
|
||||
session
|
||||
.send_event(
|
||||
turn,
|
||||
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn emit_tool_call_end(
|
||||
session: &Arc<Session>,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
invocation: McpInvocation,
|
||||
duration: Duration,
|
||||
result: Result<CallToolResult, String>,
|
||||
) {
|
||||
session
|
||||
.send_event(
|
||||
turn,
|
||||
EventMsg::McpToolCallEnd(McpToolCallEndEvent {
|
||||
call_id: call_id.to_string(),
|
||||
invocation,
|
||||
duration,
|
||||
result,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn normalize_optional_string(input: Option<String>) -> Option<String> {
|
||||
input.and_then(|value| {
|
||||
let trimmed = value.trim().to_string();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn normalize_required_string(field: &str, value: String) -> Result<String, FunctionCallError> {
|
||||
match normalize_optional_string(Some(value)) {
|
||||
Some(normalized) => Ok(normalized),
|
||||
None => Err(FunctionCallError::RespondToModel(format!(
|
||||
"{field} must be provided"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_function_output<T>(payload: T) -> Result<ToolOutput, FunctionCallError>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let content = serde_json::to_string(&payload).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to serialize MCP resource response: {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_arguments(raw_args: &str) -> Result<Option<Value>, FunctionCallError> {
|
||||
if raw_args.trim().is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
serde_json::from_str(raw_args).map(Some).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => serde_json::from_value(value).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
|
||||
}),
|
||||
None => Err(FunctionCallError::RespondToModel(
|
||||
"failed to parse function arguments: expected value".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_args_with_default<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
|
||||
where
|
||||
T: DeserializeOwned + Default,
|
||||
{
|
||||
match arguments {
|
||||
Some(value) => parse_args(Some(value)),
|
||||
None => Ok(T::default()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use mcp_types::ListResourcesResult;
|
||||
use mcp_types::ResourceTemplate;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
|
||||
fn resource(uri: &str, name: &str) -> Resource {
|
||||
Resource {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
size: None,
|
||||
title: None,
|
||||
uri: uri.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn template(uri_template: &str, name: &str) -> ResourceTemplate {
|
||||
ResourceTemplate {
|
||||
annotations: None,
|
||||
description: None,
|
||||
mime_type: None,
|
||||
name: name.to_string(),
|
||||
title: None,
|
||||
uri_template: uri_template.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resource_with_server_serializes_server_field() {
|
||||
let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize resource");
|
||||
|
||||
assert_eq!(value["server"], json!("test"));
|
||||
assert_eq!(value["uri"], json!("memo://id"));
|
||||
assert_eq!(value["name"], json!("memo"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_single_server_copies_next_cursor() {
|
||||
let result = ListResourcesResult {
|
||||
next_cursor: Some("cursor-1".to_string()),
|
||||
resources: vec![resource("memo://id", "memo")],
|
||||
};
|
||||
let payload = ListResourcesPayload::from_single_server("srv".to_string(), result);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
|
||||
assert_eq!(value["server"], json!("srv"));
|
||||
assert_eq!(value["nextCursor"], json!("cursor-1"));
|
||||
let resources = value["resources"].as_array().expect("resources array");
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0]["server"], json!("srv"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_resources_payload_from_all_servers_is_sorted() {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]);
|
||||
map.insert(
|
||||
"alpha".to_string(),
|
||||
vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")],
|
||||
);
|
||||
|
||||
let payload = ListResourcesPayload::from_all_servers(map);
|
||||
let value = serde_json::to_value(&payload).expect("serialize payload");
|
||||
let uris: Vec<String> = value["resources"]
|
||||
.as_array()
|
||||
.expect("resources array")
|
||||
.iter()
|
||||
.map(|entry| entry["uri"].as_str().unwrap().to_string())
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
uris,
|
||||
vec![
|
||||
"memo://a-1".to_string(),
|
||||
"memo://a-2".to_string(),
|
||||
"memo://b-1".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_result_from_content_marks_success() {
|
||||
let result = call_tool_result_from_content("{}", Some(true));
|
||||
assert_eq!(result.is_error, Some(false));
|
||||
assert_eq!(result.content.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_arguments_handles_empty_and_json() {
|
||||
assert!(
|
||||
parse_arguments(" \n\t").unwrap().is_none(),
|
||||
"expected None for empty arguments"
|
||||
);
|
||||
|
||||
let value = parse_arguments(r#"{"server":"figma"}"#)
|
||||
.expect("parse json")
|
||||
.expect("value present");
|
||||
assert_eq!(value["server"], json!("figma"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn template_with_server_serializes_server_field() {
|
||||
let entry =
|
||||
ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo"));
|
||||
let value = serde_json::to_value(&entry).expect("serialize template");
|
||||
|
||||
assert_eq!(
|
||||
value,
|
||||
json!({
|
||||
"server": "srv",
|
||||
"uriTemplate": "memo://{id}",
|
||||
"name": "memo"
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
25
llmx-rs/core/src/tools/handlers/mod.rs
Normal file
25
llmx-rs/core/src/tools/handlers/mod.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
pub mod apply_patch;
|
||||
mod grep_files;
|
||||
mod list_dir;
|
||||
mod mcp;
|
||||
mod mcp_resource;
|
||||
mod plan;
|
||||
mod read_file;
|
||||
mod shell;
|
||||
mod test_sync;
|
||||
mod unified_exec;
|
||||
mod view_image;
|
||||
|
||||
pub use plan::PLAN_TOOL;
|
||||
|
||||
pub use apply_patch::ApplyPatchHandler;
|
||||
pub use grep_files::GrepFilesHandler;
|
||||
pub use list_dir::ListDirHandler;
|
||||
pub use mcp::McpHandler;
|
||||
pub use mcp_resource::McpResourceHandler;
|
||||
pub use plan::PlanHandler;
|
||||
pub use read_file::ReadFileHandler;
|
||||
pub use shell::ShellHandler;
|
||||
pub use test_sync::TestSyncHandler;
|
||||
pub use unified_exec::UnifiedExecHandler;
|
||||
pub use view_image::ViewImageHandler;
|
||||
117
llmx-rs/core/src/tools/handlers/plan.rs
Normal file
117
llmx-rs/core/src/tools/handlers/plan.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
pub struct PlanHandler;
|
||||
|
||||
pub static PLAN_TOOL: LazyLock<ToolSpec> = LazyLock::new(|| {
|
||||
let mut plan_item_props = BTreeMap::new();
|
||||
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
|
||||
plan_item_props.insert(
|
||||
"status".to_string(),
|
||||
JsonSchema::String {
|
||||
description: Some("One of: pending, in_progress, completed".to_string()),
|
||||
},
|
||||
);
|
||||
|
||||
let plan_items_schema = JsonSchema::Array {
|
||||
description: Some("The list of steps".to_string()),
|
||||
items: Box::new(JsonSchema::Object {
|
||||
properties: plan_item_props,
|
||||
required: Some(vec!["step".to_string(), "status".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
}),
|
||||
};
|
||||
|
||||
let mut properties = BTreeMap::new();
|
||||
properties.insert(
|
||||
"explanation".to_string(),
|
||||
JsonSchema::String { description: None },
|
||||
);
|
||||
properties.insert("plan".to_string(), plan_items_schema);
|
||||
|
||||
ToolSpec::Function(ResponsesApiTool {
|
||||
name: "update_plan".to_string(),
|
||||
description: r#"Updates the task plan.
|
||||
Provide an optional explanation and a list of plan items, each with a step and status.
|
||||
At most one step can be in_progress at a time.
|
||||
"#
|
||||
.to_string(),
|
||||
strict: false,
|
||||
parameters: JsonSchema::Object {
|
||||
properties,
|
||||
required: Some(vec!["plan".to_string()]),
|
||||
additional_properties: Some(false.into()),
|
||||
},
|
||||
})
|
||||
});
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for PlanHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"update_plan handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let content =
|
||||
handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render.
|
||||
/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other
|
||||
/// than forcing it to come up and document a plan (TBD how that affects performance).
|
||||
pub(crate) async fn handle_update_plan(
|
||||
session: &Session,
|
||||
turn_context: &TurnContext,
|
||||
arguments: String,
|
||||
_call_id: String,
|
||||
) -> Result<String, FunctionCallError> {
|
||||
let args = parse_update_plan_arguments(&arguments)?;
|
||||
session
|
||||
.send_event(turn_context, EventMsg::PlanUpdate(args))
|
||||
.await;
|
||||
Ok("Plan updated".to_string())
|
||||
}
|
||||
|
||||
fn parse_update_plan_arguments(arguments: &str) -> Result<UpdatePlanArgs, FunctionCallError> {
|
||||
serde_json::from_str::<UpdatePlanArgs>(arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e}"))
|
||||
})
|
||||
}
|
||||
999
llmx-rs/core/src/tools/handlers/read_file.rs
Normal file
999
llmx-rs/core/src/tools/handlers/read_file.rs
Normal file
@@ -0,0 +1,999 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct ReadFileHandler;
|
||||
|
||||
const MAX_LINE_LENGTH: usize = 500;
|
||||
const TAB_WIDTH: usize = 4;
|
||||
|
||||
// TODO(jif) add support for block comments
|
||||
const COMMENT_PREFIXES: &[&str] = &["#", "//", "--"];
|
||||
|
||||
/// JSON arguments accepted by the `read_file` tool handler.
|
||||
#[derive(Deserialize)]
|
||||
struct ReadFileArgs {
|
||||
/// Absolute path to the file that will be read.
|
||||
file_path: String,
|
||||
/// 1-indexed line number to start reading from; defaults to 1.
|
||||
#[serde(default = "defaults::offset")]
|
||||
offset: usize,
|
||||
/// Maximum number of lines to return; defaults to 2000.
|
||||
#[serde(default = "defaults::limit")]
|
||||
limit: usize,
|
||||
/// Determines whether the handler reads a simple slice or indentation-aware block.
|
||||
#[serde(default)]
|
||||
mode: ReadMode,
|
||||
/// Optional indentation configuration used when `mode` is `Indentation`.
|
||||
#[serde(default)]
|
||||
indentation: Option<IndentationArgs>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ReadMode {
|
||||
Slice,
|
||||
Indentation,
|
||||
}
|
||||
/// Additional configuration for indentation-aware reads.
|
||||
#[derive(Deserialize, Clone)]
|
||||
struct IndentationArgs {
|
||||
/// Optional explicit anchor line; defaults to `offset` when omitted.
|
||||
#[serde(default)]
|
||||
anchor_line: Option<usize>,
|
||||
/// Maximum indentation depth to collect; `0` means unlimited.
|
||||
#[serde(default = "defaults::max_levels")]
|
||||
max_levels: usize,
|
||||
/// Whether to include sibling blocks at the same indentation level.
|
||||
#[serde(default = "defaults::include_siblings")]
|
||||
include_siblings: bool,
|
||||
/// Whether to include header lines above the anchor block. This made on a best effort basis.
|
||||
#[serde(default = "defaults::include_header")]
|
||||
include_header: bool,
|
||||
/// Optional hard cap on returned lines; defaults to the global `limit`.
|
||||
#[serde(default)]
|
||||
max_lines: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct LineRecord {
|
||||
number: usize,
|
||||
raw: String,
|
||||
display: String,
|
||||
indent: usize,
|
||||
}
|
||||
|
||||
impl LineRecord {
|
||||
fn trimmed(&self) -> &str {
|
||||
self.raw.trim_start()
|
||||
}
|
||||
|
||||
fn is_blank(&self) -> bool {
|
||||
self.trimmed().is_empty()
|
||||
}
|
||||
|
||||
fn is_comment(&self) -> bool {
|
||||
COMMENT_PREFIXES
|
||||
.iter()
|
||||
.any(|prefix| self.raw.trim().starts_with(prefix))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ReadFileHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"read_file handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ReadFileArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let ReadFileArgs {
|
||||
file_path,
|
||||
offset,
|
||||
limit,
|
||||
mode,
|
||||
indentation,
|
||||
} = args;
|
||||
|
||||
if offset == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset must be a 1-indexed line number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"limit must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let path = PathBuf::from(&file_path);
|
||||
if !path.is_absolute() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"file_path must be an absolute path".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let collected = match mode {
|
||||
ReadMode::Slice => slice::read(&path, offset, limit).await?,
|
||||
ReadMode::Indentation => {
|
||||
let indentation = indentation.unwrap_or_default();
|
||||
indentation::read_block(&path, offset, limit, indentation).await?
|
||||
}
|
||||
};
|
||||
Ok(ToolOutput::Function {
|
||||
content: collected.join("\n"),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
mod slice {
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::handlers::read_file::format_line;
|
||||
use std::path::Path;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
|
||||
pub async fn read(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let file = File::open(path).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||
})?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
let mut collected = Vec::new();
|
||||
let mut seen = 0usize;
|
||||
let mut buffer = Vec::new();
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||
})?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if buffer.last() == Some(&b'\n') {
|
||||
buffer.pop();
|
||||
if buffer.last() == Some(&b'\r') {
|
||||
buffer.pop();
|
||||
}
|
||||
}
|
||||
|
||||
seen += 1;
|
||||
|
||||
if seen < offset {
|
||||
continue;
|
||||
}
|
||||
|
||||
if collected.len() == limit {
|
||||
break;
|
||||
}
|
||||
|
||||
let formatted = format_line(&buffer);
|
||||
collected.push(format!("L{seen}: {formatted}"));
|
||||
|
||||
if collected.len() == limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if seen < offset {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"offset exceeds file length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(collected)
|
||||
}
|
||||
}
|
||||
|
||||
mod indentation {
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::handlers::read_file::IndentationArgs;
|
||||
use crate::tools::handlers::read_file::LineRecord;
|
||||
use crate::tools::handlers::read_file::TAB_WIDTH;
|
||||
use crate::tools::handlers::read_file::format_line;
|
||||
use crate::tools::handlers::read_file::trim_empty_lines;
|
||||
use std::collections::VecDeque;
|
||||
use std::path::Path;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
|
||||
pub async fn read_block(
|
||||
path: &Path,
|
||||
offset: usize,
|
||||
limit: usize,
|
||||
options: IndentationArgs,
|
||||
) -> Result<Vec<String>, FunctionCallError> {
|
||||
let anchor_line = options.anchor_line.unwrap_or(offset);
|
||||
if anchor_line == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"anchor_line must be a 1-indexed line number".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let guard_limit = options.max_lines.unwrap_or(limit);
|
||||
if guard_limit == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"max_lines must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let collected = collect_file_lines(path).await?;
|
||||
if collected.is_empty() || anchor_line > collected.len() {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"anchor_line exceeds file length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let anchor_index = anchor_line - 1;
|
||||
let effective_indents = compute_effective_indents(&collected);
|
||||
let anchor_indent = effective_indents[anchor_index];
|
||||
|
||||
// Compute the min indent
|
||||
let min_indent = if options.max_levels == 0 {
|
||||
0
|
||||
} else {
|
||||
anchor_indent.saturating_sub(options.max_levels * TAB_WIDTH)
|
||||
};
|
||||
|
||||
// Cap requested lines by guard_limit and file length
|
||||
let final_limit = limit.min(guard_limit).min(collected.len());
|
||||
|
||||
if final_limit == 1 {
|
||||
return Ok(vec![format!(
|
||||
"L{}: {}",
|
||||
collected[anchor_index].number, collected[anchor_index].display
|
||||
)]);
|
||||
}
|
||||
|
||||
// Cursors
|
||||
let mut i: isize = anchor_index as isize - 1; // up (inclusive)
|
||||
let mut j: usize = anchor_index + 1; // down (inclusive)
|
||||
let mut i_counter_min_indent = 0;
|
||||
let mut j_counter_min_indent = 0;
|
||||
|
||||
let mut out = VecDeque::with_capacity(limit);
|
||||
out.push_back(&collected[anchor_index]);
|
||||
|
||||
while out.len() < final_limit {
|
||||
let mut progressed = 0;
|
||||
|
||||
// Up.
|
||||
if i >= 0 {
|
||||
let iu = i as usize;
|
||||
if effective_indents[iu] >= min_indent {
|
||||
out.push_front(&collected[iu]);
|
||||
progressed += 1;
|
||||
i -= 1;
|
||||
|
||||
// We do not include the siblings (not applied to comments).
|
||||
if effective_indents[iu] == min_indent && !options.include_siblings {
|
||||
let allow_header_comment =
|
||||
options.include_header && collected[iu].is_comment();
|
||||
let can_take_line = allow_header_comment || i_counter_min_indent == 0;
|
||||
|
||||
if can_take_line {
|
||||
i_counter_min_indent += 1;
|
||||
} else {
|
||||
// This line shouldn't have been taken.
|
||||
out.pop_front();
|
||||
progressed -= 1;
|
||||
i = -1; // consider using Option<usize> or a control flag instead of a sentinel
|
||||
}
|
||||
}
|
||||
|
||||
// Short-cut.
|
||||
if out.len() >= final_limit {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Stop moving up.
|
||||
i = -1;
|
||||
}
|
||||
}
|
||||
|
||||
// Down.
|
||||
if j < collected.len() {
|
||||
let ju = j;
|
||||
if effective_indents[ju] >= min_indent {
|
||||
out.push_back(&collected[ju]);
|
||||
progressed += 1;
|
||||
j += 1;
|
||||
|
||||
// We do not include the siblings (applied to comments).
|
||||
if effective_indents[ju] == min_indent && !options.include_siblings {
|
||||
if j_counter_min_indent > 0 {
|
||||
// This line shouldn't have been taken.
|
||||
out.pop_back();
|
||||
progressed -= 1;
|
||||
j = collected.len();
|
||||
}
|
||||
j_counter_min_indent += 1;
|
||||
}
|
||||
} else {
|
||||
// Stop moving down.
|
||||
j = collected.len();
|
||||
}
|
||||
}
|
||||
|
||||
if progressed == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Trim empty lines
|
||||
trim_empty_lines(&mut out);
|
||||
|
||||
Ok(out
|
||||
.into_iter()
|
||||
.map(|record| format!("L{}: {}", record.number, record.display))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn collect_file_lines(path: &Path) -> Result<Vec<LineRecord>, FunctionCallError> {
|
||||
let file = File::open(path).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||
})?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
let mut buffer = Vec::new();
|
||||
let mut lines = Vec::new();
|
||||
let mut number = 0usize;
|
||||
|
||||
loop {
|
||||
buffer.clear();
|
||||
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
|
||||
})?;
|
||||
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
if buffer.last() == Some(&b'\n') {
|
||||
buffer.pop();
|
||||
if buffer.last() == Some(&b'\r') {
|
||||
buffer.pop();
|
||||
}
|
||||
}
|
||||
|
||||
number += 1;
|
||||
let raw = String::from_utf8_lossy(&buffer).into_owned();
|
||||
let indent = measure_indent(&raw);
|
||||
let display = format_line(&buffer);
|
||||
lines.push(LineRecord {
|
||||
number,
|
||||
raw,
|
||||
display,
|
||||
indent,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(lines)
|
||||
}
|
||||
|
||||
fn compute_effective_indents(records: &[LineRecord]) -> Vec<usize> {
|
||||
let mut effective = Vec::with_capacity(records.len());
|
||||
let mut previous_indent = 0usize;
|
||||
for record in records {
|
||||
if record.is_blank() {
|
||||
effective.push(previous_indent);
|
||||
} else {
|
||||
previous_indent = record.indent;
|
||||
effective.push(previous_indent);
|
||||
}
|
||||
}
|
||||
effective
|
||||
}
|
||||
|
||||
fn measure_indent(line: &str) -> usize {
|
||||
line.chars()
|
||||
.take_while(|c| matches!(c, ' ' | '\t'))
|
||||
.map(|c| if c == '\t' { TAB_WIDTH } else { 1 })
|
||||
.sum()
|
||||
}
|
||||
}
|
||||
|
||||
fn format_line(bytes: &[u8]) -> String {
|
||||
let decoded = String::from_utf8_lossy(bytes);
|
||||
if decoded.len() > MAX_LINE_LENGTH {
|
||||
take_bytes_at_char_boundary(&decoded, MAX_LINE_LENGTH).to_string()
|
||||
} else {
|
||||
decoded.into_owned()
|
||||
}
|
||||
}
|
||||
|
||||
fn trim_empty_lines(out: &mut VecDeque<&LineRecord>) {
|
||||
while matches!(out.front(), Some(line) if line.raw.trim().is_empty()) {
|
||||
out.pop_front();
|
||||
}
|
||||
while matches!(out.back(), Some(line) if line.raw.trim().is_empty()) {
|
||||
out.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
mod defaults {
|
||||
use super::*;
|
||||
|
||||
impl Default for IndentationArgs {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
anchor_line: None,
|
||||
max_levels: max_levels(),
|
||||
include_siblings: include_siblings(),
|
||||
include_header: include_header(),
|
||||
max_lines: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ReadMode {
|
||||
fn default() -> Self {
|
||||
Self::Slice
|
||||
}
|
||||
}
|
||||
|
||||
pub fn offset() -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn limit() -> usize {
|
||||
2000
|
||||
}
|
||||
|
||||
pub fn max_levels() -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn include_siblings() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn include_header() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::indentation::read_block;
|
||||
use super::slice::read;
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[tokio::test]
|
||||
async fn reads_requested_range() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"alpha
|
||||
beta
|
||||
gamma
|
||||
"
|
||||
)?;
|
||||
|
||||
let lines = read(temp.path(), 2, 2).await?;
|
||||
assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn errors_when_offset_exceeds_length() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
writeln!(temp, "only")?;
|
||||
|
||||
let err = read(temp.path(), 3, 1)
|
||||
.await
|
||||
.expect_err("offset exceeds length");
|
||||
assert_eq!(
|
||||
err,
|
||||
FunctionCallError::RespondToModel("offset exceeds file length".to_string())
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reads_non_utf8_lines() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
temp.as_file_mut().write_all(b"\xff\xfe\nplain\n")?;
|
||||
|
||||
let lines = read(temp.path(), 1, 2).await?;
|
||||
let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}');
|
||||
assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn trims_crlf_endings() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(temp, "one\r\ntwo\r\n")?;
|
||||
|
||||
let lines = read(temp.path(), 1, 2).await?;
|
||||
assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn respects_limit_even_with_more_lines() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"first
|
||||
second
|
||||
third
|
||||
"
|
||||
)?;
|
||||
|
||||
let lines = read(temp.path(), 1, 2).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec!["L1: first".to_string(), "L2: second".to_string()]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn truncates_lines_longer_than_max_length() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
let long_line = "x".repeat(MAX_LINE_LENGTH + 50);
|
||||
writeln!(temp, "{long_line}")?;
|
||||
|
||||
let lines = read(temp.path(), 1, 1).await?;
|
||||
let expected = "x".repeat(MAX_LINE_LENGTH);
|
||||
assert_eq!(lines, vec![format!("L1: {expected}")]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_captures_block() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"fn outer() {{
|
||||
if cond {{
|
||||
inner();
|
||||
}}
|
||||
tail();
|
||||
}}
|
||||
"
|
||||
)?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
anchor_line: Some(3),
|
||||
include_siblings: false,
|
||||
max_levels: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 3, 10, options).await?;
|
||||
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L2: if cond {".to_string(),
|
||||
"L3: inner();".to_string(),
|
||||
"L4: }".to_string()
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_expands_parents() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"mod root {{
|
||||
fn outer() {{
|
||||
if cond {{
|
||||
inner();
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"
|
||||
)?;
|
||||
|
||||
let mut options = IndentationArgs {
|
||||
anchor_line: Some(4),
|
||||
max_levels: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 4, 50, options.clone()).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L2: fn outer() {".to_string(),
|
||||
"L3: if cond {".to_string(),
|
||||
"L4: inner();".to_string(),
|
||||
"L5: }".to_string(),
|
||||
"L6: }".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
options.max_levels = 3;
|
||||
let expanded = read_block(temp.path(), 4, 50, options).await?;
|
||||
assert_eq!(
|
||||
expanded,
|
||||
vec![
|
||||
"L1: mod root {".to_string(),
|
||||
"L2: fn outer() {".to_string(),
|
||||
"L3: if cond {".to_string(),
|
||||
"L4: inner();".to_string(),
|
||||
"L5: }".to_string(),
|
||||
"L6: }".to_string(),
|
||||
"L7: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_respects_sibling_flag() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"fn wrapper() {{
|
||||
if first {{
|
||||
do_first();
|
||||
}}
|
||||
if second {{
|
||||
do_second();
|
||||
}}
|
||||
}}
|
||||
"
|
||||
)?;
|
||||
|
||||
let mut options = IndentationArgs {
|
||||
anchor_line: Some(3),
|
||||
include_siblings: false,
|
||||
max_levels: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 3, 50, options.clone()).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L2: if first {".to_string(),
|
||||
"L3: do_first();".to_string(),
|
||||
"L4: }".to_string(),
|
||||
]
|
||||
);
|
||||
|
||||
options.include_siblings = true;
|
||||
let with_siblings = read_block(temp.path(), 3, 50, options).await?;
|
||||
assert_eq!(
|
||||
with_siblings,
|
||||
vec![
|
||||
"L2: if first {".to_string(),
|
||||
"L3: do_first();".to_string(),
|
||||
"L4: }".to_string(),
|
||||
"L5: if second {".to_string(),
|
||||
"L6: do_second();".to_string(),
|
||||
"L7: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_handles_python_sample() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"class Foo:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
def double(self, value):
|
||||
if value is None:
|
||||
return 0
|
||||
result = value * self.size
|
||||
return result
|
||||
class Bar:
|
||||
def compute(self):
|
||||
helper = Foo(2)
|
||||
return helper.double(5)
|
||||
"
|
||||
)?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
anchor_line: Some(7),
|
||||
include_siblings: true,
|
||||
max_levels: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 1, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L2: def __init__(self, size):".to_string(),
|
||||
"L3: self.size = size".to_string(),
|
||||
"L4: def double(self, value):".to_string(),
|
||||
"L5: if value is None:".to_string(),
|
||||
"L6: return 0".to_string(),
|
||||
"L7: result = value * self.size".to_string(),
|
||||
"L8: return result".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore]
|
||||
async fn indentation_mode_handles_javascript_sample() -> anyhow::Result<()> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"export function makeThing() {{
|
||||
const cache = new Map();
|
||||
function ensure(key) {{
|
||||
if (!cache.has(key)) {{
|
||||
cache.set(key, []);
|
||||
}}
|
||||
return cache.get(key);
|
||||
}}
|
||||
const handlers = {{
|
||||
init() {{
|
||||
console.log(\"init\");
|
||||
}},
|
||||
run() {{
|
||||
if (Math.random() > 0.5) {{
|
||||
return \"heads\";
|
||||
}}
|
||||
return \"tails\";
|
||||
}},
|
||||
}};
|
||||
return {{ cache, handlers }};
|
||||
}}
|
||||
export function other() {{
|
||||
return makeThing();
|
||||
}}
|
||||
"
|
||||
)?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
anchor_line: Some(15),
|
||||
max_levels: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 15, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L10: init() {".to_string(),
|
||||
"L11: console.log(\"init\");".to_string(),
|
||||
"L12: },".to_string(),
|
||||
"L13: run() {".to_string(),
|
||||
"L14: if (Math.random() > 0.5) {".to_string(),
|
||||
"L15: return \"heads\";".to_string(),
|
||||
"L16: }".to_string(),
|
||||
"L17: return \"tails\";".to_string(),
|
||||
"L18: },".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn write_cpp_sample() -> anyhow::Result<NamedTempFile> {
|
||||
let mut temp = NamedTempFile::new()?;
|
||||
use std::io::Write as _;
|
||||
write!(
|
||||
temp,
|
||||
"#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace sample {{
|
||||
class Runner {{
|
||||
public:
|
||||
void setup() {{
|
||||
if (enabled_) {{
|
||||
init();
|
||||
}}
|
||||
}}
|
||||
|
||||
// Run the code
|
||||
int run() const {{
|
||||
switch (mode_) {{
|
||||
case Mode::Fast:
|
||||
return fast();
|
||||
case Mode::Slow:
|
||||
return slow();
|
||||
default:
|
||||
return fallback();
|
||||
}}
|
||||
}}
|
||||
|
||||
private:
|
||||
bool enabled_ = false;
|
||||
Mode mode_ = Mode::Fast;
|
||||
|
||||
int fast() const {{
|
||||
return 1;
|
||||
}}
|
||||
}};
|
||||
}} // namespace sample
|
||||
"
|
||||
)?;
|
||||
Ok(temp)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_handles_cpp_sample_shallow() -> anyhow::Result<()> {
|
||||
let temp = write_cpp_sample()?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
include_siblings: false,
|
||||
anchor_line: Some(18),
|
||||
max_levels: 1,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 18, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L15: switch (mode_) {".to_string(),
|
||||
"L16: case Mode::Fast:".to_string(),
|
||||
"L17: return fast();".to_string(),
|
||||
"L18: case Mode::Slow:".to_string(),
|
||||
"L19: return slow();".to_string(),
|
||||
"L20: default:".to_string(),
|
||||
"L21: return fallback();".to_string(),
|
||||
"L22: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_handles_cpp_sample() -> anyhow::Result<()> {
|
||||
let temp = write_cpp_sample()?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
include_siblings: false,
|
||||
anchor_line: Some(18),
|
||||
max_levels: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 18, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L13: // Run the code".to_string(),
|
||||
"L14: int run() const {".to_string(),
|
||||
"L15: switch (mode_) {".to_string(),
|
||||
"L16: case Mode::Fast:".to_string(),
|
||||
"L17: return fast();".to_string(),
|
||||
"L18: case Mode::Slow:".to_string(),
|
||||
"L19: return slow();".to_string(),
|
||||
"L20: default:".to_string(),
|
||||
"L21: return fallback();".to_string(),
|
||||
"L22: }".to_string(),
|
||||
"L23: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_handles_cpp_sample_no_headers() -> anyhow::Result<()> {
|
||||
let temp = write_cpp_sample()?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
include_siblings: false,
|
||||
include_header: false,
|
||||
anchor_line: Some(18),
|
||||
max_levels: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 18, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L14: int run() const {".to_string(),
|
||||
"L15: switch (mode_) {".to_string(),
|
||||
"L16: case Mode::Fast:".to_string(),
|
||||
"L17: return fast();".to_string(),
|
||||
"L18: case Mode::Slow:".to_string(),
|
||||
"L19: return slow();".to_string(),
|
||||
"L20: default:".to_string(),
|
||||
"L21: return fallback();".to_string(),
|
||||
"L22: }".to_string(),
|
||||
"L23: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn indentation_mode_handles_cpp_sample_siblings() -> anyhow::Result<()> {
|
||||
let temp = write_cpp_sample()?;
|
||||
|
||||
let options = IndentationArgs {
|
||||
include_siblings: true,
|
||||
include_header: false,
|
||||
anchor_line: Some(18),
|
||||
max_levels: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lines = read_block(temp.path(), 18, 200, options).await?;
|
||||
assert_eq!(
|
||||
lines,
|
||||
vec![
|
||||
"L7: void setup() {".to_string(),
|
||||
"L8: if (enabled_) {".to_string(),
|
||||
"L9: init();".to_string(),
|
||||
"L10: }".to_string(),
|
||||
"L11: }".to_string(),
|
||||
"L12: ".to_string(),
|
||||
"L13: // Run the code".to_string(),
|
||||
"L14: int run() const {".to_string(),
|
||||
"L15: switch (mode_) {".to_string(),
|
||||
"L16: case Mode::Fast:".to_string(),
|
||||
"L17: return fast();".to_string(),
|
||||
"L18: case Mode::Slow:".to_string(),
|
||||
"L19: return slow();".to_string(),
|
||||
"L20: default:".to_string(),
|
||||
"L21: return fallback();".to_string(),
|
||||
"L22: }".to_string(),
|
||||
"L23: }".to_string(),
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
242
llmx-rs/core/src/tools/handlers/shell.rs
Normal file
242
llmx-rs/core/src/tools/handlers/shell.rs
Normal file
@@ -0,0 +1,242 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::apply_patch;
|
||||
use crate::apply_patch::InternalApplyPatchInvocation;
|
||||
use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::exec::ExecParams;
|
||||
use crate::exec_env::create_env;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::events::ToolEmitter;
|
||||
use crate::tools::events::ToolEventCtx;
|
||||
use crate::tools::orchestrator::ToolOrchestrator;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
|
||||
use crate::tools::runtimes::shell::ShellRequest;
|
||||
use crate::tools::runtimes::shell::ShellRuntime;
|
||||
use crate::tools::sandboxing::ToolCtx;
|
||||
|
||||
pub struct ShellHandler;
|
||||
|
||||
impl ShellHandler {
|
||||
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
|
||||
ExecParams {
|
||||
command: params.command,
|
||||
cwd: turn_context.resolve_path(params.workdir.clone()),
|
||||
timeout_ms: params.timeout_ms,
|
||||
env: create_env(&turn_context.shell_environment_policy),
|
||||
with_escalated_permissions: params.with_escalated_permissions,
|
||||
justification: params.justification,
|
||||
arg0: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ShellHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::Function { .. } | ToolPayload::LocalShell { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
} = invocation;
|
||||
|
||||
match payload {
|
||||
ToolPayload::Function { arguments } => {
|
||||
let params: ShellToolCallParams =
|
||||
serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {e:?}"
|
||||
))
|
||||
})?;
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
Self::run_exec_like(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call_id,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}
|
||||
ToolPayload::LocalShell { params } => {
|
||||
let exec_params = Self::to_exec_params(params, turn.as_ref());
|
||||
Self::run_exec_like(
|
||||
tool_name.as_str(),
|
||||
exec_params,
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call_id,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported payload for shell handler: {tool_name}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ShellHandler {
|
||||
async fn run_exec_like(
|
||||
tool_name: &str,
|
||||
exec_params: ExecParams,
|
||||
session: Arc<crate::codex::Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: crate::tools::context::SharedTurnDiffTracker,
|
||||
call_id: String,
|
||||
is_user_shell_command: bool,
|
||||
) -> Result<ToolOutput, FunctionCallError> {
|
||||
// Approval policy guard for explicit escalation in non-OnRequest modes.
|
||||
if exec_params.with_escalated_permissions.unwrap_or(false)
|
||||
&& !matches!(
|
||||
turn.approval_policy,
|
||||
codex_protocol::protocol::AskForApproval::OnRequest
|
||||
)
|
||||
{
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
|
||||
policy = turn.approval_policy
|
||||
)));
|
||||
}
|
||||
|
||||
// Intercept apply_patch if present.
|
||||
match codex_apply_patch::maybe_parse_apply_patch_verified(
|
||||
&exec_params.command,
|
||||
&exec_params.cwd,
|
||||
) {
|
||||
codex_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
|
||||
match apply_patch::apply_patch(session.as_ref(), turn.as_ref(), &call_id, changes)
|
||||
.await
|
||||
{
|
||||
InternalApplyPatchInvocation::Output(item) => {
|
||||
// Programmatic apply_patch path; return its result.
|
||||
let content = item?;
|
||||
return Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
});
|
||||
}
|
||||
InternalApplyPatchInvocation::DelegateToExec(apply) => {
|
||||
let emitter = ToolEmitter::apply_patch(
|
||||
convert_apply_patch_to_protocol(&apply.action),
|
||||
!apply.user_explicitly_approved_this_action,
|
||||
);
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
Some(&tracker),
|
||||
);
|
||||
emitter.begin(event_ctx).await;
|
||||
|
||||
let req = ApplyPatchRequest {
|
||||
patch: apply.action.patch.clone(),
|
||||
cwd: apply.action.cwd.clone(),
|
||||
timeout_ms: exec_params.timeout_ms,
|
||||
user_explicitly_approved: apply.user_explicitly_approved_this_action,
|
||||
codex_exe: turn.codex_linux_sandbox_exe.clone(),
|
||||
};
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ApplyPatchRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: session.as_ref(),
|
||||
turn: turn.as_ref(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
let out = orchestrator
|
||||
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
|
||||
.await;
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
session.as_ref(),
|
||||
turn.as_ref(),
|
||||
&call_id,
|
||||
Some(&tracker),
|
||||
);
|
||||
let content = emitter.finish(event_ctx, out).await?;
|
||||
return Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"apply_patch verification failed: {parse_error}"
|
||||
)));
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => {
|
||||
tracing::trace!("Failed to parse shell command, {error:?}");
|
||||
// Fall through to regular shell execution.
|
||||
}
|
||||
codex_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => {
|
||||
// Fall through to regular shell execution.
|
||||
}
|
||||
}
|
||||
|
||||
// Regular shell execution path.
|
||||
let emitter = ToolEmitter::shell(
|
||||
exec_params.command.clone(),
|
||||
exec_params.cwd.clone(),
|
||||
is_user_shell_command,
|
||||
);
|
||||
let event_ctx = ToolEventCtx::new(session.as_ref(), turn.as_ref(), &call_id, None);
|
||||
emitter.begin(event_ctx).await;
|
||||
|
||||
let req = ShellRequest {
|
||||
command: exec_params.command.clone(),
|
||||
cwd: exec_params.cwd.clone(),
|
||||
timeout_ms: exec_params.timeout_ms,
|
||||
env: exec_params.env.clone(),
|
||||
with_escalated_permissions: exec_params.with_escalated_permissions,
|
||||
justification: exec_params.justification.clone(),
|
||||
};
|
||||
let mut orchestrator = ToolOrchestrator::new();
|
||||
let mut runtime = ShellRuntime::new();
|
||||
let tool_ctx = ToolCtx {
|
||||
session: session.as_ref(),
|
||||
turn: turn.as_ref(),
|
||||
call_id: call_id.clone(),
|
||||
tool_name: tool_name.to_string(),
|
||||
};
|
||||
let out = orchestrator
|
||||
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
|
||||
.await;
|
||||
let event_ctx = ToolEventCtx::new(session.as_ref(), turn.as_ref(), &call_id, None);
|
||||
let content = emitter.finish(event_ctx, out).await?;
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
159
llmx-rs/core/src/tools/handlers/test_sync.rs
Normal file
159
llmx-rs/core/src/tools/handlers/test_sync.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
|
||||
pub struct TestSyncHandler;
|
||||
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
|
||||
|
||||
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
|
||||
|
||||
struct BarrierState {
|
||||
barrier: Arc<Barrier>,
|
||||
participants: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct BarrierArgs {
|
||||
id: String,
|
||||
participants: usize,
|
||||
#[serde(default = "default_timeout_ms")]
|
||||
timeout_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TestSyncArgs {
|
||||
#[serde(default)]
|
||||
sleep_before_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
sleep_after_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
barrier: Option<BarrierArgs>,
|
||||
}
|
||||
|
||||
fn default_timeout_ms() -> u64 {
|
||||
DEFAULT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
|
||||
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for TestSyncHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation { payload, .. } = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"test_sync_tool handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse function arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
|
||||
if let Some(delay) = args.sleep_before_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
if let Some(barrier) = args.barrier {
|
||||
wait_on_barrier(barrier).await?;
|
||||
}
|
||||
|
||||
if let Some(delay) = args.sleep_after_ms
|
||||
&& delay > 0
|
||||
{
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "ok".to_string(),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
|
||||
if args.participants == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier participants must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if args.timeout_ms == 0 {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"barrier timeout must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let barrier_id = args.id.clone();
|
||||
let barrier = {
|
||||
let mut map = barrier_map().lock().await;
|
||||
match map.entry(barrier_id.clone()) {
|
||||
Entry::Occupied(entry) => {
|
||||
let state = entry.get();
|
||||
if state.participants != args.participants {
|
||||
let existing = state.participants;
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"barrier {barrier_id} already registered with {existing} participants"
|
||||
)));
|
||||
}
|
||||
state.barrier.clone()
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let barrier = Arc::new(Barrier::new(args.participants));
|
||||
entry.insert(BarrierState {
|
||||
barrier: barrier.clone(),
|
||||
participants: args.participants,
|
||||
});
|
||||
barrier
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let timeout = Duration::from_millis(args.timeout_ms);
|
||||
let wait_result = tokio::time::timeout(timeout, barrier.wait())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
|
||||
})?;
|
||||
|
||||
if wait_result.is_leader() {
|
||||
let mut map = barrier_map().lock().await;
|
||||
if let Some(state) = map.get(&barrier_id)
|
||||
&& Arc::ptr_eq(&state.barrier, &barrier)
|
||||
{
|
||||
map.remove(&barrier_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
19
llmx-rs/core/src/tools/handlers/tool_apply_patch.lark
Normal file
19
llmx-rs/core/src/tools/handlers/tool_apply_patch.lark
Normal file
@@ -0,0 +1,19 @@
|
||||
start: begin_patch hunk+ end_patch
|
||||
begin_patch: "*** Begin Patch" LF
|
||||
end_patch: "*** End Patch" LF?
|
||||
|
||||
hunk: add_hunk | delete_hunk | update_hunk
|
||||
add_hunk: "*** Add File: " filename LF add_line+
|
||||
delete_hunk: "*** Delete File: " filename LF
|
||||
update_hunk: "*** Update File: " filename LF change_move? change?
|
||||
|
||||
filename: /(.+)/
|
||||
add_line: "+" /(.*)/ LF -> line
|
||||
|
||||
change_move: "*** Move to: " filename LF
|
||||
change: (change_context | change_line)+ eof_line?
|
||||
change_context: ("@@" | "@@ " /(.+)/) LF
|
||||
change_line: ("+" | "-" | " ") /(.*)/ LF
|
||||
eof_line: "*** End of File" LF
|
||||
|
||||
%import common.LF
|
||||
209
llmx-rs/core/src/tools/handlers/unified_exec.rs
Normal file
209
llmx-rs/core/src/tools/handlers/unified_exec.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ExecCommandOutputDeltaEvent;
|
||||
use crate::protocol::ExecOutputStream;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::events::ToolEmitter;
|
||||
use crate::tools::events::ToolEventCtx;
|
||||
use crate::tools::events::ToolEventStage;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::unified_exec::ExecCommandRequest;
|
||||
use crate::unified_exec::UnifiedExecContext;
|
||||
use crate::unified_exec::UnifiedExecResponse;
|
||||
use crate::unified_exec::UnifiedExecSessionManager;
|
||||
use crate::unified_exec::WriteStdinRequest;
|
||||
|
||||
pub struct UnifiedExecHandler;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ExecCommandArgs {
|
||||
cmd: String,
|
||||
#[serde(default)]
|
||||
workdir: Option<String>,
|
||||
#[serde(default = "default_shell")]
|
||||
shell: String,
|
||||
#[serde(default = "default_login")]
|
||||
login: bool,
|
||||
#[serde(default)]
|
||||
yield_time_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
max_output_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct WriteStdinArgs {
|
||||
session_id: i32,
|
||||
#[serde(default)]
|
||||
chars: String,
|
||||
#[serde(default)]
|
||||
yield_time_ms: Option<u64>,
|
||||
#[serde(default)]
|
||||
max_output_tokens: Option<usize>,
|
||||
}
|
||||
|
||||
fn default_shell() -> String {
|
||||
"/bin/bash".to_string()
|
||||
}
|
||||
|
||||
fn default_login() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for UnifiedExecHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
payload,
|
||||
ToolPayload::Function { .. } | ToolPayload::UnifiedExec { .. }
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
ToolPayload::UnifiedExec { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"unified_exec handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let manager: &UnifiedExecSessionManager = &session.services.unified_exec_manager;
|
||||
let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone());
|
||||
|
||||
let response = match tool_name.as_str() {
|
||||
"exec_command" => {
|
||||
let args: ExecCommandArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse exec_command arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
let workdir = args
|
||||
.workdir
|
||||
.as_deref()
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(PathBuf::from);
|
||||
let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone());
|
||||
|
||||
let event_ctx = ToolEventCtx::new(
|
||||
context.session.as_ref(),
|
||||
context.turn.as_ref(),
|
||||
&context.call_id,
|
||||
None,
|
||||
);
|
||||
let emitter = ToolEmitter::unified_exec(args.cmd.clone(), cwd.clone(), true);
|
||||
emitter.emit(event_ctx, ToolEventStage::Begin).await;
|
||||
|
||||
manager
|
||||
.exec_command(
|
||||
ExecCommandRequest {
|
||||
command: &args.cmd,
|
||||
shell: &args.shell,
|
||||
login: args.login,
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
max_output_tokens: args.max_output_tokens,
|
||||
workdir,
|
||||
},
|
||||
&context,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("exec_command failed: {err:?}"))
|
||||
})?
|
||||
}
|
||||
"write_stdin" => {
|
||||
let args: WriteStdinArgs = serde_json::from_str(&arguments).map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"failed to parse write_stdin arguments: {err:?}"
|
||||
))
|
||||
})?;
|
||||
manager
|
||||
.write_stdin(WriteStdinRequest {
|
||||
session_id: args.session_id,
|
||||
input: &args.chars,
|
||||
yield_time_ms: args.yield_time_ms,
|
||||
max_output_tokens: args.max_output_tokens,
|
||||
})
|
||||
.await
|
||||
.map_err(|err| {
|
||||
FunctionCallError::RespondToModel(format!("write_stdin failed: {err:?}"))
|
||||
})?
|
||||
}
|
||||
other => {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"unsupported unified exec function {other}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// Emit a delta event with the chunk of output we just produced, if any.
|
||||
if !response.output.is_empty() {
|
||||
let delta = ExecCommandOutputDeltaEvent {
|
||||
call_id: response.event_call_id.clone(),
|
||||
stream: ExecOutputStream::Stdout,
|
||||
chunk: response.output.as_bytes().to_vec(),
|
||||
};
|
||||
session
|
||||
.send_event(turn.as_ref(), EventMsg::ExecCommandOutputDelta(delta))
|
||||
.await;
|
||||
}
|
||||
|
||||
let content = format_response(&response);
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content,
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn format_response(response: &UnifiedExecResponse) -> String {
|
||||
let mut sections = Vec::new();
|
||||
|
||||
if !response.chunk_id.is_empty() {
|
||||
sections.push(format!("Chunk ID: {}", response.chunk_id));
|
||||
}
|
||||
|
||||
let wall_time_seconds = response.wall_time.as_secs_f64();
|
||||
sections.push(format!("Wall time: {wall_time_seconds:.4} seconds"));
|
||||
|
||||
if let Some(exit_code) = response.exit_code {
|
||||
sections.push(format!("Process exited with code {exit_code}"));
|
||||
}
|
||||
|
||||
if let Some(session_id) = response.session_id {
|
||||
sections.push(format!("Process running with session ID {session_id}"));
|
||||
}
|
||||
|
||||
if let Some(original_token_count) = response.original_token_count {
|
||||
sections.push(format!("Original token count: {original_token_count}"));
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(response.output.clone());
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
92
llmx-rs/core/src/tools/handlers/view_image.rs
Normal file
92
llmx-rs/core/src/tools/handlers/view_image.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use tokio::fs;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::ViewImageToolCallEvent;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
pub struct ViewImageHandler;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ViewImageArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ToolHandler for ViewImageHandler {
|
||||
fn kind(&self) -> ToolKind {
|
||||
ToolKind::Function
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
|
||||
let ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
payload,
|
||||
call_id,
|
||||
..
|
||||
} = invocation;
|
||||
|
||||
let arguments = match payload {
|
||||
ToolPayload::Function { arguments } => arguments,
|
||||
_ => {
|
||||
return Err(FunctionCallError::RespondToModel(
|
||||
"view_image handler received unsupported payload".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let args: ViewImageArgs = serde_json::from_str(&arguments).map_err(|e| {
|
||||
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}"))
|
||||
})?;
|
||||
|
||||
let abs_path = turn.resolve_path(Some(args.path));
|
||||
|
||||
let metadata = fs::metadata(&abs_path).await.map_err(|error| {
|
||||
FunctionCallError::RespondToModel(format!(
|
||||
"unable to locate image at `{}`: {error}",
|
||||
abs_path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
if !metadata.is_file() {
|
||||
return Err(FunctionCallError::RespondToModel(format!(
|
||||
"image path `{}` is not a file",
|
||||
abs_path.display()
|
||||
)));
|
||||
}
|
||||
let event_path = abs_path.clone();
|
||||
|
||||
session
|
||||
.inject_input(vec![UserInput::LocalImage { path: abs_path }])
|
||||
.await
|
||||
.map_err(|_| {
|
||||
FunctionCallError::RespondToModel(
|
||||
"unable to attach image (no active task)".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
session
|
||||
.send_event(
|
||||
turn.as_ref(),
|
||||
EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
|
||||
call_id,
|
||||
path: event_path,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(ToolOutput::Function {
|
||||
content: "attached local image path".to_string(),
|
||||
content_items: None,
|
||||
success: Some(true),
|
||||
})
|
||||
}
|
||||
}
|
||||
79
llmx-rs/core/src/tools/mod.rs
Normal file
79
llmx-rs/core/src/tools/mod.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
pub mod context;
|
||||
pub mod events;
|
||||
pub(crate) mod handlers;
|
||||
pub mod orchestrator;
|
||||
pub mod parallel;
|
||||
pub mod registry;
|
||||
pub mod router;
|
||||
pub mod runtimes;
|
||||
pub mod sandboxing;
|
||||
pub mod spec;
|
||||
|
||||
use crate::context_manager::format_output_for_model_body;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
pub use router::ToolRouter;
|
||||
use serde::Serialize;
|
||||
|
||||
// Telemetry preview limits: keep log events smaller than model budgets.
|
||||
pub(crate) const TELEMETRY_PREVIEW_MAX_BYTES: usize = 2 * 1024; // 2 KiB
|
||||
pub(crate) const TELEMETRY_PREVIEW_MAX_LINES: usize = 64; // lines
|
||||
pub(crate) const TELEMETRY_PREVIEW_TRUNCATION_NOTICE: &str =
|
||||
"[... telemetry preview truncated ...]";
|
||||
|
||||
/// Format the combined exec output for sending back to the model.
|
||||
/// Includes exit code and duration metadata; truncates large bodies safely.
|
||||
pub fn format_exec_output_for_model(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
exit_code,
|
||||
duration,
|
||||
..
|
||||
} = exec_output;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ExecMetadata {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ExecOutput<'a> {
|
||||
output: &'a str,
|
||||
metadata: ExecMetadata,
|
||||
}
|
||||
|
||||
// round to 1 decimal place
|
||||
let duration_seconds = ((duration.as_secs_f32()) * 10.0).round() / 10.0;
|
||||
|
||||
let formatted_output = format_exec_output_str(exec_output);
|
||||
|
||||
let payload = ExecOutput {
|
||||
output: &formatted_output,
|
||||
metadata: ExecMetadata {
|
||||
exit_code: *exit_code,
|
||||
duration_seconds,
|
||||
},
|
||||
};
|
||||
|
||||
#[expect(clippy::expect_used)]
|
||||
serde_json::to_string(&payload).expect("serialize ExecOutput")
|
||||
}
|
||||
|
||||
pub fn format_exec_output_str(exec_output: &ExecToolCallOutput) -> String {
|
||||
let ExecToolCallOutput {
|
||||
aggregated_output, ..
|
||||
} = exec_output;
|
||||
|
||||
let content = aggregated_output.text.as_str();
|
||||
|
||||
let body = if exec_output.timed_out {
|
||||
format!(
|
||||
"command timed out after {} milliseconds\n{content}",
|
||||
exec_output.duration.as_millis()
|
||||
)
|
||||
} else {
|
||||
content.to_string()
|
||||
};
|
||||
|
||||
// Truncate for model consumption before serialization.
|
||||
format_output_for_model_body(&body)
|
||||
}
|
||||
186
llmx-rs/core/src/tools/orchestrator.rs
Normal file
186
llmx-rs/core/src/tools/orchestrator.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
/*
|
||||
Module: orchestrator
|
||||
|
||||
Central place for approvals + sandbox selection + retry semantics. Drives a
|
||||
simple sequence for any ToolRuntime: approval → select sandbox → attempt →
|
||||
retry without sandbox on denial (no re‑approval thanks to caching).
|
||||
*/
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::SandboxErr;
|
||||
use crate::error::get_error_message_ui;
|
||||
use crate::exec::ExecToolCallOutput;
|
||||
use crate::sandboxing::SandboxManager;
|
||||
use crate::tools::sandboxing::ApprovalCtx;
|
||||
use crate::tools::sandboxing::ProvidesSandboxRetryData;
|
||||
use crate::tools::sandboxing::SandboxAttempt;
|
||||
use crate::tools::sandboxing::ToolCtx;
|
||||
use crate::tools::sandboxing::ToolError;
|
||||
use crate::tools::sandboxing::ToolRuntime;
|
||||
use codex_protocol::protocol::AskForApproval;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
|
||||
pub(crate) struct ToolOrchestrator {
|
||||
sandbox: SandboxManager,
|
||||
}
|
||||
|
||||
impl ToolOrchestrator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: SandboxManager::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run<Rq, Out, T>(
|
||||
&mut self,
|
||||
tool: &mut T,
|
||||
req: &Rq,
|
||||
tool_ctx: &ToolCtx<'_>,
|
||||
turn_ctx: &crate::codex::TurnContext,
|
||||
approval_policy: AskForApproval,
|
||||
) -> Result<Out, ToolError>
|
||||
where
|
||||
T: ToolRuntime<Rq, Out>,
|
||||
Rq: ProvidesSandboxRetryData,
|
||||
{
|
||||
let otel = turn_ctx.client.get_otel_event_manager();
|
||||
let otel_tn = &tool_ctx.tool_name;
|
||||
let otel_ci = &tool_ctx.call_id;
|
||||
let otel_user = codex_otel::otel_event_manager::ToolDecisionSource::User;
|
||||
let otel_cfg = codex_otel::otel_event_manager::ToolDecisionSource::Config;
|
||||
|
||||
// 1) Approval
|
||||
let needs_initial_approval =
|
||||
tool.wants_initial_approval(req, approval_policy, &turn_ctx.sandbox_policy);
|
||||
let mut already_approved = false;
|
||||
|
||||
if needs_initial_approval {
|
||||
let mut risk = None;
|
||||
|
||||
if let Some(metadata) = req.sandbox_retry_data() {
|
||||
risk = tool_ctx
|
||||
.session
|
||||
.assess_sandbox_command(turn_ctx, &tool_ctx.call_id, &metadata.command, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: None,
|
||||
risk,
|
||||
};
|
||||
let decision = tool.start_approval_async(req, approval_ctx).await;
|
||||
|
||||
otel.tool_decision(otel_tn, otel_ci, decision, otel_user.clone());
|
||||
|
||||
match decision {
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return Err(ToolError::Rejected("rejected by user".to_string()));
|
||||
}
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {}
|
||||
}
|
||||
already_approved = true;
|
||||
} else {
|
||||
otel.tool_decision(otel_tn, otel_ci, ReviewDecision::Approved, otel_cfg);
|
||||
}
|
||||
|
||||
// 2) First attempt under the selected sandbox.
|
||||
let mut initial_sandbox = self
|
||||
.sandbox
|
||||
.select_initial(&turn_ctx.sandbox_policy, tool.sandbox_preference());
|
||||
if tool.wants_escalated_first_attempt(req) {
|
||||
initial_sandbox = crate::exec::SandboxType::None;
|
||||
}
|
||||
// Platform-specific flag gating is handled by SandboxManager::select_initial
|
||||
// via crate::safety::get_platform_sandbox().
|
||||
let initial_attempt = SandboxAttempt {
|
||||
sandbox: initial_sandbox,
|
||||
policy: &turn_ctx.sandbox_policy,
|
||||
manager: &self.sandbox,
|
||||
sandbox_cwd: &turn_ctx.cwd,
|
||||
codex_linux_sandbox_exe: turn_ctx.codex_linux_sandbox_exe.as_ref(),
|
||||
};
|
||||
|
||||
match tool.run(req, &initial_attempt, tool_ctx).await {
|
||||
Ok(out) => {
|
||||
// We have a successful initial result
|
||||
Ok(out)
|
||||
}
|
||||
Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied { output }))) => {
|
||||
if !tool.escalate_on_failure() {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output,
|
||||
})));
|
||||
}
|
||||
// Under `Never` or `OnRequest`, do not retry without sandbox; surface a concise
|
||||
// sandbox denial that preserves the original output.
|
||||
if !tool.wants_no_sandbox_approval(approval_policy) {
|
||||
return Err(ToolError::Codex(CodexErr::Sandbox(SandboxErr::Denied {
|
||||
output,
|
||||
})));
|
||||
}
|
||||
|
||||
// Ask for approval before retrying without sandbox.
|
||||
if !tool.should_bypass_approval(approval_policy, already_approved) {
|
||||
let mut risk = None;
|
||||
|
||||
if let Some(metadata) = req.sandbox_retry_data() {
|
||||
let err = SandboxErr::Denied {
|
||||
output: output.clone(),
|
||||
};
|
||||
let friendly = get_error_message_ui(&CodexErr::Sandbox(err));
|
||||
let failure_summary = format!("failed in sandbox: {friendly}");
|
||||
|
||||
risk = tool_ctx
|
||||
.session
|
||||
.assess_sandbox_command(
|
||||
turn_ctx,
|
||||
&tool_ctx.call_id,
|
||||
&metadata.command,
|
||||
Some(failure_summary.as_str()),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
let reason_msg = build_denial_reason_from_output(output.as_ref());
|
||||
let approval_ctx = ApprovalCtx {
|
||||
session: tool_ctx.session,
|
||||
turn: turn_ctx,
|
||||
call_id: &tool_ctx.call_id,
|
||||
retry_reason: Some(reason_msg),
|
||||
risk,
|
||||
};
|
||||
|
||||
let decision = tool.start_approval_async(req, approval_ctx).await;
|
||||
otel.tool_decision(otel_tn, otel_ci, decision, otel_user);
|
||||
|
||||
match decision {
|
||||
ReviewDecision::Denied | ReviewDecision::Abort => {
|
||||
return Err(ToolError::Rejected("rejected by user".to_string()));
|
||||
}
|
||||
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {}
|
||||
}
|
||||
}
|
||||
|
||||
let escalated_attempt = SandboxAttempt {
|
||||
sandbox: crate::exec::SandboxType::None,
|
||||
policy: &turn_ctx.sandbox_policy,
|
||||
manager: &self.sandbox,
|
||||
sandbox_cwd: &turn_ctx.cwd,
|
||||
codex_linux_sandbox_exe: None,
|
||||
};
|
||||
|
||||
// Second attempt.
|
||||
(*tool).run(req, &escalated_attempt, tool_ctx).await
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_denial_reason_from_output(_output: &ExecToolCallOutput) -> String {
|
||||
// Keep approval reason terse and stable for UX/tests, but accept the
|
||||
// output so we can evolve heuristics later without touching call sites.
|
||||
"command failed; retry without sandbox?".to_string()
|
||||
}
|
||||
126
llmx-rs/core/src/tools/parallel.rs
Normal file
126
llmx-rs/core/src/tools/parallel.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::CodexErr;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::router::ToolCall;
|
||||
use crate::tools::router::ToolRouter;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_utils_readiness::Readiness;
|
||||
|
||||
pub(crate) struct ToolCallRuntime {
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
parallel_execution: Arc<RwLock<()>>,
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
pub(crate) fn new(
|
||||
router: Arc<ToolRouter>,
|
||||
session: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
) -> Self {
|
||||
Self {
|
||||
router,
|
||||
session,
|
||||
turn_context,
|
||||
tracker,
|
||||
parallel_execution: Arc::new(RwLock::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn handle_tool_call(
|
||||
&self,
|
||||
call: ToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
||||
let supports_parallel = self.router.tool_supports_parallel(&call.tool_name);
|
||||
|
||||
let router = Arc::clone(&self.router);
|
||||
let session = Arc::clone(&self.session);
|
||||
let turn = Arc::clone(&self.turn_context);
|
||||
let tracker = Arc::clone(&self.tracker);
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let started = Instant::now();
|
||||
let readiness = self.turn_context.tool_call_gate.clone();
|
||||
|
||||
let handle: AbortOnDropHandle<Result<ResponseInputItem, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
let secs = started.elapsed().as_secs_f32().max(0.1);
|
||||
Ok(Self::aborted_response(&call, secs))
|
||||
},
|
||||
res = async {
|
||||
tracing::info!("waiting for tool gate");
|
||||
readiness.wait_ready().await;
|
||||
tracing::info!("tool gate released");
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call(session, turn, tracker, call.clone())
|
||||
.await
|
||||
} => res,
|
||||
}
|
||||
}));
|
||||
|
||||
async move {
|
||||
match handle.await {
|
||||
Ok(Ok(response)) => Ok(response),
|
||||
Ok(Err(FunctionCallError::Fatal(message))) => Err(CodexErr::Fatal(message)),
|
||||
Ok(Err(other)) => Err(CodexErr::Fatal(other.to_string())),
|
||||
Err(err) => Err(CodexErr::Fatal(format!(
|
||||
"tool task failed to receive: {err:?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn aborted_response(call: &ToolCall, secs: f32) -> ResponseInputItem {
|
||||
match &call.payload {
|
||||
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: Self::abort_message(call, secs),
|
||||
},
|
||||
ToolPayload::Mcp { .. } => ResponseInputItem::McpToolCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
result: Err(Self::abort_message(call, secs)),
|
||||
},
|
||||
_ => ResponseInputItem::FunctionCallOutput {
|
||||
call_id: call.call_id.clone(),
|
||||
output: FunctionCallOutputPayload {
|
||||
content: Self::abort_message(call, secs),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn abort_message(call: &ToolCall, secs: f32) -> String {
|
||||
match call.tool_name.as_str() {
|
||||
"shell" | "container.exec" | "local_shell" | "unified_exec" => {
|
||||
format!("Wall time: {secs:.1} seconds\naborted by user")
|
||||
}
|
||||
_ => format!("aborted by user after {secs:.1}s"),
|
||||
}
|
||||
}
|
||||
}
|
||||
218
llmx-rs/core/src/tools/registry.rs
Normal file
218
llmx-rs/core/src/tools/registry.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum ToolKind {
|
||||
Function,
|
||||
Mcp,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ToolHandler: Send + Sync {
|
||||
fn kind(&self) -> ToolKind;
|
||||
|
||||
fn matches_kind(&self, payload: &ToolPayload) -> bool {
|
||||
matches!(
|
||||
(self.kind(), payload),
|
||||
(ToolKind::Function, ToolPayload::Function { .. })
|
||||
| (ToolKind::Mcp, ToolPayload::Mcp { .. })
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError>;
|
||||
}
|
||||
|
||||
pub struct ToolRegistry {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(handlers: HashMap<String, Arc<dyn ToolHandler>>) -> Self {
|
||||
Self { handlers }
|
||||
}
|
||||
|
||||
pub fn handler(&self, name: &str) -> Option<Arc<dyn ToolHandler>> {
|
||||
self.handlers.get(name).map(Arc::clone)
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
// let name = name.into();
|
||||
// if self.handlers.insert(name.clone(), handler).is_some() {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
|
||||
pub async fn dispatch(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let call_id_owned = invocation.call_id.clone();
|
||||
let otel = invocation.turn.client.get_otel_event_manager();
|
||||
let payload_for_response = invocation.payload.clone();
|
||||
let log_payload = payload_for_response.log_payload();
|
||||
|
||||
let handler = match self.handler(tool_name.as_ref()) {
|
||||
Some(handler) => handler,
|
||||
None => {
|
||||
let message =
|
||||
unsupported_tool_call_message(&invocation.payload, tool_name.as_ref());
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::RespondToModel(message));
|
||||
}
|
||||
};
|
||||
|
||||
if !handler.matches_kind(&invocation.payload) {
|
||||
let message = format!("tool {tool_name} invoked with incompatible payload");
|
||||
otel.tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
Duration::ZERO,
|
||||
false,
|
||||
&message,
|
||||
);
|
||||
return Err(FunctionCallError::Fatal(message));
|
||||
}
|
||||
|
||||
let output_cell = tokio::sync::Mutex::new(None);
|
||||
|
||||
let result = otel
|
||||
.log_tool_result(
|
||||
tool_name.as_ref(),
|
||||
&call_id_owned,
|
||||
log_payload.as_ref(),
|
||||
|| {
|
||||
let handler = handler.clone();
|
||||
let output_cell = &output_cell;
|
||||
let invocation = invocation;
|
||||
async move {
|
||||
match handler.handle(invocation).await {
|
||||
Ok(output) => {
|
||||
let preview = output.log_preview();
|
||||
let success = output.success_for_logging();
|
||||
let mut guard = output_cell.lock().await;
|
||||
*guard = Some(output);
|
||||
Ok((preview, success))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
let mut guard = output_cell.lock().await;
|
||||
let output = guard.take().ok_or_else(|| {
|
||||
FunctionCallError::Fatal("tool produced no output".to_string())
|
||||
})?;
|
||||
Ok(output.into_response(&call_id_owned, &payload_for_response))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfiguredToolSpec {
|
||||
pub spec: ToolSpec,
|
||||
pub supports_parallel_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl ConfiguredToolSpec {
|
||||
pub fn new(spec: ToolSpec, supports_parallel_tool_calls: bool) -> Self {
|
||||
Self {
|
||||
spec,
|
||||
supports_parallel_tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ToolRegistryBuilder {
|
||||
handlers: HashMap<String, Arc<dyn ToolHandler>>,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRegistryBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
specs: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_spec(&mut self, spec: ToolSpec) {
|
||||
self.push_spec_with_parallel_support(spec, false);
|
||||
}
|
||||
|
||||
pub fn push_spec_with_parallel_support(
|
||||
&mut self,
|
||||
spec: ToolSpec,
|
||||
supports_parallel_tool_calls: bool,
|
||||
) {
|
||||
self.specs
|
||||
.push(ConfiguredToolSpec::new(spec, supports_parallel_tool_calls));
|
||||
}
|
||||
|
||||
pub fn register_handler(&mut self, name: impl Into<String>, handler: Arc<dyn ToolHandler>) {
|
||||
let name = name.into();
|
||||
if self
|
||||
.handlers
|
||||
.insert(name.clone(), handler.clone())
|
||||
.is_some()
|
||||
{
|
||||
warn!("overwriting handler for tool {name}");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(jif) for dynamic tools.
|
||||
// pub fn register_many<I>(&mut self, names: I, handler: Arc<dyn ToolHandler>)
|
||||
// where
|
||||
// I: IntoIterator,
|
||||
// I::Item: Into<String>,
|
||||
// {
|
||||
// for name in names {
|
||||
// let name = name.into();
|
||||
// if self
|
||||
// .handlers
|
||||
// .insert(name.clone(), handler.clone())
|
||||
// .is_some()
|
||||
// {
|
||||
// warn!("overwriting handler for tool {name}");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn build(self) -> (Vec<ConfiguredToolSpec>, ToolRegistry) {
|
||||
let registry = ToolRegistry::new(self.handlers);
|
||||
(self.specs, registry)
|
||||
}
|
||||
}
|
||||
|
||||
fn unsupported_tool_call_message(payload: &ToolPayload, tool_name: &str) -> String {
|
||||
match payload {
|
||||
ToolPayload::Custom { .. } => format!("unsupported custom tool call: {tool_name}"),
|
||||
_ => format!("unsupported call: {tool_name}"),
|
||||
}
|
||||
}
|
||||
189
llmx-rs/core/src/tools/router.rs
Normal file
189
llmx-rs/core/src/tools/router.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::models::ShellToolCallParams;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ToolCall {
|
||||
pub tool_name: String,
|
||||
pub call_id: String,
|
||||
pub payload: ToolPayload,
|
||||
}
|
||||
|
||||
pub struct ToolRouter {
|
||||
registry: ToolRegistry,
|
||||
specs: Vec<ConfiguredToolSpec>,
|
||||
}
|
||||
|
||||
impl ToolRouter {
|
||||
pub fn from_config(
|
||||
config: &ToolsConfig,
|
||||
mcp_tools: Option<HashMap<String, mcp_types::Tool>>,
|
||||
) -> Self {
|
||||
let builder = build_specs(config, mcp_tools);
|
||||
let (specs, registry) = builder.build();
|
||||
|
||||
Self { registry, specs }
|
||||
}
|
||||
|
||||
pub fn specs(&self) -> Vec<ToolSpec> {
|
||||
self.specs
|
||||
.iter()
|
||||
.map(|config| config.spec.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tool_supports_parallel(&self, tool_name: &str) -> bool {
|
||||
self.specs
|
||||
.iter()
|
||||
.filter(|config| config.supports_parallel_tool_calls)
|
||||
.any(|config| config.spec.name() == tool_name)
|
||||
}
|
||||
|
||||
pub fn build_tool_call(
|
||||
session: &Session,
|
||||
item: ResponseItem,
|
||||
) -> Result<Option<ToolCall>, FunctionCallError> {
|
||||
match item {
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
if let Some((server, tool)) = session.parse_mcp_tool_name(&name) {
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Mcp {
|
||||
server,
|
||||
tool,
|
||||
raw_arguments: arguments,
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
let payload = if name == "unified_exec" {
|
||||
ToolPayload::UnifiedExec { arguments }
|
||||
} else {
|
||||
ToolPayload::Function { arguments }
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload,
|
||||
}))
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
name,
|
||||
input,
|
||||
call_id,
|
||||
..
|
||||
} => Ok(Some(ToolCall {
|
||||
tool_name: name,
|
||||
call_id,
|
||||
payload: ToolPayload::Custom { input },
|
||||
})),
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id,
|
||||
action,
|
||||
..
|
||||
} => {
|
||||
let call_id = call_id
|
||||
.or(id)
|
||||
.ok_or(FunctionCallError::MissingLocalShellCallId)?;
|
||||
|
||||
match action {
|
||||
LocalShellAction::Exec(exec) => {
|
||||
let params = ShellToolCallParams {
|
||||
command: exec.command,
|
||||
workdir: exec.working_directory,
|
||||
timeout_ms: exec.timeout_ms,
|
||||
with_escalated_permissions: None,
|
||||
justification: None,
|
||||
};
|
||||
Ok(Some(ToolCall {
|
||||
tool_name: "local_shell".to_string(),
|
||||
call_id,
|
||||
payload: ToolPayload::LocalShell { params },
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn dispatch_tool_call(
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
call: ToolCall,
|
||||
) -> Result<ResponseInputItem, FunctionCallError> {
|
||||
let ToolCall {
|
||||
tool_name,
|
||||
call_id,
|
||||
payload,
|
||||
} = call;
|
||||
let payload_outputs_custom = matches!(payload, ToolPayload::Custom { .. });
|
||||
let failure_call_id = call_id.clone();
|
||||
|
||||
let invocation = ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
tracker,
|
||||
call_id,
|
||||
tool_name,
|
||||
payload,
|
||||
};
|
||||
|
||||
match self.registry.dispatch(invocation).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(FunctionCallError::Fatal(message)) => Err(FunctionCallError::Fatal(message)),
|
||||
Err(err) => Ok(Self::failure_response(
|
||||
failure_call_id,
|
||||
payload_outputs_custom,
|
||||
err,
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn failure_response(
|
||||
call_id: String,
|
||||
payload_outputs_custom: bool,
|
||||
err: FunctionCallError,
|
||||
) -> ResponseInputItem {
|
||||
let message = err.to_string();
|
||||
if payload_outputs_custom {
|
||||
ResponseInputItem::CustomToolCallOutput {
|
||||
call_id,
|
||||
output: message,
|
||||
}
|
||||
} else {
|
||||
ResponseInputItem::FunctionCallOutput {
|
||||
call_id,
|
||||
output: codex_protocol::models::FunctionCallOutputPayload {
|
||||
content: message,
|
||||
success: Some(false),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user