feat: Complete LLMX v0.1.0 - Rebrand from Codex with LiteLLM Integration

This release represents a comprehensive transformation of the codebase from Codex to LLMX,
enhanced with LiteLLM integration to support 100+ LLM providers through a unified API.

## Major Changes

### Phase 1: Repository & Infrastructure Setup
- Established new repository structure and branching strategy
- Created comprehensive project documentation (CLAUDE.md, LITELLM-SETUP.md)
- Set up development environment and tooling configuration

### Phase 2: Rust Workspace Transformation
- Renamed all Rust crates from `codex-*` to `llmx-*` (30+ crates)
- Updated package names, binary names, and workspace members
- Renamed core modules: codex.rs → llmx.rs, codex_delegate.rs → llmx_delegate.rs
- Updated all internal references, imports, and type names
- Renamed directories: codex-rs/ → llmx-rs/, codex-backend-openapi-models/ → llmx-backend-openapi-models/
- Fixed all Rust compilation errors after mass rename

### Phase 3: LiteLLM Integration
- Integrated LiteLLM for multi-provider LLM support (Anthropic, OpenAI, Azure, Google AI, AWS Bedrock, etc.)
- Implemented OpenAI-compatible Chat Completions API support
- Added model family detection and provider-specific handling
- Updated authentication to support LiteLLM API keys
- Renamed environment variables: OPENAI_BASE_URL → LLMX_BASE_URL
- Added LLMX_API_KEY for unified authentication
- Enhanced error handling for Chat Completions API responses
- Implemented fallback mechanisms between Responses API and Chat Completions API

### Phase 4: TypeScript/Node.js Components
- Renamed npm package: @codex/codex-cli → @valknar/llmx
- Updated TypeScript SDK to use new LLMX APIs and endpoints
- Fixed all TypeScript compilation and linting errors
- Updated SDK tests to support both API backends
- Enhanced mock server to handle multiple API formats
- Updated build scripts for cross-platform packaging

### Phase 5: Configuration & Documentation
- Updated all configuration files to use LLMX naming
- Rewrote README and documentation for LLMX branding
- Updated config paths: ~/.codex/ → ~/.llmx/
- Added comprehensive LiteLLM setup guide
- Updated all user-facing strings and help text
- Created release plan and migration documentation

### Phase 6: Testing & Validation
- Fixed all Rust tests for new naming scheme
- Updated snapshot tests in TUI (36 frame files)
- Fixed authentication storage tests
- Updated Chat Completions payload and SSE tests
- Fixed SDK tests for new API endpoints
- Ensured compatibility with Claude Sonnet 4.5 model
- Fixed test environment variables (LLMX_API_KEY, LLMX_BASE_URL)

### Phase 7: Build & Release Pipeline
- Updated GitHub Actions workflows for LLMX binary names
- Fixed rust-release.yml to reference llmx-rs/ instead of codex-rs/
- Updated CI/CD pipelines for new package names
- Made Apple code signing optional in release workflow
- Enhanced npm packaging resilience for partial platform builds
- Added Windows sandbox support to workspace
- Updated dotslash configuration for new binary names

### Phase 8: Final Polish
- Renamed all assets (.github images, labels, templates)
- Updated VSCode and DevContainer configurations
- Fixed all clippy warnings and formatting issues
- Applied cargo fmt and prettier formatting across codebase
- Updated issue templates and pull request templates
- Fixed all remaining UI text references

## Technical Details

**Breaking Changes:**
- Binary name changed from `codex` to `llmx`
- Config directory changed from `~/.codex/` to `~/.llmx/`
- Environment variables renamed (CODEX_* → LLMX_*)
- npm package renamed to `@valknar/llmx`

**New Features:**
- Support for 100+ LLM providers via LiteLLM
- Unified authentication with LLMX_API_KEY
- Enhanced model provider detection and handling
- Improved error handling and fallback mechanisms

**Files Changed:**
- 578 files modified across Rust, TypeScript, and documentation
- 30+ Rust crates renamed and updated
- Complete rebrand of UI, CLI, and documentation
- All tests updated and passing

**Dependencies:**
- Updated Cargo.lock with new package names
- Updated npm dependencies in llmx-cli
- Enhanced OpenAPI models for LLMX backend

This release establishes LLMX as a standalone project with comprehensive LiteLLM
integration, maintaining full backward compatibility with existing functionality
while opening support for a wide ecosystem of LLM providers.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Sebastian Krüger <support@pivoine.art>
This commit is contained in:
Sebastian Krüger
2025-11-12 20:40:44 +01:00
parent 052b052832
commit 3c7efc58c8
1248 changed files with 10085 additions and 9580 deletions

123
llmx-rs/core/Cargo.toml Normal file
View File

@@ -0,0 +1,123 @@
[package]
edition = "2024"
name = "llmx-core"
version = { workspace = true }
[lib]
doctest = false
name = "llmx_core"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
anyhow = { workspace = true }
askama = { workspace = true }
async-channel = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
llmx-app-server-protocol = { workspace = true }
llmx-apply-patch = { workspace = true }
llmx-async-utils = { workspace = true }
llmx-file-search = { workspace = true }
llmx-git = { workspace = true }
llmx-keyring-store = { workspace = true }
llmx-otel = { workspace = true, features = ["otel"] }
llmx-protocol = { workspace = true }
llmx-rmcp-client = { workspace = true }
llmx-utils-pty = { workspace = true }
llmx-utils-readiness = { workspace = true }
llmx-utils-string = { workspace = true }
llmx-utils-tokenizer = { workspace = true }
llmx-windows-sandbox = { package = "llmx-windows-sandbox", path = "../windows-sandbox-rs" }
dirs = { workspace = true }
dunce = { workspace = true }
env-flags = { workspace = true }
eventsource-stream = { workspace = true }
futures = { workspace = true }
http = { workspace = true }
indexmap = { workspace = true }
keyring = { workspace = true, features = [
"apple-native",
"crypto-rust",
"linux-native-async-persistent",
"windows-native",
] }
libc = { workspace = true }
mcp-types = { workspace = true }
os_info = { workspace = true }
rand = { workspace = true }
regex-lite = { workspace = true }
reqwest = { workspace = true, features = ["json", "stream"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
sha1 = { workspace = true }
sha2 = { workspace = true }
shlex = { workspace = true }
similar = { workspace = true }
strum_macros = { workspace = true }
tempfile = { workspace = true }
test-log = { workspace = true }
thiserror = { workspace = true }
time = { workspace = true, features = [
"formatting",
"parsing",
"local-offset",
"macros",
] }
tokio = { workspace = true, features = [
"io-std",
"macros",
"process",
"rt-multi-thread",
"signal",
] }
tokio-util = { workspace = true, features = ["rt"] }
toml = { workspace = true }
toml_edit = { workspace = true }
tracing = { workspace = true, features = ["log"] }
tree-sitter = { workspace = true }
tree-sitter-bash = { workspace = true }
uuid = { workspace = true, features = ["serde", "v4", "v5"] }
which = { workspace = true }
wildmatch = { workspace = true }
[target.'cfg(target_os = "linux")'.dependencies]
landlock = { workspace = true }
seccompiler = { workspace = true }
[target.'cfg(target_os = "macos")'.dependencies]
core-foundation = "0.9"
# Build OpenSSL from source for musl builds.
[target.x86_64-unknown-linux-musl.dependencies]
openssl-sys = { workspace = true, features = ["vendored"] }
# Build OpenSSL from source for musl builds.
[target.aarch64-unknown-linux-musl.dependencies]
openssl-sys = { workspace = true, features = ["vendored"] }
[dev-dependencies]
assert_cmd = { workspace = true }
assert_matches = { workspace = true }
llmx-arg0 = { workspace = true }
core_test_support = { workspace = true }
ctor = { workspace = true }
escargot = { workspace = true }
image = { workspace = true, features = ["jpeg", "png"] }
maplit = { workspace = true }
predicates = { workspace = true }
pretty_assertions = { workspace = true }
serial_test = { workspace = true }
tempfile = { workspace = true }
tokio-test = { workspace = true }
tracing-test = { workspace = true, features = ["no-env-filter"] }
walkdir = { workspace = true }
wiremock = { workspace = true }
[package.metadata.cargo-shear]
ignored = ["openssl-sys"]

19
llmx-rs/core/README.md Normal file
View File

@@ -0,0 +1,19 @@
# llmx-core
This crate implements the business logic for LLMX. It is designed to be used by the various LLMX UIs written in Rust.
## Dependencies
Note that `llmx-core` makes some assumptions about certain helper utilities being available in the environment. Currently, this support matrix is:
### macOS
Expects `/usr/bin/sandbox-exec` to be present.
### Linux
Expects the binary containing `llmx-core` to run the equivalent of `llmx sandbox linux` (legacy alias: `llmx debug landlock`) when `arg0` is `llmx-linux-sandbox`. See the `llmx-arg0` crate for details.
### All Platforms
Expects the binary containing `llmx-core` to simulate the virtual `apply_patch` CLI when `arg1` is `--llmx-run-as-apply-patch`. See the `llmx-arg0` crate for details.

View File

@@ -0,0 +1,107 @@
You are LLMX, based on GPT-5. You are running as a coding agent in the LLMX CLI on a user's computer.
## General
- The arguments to `shell` will be passed to execvp(). Most terminal commands should be prefixed with ["bash", "-lc"].
- Always set the `workdir` param when using the shell function. Do not use `cd` unless absolutely necessary.
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
## Editing constraints
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
- You may be in a dirty git worktree.
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
* If the changes are in unrelated files, just ignore them and don't revert them.
- Do not amend a commit unless explicitly requested to do so.
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
## Plan tool
When using the planning tool:
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
- Do not make single-step plans.
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
## LLMX CLI harness, sandboxing, and approvals
The LLMX CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
- **read-only**: The sandbox only permits reading files.
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
- **restricted**: Requires approval
- **enabled**: No approval needed
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `with_escalated_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (for all of these, you should weigh alternative paths that do not require approval)
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
When requesting approval to execute a command that will require escalated privileges:
- Provide the `with_escalated_permissions` parameter with the boolean value true
- Include a short, 1 sentence explanation for why you need to enable `with_escalated_permissions` in the justification parameter
## Special user requests
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
## Presenting your work and final message
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
- Default: be very concise; friendly coding teammate tone.
- Ask only when needed; suggest ideas; mirror the user's style.
- For substantial work, summarize clearly; follow finalanswer formatting.
- Skip heavy formatting for simple confirmations.
- Don't dump large files you've written; reference paths only.
- No "save/copy this file" - User is on the same machine.
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
- For code changes:
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
### Final answer structure and style guidelines
- Plain text; CLI handles styling. Use structure only when it helps scanability.
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
- Bullets: use - ; merge related points; keep to one line when possible; 46 per list ordered by importance; keep phrasing consistent.
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
- Tone: collaborative, concise, factual; present tense, active voice; selfcontained; no "above/below"; parallel wording.
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
- File References: When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Line/column (1based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5

310
llmx-rs/core/prompt.md Normal file
View File

@@ -0,0 +1,310 @@
You are a coding agent running in the LLMX CLI, a terminal-based coding assistant. LLMX CLI is an open source project led by OpenAI. You are expected to be precise, safe, and helpful.
Your capabilities:
- Receive user prompts and other context provided by the harness, such as files in the workspace.
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
- Emit function calls to run terminal commands and apply patches. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
Within this context, LLMX refers to the open-source agentic coding interface (not the old LLMX language model built by OpenAI).
# How you work
## Personality
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
# AGENTS.md spec
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
- Instructions in AGENTS.md files:
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
## Responsiveness
### Preamble messages
Before making tool calls, send a brief preamble to the user explaining what youre about to do. When sending preamble messages, follow these principles and examples:
- **Logically group related actions**: if youre about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (812 words for quick updates).
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with whats been done so far and create a sense of momentum and clarity for the user to understand your next actions.
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless its part of a larger grouped action.
**Examples:**
- “Ive explored the repo; now checking the API route definitions.”
- “Next, Ill patch the config and update the related tests.”
- “Im about to scaffold the CLI commands and helper functions.”
- “Ok cool, so Ive wrapped my head around the repo. Now digging into the API routes.”
- “Configs looking tidy. Next up is patching helpers to keep things in sync.”
- “Finished poking at the DB gateway. I will now chase down error handling.”
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
- “Spotted a clever caching util; now hunting where it gets used.”
## Planning
You have access to an `update_plan` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
Do not repeat the full contents of the plan after an `update_plan` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
Before running a command, consider whether or not you have completed the previous step, and make sure to mark it as completed before moving on to the next step. It may be the case that you complete all steps in your plan after a single pass of implementation. If this is the case, you can simply mark all the planned steps as completed. Sometimes, you may need to change plans in the middle of a task: call `update_plan` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
Use a plan when:
- The task is non-trivial and will require multiple actions over a long time horizon.
- There are logical phases or dependencies where sequencing matters.
- The work has ambiguity that benefits from outlining high-level goals.
- You want intermediate checkpoints for feedback and validation.
- When the user asked you to do more than one thing in a single prompt
- The user has asked you to use the plan tool (aka "TODOs")
- You generate additional steps while working, and plan to do them before yielding to the user
### Examples
**High-quality plans**
Example 1:
1. Add CLI entry with file args
2. Parse Markdown via CommonMark library
3. Apply semantic HTML template
4. Handle code blocks, images, links
5. Add error handling for invalid files
Example 2:
1. Define CSS variables for colors
2. Add toggle with localStorage state
3. Refactor components to use variables
4. Verify all views for readability
5. Add smooth theme-change transition
Example 3:
1. Set up Node.js + WebSocket server
2. Add join/leave broadcast events
3. Implement messaging with timestamps
4. Add usernames + mention highlighting
5. Persist messages in lightweight DB
6. Add typing indicators + unread count
**Low-quality plans**
Example 1:
1. Create CLI tool
2. Add Markdown parser
3. Convert to HTML
Example 2:
1. Add dark mode toggle
2. Save preference
3. Make styles look good
Example 3:
1. Create single-file HTML game
2. Run quick sanity check
3. Summarize usage instructions
If you need to write a plan, only write high quality plans, not low quality ones.
## Task execution
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
You MUST adhere to the following criteria when solving queries:
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
- Analyzing code for vulnerabilities is allowed.
- Showing user code and tool call details is allowed.
- Use the `apply_patch` tool to edit files (NEVER try `applypatch` or `apply-patch`, only `apply_patch`): {"command":["apply_patch","*** Begin Patch\\n*** Update File: path/to/file.py\\n@@ def example():\\n- pass\\n+ return 123\\n*** End Patch"]}
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
- Avoid unneeded complexity in your solution.
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
- Update documentation as necessary.
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
- NEVER add copyright or license headers unless specifically requested.
- Do not waste tokens by re-reading files after calling `apply_patch` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
- Do not `git commit` your changes or create new git branches unless explicitly requested.
- Do not add inline comments within code unless explicitly requested.
- Do not use one-letter variable names unless explicitly requested.
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
## Sandbox and approvals
The LLMX CLI harness supports several different sandboxing, and approval configurations that the user can choose from.
Filesystem sandboxing prevents you from editing files without user approval. The options are:
- **read-only**: You can only read files.
- **workspace-write**: You can read files. You can write to files in your workspace folder, but not outside it.
- **danger-full-access**: No filesystem sandboxing.
Network sandboxing prevents you from accessing network without approval. Options are
- **restricted**
- **enabled**
Approvals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
When you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
- (For all of these, you should weigh alternative paths that do not require approval.)
Note that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.
## Validating your work
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
## Ambition vs. precision
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
## Sharing progress updates
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
## Presenting your work and final message
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the users style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multi-section structured responses for results that need grouping or explanation.
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `apply_patch`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If theres something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
### Final answer structure and style guidelines
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
**Section Headers**
- Use only when they improve clarity — they are not mandatory for every answer.
- Choose descriptive names that fit the content
- Keep headers short (13 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
- Leave no blank line before the first bullet under a header.
- Section headers should only be used where they genuinely improve scanability; avoid fragmenting the answer.
**Bullets**
- Use `-` followed by a space for every bullet.
- Merge related points when possible; avoid a bullet for every trivial detail.
- Keep bullets to one line unless breaking for clarity is unavoidable.
- Group into short lists (46 bullets) ordered by importance.
- Use consistent keyword phrasing and formatting across sections.
**Monospace**
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
- Never mix monospace and bold markers; choose one based on whether its a keyword (`**`) or inline code/path (`` ` ``).
**File References**
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
* Use inline code to make file paths clickable.
* Each reference should have a stand alone path. Even if it's the same file.
* Accepted: absolute, workspacerelative, a/ or b/ diff prefixes, or bare filename/suffix.
* Line/column (1based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
* Do not use URIs like file://, vscode://, or https://.
* Do not provide range of lines
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
**Structure**
- Place related bullets together; dont mix unrelated concepts in the same section.
- Order sections from general → specific → supporting info.
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
- Match structure to complexity:
- Multi-part or detailed results → use clear headers and grouped bullets.
- Simple results → minimal headers, possibly just a short list or paragraph.
**Tone**
- Keep the voice collaborative and natural, like a coding partner handing off work.
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
- Keep descriptions self-contained; dont refer to “above” or “below”.
- Use parallel structure in lists for consistency.
**Dont**
- Dont use literal words “bold” or “monospace” in the content.
- Dont nest bullets or create deep hierarchies.
- Dont output ANSI escape codes directly — the CLI renderer applies them.
- Dont cram unrelated keywords into a single bullet; split for clarity.
- Dont let keyword lists run long — wrap or reformat for scanability.
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with whats needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
# Tool Guidelines
## Shell commands
When using the shell, you must adhere to the following guidelines:
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
## `update_plan`
A tool named `update_plan` is available to you. You can use it to keep an uptodate, stepbystep plan for the task.
To create a new plan, call `update_plan` with a short list of 1sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
When steps have been completed, use `update_plan` to mark each finished step as `completed` and the next step you are working on as `in_progress`. There should always be exactly one `in_progress` step until everything is done. You can mark multiple items as complete in a single `update_plan` call.
If all steps are complete, ensure you call `update_plan` to mark all steps as `completed`.

View File

@@ -0,0 +1,87 @@
# Review guidelines:
You are acting as a reviewer for a proposed code change made by another engineer.
Below are some default guidelines for determining whether the original author would appreciate the issue being flagged.
These are not the final word in determining whether an issue is a bug. In many cases, you will encounter other, more specific guidelines. These may be present elsewhere in a developer message, a user message, a file, or even elsewhere in this system message.
Those guidelines should be considered to override these general instructions.
Here are the general guidelines for determining whether something is a bug and should be flagged.
1. It meaningfully impacts the accuracy, performance, security, or maintainability of the code.
2. The bug is discrete and actionable (i.e. not a general issue with the codebase or a combination of multiple issues).
3. Fixing the bug does not demand a level of rigor that is not present in the rest of the codebase (e.g. one doesn't need very detailed comments and input validation in a repository of one-off scripts in personal projects)
4. The bug was introduced in the commit (pre-existing bugs should not be flagged).
5. The author of the original PR would likely fix the issue if they were made aware of it.
6. The bug does not rely on unstated assumptions about the codebase or author's intent.
7. It is not enough to speculate that a change may disrupt another part of the codebase, to be considered a bug, one must identify the other parts of the code that are provably affected.
8. The bug is clearly not just an intentional change by the original author.
When flagging a bug, you will also provide an accompanying comment. Once again, these guidelines are not the final word on how to construct a comment -- defer to any subsequent guidelines that you encounter.
1. The comment should be clear about why the issue is a bug.
2. The comment should appropriately communicate the severity of the issue. It should not claim that an issue is more severe than it actually is.
3. The comment should be brief. The body should be at most 1 paragraph. It should not introduce line breaks within the natural language flow unless it is necessary for the code fragment.
4. The comment should not include any chunks of code longer than 3 lines. Any code chunks should be wrapped in markdown inline code tags or a code block.
5. The comment should clearly and explicitly communicate the scenarios, environments, or inputs that are necessary for the bug to arise. The comment should immediately indicate that the issue's severity depends on these factors.
6. The comment's tone should be matter-of-fact and not accusatory or overly positive. It should read as a helpful AI assistant suggestion without sounding too much like a human reviewer.
7. The comment should be written such that the original author can immediately grasp the idea without close reading.
8. The comment should avoid excessive flattery and comments that are not helpful to the original author. The comment should avoid phrasing like "Great job ...", "Thanks for ...".
Below are some more detailed guidelines that you should apply to this specific review.
HOW MANY FINDINGS TO RETURN:
Output all findings that the original author would fix if they knew about it. If there is no finding that a person would definitely love to see and fix, prefer outputting no findings. Do not stop at the first qualifying finding. Continue until you've listed every qualifying finding.
GUIDELINES:
- Ignore trivial style unless it obscures meaning or violates documented standards.
- Use one comment per distinct issue (or a multi-line range if necessary).
- Use ```suggestion blocks ONLY for concrete replacement code (minimal lines; no commentary inside the block).
- In every ```suggestion block, preserve the exact leading whitespace of the replaced lines (spaces vs tabs, number of spaces).
- Do NOT introduce or remove outer indentation levels unless that is the actual fix.
The comments will be presented in the code review as inline comments. You should avoid providing unnecessary location details in the comment body. Always keep the line range as short as possible for interpreting the issue. Avoid ranges longer than 510 lines; instead, choose the most suitable subrange that pinpoints the problem.
At the beginning of the finding title, tag the bug with priority level. For example "[P1] Un-padding slices along wrong tensor dimensions". [P0] Drop everything to fix. Blocking release, operations, or major usage. Only use for universal issues that do not depend on any assumptions about the inputs. · [P1] Urgent. Should be addressed in the next cycle · [P2] Normal. To be fixed eventually · [P3] Low. Nice to have.
Additionally, include a numeric priority field in the JSON output for each finding: set "priority" to 0 for P0, 1 for P1, 2 for P2, or 3 for P3. If a priority cannot be determined, omit the field or use null.
At the end of your findings, output an "overall correctness" verdict of whether or not the patch should be considered "correct".
Correct implies that existing code and tests will not break, and the patch is free of bugs and other blocking issues.
Ignore non-blocking issues such as style, formatting, typos, documentation, and other nits.
FORMATTING GUIDELINES:
The finding description should be one paragraph.
OUTPUT FORMAT:
## Output schema — MUST MATCH *exactly*
```json
{
"findings": [
{
"title": "<≤ 80 chars, imperative>",
"body": "<valid Markdown explaining *why* this is a problem; cite files/lines/functions>",
"confidence_score": <float 0.0-1.0>,
"priority": <int 0-3, optional>,
"code_location": {
"absolute_file_path": "<file path>",
"line_range": {"start": <int>, "end": <int>}
}
}
],
"overall_correctness": "patch is correct" | "patch is incorrect",
"overall_explanation": "<1-3 sentence explanation justifying the overall_correctness verdict>",
"overall_confidence_score": <float 0.0-1.0>
}
```
* **Do not** wrap the JSON in markdown fences or extra prose.
* The code_location field is required and must include absolute_file_path and line_range.
* Line ranges must be as short as possible for interpreting the issue (avoid ranges over 510 lines; pick the most suitable subrange).
* The code_location should overlap with the diff.
* Do not generate a PR fix.

View File

@@ -0,0 +1,142 @@
use crate::function_tool::FunctionCallError;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::protocol::FileChange;
use crate::protocol::ReviewDecision;
use crate::safety::SafetyCheck;
use crate::safety::assess_patch_safety;
use llmx_apply_patch::ApplyPatchAction;
use llmx_apply_patch::ApplyPatchFileChange;
use std::collections::HashMap;
use std::path::PathBuf;
pub const LLMX_APPLY_PATCH_ARG1: &str = "--llmx-run-as-apply-patch";
pub(crate) enum InternalApplyPatchInvocation {
/// The `apply_patch` call was handled programmatically, without any sort
/// of sandbox, because the user explicitly approved it. This is the
/// result to use with the `shell` function call that contained `apply_patch`.
Output(Result<String, FunctionCallError>),
/// The `apply_patch` call was approved, either automatically because it
/// appears that it should be allowed based on the user's sandbox policy
/// *or* because the user explicitly approved it. In either case, we use
/// exec with [`LLMX_APPLY_PATCH_ARG1`] to realize the `apply_patch` call,
/// but [`ApplyPatchExec::auto_approved`] is used to determine the sandbox
/// used with the `exec()`.
DelegateToExec(ApplyPatchExec),
}
#[derive(Debug)]
pub(crate) struct ApplyPatchExec {
pub(crate) action: ApplyPatchAction,
pub(crate) user_explicitly_approved_this_action: bool,
}
pub(crate) async fn apply_patch(
sess: &Session,
turn_context: &TurnContext,
call_id: &str,
action: ApplyPatchAction,
) -> InternalApplyPatchInvocation {
match assess_patch_safety(
&action,
turn_context.approval_policy,
&turn_context.sandbox_policy,
&turn_context.cwd,
) {
SafetyCheck::AutoApprove {
user_explicitly_approved,
..
} => InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
action,
user_explicitly_approved_this_action: user_explicitly_approved,
}),
SafetyCheck::AskUser => {
// Compute a readable summary of path changes to include in the
// approval request so the user can make an informed decision.
//
// Note that it might be worth expanding this approval request to
// give the user the option to expand the set of writable roots so
// that similar patches can be auto-approved in the future during
// this session.
let rx_approve = sess
.request_patch_approval(
turn_context,
call_id.to_owned(),
convert_apply_patch_to_protocol(&action),
None,
None,
)
.await;
match rx_approve.await.unwrap_or_default() {
ReviewDecision::Approved | ReviewDecision::ApprovedForSession => {
InternalApplyPatchInvocation::DelegateToExec(ApplyPatchExec {
action,
user_explicitly_approved_this_action: true,
})
}
ReviewDecision::Denied | ReviewDecision::Abort => {
InternalApplyPatchInvocation::Output(Err(FunctionCallError::RespondToModel(
"patch rejected by user".to_string(),
)))
}
}
}
SafetyCheck::Reject { reason } => InternalApplyPatchInvocation::Output(Err(
FunctionCallError::RespondToModel(format!("patch rejected: {reason}")),
)),
}
}
pub(crate) fn convert_apply_patch_to_protocol(
action: &ApplyPatchAction,
) -> HashMap<PathBuf, FileChange> {
let changes = action.changes();
let mut result = HashMap::with_capacity(changes.len());
for (path, change) in changes {
let protocol_change = match change {
ApplyPatchFileChange::Add { content } => FileChange::Add {
content: content.clone(),
},
ApplyPatchFileChange::Delete { content } => FileChange::Delete {
content: content.clone(),
},
ApplyPatchFileChange::Update {
unified_diff,
move_path,
new_content: _new_content,
} => FileChange::Update {
unified_diff: unified_diff.clone(),
move_path: move_path.clone(),
},
};
result.insert(path.clone(), protocol_change);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
#[test]
fn convert_apply_patch_maps_add_variant() {
let tmp = tempdir().expect("tmp");
let p = tmp.path().join("a.txt");
// Create an action with a single Add change
let action = ApplyPatchAction::new_add_for_test(&p, "hello".to_string());
let got = convert_apply_patch_to_protocol(&action);
assert_eq!(
got.get(&p),
Some(&FileChange::Add {
content: "hello".to_string()
})
);
}
}

1207
llmx-rs/core/src/auth.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,684 @@
use chrono::DateTime;
use chrono::Utc;
use serde::Deserialize;
use serde::Serialize;
use sha2::Digest;
use sha2::Sha256;
use std::fmt::Debug;
use std::fs::File;
use std::fs::OpenOptions;
use std::io::Read;
use std::io::Write;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use tracing::warn;
use crate::token_data::TokenData;
use llmx_keyring_store::DefaultKeyringStore;
use llmx_keyring_store::KeyringStore;
/// Determine where Llmx should store CLI auth credentials.
#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AuthCredentialsStoreMode {
#[default]
/// Persist credentials in LLMX_HOME/auth.json.
File,
/// Persist credentials in the keyring. Fail if unavailable.
Keyring,
/// Use keyring when available; otherwise, fall back to a file in LLMX_HOME.
Auto,
}
/// Expected structure for $LLMX_HOME/auth.json.
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct AuthDotJson {
#[serde(rename = "OPENAI_API_KEY")]
pub openai_api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokens: Option<TokenData>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub last_refresh: Option<DateTime<Utc>>,
}
pub(super) fn get_auth_file(llmx_home: &Path) -> PathBuf {
llmx_home.join("auth.json")
}
pub(super) fn delete_file_if_exists(llmx_home: &Path) -> std::io::Result<bool> {
let auth_file = get_auth_file(llmx_home);
match std::fs::remove_file(&auth_file) {
Ok(()) => Ok(true),
Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(false),
Err(err) => Err(err),
}
}
pub(super) trait AuthStorageBackend: Debug + Send + Sync {
fn load(&self) -> std::io::Result<Option<AuthDotJson>>;
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()>;
fn delete(&self) -> std::io::Result<bool>;
}
#[derive(Clone, Debug)]
pub(super) struct FileAuthStorage {
llmx_home: PathBuf,
}
impl FileAuthStorage {
pub(super) fn new(llmx_home: PathBuf) -> Self {
Self { llmx_home }
}
/// Attempt to read and refresh the `auth.json` file in the given `LLMX_HOME` directory.
/// Returns the full AuthDotJson structure after refreshing if necessary.
pub(super) fn try_read_auth_json(&self, auth_file: &Path) -> std::io::Result<AuthDotJson> {
let mut file = File::open(auth_file)?;
let mut contents = String::new();
file.read_to_string(&mut contents)?;
let auth_dot_json: AuthDotJson = serde_json::from_str(&contents)?;
Ok(auth_dot_json)
}
}
impl AuthStorageBackend for FileAuthStorage {
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
let auth_file = get_auth_file(&self.llmx_home);
let auth_dot_json = match self.try_read_auth_json(&auth_file) {
Ok(auth) => auth,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(err) => return Err(err),
};
Ok(Some(auth_dot_json))
}
fn save(&self, auth_dot_json: &AuthDotJson) -> std::io::Result<()> {
let auth_file = get_auth_file(&self.llmx_home);
if let Some(parent) = auth_file.parent() {
std::fs::create_dir_all(parent)?;
}
let json_data = serde_json::to_string_pretty(auth_dot_json)?;
let mut options = OpenOptions::new();
options.truncate(true).write(true).create(true);
#[cfg(unix)]
{
options.mode(0o600);
}
let mut file = options.open(auth_file)?;
file.write_all(json_data.as_bytes())?;
file.flush()?;
Ok(())
}
fn delete(&self) -> std::io::Result<bool> {
delete_file_if_exists(&self.llmx_home)
}
}
const KEYRING_SERVICE: &str = "LLMX Auth";
// turns llmx_home path into a stable, short key string
fn compute_store_key(llmx_home: &Path) -> std::io::Result<String> {
let canonical = llmx_home
.canonicalize()
.unwrap_or_else(|_| llmx_home.to_path_buf());
let path_str = canonical.to_string_lossy();
let mut hasher = Sha256::new();
hasher.update(path_str.as_bytes());
let digest = hasher.finalize();
let hex = format!("{digest:x}");
let truncated = hex.get(..16).unwrap_or(&hex);
Ok(format!("cli|{truncated}"))
}
#[derive(Clone, Debug)]
struct KeyringAuthStorage {
llmx_home: PathBuf,
keyring_store: Arc<dyn KeyringStore>,
}
impl KeyringAuthStorage {
fn new(llmx_home: PathBuf, keyring_store: Arc<dyn KeyringStore>) -> Self {
Self {
llmx_home,
keyring_store,
}
}
fn load_from_keyring(&self, key: &str) -> std::io::Result<Option<AuthDotJson>> {
match self.keyring_store.load(KEYRING_SERVICE, key) {
Ok(Some(serialized)) => serde_json::from_str(&serialized).map(Some).map_err(|err| {
std::io::Error::other(format!(
"failed to deserialize CLI auth from keyring: {err}"
))
}),
Ok(None) => Ok(None),
Err(error) => Err(std::io::Error::other(format!(
"failed to load CLI auth from keyring: {}",
error.message()
))),
}
}
fn save_to_keyring(&self, key: &str, value: &str) -> std::io::Result<()> {
match self.keyring_store.save(KEYRING_SERVICE, key, value) {
Ok(()) => Ok(()),
Err(error) => {
let message = format!(
"failed to write OAuth tokens to keyring: {}",
error.message()
);
warn!("{message}");
Err(std::io::Error::other(message))
}
}
}
}
impl AuthStorageBackend for KeyringAuthStorage {
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
let key = compute_store_key(&self.llmx_home)?;
self.load_from_keyring(&key)
}
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> {
let key = compute_store_key(&self.llmx_home)?;
// Simpler error mapping per style: prefer method reference over closure
let serialized = serde_json::to_string(auth).map_err(std::io::Error::other)?;
self.save_to_keyring(&key, &serialized)?;
if let Err(err) = delete_file_if_exists(&self.llmx_home) {
warn!("failed to remove CLI auth fallback file: {err}");
}
Ok(())
}
fn delete(&self) -> std::io::Result<bool> {
let key = compute_store_key(&self.llmx_home)?;
let keyring_removed = self
.keyring_store
.delete(KEYRING_SERVICE, &key)
.map_err(|err| {
std::io::Error::other(format!("failed to delete auth from keyring: {err}"))
})?;
let file_removed = delete_file_if_exists(&self.llmx_home)?;
Ok(keyring_removed || file_removed)
}
}
#[derive(Clone, Debug)]
struct AutoAuthStorage {
keyring_storage: Arc<KeyringAuthStorage>,
file_storage: Arc<FileAuthStorage>,
}
impl AutoAuthStorage {
fn new(llmx_home: PathBuf, keyring_store: Arc<dyn KeyringStore>) -> Self {
Self {
keyring_storage: Arc::new(KeyringAuthStorage::new(llmx_home.clone(), keyring_store)),
file_storage: Arc::new(FileAuthStorage::new(llmx_home)),
}
}
}
impl AuthStorageBackend for AutoAuthStorage {
fn load(&self) -> std::io::Result<Option<AuthDotJson>> {
match self.keyring_storage.load() {
Ok(Some(auth)) => Ok(Some(auth)),
Ok(None) => self.file_storage.load(),
Err(err) => {
warn!("failed to load CLI auth from keyring, falling back to file storage: {err}");
self.file_storage.load()
}
}
}
fn save(&self, auth: &AuthDotJson) -> std::io::Result<()> {
match self.keyring_storage.save(auth) {
Ok(()) => Ok(()),
Err(err) => {
warn!("failed to save auth to keyring, falling back to file storage: {err}");
self.file_storage.save(auth)
}
}
}
fn delete(&self) -> std::io::Result<bool> {
// Keyring storage will delete from disk as well
self.keyring_storage.delete()
}
}
pub(super) fn create_auth_storage(
llmx_home: PathBuf,
mode: AuthCredentialsStoreMode,
) -> Arc<dyn AuthStorageBackend> {
let keyring_store: Arc<dyn KeyringStore> = Arc::new(DefaultKeyringStore);
create_auth_storage_with_keyring_store(llmx_home, mode, keyring_store)
}
fn create_auth_storage_with_keyring_store(
llmx_home: PathBuf,
mode: AuthCredentialsStoreMode,
keyring_store: Arc<dyn KeyringStore>,
) -> Arc<dyn AuthStorageBackend> {
match mode {
AuthCredentialsStoreMode::File => Arc::new(FileAuthStorage::new(llmx_home)),
AuthCredentialsStoreMode::Keyring => {
Arc::new(KeyringAuthStorage::new(llmx_home, keyring_store))
}
AuthCredentialsStoreMode::Auto => Arc::new(AutoAuthStorage::new(llmx_home, keyring_store)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::token_data::IdTokenInfo;
use anyhow::Context;
use base64::Engine;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::tempdir;
use keyring::Error as KeyringError;
use llmx_keyring_store::tests::MockKeyringStore;
#[tokio::test]
async fn file_storage_load_returns_auth_dot_json() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let storage = FileAuthStorage::new(llmx_home.path().to_path_buf());
let auth_dot_json = AuthDotJson {
openai_api_key: Some("test-key".to_string()),
tokens: None,
last_refresh: Some(Utc::now()),
};
storage
.save(&auth_dot_json)
.context("failed to save auth file")?;
let loaded = storage.load().context("failed to load auth file")?;
assert_eq!(Some(auth_dot_json), loaded);
Ok(())
}
#[tokio::test]
async fn file_storage_save_persists_auth_dot_json() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let storage = FileAuthStorage::new(llmx_home.path().to_path_buf());
let auth_dot_json = AuthDotJson {
openai_api_key: Some("test-key".to_string()),
tokens: None,
last_refresh: Some(Utc::now()),
};
let file = get_auth_file(llmx_home.path());
storage
.save(&auth_dot_json)
.context("failed to save auth file")?;
let same_auth_dot_json = storage
.try_read_auth_json(&file)
.context("failed to read auth file after save")?;
assert_eq!(auth_dot_json, same_auth_dot_json);
Ok(())
}
#[test]
fn file_storage_delete_removes_auth_file() -> anyhow::Result<()> {
let dir = tempdir()?;
let auth_dot_json = AuthDotJson {
openai_api_key: Some("sk-test-key".to_string()),
tokens: None,
last_refresh: None,
};
let storage = create_auth_storage(dir.path().to_path_buf(), AuthCredentialsStoreMode::File);
storage.save(&auth_dot_json)?;
assert!(dir.path().join("auth.json").exists());
let storage = FileAuthStorage::new(dir.path().to_path_buf());
let removed = storage.delete()?;
assert!(removed);
assert!(!dir.path().join("auth.json").exists());
Ok(())
}
fn seed_keyring_and_fallback_auth_file_for_delete<F>(
mock_keyring: &MockKeyringStore,
llmx_home: &Path,
compute_key: F,
) -> anyhow::Result<(String, PathBuf)>
where
F: FnOnce() -> std::io::Result<String>,
{
let key = compute_key()?;
mock_keyring.save(KEYRING_SERVICE, &key, "{}")?;
let auth_file = get_auth_file(llmx_home);
std::fs::write(&auth_file, "stale")?;
Ok((key, auth_file))
}
fn seed_keyring_with_auth<F>(
mock_keyring: &MockKeyringStore,
compute_key: F,
auth: &AuthDotJson,
) -> anyhow::Result<()>
where
F: FnOnce() -> std::io::Result<String>,
{
let key = compute_key()?;
let serialized = serde_json::to_string(auth)?;
mock_keyring.save(KEYRING_SERVICE, &key, &serialized)?;
Ok(())
}
fn assert_keyring_saved_auth_and_removed_fallback(
mock_keyring: &MockKeyringStore,
key: &str,
llmx_home: &Path,
expected: &AuthDotJson,
) {
let saved_value = mock_keyring
.saved_value(key)
.expect("keyring entry should exist");
let expected_serialized = serde_json::to_string(expected).expect("serialize expected auth");
assert_eq!(saved_value, expected_serialized);
let auth_file = get_auth_file(llmx_home);
assert!(
!auth_file.exists(),
"fallback auth.json should be removed after keyring save"
);
}
fn id_token_with_prefix(prefix: &str) -> IdTokenInfo {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = json!({
"email": format!("{prefix}@example.com"),
"https://api.openai.com/auth": {
"chatgpt_account_id": format!("{prefix}-account"),
},
});
let encode = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes);
let header_b64 = encode(&serde_json::to_vec(&header).expect("serialize header"));
let payload_b64 = encode(&serde_json::to_vec(&payload).expect("serialize payload"));
let signature_b64 = encode(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
crate::token_data::parse_id_token(&fake_jwt).expect("fake JWT should parse")
}
fn auth_with_prefix(prefix: &str) -> AuthDotJson {
AuthDotJson {
openai_api_key: Some(format!("{prefix}-api-key")),
tokens: Some(TokenData {
id_token: id_token_with_prefix(prefix),
access_token: format!("{prefix}-access"),
refresh_token: format!("{prefix}-refresh"),
account_id: Some(format!("{prefix}-account-id")),
}),
last_refresh: None,
}
}
#[test]
fn keyring_auth_storage_load_returns_deserialized_auth() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let expected = AuthDotJson {
openai_api_key: Some("sk-test".to_string()),
tokens: None,
last_refresh: None,
};
seed_keyring_with_auth(
&mock_keyring,
|| compute_store_key(llmx_home.path()),
&expected,
)?;
let loaded = storage.load()?;
assert_eq!(Some(expected), loaded);
Ok(())
}
#[test]
fn keyring_auth_storage_compute_store_key_for_home_directory() -> anyhow::Result<()> {
let llmx_home = PathBuf::from("~/.llmx");
let key = compute_store_key(llmx_home.as_path())?;
// Test that the key has the expected format (cli|<16-char-hex>)
assert!(key.starts_with("cli|"), "key should start with 'cli|'");
assert_eq!(
key.len(),
20,
"key should be 20 characters (cli| + 16 hex chars)"
);
// Verify the hash portion is valid hex
let hash_part = &key[4..];
assert!(
hash_part.chars().all(|c| c.is_ascii_hexdigit()),
"hash portion should be hex"
);
Ok(())
}
#[test]
fn keyring_auth_storage_save_persists_and_removes_fallback_file() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let auth_file = get_auth_file(llmx_home.path());
std::fs::write(&auth_file, "stale")?;
let auth = AuthDotJson {
openai_api_key: None,
tokens: Some(TokenData {
id_token: Default::default(),
access_token: "access".to_string(),
refresh_token: "refresh".to_string(),
account_id: Some("account".to_string()),
}),
last_refresh: Some(Utc::now()),
};
storage.save(&auth)?;
let key = compute_store_key(llmx_home.path())?;
assert_keyring_saved_auth_and_removed_fallback(
&mock_keyring,
&key,
llmx_home.path(),
&auth,
);
Ok(())
}
#[test]
fn keyring_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = KeyringAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete(
&mock_keyring,
llmx_home.path(),
|| compute_store_key(llmx_home.path()),
)?;
let removed = storage.delete()?;
assert!(removed, "delete should report removal");
assert!(
!mock_keyring.contains(&key),
"keyring entry should be removed"
);
assert!(
!auth_file.exists(),
"fallback auth.json should be removed after keyring delete"
);
Ok(())
}
#[test]
fn auto_auth_storage_load_prefers_keyring_value() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let keyring_auth = auth_with_prefix("keyring");
seed_keyring_with_auth(
&mock_keyring,
|| compute_store_key(llmx_home.path()),
&keyring_auth,
)?;
let file_auth = auth_with_prefix("file");
storage.file_storage.save(&file_auth)?;
let loaded = storage.load()?;
assert_eq!(loaded, Some(keyring_auth));
Ok(())
}
#[test]
fn auto_auth_storage_load_uses_file_when_keyring_empty() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(llmx_home.path().to_path_buf(), Arc::new(mock_keyring));
let expected = auth_with_prefix("file-only");
storage.file_storage.save(&expected)?;
let loaded = storage.load()?;
assert_eq!(loaded, Some(expected));
Ok(())
}
#[test]
fn auto_auth_storage_load_falls_back_when_keyring_errors() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let key = compute_store_key(llmx_home.path())?;
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "load".into()));
let expected = auth_with_prefix("fallback");
storage.file_storage.save(&expected)?;
let loaded = storage.load()?;
assert_eq!(loaded, Some(expected));
Ok(())
}
#[test]
fn auto_auth_storage_save_prefers_keyring() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let key = compute_store_key(llmx_home.path())?;
let stale = auth_with_prefix("stale");
storage.file_storage.save(&stale)?;
let expected = auth_with_prefix("to-save");
storage.save(&expected)?;
assert_keyring_saved_auth_and_removed_fallback(
&mock_keyring,
&key,
llmx_home.path(),
&expected,
);
Ok(())
}
#[test]
fn auto_auth_storage_save_falls_back_when_keyring_errors() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let key = compute_store_key(llmx_home.path())?;
mock_keyring.set_error(&key, KeyringError::Invalid("error".into(), "save".into()));
let auth = auth_with_prefix("fallback");
storage.save(&auth)?;
let auth_file = get_auth_file(llmx_home.path());
assert!(
auth_file.exists(),
"fallback auth.json should be created when keyring save fails"
);
let saved = storage
.file_storage
.load()?
.context("fallback auth should exist")?;
assert_eq!(saved, auth);
assert!(
mock_keyring.saved_value(&key).is_none(),
"keyring should not contain value when save fails"
);
Ok(())
}
#[test]
fn auto_auth_storage_delete_removes_keyring_and_file() -> anyhow::Result<()> {
let llmx_home = tempdir()?;
let mock_keyring = MockKeyringStore::default();
let storage = AutoAuthStorage::new(
llmx_home.path().to_path_buf(),
Arc::new(mock_keyring.clone()),
);
let (key, auth_file) = seed_keyring_and_fallback_auth_file_for_delete(
&mock_keyring,
llmx_home.path(),
|| compute_store_key(llmx_home.path()),
)?;
let removed = storage.delete()?;
assert!(removed, "delete should report removal");
assert!(
!mock_keyring.contains(&key),
"keyring entry should be removed"
);
assert!(
!auth_file.exists(),
"fallback auth.json should be removed after delete"
);
Ok(())
}
}

261
llmx-rs/core/src/bash.rs Normal file
View File

@@ -0,0 +1,261 @@
use tree_sitter::Node;
use tree_sitter::Parser;
use tree_sitter::Tree;
use tree_sitter_bash::LANGUAGE as BASH;
/// Parse the provided bash source using tree-sitter-bash, returning a Tree on
/// success or None if parsing failed.
pub fn try_parse_shell(shell_lc_arg: &str) -> Option<Tree> {
let lang = BASH.into();
let mut parser = Parser::new();
#[expect(clippy::expect_used)]
parser.set_language(&lang).expect("load bash grammar");
let old_tree: Option<&Tree> = None;
parser.parse(shell_lc_arg, old_tree)
}
/// Parse a script which may contain multiple simple commands joined only by
/// the safe logical/pipe/sequencing operators: `&&`, `||`, `;`, `|`.
///
/// Returns `Some(Vec<command_words>)` if every command is a plain wordonly
/// command and the parse tree does not contain disallowed constructs
/// (parentheses, redirections, substitutions, control flow, etc.). Otherwise
/// returns `None`.
pub fn try_parse_word_only_commands_sequence(tree: &Tree, src: &str) -> Option<Vec<Vec<String>>> {
if tree.root_node().has_error() {
return None;
}
// List of allowed (named) node kinds for a "word only commands sequence".
// If we encounter a named node that is not in this list we reject.
const ALLOWED_KINDS: &[&str] = &[
// top level containers
"program",
"list",
"pipeline",
// commands & words
"command",
"command_name",
"word",
"string",
"string_content",
"raw_string",
"number",
];
// Allow only safe punctuation / operator tokens; anything else causes reject.
const ALLOWED_PUNCT_TOKENS: &[&str] = &["&&", "||", ";", "|", "\"", "'"];
let root = tree.root_node();
let mut cursor = root.walk();
let mut stack = vec![root];
let mut command_nodes = Vec::new();
while let Some(node) = stack.pop() {
let kind = node.kind();
if node.is_named() {
if !ALLOWED_KINDS.contains(&kind) {
return None;
}
if kind == "command" {
command_nodes.push(node);
}
} else {
// Reject any punctuation / operator tokens that are not explicitly allowed.
if kind.chars().any(|c| "&;|".contains(c)) && !ALLOWED_PUNCT_TOKENS.contains(&kind) {
return None;
}
if !(ALLOWED_PUNCT_TOKENS.contains(&kind) || kind.trim().is_empty()) {
// If it's a quote token or operator it's allowed above; we also allow whitespace tokens.
// Any other punctuation like parentheses, braces, redirects, backticks, etc are rejected.
return None;
}
}
for child in node.children(&mut cursor) {
stack.push(child);
}
}
// Walk uses a stack (LIFO), so re-sort by position to restore source order.
command_nodes.sort_by_key(Node::start_byte);
let mut commands = Vec::new();
for node in command_nodes {
if let Some(words) = parse_plain_command_from_node(node, src) {
commands.push(words);
} else {
return None;
}
}
Some(commands)
}
pub fn is_well_known_sh_shell(shell: &str) -> bool {
if shell == "bash" || shell == "zsh" {
return true;
}
let shell_name = std::path::Path::new(shell)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(shell);
matches!(shell_name, "bash" | "zsh")
}
pub fn extract_bash_command(command: &[String]) -> Option<(&str, &str)> {
let [shell, flag, script] = command else {
return None;
};
if flag != "-lc" || !is_well_known_sh_shell(shell) {
return None;
}
Some((shell, script))
}
/// Returns the sequence of plain commands within a `bash -lc "..."` or
/// `zsh -lc "..."` invocation when the script only contains word-only commands
/// joined by safe operators.
pub fn parse_shell_lc_plain_commands(command: &[String]) -> Option<Vec<Vec<String>>> {
let (_, script) = extract_bash_command(command)?;
let tree = try_parse_shell(script)?;
try_parse_word_only_commands_sequence(&tree, script)
}
fn parse_plain_command_from_node(cmd: tree_sitter::Node, src: &str) -> Option<Vec<String>> {
if cmd.kind() != "command" {
return None;
}
let mut words = Vec::new();
let mut cursor = cmd.walk();
for child in cmd.named_children(&mut cursor) {
match child.kind() {
"command_name" => {
let word_node = child.named_child(0)?;
if word_node.kind() != "word" {
return None;
}
words.push(word_node.utf8_text(src.as_bytes()).ok()?.to_owned());
}
"word" | "number" => {
words.push(child.utf8_text(src.as_bytes()).ok()?.to_owned());
}
"string" => {
if child.child_count() == 3
&& child.child(0)?.kind() == "\""
&& child.child(1)?.kind() == "string_content"
&& child.child(2)?.kind() == "\""
{
words.push(child.child(1)?.utf8_text(src.as_bytes()).ok()?.to_owned());
} else {
return None;
}
}
"raw_string" => {
let raw_string = child.utf8_text(src.as_bytes()).ok()?;
let stripped = raw_string
.strip_prefix('\'')
.and_then(|s| s.strip_suffix('\''));
if let Some(s) = stripped {
words.push(s.to_owned());
} else {
return None;
}
}
_ => return None,
}
}
Some(words)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_seq(src: &str) -> Option<Vec<Vec<String>>> {
let tree = try_parse_shell(src)?;
try_parse_word_only_commands_sequence(&tree, src)
}
#[test]
fn accepts_single_simple_command() {
let cmds = parse_seq("ls -1").unwrap();
assert_eq!(cmds, vec![vec!["ls".to_string(), "-1".to_string()]]);
}
#[test]
fn accepts_multiple_commands_with_allowed_operators() {
let src = "ls && pwd; echo 'hi there' | wc -l";
let cmds = parse_seq(src).unwrap();
let expected: Vec<Vec<String>> = vec![
vec!["ls".to_string()],
vec!["pwd".to_string()],
vec!["echo".to_string(), "hi there".to_string()],
vec!["wc".to_string(), "-l".to_string()],
];
assert_eq!(cmds, expected);
}
#[test]
fn extracts_double_and_single_quoted_strings() {
let cmds = parse_seq("echo \"hello world\"").unwrap();
assert_eq!(
cmds,
vec![vec!["echo".to_string(), "hello world".to_string()]]
);
let cmds2 = parse_seq("echo 'hi there'").unwrap();
assert_eq!(
cmds2,
vec![vec!["echo".to_string(), "hi there".to_string()]]
);
}
#[test]
fn accepts_numbers_as_words() {
let cmds = parse_seq("echo 123 456").unwrap();
assert_eq!(
cmds,
vec![vec![
"echo".to_string(),
"123".to_string(),
"456".to_string()
]]
);
}
#[test]
fn rejects_parentheses_and_subshells() {
assert!(parse_seq("(ls)").is_none());
assert!(parse_seq("ls || (pwd && echo hi)").is_none());
}
#[test]
fn rejects_redirections_and_unsupported_operators() {
assert!(parse_seq("ls > out.txt").is_none());
assert!(parse_seq("echo hi & echo bye").is_none());
}
#[test]
fn rejects_command_and_process_substitutions_and_expansions() {
assert!(parse_seq("echo $(pwd)").is_none());
assert!(parse_seq("echo `pwd`").is_none());
assert!(parse_seq("echo $HOME").is_none());
assert!(parse_seq("echo \"hi $USER\"").is_none());
}
#[test]
fn rejects_variable_assignment_prefix() {
assert!(parse_seq("FOO=bar ls").is_none());
}
#[test]
fn rejects_trailing_operator_parse_error() {
assert!(parse_seq("ls &&").is_none());
}
#[test]
fn parse_zsh_lc_plain_commands() {
let command = vec!["zsh".to_string(), "-lc".to_string(), "ls".to_string()];
let parsed = parse_shell_lc_plain_commands(&command).unwrap();
assert_eq!(parsed, vec![vec!["ls".to_string()]]);
}
}

File diff suppressed because it is too large Load Diff

1523
llmx-rs/core/src/client.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,553 @@
use crate::client_common::tools::ToolSpec;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use futures::Stream;
use llmx_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
use llmx_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use llmx_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use llmx_protocol::config_types::Verbosity as VerbosityConfig;
use llmx_protocol::models::ResponseItem;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::borrow::Cow;
use std::collections::HashSet;
use std::ops::Deref;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
// Centralized templates for review-related user messages
pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exit_success.xml");
pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str =
include_str!("../templates/review/exit_interrupted.xml");
/// API request payload for a single model turn
#[derive(Default, Debug, Clone)]
pub struct Prompt {
/// Conversation context input items.
pub input: Vec<ResponseItem>,
/// Tools available to the model, including additional tools sourced from
/// external MCP servers.
pub(crate) tools: Vec<ToolSpec>,
/// Whether parallel tool calls are permitted for this prompt.
pub(crate) parallel_tool_calls: bool,
/// Optional override for the built-in BASE_INSTRUCTIONS.
pub base_instructions_override: Option<String>,
/// Optional the output schema for the model's response.
pub output_schema: Option<Value>,
}
impl Prompt {
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
let base = self
.base_instructions_override
.as_deref()
.unwrap_or(model.base_instructions.deref());
// When there are no custom instructions, add apply_patch_tool_instructions if:
// - the model needs special instructions (4.1)
// AND
// - there is no apply_patch tool present
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
ToolSpec::Function(f) => f.name == "apply_patch",
ToolSpec::Freeform(f) => f.name == "apply_patch",
_ => false,
});
if self.base_instructions_override.is_none()
&& model.needs_special_apply_patch_instructions
&& !is_apply_patch_tool_present
{
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
} else {
Cow::Borrowed(base)
}
}
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
let mut input = self.input.clone();
// when using the *Freeform* apply_patch tool specifically, tool outputs
// should be structured text, not json. Do NOT reserialize when using
// the Function tool - note that this differs from the check above for
// instructions. We declare the result as a named variable for clarity.
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
ToolSpec::Freeform(f) => f.name == "apply_patch",
_ => false,
});
if is_freeform_apply_patch_tool_present {
reserialize_shell_outputs(&mut input);
}
input
}
}
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
let mut shell_call_ids: HashSet<String> = HashSet::new();
items.iter_mut().for_each(|item| match item {
ResponseItem::LocalShellCall { call_id, id, .. } => {
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
shell_call_ids.insert(identifier);
}
}
ResponseItem::CustomToolCall {
id: _,
status: _,
call_id,
name,
input: _,
} => {
if name == "apply_patch" {
shell_call_ids.insert(call_id.clone());
}
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
if shell_call_ids.remove(call_id)
&& let Some(structured) = parse_structured_shell_output(output)
{
*output = structured
}
}
ResponseItem::FunctionCall { name, call_id, .. }
if is_shell_tool_name(name) || name == "apply_patch" =>
{
shell_call_ids.insert(call_id.clone());
}
ResponseItem::FunctionCallOutput { call_id, output } => {
if shell_call_ids.remove(call_id)
&& let Some(structured) = parse_structured_shell_output(&output.content)
{
output.content = structured
}
}
_ => {}
})
}
fn is_shell_tool_name(name: &str) -> bool {
matches!(name, "shell" | "container.exec")
}
#[derive(Deserialize)]
struct ExecOutputJson {
output: String,
metadata: ExecOutputMetadataJson,
}
#[derive(Deserialize)]
struct ExecOutputMetadataJson {
exit_code: i32,
duration_seconds: f32,
}
fn parse_structured_shell_output(raw: &str) -> Option<String> {
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
Some(build_structured_output(&parsed))
}
fn build_structured_output(parsed: &ExecOutputJson) -> String {
let mut sections = Vec::new();
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
sections.push(format!(
"Wall time: {} seconds",
parsed.metadata.duration_seconds
));
let mut output = parsed.output.clone();
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
sections.push(format!("Total output lines: {total_lines}"));
if let Some(stripped) = strip_total_output_header(&output) {
output = stripped.to_string();
}
}
sections.push("Output:".to_string());
sections.push(output);
sections.join("\n")
}
fn extract_total_output_lines(output: &str) -> Option<u32> {
let marker_start = output.find("[... omitted ")?;
let marker = &output[marker_start..];
let (_, after_of) = marker.split_once(" of ")?;
let (total_segment, _) = after_of.split_once(' ')?;
total_segment.parse::<u32>().ok()
}
fn strip_total_output_header(output: &str) -> Option<&str> {
let after_prefix = output.strip_prefix("Total output lines: ")?;
let (_, remainder) = after_prefix.split_once('\n')?;
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
Some(remainder)
}
#[derive(Debug)]
pub enum ResponseEvent {
Created,
OutputItemDone(ResponseItem),
OutputItemAdded(ResponseItem),
Completed {
response_id: String,
token_usage: Option<TokenUsage>,
},
OutputTextDelta(String),
ReasoningSummaryDelta(String),
ReasoningContentDelta(String),
ReasoningSummaryPartAdded,
RateLimits(RateLimitSnapshot),
}
#[derive(Debug, Serialize)]
pub(crate) struct Reasoning {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) effort: Option<ReasoningEffortConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) summary: Option<ReasoningSummaryConfig>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "snake_case")]
pub(crate) enum TextFormatType {
#[default]
JsonSchema,
}
#[derive(Debug, Serialize, Default, Clone)]
pub(crate) struct TextFormat {
pub(crate) r#type: TextFormatType,
pub(crate) strict: bool,
pub(crate) schema: Value,
pub(crate) name: String,
}
/// Controls under the `text` field in the Responses API for GPT-5.
#[derive(Debug, Serialize, Default, Clone)]
pub(crate) struct TextControls {
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) verbosity: Option<OpenAiVerbosity>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) format: Option<TextFormat>,
}
#[derive(Debug, Serialize, Default, Clone)]
#[serde(rename_all = "lowercase")]
pub(crate) enum OpenAiVerbosity {
Low,
#[default]
Medium,
High,
}
impl From<VerbosityConfig> for OpenAiVerbosity {
fn from(v: VerbosityConfig) -> Self {
match v {
VerbosityConfig::Low => OpenAiVerbosity::Low,
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
VerbosityConfig::High => OpenAiVerbosity::High,
}
}
}
/// Request object that is serialized as JSON and POST'ed when using the
/// Responses API.
#[derive(Debug, Serialize)]
pub(crate) struct ResponsesApiRequest<'a> {
pub(crate) model: &'a str,
pub(crate) instructions: &'a str,
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
// we code defensively to avoid this case, but perhaps we should use a
// separate enum for serialization.
pub(crate) input: &'a Vec<ResponseItem>,
pub(crate) tools: &'a [serde_json::Value],
pub(crate) tool_choice: &'static str,
pub(crate) parallel_tool_calls: bool,
pub(crate) reasoning: Option<Reasoning>,
pub(crate) store: bool,
pub(crate) stream: bool,
pub(crate) include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) text: Option<TextControls>,
}
pub(crate) mod tools {
use crate::tools::spec::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
/// Responses API.
#[derive(Debug, Clone, Serialize, PartialEq)]
#[serde(tag = "type")]
pub(crate) enum ToolSpec {
#[serde(rename = "function")]
Function(ResponsesApiTool),
#[serde(rename = "local_shell")]
LocalShell {},
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
#[serde(rename = "web_search")]
WebSearch {},
#[serde(rename = "custom")]
Freeform(FreeformTool),
}
impl ToolSpec {
pub(crate) fn name(&self) -> &str {
match self {
ToolSpec::Function(tool) => tool.name.as_str(),
ToolSpec::LocalShell {} => "local_shell",
ToolSpec::WebSearch {} => "web_search",
ToolSpec::Freeform(tool) => tool.name.as_str(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FreeformTool {
pub(crate) name: String,
pub(crate) description: String,
pub(crate) format: FreeformToolFormat,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FreeformToolFormat {
pub(crate) r#type: String,
pub(crate) syntax: String,
pub(crate) definition: String,
}
#[derive(Debug, Clone, Serialize, PartialEq)]
pub struct ResponsesApiTool {
pub(crate) name: String,
pub(crate) description: String,
/// TODO: Validation. When strict is set to true, the JSON schema,
/// `required` and `additional_properties` must be present. All fields in
/// `properties` must be present in `required`.
pub(crate) strict: bool,
pub(crate) parameters: JsonSchema,
}
}
pub(crate) fn create_reasoning_param_for_request(
model_family: &ModelFamily,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
) -> Option<Reasoning> {
if !model_family.supports_reasoning_summaries {
return None;
}
Some(Reasoning {
effort,
summary: Some(summary),
})
}
pub(crate) fn create_text_param_for_request(
verbosity: Option<VerbosityConfig>,
output_schema: &Option<Value>,
) -> Option<TextControls> {
if verbosity.is_none() && output_schema.is_none() {
return None;
}
Some(TextControls {
verbosity: verbosity.map(std::convert::Into::into),
format: output_schema.as_ref().map(|schema| TextFormat {
r#type: TextFormatType::JsonSchema,
strict: true,
schema: schema.clone(),
name: "llmx_output_schema".to_string(),
}),
})
}
pub struct ResponseStream {
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
}
impl Stream for ResponseStream {
type Item = Result<ResponseEvent>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx_event.poll_recv(cx)
}
}
#[cfg(test)]
mod tests {
use crate::model_family::find_family_for_model;
use pretty_assertions::assert_eq;
use super::*;
struct InstructionsTestCase {
pub slug: &'static str,
pub expects_apply_patch_instructions: bool,
}
#[test]
fn get_full_instructions_no_user_content() {
let prompt = Prompt {
..Default::default()
};
let test_cases = vec![
InstructionsTestCase {
slug: "gpt-3.5",
expects_apply_patch_instructions: true,
},
InstructionsTestCase {
slug: "gpt-4.1",
expects_apply_patch_instructions: true,
},
InstructionsTestCase {
slug: "gpt-4o",
expects_apply_patch_instructions: true,
},
InstructionsTestCase {
slug: "gpt-5",
expects_apply_patch_instructions: true,
},
InstructionsTestCase {
slug: "llmx-mini-latest",
expects_apply_patch_instructions: true,
},
InstructionsTestCase {
slug: "gpt-oss:120b",
expects_apply_patch_instructions: false,
},
InstructionsTestCase {
slug: "gpt-5-llmx",
expects_apply_patch_instructions: false,
},
];
for test_case in test_cases {
let model_family = find_family_for_model(test_case.slug).expect("known model slug");
let expected = if test_case.expects_apply_patch_instructions {
format!(
"{}\n{}",
model_family.clone().base_instructions,
APPLY_PATCH_TOOL_INSTRUCTIONS
)
} else {
model_family.clone().base_instructions
};
let full = prompt.get_full_instructions(&model_family);
assert_eq!(full, expected);
}
}
#[test]
fn serializes_text_verbosity_when_set() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: Some(TextControls {
verbosity: Some(OpenAiVerbosity::Low),
format: None,
}),
};
let v = serde_json::to_value(&req).expect("json");
assert_eq!(
v.get("text")
.and_then(|t| t.get("verbosity"))
.and_then(|s| s.as_str()),
Some("low")
);
}
#[test]
fn serializes_text_schema_with_strict_format() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let schema = serde_json::json!({
"type": "object",
"properties": {
"answer": {"type": "string"}
},
"required": ["answer"],
});
let text_controls =
create_text_param_for_request(None, &Some(schema.clone())).expect("text controls");
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: Some(text_controls),
};
let v = serde_json::to_value(&req).expect("json");
let text = v.get("text").expect("text field");
assert!(text.get("verbosity").is_none());
let format = text.get("format").expect("format field");
assert_eq!(
format.get("name"),
Some(&serde_json::Value::String("llmx_output_schema".into()))
);
assert_eq!(
format.get("type"),
Some(&serde_json::Value::String("json_schema".into()))
);
assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true)));
assert_eq!(format.get("schema"), Some(&schema));
}
#[test]
fn omits_text_when_not_set() {
let input: Vec<ResponseItem> = vec![];
let tools: Vec<serde_json::Value> = vec![];
let req = ResponsesApiRequest {
model: "gpt-5",
instructions: "i",
input: &input,
tools: &tools,
tool_choice: "auto",
parallel_tool_calls: true,
reasoning: None,
store: false,
stream: true,
include: vec![],
prompt_cache_key: None,
text: None,
};
let v = serde_json::to_value(&req).expect("json");
assert!(v.get("text").is_none());
}
}

View File

@@ -0,0 +1,142 @@
use llmx_protocol::protocol::AskForApproval;
use llmx_protocol::protocol::SandboxPolicy;
use crate::bash::parse_shell_lc_plain_commands;
use crate::is_safe_command::is_known_safe_command;
pub fn requires_initial_appoval(
policy: AskForApproval,
sandbox_policy: &SandboxPolicy,
command: &[String],
with_escalated_permissions: bool,
) -> bool {
if is_known_safe_command(command) {
return false;
}
match policy {
AskForApproval::Never | AskForApproval::OnFailure => false,
AskForApproval::OnRequest => {
// In DangerFullAccess, only prompt if the command looks dangerous.
if matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) {
return command_might_be_dangerous(command);
}
// In restricted sandboxes (ReadOnly/WorkspaceWrite), do not prompt for
// nonescalated, nondangerous commands — let the sandbox enforce
// restrictions (e.g., block network/write) without a user prompt.
let wants_escalation: bool = with_escalated_permissions;
if wants_escalation {
return true;
}
command_might_be_dangerous(command)
}
AskForApproval::UnlessTrusted => !is_known_safe_command(command),
}
}
pub fn command_might_be_dangerous(command: &[String]) -> bool {
if is_dangerous_to_call_with_exec(command) {
return true;
}
// Support `bash -lc "<script>"` where the any part of the script might contain a dangerous command.
if let Some(all_commands) = parse_shell_lc_plain_commands(command)
&& all_commands
.iter()
.any(|cmd| is_dangerous_to_call_with_exec(cmd))
{
return true;
}
false
}
fn is_dangerous_to_call_with_exec(command: &[String]) -> bool {
let cmd0 = command.first().map(String::as_str);
match cmd0 {
Some(cmd) if cmd.ends_with("git") || cmd.ends_with("/git") => {
matches!(command.get(1).map(String::as_str), Some("reset" | "rm"))
}
Some("rm") => matches!(command.get(1).map(String::as_str), Some("-f" | "-rf")),
// for sudo <cmd> simply do the check for <cmd>
Some("sudo") => is_dangerous_to_call_with_exec(&command[1..]),
// ── anything else ─────────────────────────────────────────────────
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vec_str(items: &[&str]) -> Vec<String> {
items.iter().map(std::string::ToString::to_string).collect()
}
#[test]
fn git_reset_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&["git", "reset"])));
}
#[test]
fn bash_git_reset_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&[
"bash",
"-lc",
"git reset --hard"
])));
}
#[test]
fn zsh_git_reset_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&[
"zsh",
"-lc",
"git reset --hard"
])));
}
#[test]
fn git_status_is_not_dangerous() {
assert!(!command_might_be_dangerous(&vec_str(&["git", "status"])));
}
#[test]
fn bash_git_status_is_not_dangerous() {
assert!(!command_might_be_dangerous(&vec_str(&[
"bash",
"-lc",
"git status"
])));
}
#[test]
fn sudo_git_reset_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&[
"sudo", "git", "reset", "--hard"
])));
}
#[test]
fn usr_bin_git_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&[
"/usr/bin/git",
"reset",
"--hard"
])));
}
#[test]
fn rm_rf_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&["rm", "-rf", "/"])));
}
#[test]
fn rm_f_is_dangerous() {
assert!(command_might_be_dangerous(&vec_str(&["rm", "-f", "/"])));
}
}

View File

@@ -0,0 +1,366 @@
use crate::bash::parse_shell_lc_plain_commands;
pub fn is_known_safe_command(command: &[String]) -> bool {
let command: Vec<String> = command
.iter()
.map(|s| {
if s == "zsh" {
"bash".to_string()
} else {
s.clone()
}
})
.collect();
#[cfg(target_os = "windows")]
{
use super::windows_safe_commands::is_safe_command_windows;
if is_safe_command_windows(&command) {
return true;
}
}
if is_safe_to_call_with_exec(&command) {
return true;
}
// Support `bash -lc "..."` where the script consists solely of one or
// more "plain" commands (only bare words / quoted strings) combined with
// a conservative allowlist of shell operators that themselves do not
// introduce side effects ( "&&", "||", ";", and "|" ). If every
// individual command in the script is itself a knownsafe command, then
// the composite expression is considered safe.
if let Some(all_commands) = parse_shell_lc_plain_commands(&command)
&& !all_commands.is_empty()
&& all_commands
.iter()
.all(|cmd| is_safe_to_call_with_exec(cmd))
{
return true;
}
false
}
fn is_safe_to_call_with_exec(command: &[String]) -> bool {
let Some(cmd0) = command.first().map(String::as_str) else {
return false;
};
match std::path::Path::new(&cmd0)
.file_name()
.and_then(|osstr| osstr.to_str())
{
#[rustfmt::skip]
Some(
"cat" |
"cd" |
"echo" |
"false" |
"grep" |
"head" |
"ls" |
"nl" |
"pwd" |
"tail" |
"true" |
"wc" |
"which") => {
true
},
Some("find") => {
// Certain options to `find` can delete files, write to files, or
// execute arbitrary commands, so we cannot auto-approve the
// invocation of `find` in such cases.
#[rustfmt::skip]
const UNSAFE_FIND_OPTIONS: &[&str] = &[
// Options that can execute arbitrary commands.
"-exec", "-execdir", "-ok", "-okdir",
// Option that deletes matching files.
"-delete",
// Options that write pathnames to a file.
"-fls", "-fprint", "-fprint0", "-fprintf",
];
!command
.iter()
.any(|arg| UNSAFE_FIND_OPTIONS.contains(&arg.as_str()))
}
// Ripgrep
Some("rg") => {
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &[
// Takes an arbitrary command that is executed for each match.
"--pre",
// Takes a command that can be used to obtain the local hostname.
"--hostname-bin",
];
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &[
// Calls out to other decompression tools, so do not auto-approve
// out of an abundance of caution.
"--search-zip",
"-z",
];
!command.iter().any(|arg| {
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg.as_str())
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
.iter()
.any(|&opt| arg == opt || arg.starts_with(&format!("{opt}=")))
})
}
// Git
Some("git") => matches!(
command.get(1).map(String::as_str),
Some("branch" | "status" | "log" | "diff" | "show")
),
// Rust
Some("cargo") if command.get(1).map(String::as_str) == Some("check") => true,
// Special-case `sed -n {N|M,N}p`
Some("sed")
if {
command.len() <= 4
&& command.get(1).map(String::as_str) == Some("-n")
&& is_valid_sed_n_arg(command.get(2).map(String::as_str))
} =>
{
true
}
// ── anything else ─────────────────────────────────────────────────
_ => false,
}
}
// (bash parsing helpers implemented in crate::bash)
/* ----------------------------------------------------------
Example
---------------------------------------------------------- */
/// Returns true if `arg` matches /^(\d+,)?\d+p$/
fn is_valid_sed_n_arg(arg: Option<&str>) -> bool {
// unwrap or bail
let s = match arg {
Some(s) => s,
None => return false,
};
// must end with 'p', strip it
let core = match s.strip_suffix('p') {
Some(rest) => rest,
None => return false,
};
// split on ',' and ensure 1 or 2 numeric parts
let parts: Vec<&str> = core.split(',').collect();
match parts.as_slice() {
// single number, e.g. "10"
[num] => !num.is_empty() && num.chars().all(|c| c.is_ascii_digit()),
// two numbers, e.g. "1,5"
[a, b] => {
!a.is_empty()
&& !b.is_empty()
&& a.chars().all(|c| c.is_ascii_digit())
&& b.chars().all(|c| c.is_ascii_digit())
}
// anything else (more than one comma) is invalid
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::string::ToString;
fn vec_str(args: &[&str]) -> Vec<String> {
args.iter().map(ToString::to_string).collect()
}
#[test]
fn known_safe_examples() {
assert!(is_safe_to_call_with_exec(&vec_str(&["ls"])));
assert!(is_safe_to_call_with_exec(&vec_str(&["git", "status"])));
assert!(is_safe_to_call_with_exec(&vec_str(&[
"sed", "-n", "1,5p", "file.txt"
])));
assert!(is_safe_to_call_with_exec(&vec_str(&[
"nl",
"-nrz",
"Cargo.toml"
])));
// Safe `find` command (no unsafe options).
assert!(is_safe_to_call_with_exec(&vec_str(&[
"find", ".", "-name", "file.txt"
])));
}
#[test]
fn zsh_lc_safe_command_sequence() {
assert!(is_known_safe_command(&vec_str(&["zsh", "-lc", "ls"])));
}
#[test]
fn unknown_or_partial() {
assert!(!is_safe_to_call_with_exec(&vec_str(&["foo"])));
assert!(!is_safe_to_call_with_exec(&vec_str(&["git", "fetch"])));
assert!(!is_safe_to_call_with_exec(&vec_str(&[
"sed", "-n", "xp", "file.txt"
])));
// Unsafe `find` commands.
for args in [
vec_str(&["find", ".", "-name", "file.txt", "-exec", "rm", "{}", ";"]),
vec_str(&[
"find", ".", "-name", "*.py", "-execdir", "python3", "{}", ";",
]),
vec_str(&["find", ".", "-name", "file.txt", "-ok", "rm", "{}", ";"]),
vec_str(&["find", ".", "-name", "*.py", "-okdir", "python3", "{}", ";"]),
vec_str(&["find", ".", "-delete", "-name", "file.txt"]),
vec_str(&["find", ".", "-fls", "/etc/passwd"]),
vec_str(&["find", ".", "-fprint", "/etc/passwd"]),
vec_str(&["find", ".", "-fprint0", "/etc/passwd"]),
vec_str(&["find", ".", "-fprintf", "/root/suid.txt", "%#m %u %p\n"]),
] {
assert!(
!is_safe_to_call_with_exec(&args),
"expected {args:?} to be unsafe"
);
}
}
#[test]
fn ripgrep_rules() {
// Safe ripgrep invocations none of the unsafe flags are present.
assert!(is_safe_to_call_with_exec(&vec_str(&[
"rg",
"Cargo.toml",
"-n"
])));
// Unsafe flags that do not take an argument (present verbatim).
for args in [
vec_str(&["rg", "--search-zip", "files"]),
vec_str(&["rg", "-z", "files"]),
] {
assert!(
!is_safe_to_call_with_exec(&args),
"expected {args:?} to be considered unsafe due to zip-search flag",
);
}
// Unsafe flags that expect a value, provided in both split and = forms.
for args in [
vec_str(&["rg", "--pre", "pwned", "files"]),
vec_str(&["rg", "--pre=pwned", "files"]),
vec_str(&["rg", "--hostname-bin", "pwned", "files"]),
vec_str(&["rg", "--hostname-bin=pwned", "files"]),
] {
assert!(
!is_safe_to_call_with_exec(&args),
"expected {args:?} to be considered unsafe due to external-command flag",
);
}
}
#[test]
fn bash_lc_safe_examples() {
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls"])));
assert!(is_known_safe_command(&vec_str(&["bash", "-lc", "ls -1"])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"git status"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"grep -R \"Cargo.toml\" -n"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"sed -n 1,5p file.txt"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"sed -n '1,5p' file.txt"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"find . -name file.txt"
])));
}
#[test]
fn bash_lc_safe_examples_with_operators() {
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"grep -R \"Cargo.toml\" -n || true"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"ls && pwd"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"echo 'hi' ; ls"
])));
assert!(is_known_safe_command(&vec_str(&[
"bash",
"-lc",
"ls | wc -l"
])));
}
#[test]
fn bash_lc_unsafe_examples() {
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "git", "status"])),
"Four arg version is not known to be safe."
);
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "'git status'"])),
"The extra quoting around 'git status' makes it a program named 'git status' and is therefore unsafe."
);
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "find . -name file.txt -delete"])),
"Unsafe find option should not be auto-approved."
);
// Disallowed because of unsafe command in sequence.
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls && rm -rf /"])),
"Sequence containing unsafe command must be rejected"
);
// Disallowed because of parentheses / subshell.
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "(ls)"])),
"Parentheses (subshell) are not provably safe with the current parser"
);
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls || (pwd && echo hi)"])),
"Nested parentheses are not provably safe with the current parser"
);
// Disallowed redirection.
assert!(
!is_known_safe_command(&vec_str(&["bash", "-lc", "ls > out.txt"])),
"> redirection should be rejected"
);
}
}

View File

@@ -0,0 +1,4 @@
pub mod is_dangerous_command;
pub mod is_safe_command;
#[cfg(target_os = "windows")]
pub mod windows_safe_commands;

View File

@@ -0,0 +1,431 @@
use shlex::split as shlex_split;
/// On Windows, we conservatively allow only clearly read-only PowerShell invocations
/// that match a small safelist. Anything else (including direct CMD commands) is unsafe.
pub fn is_safe_command_windows(command: &[String]) -> bool {
if let Some(commands) = try_parse_powershell_command_sequence(command) {
return commands
.iter()
.all(|cmd| is_safe_powershell_command(cmd.as_slice()));
}
// Only PowerShell invocations are allowed on Windows for now; anything else is unsafe.
false
}
/// Returns each command sequence if the invocation starts with a PowerShell binary.
/// For example, the tokens from `pwsh Get-ChildItem | Measure-Object` become two sequences.
fn try_parse_powershell_command_sequence(command: &[String]) -> Option<Vec<Vec<String>>> {
let (exe, rest) = command.split_first()?;
if !is_powershell_executable(exe) {
return None;
}
parse_powershell_invocation(rest)
}
/// Parses a PowerShell invocation into discrete command vectors, rejecting unsafe patterns.
fn parse_powershell_invocation(args: &[String]) -> Option<Vec<Vec<String>>> {
if args.is_empty() {
// Examples rejected here: "pwsh" and "powershell.exe" with no additional arguments.
return None;
}
let mut idx = 0;
while idx < args.len() {
let arg = &args[idx];
let lower = arg.to_ascii_lowercase();
match lower.as_str() {
"-command" | "/command" | "-c" => {
let script = args.get(idx + 1)?;
if idx + 2 != args.len() {
// Reject if there is more than one token representing the actual command.
// Examples rejected here: "pwsh -Command foo bar" and "powershell -c ls extra".
return None;
}
return parse_powershell_script(script);
}
_ if lower.starts_with("-command:") || lower.starts_with("/command:") => {
if idx + 1 != args.len() {
// Reject if there are more tokens after the command itself.
// Examples rejected here: "pwsh -Command:dir C:\\" and "powershell /Command:dir C:\\" with trailing args.
return None;
}
let script = arg.split_once(':')?.1;
return parse_powershell_script(script);
}
// Benign, no-arg flags we tolerate.
"-nologo" | "-noprofile" | "-noninteractive" | "-mta" | "-sta" => {
idx += 1;
continue;
}
// Explicitly forbidden/opaque or unnecessary for read-only operations.
"-encodedcommand" | "-ec" | "-file" | "/file" | "-windowstyle" | "-executionpolicy"
| "-workingdirectory" => {
// Examples rejected here: "pwsh -EncodedCommand ..." and "powershell -File script.ps1".
return None;
}
// Unknown switch → bail conservatively.
_ if lower.starts_with('-') => {
// Examples rejected here: "pwsh -UnknownFlag" and "powershell -foo bar".
return None;
}
// If we hit non-flag tokens, treat the remainder as a command sequence.
// This happens if powershell is invoked without -Command, e.g.
// ["pwsh", "-NoLogo", "git", "-c", "core.pager=cat", "status"]
_ => {
return split_into_commands(args[idx..].to_vec());
}
}
}
// Examples rejected here: "pwsh" and "powershell.exe -NoLogo" without a script.
None
}
/// Tokenizes an inline PowerShell script and delegates to the command splitter.
/// Examples of when this is called: pwsh.exe -Command '<script>' or pwsh.exe -Command:<script>
fn parse_powershell_script(script: &str) -> Option<Vec<Vec<String>>> {
let tokens = shlex_split(script)?;
split_into_commands(tokens)
}
/// Splits tokens into pipeline segments while ensuring no unsafe separators slip through.
/// e.g. Get-ChildItem | Measure-Object -> [['Get-ChildItem'], ['Measure-Object']]
fn split_into_commands(tokens: Vec<String>) -> Option<Vec<Vec<String>>> {
if tokens.is_empty() {
// Examples rejected here: "pwsh -Command ''" and "powershell -Command \"\"".
return None;
}
let mut commands = Vec::new();
let mut current = Vec::new();
for token in tokens.into_iter() {
match token.as_str() {
"|" | "||" | "&&" | ";" => {
if current.is_empty() {
// Examples rejected here: "pwsh -Command '| Get-ChildItem'" and "pwsh -Command '; dir'".
return None;
}
commands.push(current);
current = Vec::new();
}
// Reject if any token embeds separators, redirection, or call operator characters.
_ if token.contains(['|', ';', '>', '<', '&']) || token.contains("$(") => {
// Examples rejected here: "pwsh -Command 'dir|select'" and "pwsh -Command 'echo hi > out.txt'".
return None;
}
_ => current.push(token),
}
}
if current.is_empty() {
// Examples rejected here: "pwsh -Command 'dir |'" and "pwsh -Command 'Get-ChildItem ;'".
return None;
}
commands.push(current);
Some(commands)
}
/// Returns true when the executable name is one of the supported PowerShell binaries.
fn is_powershell_executable(exe: &str) -> bool {
matches!(
exe.to_ascii_lowercase().as_str(),
"powershell" | "powershell.exe" | "pwsh" | "pwsh.exe"
)
}
/// Validates that a parsed PowerShell command stays within our read-only safelist.
/// Everything before this is parsing, and rejecting things that make us feel uncomfortable.
fn is_safe_powershell_command(words: &[String]) -> bool {
if words.is_empty() {
// Examples rejected here: "pwsh -Command ''" and "pwsh -Command \"\"".
return false;
}
// Reject nested unsafe cmdlets inside parentheses or arguments
for w in words.iter() {
let inner = w
.trim_matches(|c| c == '(' || c == ')')
.trim_start_matches('-')
.to_ascii_lowercase();
if matches!(
inner.as_str(),
"set-content"
| "add-content"
| "out-file"
| "new-item"
| "remove-item"
| "move-item"
| "copy-item"
| "rename-item"
| "start-process"
| "stop-process"
) {
// Examples rejected here: "Write-Output (Set-Content foo6.txt 'abc')" and "Get-Content (New-Item bar.txt)".
return false;
}
}
// Block PowerShell call operator or any redirection explicitly.
if words.iter().any(|w| {
matches!(
w.as_str(),
"&" | ">" | ">>" | "1>" | "2>" | "2>&1" | "*>" | "<" | "<<"
)
}) {
// Examples rejected here: "pwsh -Command '& Remove-Item foo'" and "pwsh -Command 'Get-Content foo > bar'".
return false;
}
let command = words[0]
.trim_matches(|c| c == '(' || c == ')')
.trim_start_matches('-')
.to_ascii_lowercase();
match command.as_str() {
"echo" | "write-output" | "write-host" => true, // (no redirection allowed)
"dir" | "ls" | "get-childitem" | "gci" => true,
"cat" | "type" | "gc" | "get-content" => true,
"select-string" | "sls" | "findstr" => true,
"measure-object" | "measure" => true,
"get-location" | "gl" | "pwd" => true,
"test-path" | "tp" => true,
"resolve-path" | "rvpa" => true,
"select-object" | "select" => true,
"get-item" => true,
"git" => is_safe_git_command(words),
"rg" => is_safe_ripgrep(words),
// Extra safety: explicitly prohibit common side-effecting cmdlets regardless of args.
"set-content" | "add-content" | "out-file" | "new-item" | "remove-item" | "move-item"
| "copy-item" | "rename-item" | "start-process" | "stop-process" => {
// Examples rejected here: "pwsh -Command 'Set-Content notes.txt data'" and "pwsh -Command 'Remove-Item temp.log'".
false
}
_ => {
// Examples rejected here: "pwsh -Command 'Invoke-WebRequest https://example.com'" and "pwsh -Command 'Start-Service Spooler'".
false
}
}
}
/// Checks that an `rg` invocation avoids options that can spawn arbitrary executables.
fn is_safe_ripgrep(words: &[String]) -> bool {
const UNSAFE_RIPGREP_OPTIONS_WITH_ARGS: &[&str] = &["--pre", "--hostname-bin"];
const UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS: &[&str] = &["--search-zip", "-z"];
!words.iter().skip(1).any(|arg| {
let arg_lc = arg.to_ascii_lowercase();
// Examples rejected here: "pwsh -Command 'rg --pre cat pattern'" and "pwsh -Command 'rg --search-zip pattern'".
UNSAFE_RIPGREP_OPTIONS_WITHOUT_ARGS.contains(&arg_lc.as_str())
|| UNSAFE_RIPGREP_OPTIONS_WITH_ARGS
.iter()
.any(|opt| arg_lc == *opt || arg_lc.starts_with(&format!("{opt}=")))
})
}
/// Ensures a Git command sticks to whitelisted read-only subcommands and flags.
fn is_safe_git_command(words: &[String]) -> bool {
const SAFE_SUBCOMMANDS: &[&str] = &["status", "log", "show", "diff", "cat-file"];
let mut iter = words.iter().skip(1);
while let Some(arg) = iter.next() {
let arg_lc = arg.to_ascii_lowercase();
if arg.starts_with('-') {
if arg.eq_ignore_ascii_case("-c") || arg.eq_ignore_ascii_case("--config") {
if iter.next().is_none() {
// Examples rejected here: "pwsh -Command 'git -c'" and "pwsh -Command 'git --config'".
return false;
}
continue;
}
if arg_lc.starts_with("-c=")
|| arg_lc.starts_with("--config=")
|| arg_lc.starts_with("--git-dir=")
|| arg_lc.starts_with("--work-tree=")
{
continue;
}
if arg.eq_ignore_ascii_case("--git-dir") || arg.eq_ignore_ascii_case("--work-tree") {
if iter.next().is_none() {
// Examples rejected here: "pwsh -Command 'git --git-dir'" and "pwsh -Command 'git --work-tree'".
return false;
}
continue;
}
continue;
}
return SAFE_SUBCOMMANDS.contains(&arg_lc.as_str());
}
// Examples rejected here: "pwsh -Command 'git'" and "pwsh -Command 'git status --short | Remove-Item foo'".
false
}
#[cfg(test)]
mod tests {
use super::is_safe_command_windows;
use std::string::ToString;
/// Converts a slice of string literals into owned `String`s for the tests.
fn vec_str(args: &[&str]) -> Vec<String> {
args.iter().map(ToString::to_string).collect()
}
#[test]
fn recognizes_safe_powershell_wrappers() {
assert!(is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-NoLogo",
"-Command",
"Get-ChildItem -Path .",
])));
assert!(is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-NoProfile",
"-Command",
"git status",
])));
assert!(is_safe_command_windows(&vec_str(&[
"powershell.exe",
"Get-Content",
"Cargo.toml",
])));
// pwsh parity
assert!(is_safe_command_windows(&vec_str(&[
"pwsh.exe",
"-NoProfile",
"-Command",
"Get-ChildItem",
])));
}
#[test]
fn allows_read_only_pipelines_and_git_usage() {
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-NoLogo",
"-NoProfile",
"-Command",
"rg --files-with-matches foo | Measure-Object | Select-Object -ExpandProperty Count",
])));
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-NoLogo",
"-NoProfile",
"-Command",
"Get-Content foo.rs | Select-Object -Skip 200",
])));
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-NoLogo",
"-NoProfile",
"-Command",
"git -c core.pager=cat show HEAD:foo.rs",
])));
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-Command",
"-git cat-file -p HEAD:foo.rs",
])));
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-Command",
"(Get-Content foo.rs -Raw)",
])));
assert!(is_safe_command_windows(&vec_str(&[
"pwsh",
"-Command",
"Get-Item foo.rs | Select-Object Length",
])));
}
#[test]
fn rejects_powershell_commands_with_side_effects() {
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-NoLogo",
"-Command",
"Remove-Item foo.txt",
])));
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-NoProfile",
"-Command",
"rg --pre cat",
])));
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Set-Content foo.txt 'hello'",
])));
// Redirections are blocked
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"echo hi > out.txt",
])));
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Get-Content x | Out-File y",
])));
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Write-Output foo 2> err.txt",
])));
// Call operator is blocked
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"& Remove-Item foo",
])));
// Chained safe + unsafe must fail
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Get-ChildItem; Remove-Item foo",
])));
// Nested unsafe cmdlet inside safe command must fail
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Write-Output (Set-Content foo6.txt 'abc')",
])));
// Additional nested unsafe cmdlet examples must fail
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Write-Host (Remove-Item foo.txt)",
])));
assert!(!is_safe_command_windows(&vec_str(&[
"powershell.exe",
"-Command",
"Get-Content (New-Item bar.txt)",
])));
}
}

451
llmx-rs/core/src/compact.rs Normal file
View File

@@ -0,0 +1,451 @@
use std::sync::Arc;
use crate::Prompt;
use crate::client_common::ResponseEvent;
use crate::error::LlmxErr;
use crate::error::Result as LlmxResult;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::llmx::get_last_assistant_message_from_turn;
use crate::protocol::AgentMessageEvent;
use crate::protocol::CompactedItem;
use crate::protocol::ErrorEvent;
use crate::protocol::EventMsg;
use crate::protocol::TaskStartedEvent;
use crate::protocol::TurnContextItem;
use crate::protocol::WarningEvent;
use crate::truncate::truncate_middle;
use crate::util::backoff;
use futures::prelude::*;
use llmx_protocol::items::TurnItem;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ResponseInputItem;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::RolloutItem;
use llmx_protocol::user_input::UserInput;
use tracing::error;
pub const SUMMARIZATION_PROMPT: &str = include_str!("../templates/compact/prompt.md");
const COMPACT_USER_MESSAGE_MAX_TOKENS: usize = 20_000;
pub(crate) async fn run_inline_auto_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
) {
let prompt = turn_context.compact_prompt().to_string();
let input = vec![UserInput::Text { text: prompt }];
run_compact_task_inner(sess, turn_context, input).await;
}
pub(crate) async fn run_compact_task(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
) -> Option<String> {
let start_event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
});
sess.send_event(&turn_context, start_event).await;
run_compact_task_inner(sess.clone(), turn_context, input).await;
None
}
async fn run_compact_task_inner(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
) {
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
let mut history = sess.clone_history().await;
history.record_items(&[initial_input_for_turn.into()]);
let mut truncated_count = 0usize;
let max_retries = turn_context.client.get_provider().stream_max_retries();
let mut retries = 0;
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
approval_policy: turn_context.approval_policy,
sandbox_policy: turn_context.sandbox_policy.clone(),
model: turn_context.client.get_model(),
effort: turn_context.client.get_reasoning_effort(),
summary: turn_context.client.get_reasoning_summary(),
});
sess.persist_rollout_items(&[rollout_item]).await;
loop {
let turn_input = history.get_history_for_prompt();
let prompt = Prompt {
input: turn_input.clone(),
..Default::default()
};
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
match attempt_result {
Ok(()) => {
if truncated_count > 0 {
sess.notify_background_event(
turn_context.as_ref(),
format!(
"Trimmed {truncated_count} older conversation item(s) before compacting so the prompt fits the model context window."
),
)
.await;
}
break;
}
Err(LlmxErr::Interrupted) => {
return;
}
Err(e @ LlmxErr::ContextWindowExceeded) => {
if turn_input.len() > 1 {
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
error!(
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
);
history.remove_first_item();
truncated_count += 1;
retries = 0;
continue;
}
sess.set_total_tokens_full(turn_context.as_ref()).await;
let event = EventMsg::Error(ErrorEvent {
message: e.to_string(),
});
sess.send_event(&turn_context, event).await;
return;
}
Err(e) => {
if retries < max_retries {
retries += 1;
let delay = backoff(retries);
sess.notify_stream_error(
turn_context.as_ref(),
format!("Reconnecting... {retries}/{max_retries}"),
)
.await;
tokio::time::sleep(delay).await;
continue;
} else {
let event = EventMsg::Error(ErrorEvent {
message: e.to_string(),
});
sess.send_event(&turn_context, event).await;
return;
}
}
}
}
let history_snapshot = sess.clone_history().await.get_history();
let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default();
let user_messages = collect_user_messages(&history_snapshot);
let initial_context = sess.build_initial_context(turn_context.as_ref());
let mut new_history = build_compacted_history(initial_context, &user_messages, &summary_text);
let ghost_snapshots: Vec<ResponseItem> = history_snapshot
.iter()
.filter(|item| matches!(item, ResponseItem::GhostSnapshot { .. }))
.cloned()
.collect();
new_history.extend(ghost_snapshots);
sess.replace_history(new_history).await;
let rollout_item = RolloutItem::Compacted(CompactedItem {
message: summary_text.clone(),
});
sess.persist_rollout_items(&[rollout_item]).await;
let event = EventMsg::AgentMessage(AgentMessageEvent {
message: "Compact task completed".to_string(),
});
sess.send_event(&turn_context, event).await;
let warning = EventMsg::Warning(WarningEvent {
message: "Heads up: Long conversations and multiple compactions can cause the model to be less accurate. Start a new conversation when possible to keep conversations small and targeted.".to_string(),
});
sess.send_event(&turn_context, warning).await;
}
pub fn content_items_to_text(content: &[ContentItem]) -> Option<String> {
let mut pieces = Vec::new();
for item in content {
match item {
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
if !text.is_empty() {
pieces.push(text.as_str());
}
}
ContentItem::InputImage { .. } => {}
}
}
if pieces.is_empty() {
None
} else {
Some(pieces.join("\n"))
}
}
pub(crate) fn collect_user_messages(items: &[ResponseItem]) -> Vec<String> {
items
.iter()
.filter_map(|item| match crate::event_mapping::parse_turn_item(item) {
Some(TurnItem::UserMessage(user)) => Some(user.message()),
_ => None,
})
.collect()
}
pub(crate) fn build_compacted_history(
initial_context: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
) -> Vec<ResponseItem> {
build_compacted_history_with_limit(
initial_context,
user_messages,
summary_text,
COMPACT_USER_MESSAGE_MAX_TOKENS * 4,
)
}
fn build_compacted_history_with_limit(
mut history: Vec<ResponseItem>,
user_messages: &[String],
summary_text: &str,
max_bytes: usize,
) -> Vec<ResponseItem> {
let mut selected_messages: Vec<String> = Vec::new();
if max_bytes > 0 {
let mut remaining = max_bytes;
for message in user_messages.iter().rev() {
if remaining == 0 {
break;
}
if message.len() <= remaining {
selected_messages.push(message.clone());
remaining = remaining.saturating_sub(message.len());
} else {
let (truncated, _) = truncate_middle(message, remaining);
selected_messages.push(truncated);
break;
}
}
selected_messages.reverse();
}
for message in &selected_messages {
history.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: message.clone(),
}],
});
}
let summary_text = if summary_text.is_empty() {
"(no summary available)".to_string()
} else {
summary_text.to_string()
};
history.push(ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: summary_text }],
});
history
}
async fn drain_to_completed(
sess: &Session,
turn_context: &TurnContext,
prompt: &Prompt,
) -> LlmxResult<()> {
let mut stream = turn_context.client.clone().stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
let Some(event) = maybe_event else {
return Err(LlmxErr::Stream(
"stream closed before response.completed".into(),
None,
));
};
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
sess.record_into_history(std::slice::from_ref(&item)).await;
}
Ok(ResponseEvent::RateLimits(snapshot)) => {
sess.update_rate_limits(turn_context, snapshot).await;
}
Ok(ResponseEvent::Completed { token_usage, .. }) => {
sess.update_token_usage_info(turn_context, token_usage.as_ref())
.await;
return Ok(());
}
Ok(_) => continue,
Err(e) => return Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn content_items_to_text_joins_non_empty_segments() {
let items = vec![
ContentItem::InputText {
text: "hello".to_string(),
},
ContentItem::OutputText {
text: String::new(),
},
ContentItem::OutputText {
text: "world".to_string(),
},
];
let joined = content_items_to_text(&items);
assert_eq!(Some("hello\nworld".to_string()), joined);
}
#[test]
fn content_items_to_text_ignores_image_only_content() {
let items = vec![ContentItem::InputImage {
image_url: "file://image.png".to_string(),
}];
let joined = content_items_to_text(&items);
assert_eq!(None, joined);
}
#[test]
fn collect_user_messages_extracts_user_text_only() {
let items = vec![
ResponseItem::Message {
id: Some("assistant".to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "ignored".to_string(),
}],
},
ResponseItem::Message {
id: Some("user".to_string()),
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "first".to_string(),
}],
},
ResponseItem::Other,
];
let collected = collect_user_messages(&items);
assert_eq!(vec!["first".to_string()], collected);
}
#[test]
fn collect_user_messages_filters_session_prefix_entries() {
let items = vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "# AGENTS.md instructions for project\n\n<INSTRUCTIONS>\ndo things\n</INSTRUCTIONS>"
.to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<ENVIRONMENT_CONTEXT>cwd=/tmp</ENVIRONMENT_CONTEXT>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "real user message".to_string(),
}],
},
];
let collected = collect_user_messages(&items);
assert_eq!(vec!["real user message".to_string()], collected);
}
#[test]
fn build_compacted_history_truncates_overlong_user_messages() {
// Use a small truncation limit so the test remains fast while still validating
// that oversized user content is truncated.
let max_bytes = 128;
let big = "X".repeat(max_bytes + 50);
let history = super::build_compacted_history_with_limit(
Vec::new(),
std::slice::from_ref(&big),
"SUMMARY",
max_bytes,
);
assert_eq!(history.len(), 2);
let truncated_message = &history[0];
let summary_message = &history[1];
let truncated_text = match truncated_message {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("unexpected item in history: {other:?}"),
};
assert!(
truncated_text.contains("tokens truncated"),
"expected truncation marker in truncated user message"
);
assert!(
!truncated_text.contains(&big),
"truncated user message should not include the full oversized user text"
);
let summary_text = match summary_message {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("unexpected item in history: {other:?}"),
};
assert_eq!(summary_text, "SUMMARY");
}
#[test]
fn build_compacted_history_appends_summary_message() {
let initial_context: Vec<ResponseItem> = Vec::new();
let user_messages = vec!["first user message".to_string()];
let summary_text = "summary text";
let history = build_compacted_history(initial_context, &user_messages, summary_text);
assert!(
!history.is_empty(),
"expected compacted history to include summary"
);
let last = history.last().expect("history should have a summary entry");
let summary = match last {
ResponseItem::Message { role, content, .. } if role == "user" => {
content_items_to_text(content).unwrap_or_default()
}
other => panic!("expected summary message, found {other:?}"),
};
assert_eq!(summary, summary_text);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,50 @@
use serde::Deserialize;
use std::path::PathBuf;
use crate::protocol::AskForApproval;
use llmx_protocol::config_types::ReasoningEffort;
use llmx_protocol::config_types::ReasoningSummary;
use llmx_protocol::config_types::SandboxMode;
use llmx_protocol::config_types::Verbosity;
/// Collection of common configuration options that a user can define as a unit
/// in `config.toml`.
#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
pub struct ConfigProfile {
pub model: Option<String>,
/// The key in the `model_providers` map identifying the
/// [`ModelProviderInfo`] to use.
pub model_provider: Option<String>,
pub approval_policy: Option<AskForApproval>,
pub sandbox_mode: Option<SandboxMode>,
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: Option<ReasoningSummary>,
pub model_verbosity: Option<Verbosity>,
pub chatgpt_base_url: Option<String>,
pub experimental_instructions_file: Option<PathBuf>,
pub experimental_compact_prompt_file: Option<PathBuf>,
pub include_apply_patch_tool: Option<bool>,
pub experimental_use_unified_exec_tool: Option<bool>,
pub experimental_use_rmcp_client: Option<bool>,
pub experimental_use_freeform_apply_patch: Option<bool>,
pub experimental_sandbox_command_assessment: Option<bool>,
pub tools_web_search: Option<bool>,
pub tools_view_image: Option<bool>,
/// Optional feature toggles scoped to this profile.
#[serde(default)]
pub features: Option<crate::features::FeaturesToml>,
}
impl From<ConfigProfile> for llmx_app_server_protocol::Profile {
fn from(config_profile: ConfigProfile) -> Self {
Self {
model: config_profile.model,
model_provider: config_profile.model_provider,
approval_policy: config_profile.approval_policy,
model_reasoning_effort: config_profile.model_reasoning_effort,
model_reasoning_summary: config_profile.model_reasoning_summary,
model_verbosity: config_profile.model_verbosity,
chatgpt_base_url: config_profile.chatgpt_base_url,
}
}
}

View File

@@ -0,0 +1,771 @@
//! Types used to define the fields of [`crate::config::Config`].
// Note this file should generally be restricted to simple struct/enum
// definitions that do not contain business logic.
use serde::Deserializer;
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;
use serde::Deserialize;
use serde::Serialize;
use serde::de::Error as SerdeError;
pub const DEFAULT_OTEL_ENVIRONMENT: &str = "dev";
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
#[serde(flatten)]
pub transport: McpServerTransportConfig,
/// When `false`, Llmx skips initializing this MCP server.
#[serde(default = "default_enabled")]
pub enabled: bool,
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
#[serde(
default,
with = "option_duration_secs",
skip_serializing_if = "Option::is_none"
)]
pub startup_timeout_sec: Option<Duration>,
/// Default timeout for MCP tool calls initiated via this server.
#[serde(default, with = "option_duration_secs")]
pub tool_timeout_sec: Option<Duration>,
/// Explicit allow-list of tools exposed from this server. When set, only these tools will be registered.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub enabled_tools: Option<Vec<String>>,
/// Explicit deny-list of tools. These tools will be removed after applying `enabled_tools`.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub disabled_tools: Option<Vec<String>>,
}
impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize, Clone)]
struct RawMcpServerConfig {
// stdio
command: Option<String>,
#[serde(default)]
args: Option<Vec<String>>,
#[serde(default)]
env: Option<HashMap<String, String>>,
#[serde(default)]
env_vars: Option<Vec<String>>,
#[serde(default)]
cwd: Option<PathBuf>,
http_headers: Option<HashMap<String, String>>,
#[serde(default)]
env_http_headers: Option<HashMap<String, String>>,
// streamable_http
url: Option<String>,
bearer_token: Option<String>,
bearer_token_env_var: Option<String>,
// shared
#[serde(default)]
startup_timeout_sec: Option<f64>,
#[serde(default)]
startup_timeout_ms: Option<u64>,
#[serde(default, with = "option_duration_secs")]
tool_timeout_sec: Option<Duration>,
#[serde(default)]
enabled: Option<bool>,
#[serde(default)]
enabled_tools: Option<Vec<String>>,
#[serde(default)]
disabled_tools: Option<Vec<String>>,
}
let mut raw = RawMcpServerConfig::deserialize(deserializer)?;
let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
(Some(sec), _) => {
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
Some(duration)
}
(None, Some(ms)) => Some(Duration::from_millis(ms)),
(None, None) => None,
};
let tool_timeout_sec = raw.tool_timeout_sec;
let enabled = raw.enabled.unwrap_or_else(default_enabled);
let enabled_tools = raw.enabled_tools.clone();
let disabled_tools = raw.disabled_tools.clone();
fn throw_if_set<E, T>(transport: &str, field: &str, value: Option<&T>) -> Result<(), E>
where
E: SerdeError,
{
if value.is_none() {
return Ok(());
}
Err(E::custom(format!(
"{field} is not supported for {transport}",
)))
}
let transport = if let Some(command) = raw.command.clone() {
throw_if_set("stdio", "url", raw.url.as_ref())?;
throw_if_set(
"stdio",
"bearer_token_env_var",
raw.bearer_token_env_var.as_ref(),
)?;
throw_if_set("stdio", "bearer_token", raw.bearer_token.as_ref())?;
throw_if_set("stdio", "http_headers", raw.http_headers.as_ref())?;
throw_if_set("stdio", "env_http_headers", raw.env_http_headers.as_ref())?;
McpServerTransportConfig::Stdio {
command,
args: raw.args.clone().unwrap_or_default(),
env: raw.env.clone(),
env_vars: raw.env_vars.clone().unwrap_or_default(),
cwd: raw.cwd.take(),
}
} else if let Some(url) = raw.url.clone() {
throw_if_set("streamable_http", "args", raw.args.as_ref())?;
throw_if_set("streamable_http", "env", raw.env.as_ref())?;
throw_if_set("streamable_http", "env_vars", raw.env_vars.as_ref())?;
throw_if_set("streamable_http", "cwd", raw.cwd.as_ref())?;
throw_if_set("streamable_http", "bearer_token", raw.bearer_token.as_ref())?;
McpServerTransportConfig::StreamableHttp {
url,
bearer_token_env_var: raw.bearer_token_env_var.clone(),
http_headers: raw.http_headers.clone(),
env_http_headers: raw.env_http_headers.take(),
}
} else {
return Err(SerdeError::custom("invalid transport"));
};
Ok(Self {
transport,
startup_timeout_sec,
tool_timeout_sec,
enabled,
enabled_tools,
disabled_tools,
})
}
}
const fn default_enabled() -> bool {
true
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(untagged, deny_unknown_fields, rename_all = "snake_case")]
pub enum McpServerTransportConfig {
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#stdio
Stdio {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
env: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
env_vars: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
cwd: Option<PathBuf>,
},
/// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http
StreamableHttp {
url: String,
/// Name of the environment variable to read for an HTTP bearer token.
/// When set, requests will include the token via `Authorization: Bearer <token>`.
/// The actual secret value must be provided via the environment.
#[serde(default, skip_serializing_if = "Option::is_none")]
bearer_token_env_var: Option<String>,
/// Additional HTTP headers to include in requests to this server.
#[serde(default, skip_serializing_if = "Option::is_none")]
http_headers: Option<HashMap<String, String>>,
/// HTTP headers where the value is sourced from an environment variable.
#[serde(default, skip_serializing_if = "Option::is_none")]
env_http_headers: Option<HashMap<String, String>>,
},
}
mod option_duration_secs {
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
use std::time::Duration;
pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
None => serializer.serialize_none(),
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let secs = Option::<f64>::deserialize(deserializer)?;
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
.transpose()
}
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
pub enum UriBasedFileOpener {
#[serde(rename = "vscode")]
VsCode,
#[serde(rename = "vscode-insiders")]
VsCodeInsiders,
#[serde(rename = "windsurf")]
Windsurf,
#[serde(rename = "cursor")]
Cursor,
/// Option to disable the URI-based file opener.
#[serde(rename = "none")]
None,
}
impl UriBasedFileOpener {
pub fn get_scheme(&self) -> Option<&str> {
match self {
UriBasedFileOpener::VsCode => Some("vscode"),
UriBasedFileOpener::VsCodeInsiders => Some("vscode-insiders"),
UriBasedFileOpener::Windsurf => Some("windsurf"),
UriBasedFileOpener::Cursor => Some("cursor"),
UriBasedFileOpener::None => None,
}
}
}
/// Settings that govern if and what will be written to `~/.llmx/history.jsonl`.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct History {
/// If true, history entries will not be written to disk.
pub persistence: HistoryPersistence,
/// If set, the maximum size of the history file in bytes.
/// TODO(mbolin): Not currently honored.
pub max_bytes: Option<usize>,
}
#[derive(Deserialize, Debug, Copy, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum HistoryPersistence {
/// Save all history entries to disk.
#[default]
SaveAll,
/// Do not write history to disk.
None,
}
// ===== OTEL configuration =====
#[derive(Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "kebab-case")]
pub enum OtelHttpProtocol {
/// Binary payload
Binary,
/// JSON payload
Json,
}
/// Which OTEL exporter to use.
#[derive(Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "kebab-case")]
pub enum OtelExporterKind {
None,
OtlpHttp {
endpoint: String,
headers: HashMap<String, String>,
protocol: OtelHttpProtocol,
},
OtlpGrpc {
endpoint: String,
headers: HashMap<String, String>,
},
}
/// OTEL settings loaded from config.toml. Fields are optional so we can apply defaults.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct OtelConfigToml {
/// Log user prompt in traces
pub log_user_prompt: Option<bool>,
/// Mark traces with environment (dev, staging, prod, test). Defaults to dev.
pub environment: Option<String>,
/// Exporter to use. Defaults to `otlp-file`.
pub exporter: Option<OtelExporterKind>,
}
/// Effective OTEL settings after defaults are applied.
#[derive(Debug, Clone, PartialEq)]
pub struct OtelConfig {
pub log_user_prompt: bool,
pub environment: String,
pub exporter: OtelExporterKind,
}
impl Default for OtelConfig {
fn default() -> Self {
OtelConfig {
log_user_prompt: false,
environment: DEFAULT_OTEL_ENVIRONMENT.to_owned(),
exporter: OtelExporterKind::None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
#[serde(untagged)]
pub enum Notifications {
Enabled(bool),
Custom(Vec<String>),
}
impl Default for Notifications {
fn default() -> Self {
Self::Enabled(false)
}
}
/// Collection of settings that are specific to the TUI.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Tui {
/// Enable desktop notifications from the TUI when the terminal is unfocused.
/// Defaults to `false`.
#[serde(default)]
pub notifications: Notifications,
}
/// Settings for notices we display to users via the tui and app-server clients
/// (primarily the Llmx IDE extension). NOTE: these are different from
/// notifications - notices are warnings, NUX screens, acknowledgements, etc.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Notice {
/// Tracks whether the user has acknowledged the full access warning prompt.
pub hide_full_access_warning: Option<bool>,
/// Tracks whether the user has acknowledged the Windows world-writable directories warning.
pub hide_world_writable_warning: Option<bool>,
/// Tracks whether the user opted out of the rate limit model switch reminder.
pub hide_rate_limit_model_nudge: Option<bool>,
}
impl Notice {
/// referenced by config_edit helpers when writing notice flags
pub(crate) const TABLE_KEY: &'static str = "notice";
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct SandboxWorkspaceWrite {
#[serde(default)]
pub writable_roots: Vec<PathBuf>,
#[serde(default)]
pub network_access: bool,
#[serde(default)]
pub exclude_tmpdir_env_var: bool,
#[serde(default)]
pub exclude_slash_tmp: bool,
}
impl From<SandboxWorkspaceWrite> for llmx_app_server_protocol::SandboxSettings {
fn from(sandbox_workspace_write: SandboxWorkspaceWrite) -> Self {
Self {
writable_roots: sandbox_workspace_write.writable_roots,
network_access: Some(sandbox_workspace_write.network_access),
exclude_tmpdir_env_var: Some(sandbox_workspace_write.exclude_tmpdir_env_var),
exclude_slash_tmp: Some(sandbox_workspace_write.exclude_slash_tmp),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
#[serde(rename_all = "kebab-case")]
pub enum ShellEnvironmentPolicyInherit {
/// "Core" environment variables for the platform. On UNIX, this would
/// include HOME, LOGNAME, PATH, SHELL, and USER, among others.
Core,
/// Inherits the full environment from the parent process.
#[default]
All,
/// Do not inherit any environment variables from the parent process.
None,
}
/// Policy for building the `env` when spawning a process via either the
/// `shell` or `local_shell` tool.
#[derive(Deserialize, Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicyToml {
pub inherit: Option<ShellEnvironmentPolicyInherit>,
pub ignore_default_excludes: Option<bool>,
/// List of regular expressions.
pub exclude: Option<Vec<String>>,
pub r#set: Option<HashMap<String, String>>,
/// List of regular expressions.
pub include_only: Option<Vec<String>>,
pub experimental_use_profile: Option<bool>,
}
pub type EnvironmentVariablePattern = WildMatchPattern<'*', '?'>;
/// Deriving the `env` based on this policy works as follows:
/// 1. Create an initial map based on the `inherit` policy.
/// 2. If `ignore_default_excludes` is false, filter the map using the default
/// exclude pattern(s), which are: `"*KEY*"` and `"*TOKEN*"`.
/// 3. If `exclude` is not empty, filter the map using the provided patterns.
/// 4. Insert any entries from `r#set` into the map.
/// 5. If non-empty, filter the map using the `include_only` patterns.
#[derive(Debug, Clone, PartialEq, Default)]
pub struct ShellEnvironmentPolicy {
/// Starting point when building the environment.
pub inherit: ShellEnvironmentPolicyInherit,
/// True to skip the check to exclude default environment variables that
/// contain "KEY" or "TOKEN" in their name.
pub ignore_default_excludes: bool,
/// Environment variable names to exclude from the environment.
pub exclude: Vec<EnvironmentVariablePattern>,
/// (key, value) pairs to insert in the environment.
pub r#set: HashMap<String, String>,
/// Environment variable names to retain in the environment.
pub include_only: Vec<EnvironmentVariablePattern>,
/// If true, the shell profile will be used to run the command.
pub use_profile: bool,
}
impl From<ShellEnvironmentPolicyToml> for ShellEnvironmentPolicy {
fn from(toml: ShellEnvironmentPolicyToml) -> Self {
// Default to inheriting the full environment when not specified.
let inherit = toml.inherit.unwrap_or(ShellEnvironmentPolicyInherit::All);
let ignore_default_excludes = toml.ignore_default_excludes.unwrap_or(false);
let exclude = toml
.exclude
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let r#set = toml.r#set.unwrap_or_default();
let include_only = toml
.include_only
.unwrap_or_default()
.into_iter()
.map(|s| EnvironmentVariablePattern::new_case_insensitive(&s))
.collect();
let use_profile = toml.experimental_use_profile.unwrap_or(false);
Self {
inherit,
ignore_default_excludes,
exclude,
r#set,
include_only,
use_profile,
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq, Eq, Default, Hash)]
#[serde(rename_all = "kebab-case")]
pub enum ReasoningSummaryFormat {
#[default]
None,
Experimental,
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn deserialize_stdio_command_server_config() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
"#,
)
.expect("should deserialize command config");
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec![],
env: None,
env_vars: Vec::new(),
cwd: None,
}
);
assert!(cfg.enabled);
assert!(cfg.enabled_tools.is_none());
assert!(cfg.disabled_tools.is_none());
}
#[test]
fn deserialize_stdio_command_server_config_with_args() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
args = ["hello", "world"]
"#,
)
.expect("should deserialize command config");
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec!["hello".to_string(), "world".to_string()],
env: None,
env_vars: Vec::new(),
cwd: None,
}
);
assert!(cfg.enabled);
}
#[test]
fn deserialize_stdio_command_server_config_with_arg_with_args_and_env() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
args = ["hello", "world"]
env = { "FOO" = "BAR" }
"#,
)
.expect("should deserialize command config");
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec!["hello".to_string(), "world".to_string()],
env: Some(HashMap::from([("FOO".to_string(), "BAR".to_string())])),
env_vars: Vec::new(),
cwd: None,
}
);
assert!(cfg.enabled);
}
#[test]
fn deserialize_stdio_command_server_config_with_env_vars() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
env_vars = ["FOO", "BAR"]
"#,
)
.expect("should deserialize command config with env_vars");
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec![],
env: None,
env_vars: vec!["FOO".to_string(), "BAR".to_string()],
cwd: None,
}
);
}
#[test]
fn deserialize_stdio_command_server_config_with_cwd() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
cwd = "/tmp"
"#,
)
.expect("should deserialize command config with cwd");
assert_eq!(
cfg.transport,
McpServerTransportConfig::Stdio {
command: "echo".to_string(),
args: vec![],
env: None,
env_vars: Vec::new(),
cwd: Some(PathBuf::from("/tmp")),
}
);
}
#[test]
fn deserialize_disabled_server_config() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
enabled = false
"#,
)
.expect("should deserialize disabled server config");
assert!(!cfg.enabled);
}
#[test]
fn deserialize_streamable_http_server_config() {
let cfg: McpServerConfig = toml::from_str(
r#"
url = "https://example.com/mcp"
"#,
)
.expect("should deserialize http config");
assert_eq!(
cfg.transport,
McpServerTransportConfig::StreamableHttp {
url: "https://example.com/mcp".to_string(),
bearer_token_env_var: None,
http_headers: None,
env_http_headers: None,
}
);
assert!(cfg.enabled);
}
#[test]
fn deserialize_streamable_http_server_config_with_env_var() {
let cfg: McpServerConfig = toml::from_str(
r#"
url = "https://example.com/mcp"
bearer_token_env_var = "GITHUB_TOKEN"
"#,
)
.expect("should deserialize http config");
assert_eq!(
cfg.transport,
McpServerTransportConfig::StreamableHttp {
url: "https://example.com/mcp".to_string(),
bearer_token_env_var: Some("GITHUB_TOKEN".to_string()),
http_headers: None,
env_http_headers: None,
}
);
assert!(cfg.enabled);
}
#[test]
fn deserialize_streamable_http_server_config_with_headers() {
let cfg: McpServerConfig = toml::from_str(
r#"
url = "https://example.com/mcp"
http_headers = { "X-Foo" = "bar" }
env_http_headers = { "X-Token" = "TOKEN_ENV" }
"#,
)
.expect("should deserialize http config with headers");
assert_eq!(
cfg.transport,
McpServerTransportConfig::StreamableHttp {
url: "https://example.com/mcp".to_string(),
bearer_token_env_var: None,
http_headers: Some(HashMap::from([("X-Foo".to_string(), "bar".to_string())])),
env_http_headers: Some(HashMap::from([(
"X-Token".to_string(),
"TOKEN_ENV".to_string()
)])),
}
);
}
#[test]
fn deserialize_server_config_with_tool_filters() {
let cfg: McpServerConfig = toml::from_str(
r#"
command = "echo"
enabled_tools = ["allowed"]
disabled_tools = ["blocked"]
"#,
)
.expect("should deserialize tool filters");
assert_eq!(cfg.enabled_tools, Some(vec!["allowed".to_string()]));
assert_eq!(cfg.disabled_tools, Some(vec!["blocked".to_string()]));
}
#[test]
fn deserialize_rejects_command_and_url() {
toml::from_str::<McpServerConfig>(
r#"
command = "echo"
url = "https://example.com"
"#,
)
.expect_err("should reject command+url");
}
#[test]
fn deserialize_rejects_env_for_http_transport() {
toml::from_str::<McpServerConfig>(
r#"
url = "https://example.com"
env = { "FOO" = "BAR" }
"#,
)
.expect_err("should reject env for http transport");
}
#[test]
fn deserialize_rejects_headers_for_stdio() {
toml::from_str::<McpServerConfig>(
r#"
command = "echo"
http_headers = { "X-Foo" = "bar" }
"#,
)
.expect_err("should reject http_headers for stdio transport");
toml::from_str::<McpServerConfig>(
r#"
command = "echo"
env_http_headers = { "X-Foo" = "BAR_ENV" }
"#,
)
.expect_err("should reject env_http_headers for stdio transport");
}
#[test]
fn deserialize_rejects_inline_bearer_token_field() {
let err = toml::from_str::<McpServerConfig>(
r#"
url = "https://example.com"
bearer_token = "secret"
"#,
)
.expect_err("should reject bearer_token field");
assert!(
err.to_string().contains("bearer_token is not supported"),
"unexpected error: {err}"
);
}
}

View File

@@ -0,0 +1,118 @@
use std::io;
use toml::Value as TomlValue;
#[cfg(target_os = "macos")]
mod native {
use super::*;
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use core_foundation::base::TCFType;
use core_foundation::string::CFString;
use core_foundation::string::CFStringRef;
use std::ffi::c_void;
use tokio::task;
pub(crate) async fn load_managed_admin_config_layer(
override_base64: Option<&str>,
) -> io::Result<Option<TomlValue>> {
if let Some(encoded) = override_base64 {
let trimmed = encoded.trim();
return if trimmed.is_empty() {
Ok(None)
} else {
parse_managed_preferences_base64(trimmed).map(Some)
};
}
const LOAD_ERROR: &str = "Failed to load managed preferences configuration";
match task::spawn_blocking(load_managed_admin_config).await {
Ok(result) => result,
Err(join_err) => {
if join_err.is_cancelled() {
tracing::error!("Managed preferences load task was cancelled");
} else {
tracing::error!("Managed preferences load task failed: {join_err}");
}
Err(io::Error::other(LOAD_ERROR))
}
}
}
pub(super) fn load_managed_admin_config() -> io::Result<Option<TomlValue>> {
#[link(name = "CoreFoundation", kind = "framework")]
unsafe extern "C" {
fn CFPreferencesCopyAppValue(
key: CFStringRef,
application_id: CFStringRef,
) -> *mut c_void;
}
const MANAGED_PREFERENCES_APPLICATION_ID: &str = "com.openai.llmx";
const MANAGED_PREFERENCES_CONFIG_KEY: &str = "config_toml_base64";
let application_id = CFString::new(MANAGED_PREFERENCES_APPLICATION_ID);
let key = CFString::new(MANAGED_PREFERENCES_CONFIG_KEY);
let value_ref = unsafe {
CFPreferencesCopyAppValue(
key.as_concrete_TypeRef(),
application_id.as_concrete_TypeRef(),
)
};
if value_ref.is_null() {
tracing::debug!(
"Managed preferences for {} key {} not found",
MANAGED_PREFERENCES_APPLICATION_ID,
MANAGED_PREFERENCES_CONFIG_KEY
);
return Ok(None);
}
let value = unsafe { CFString::wrap_under_create_rule(value_ref as _) };
let contents = value.to_string();
let trimmed = contents.trim();
parse_managed_preferences_base64(trimmed).map(Some)
}
pub(super) fn parse_managed_preferences_base64(encoded: &str) -> io::Result<TomlValue> {
let decoded = BASE64_STANDARD.decode(encoded.as_bytes()).map_err(|err| {
tracing::error!("Failed to decode managed preferences as base64: {err}");
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
let decoded_str = String::from_utf8(decoded).map_err(|err| {
tracing::error!("Managed preferences base64 contents were not valid UTF-8: {err}");
io::Error::new(io::ErrorKind::InvalidData, err)
})?;
match toml::from_str::<TomlValue>(&decoded_str) {
Ok(TomlValue::Table(parsed)) => Ok(TomlValue::Table(parsed)),
Ok(other) => {
tracing::error!(
"Managed preferences TOML must have a table at the root, found {other:?}",
);
Err(io::Error::new(
io::ErrorKind::InvalidData,
"managed preferences root must be a table",
))
}
Err(err) => {
tracing::error!("Failed to parse managed preferences TOML: {err}");
Err(io::Error::new(io::ErrorKind::InvalidData, err))
}
}
}
}
#[cfg(target_os = "macos")]
pub(crate) use native::load_managed_admin_config_layer;
#[cfg(not(target_os = "macos"))]
pub(crate) async fn load_managed_admin_config_layer(
_override_base64: Option<&str>,
) -> io::Result<Option<TomlValue>> {
Ok(None)
}

View File

@@ -0,0 +1,311 @@
mod macos;
use crate::config::CONFIG_TOML_FILE;
use macos::load_managed_admin_config_layer;
use std::io;
use std::path::Path;
use std::path::PathBuf;
use tokio::fs;
use toml::Value as TomlValue;
#[cfg(unix)]
const LLMX_MANAGED_CONFIG_SYSTEM_PATH: &str = "/etc/llmx/managed_config.toml";
#[derive(Debug)]
pub(crate) struct LoadedConfigLayers {
pub base: TomlValue,
pub managed_config: Option<TomlValue>,
pub managed_preferences: Option<TomlValue>,
}
#[derive(Debug, Default)]
pub(crate) struct LoaderOverrides {
pub managed_config_path: Option<PathBuf>,
#[cfg(target_os = "macos")]
pub managed_preferences_base64: Option<String>,
}
// Configuration layering pipeline (top overrides bottom):
//
// +-------------------------+
// | Managed preferences (*) |
// +-------------------------+
// ^
// |
// +-------------------------+
// | managed_config.toml |
// +-------------------------+
// ^
// |
// +-------------------------+
// | config.toml (base) |
// +-------------------------+
//
// (*) Only available on macOS via managed device profiles.
pub async fn load_config_as_toml(llmx_home: &Path) -> io::Result<TomlValue> {
load_config_as_toml_with_overrides(llmx_home, LoaderOverrides::default()).await
}
fn default_empty_table() -> TomlValue {
TomlValue::Table(Default::default())
}
pub(crate) async fn load_config_layers_with_overrides(
llmx_home: &Path,
overrides: LoaderOverrides,
) -> io::Result<LoadedConfigLayers> {
load_config_layers_internal(llmx_home, overrides).await
}
async fn load_config_as_toml_with_overrides(
llmx_home: &Path,
overrides: LoaderOverrides,
) -> io::Result<TomlValue> {
let layers = load_config_layers_internal(llmx_home, overrides).await?;
Ok(apply_managed_layers(layers))
}
async fn load_config_layers_internal(
llmx_home: &Path,
overrides: LoaderOverrides,
) -> io::Result<LoadedConfigLayers> {
#[cfg(target_os = "macos")]
let LoaderOverrides {
managed_config_path,
managed_preferences_base64,
} = overrides;
#[cfg(not(target_os = "macos"))]
let LoaderOverrides {
managed_config_path,
} = overrides;
let managed_config_path =
managed_config_path.unwrap_or_else(|| managed_config_default_path(llmx_home));
let user_config_path = llmx_home.join(CONFIG_TOML_FILE);
let user_config = read_config_from_path(&user_config_path, true).await?;
let managed_config = read_config_from_path(&managed_config_path, false).await?;
#[cfg(target_os = "macos")]
let managed_preferences =
load_managed_admin_config_layer(managed_preferences_base64.as_deref()).await?;
#[cfg(not(target_os = "macos"))]
let managed_preferences = load_managed_admin_config_layer(None).await?;
Ok(LoadedConfigLayers {
base: user_config.unwrap_or_else(default_empty_table),
managed_config,
managed_preferences,
})
}
async fn read_config_from_path(
path: &Path,
log_missing_as_info: bool,
) -> io::Result<Option<TomlValue>> {
match fs::read_to_string(path).await {
Ok(contents) => match toml::from_str::<TomlValue>(&contents) {
Ok(value) => Ok(Some(value)),
Err(err) => {
tracing::error!("Failed to parse {}: {err}", path.display());
Err(io::Error::new(io::ErrorKind::InvalidData, err))
}
},
Err(err) if err.kind() == io::ErrorKind::NotFound => {
if log_missing_as_info {
tracing::info!("{} not found, using defaults", path.display());
} else {
tracing::debug!("{} not found", path.display());
}
Ok(None)
}
Err(err) => {
tracing::error!("Failed to read {}: {err}", path.display());
Err(err)
}
}
}
/// Merge config `overlay` into `base`, giving `overlay` precedence.
pub(crate) fn merge_toml_values(base: &mut TomlValue, overlay: &TomlValue) {
if let TomlValue::Table(overlay_table) = overlay
&& let TomlValue::Table(base_table) = base
{
for (key, value) in overlay_table {
if let Some(existing) = base_table.get_mut(key) {
merge_toml_values(existing, value);
} else {
base_table.insert(key.clone(), value.clone());
}
}
} else {
*base = overlay.clone();
}
}
fn managed_config_default_path(llmx_home: &Path) -> PathBuf {
#[cfg(unix)]
{
let _ = llmx_home;
PathBuf::from(LLMX_MANAGED_CONFIG_SYSTEM_PATH)
}
#[cfg(not(unix))]
{
llmx_home.join("managed_config.toml")
}
}
fn apply_managed_layers(layers: LoadedConfigLayers) -> TomlValue {
let LoadedConfigLayers {
mut base,
managed_config,
managed_preferences,
} = layers;
for overlay in [managed_config, managed_preferences].into_iter().flatten() {
merge_toml_values(&mut base, &overlay);
}
base
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn merges_managed_config_layer_on_top() {
let tmp = tempdir().expect("tempdir");
let managed_path = tmp.path().join("managed_config.toml");
std::fs::write(
tmp.path().join(CONFIG_TOML_FILE),
r#"foo = 1
[nested]
value = "base"
"#,
)
.expect("write base");
std::fs::write(
&managed_path,
r#"foo = 2
[nested]
value = "managed_config"
extra = true
"#,
)
.expect("write managed config");
let overrides = LoaderOverrides {
managed_config_path: Some(managed_path),
#[cfg(target_os = "macos")]
managed_preferences_base64: None,
};
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
.await
.expect("load config");
let table = loaded.as_table().expect("top-level table expected");
assert_eq!(table.get("foo"), Some(&TomlValue::Integer(2)));
let nested = table
.get("nested")
.and_then(|v| v.as_table())
.expect("nested");
assert_eq!(
nested.get("value"),
Some(&TomlValue::String("managed_config".to_string()))
);
assert_eq!(nested.get("extra"), Some(&TomlValue::Boolean(true)));
}
#[tokio::test]
async fn returns_empty_when_all_layers_missing() {
let tmp = tempdir().expect("tempdir");
let managed_path = tmp.path().join("managed_config.toml");
let overrides = LoaderOverrides {
managed_config_path: Some(managed_path),
#[cfg(target_os = "macos")]
managed_preferences_base64: None,
};
let layers = load_config_layers_with_overrides(tmp.path(), overrides)
.await
.expect("load layers");
let base_table = layers.base.as_table().expect("base table expected");
assert!(
base_table.is_empty(),
"expected empty base layer when configs missing"
);
assert!(
layers.managed_config.is_none(),
"managed config layer should be absent when file missing"
);
#[cfg(not(target_os = "macos"))]
{
let loaded = load_config_as_toml(tmp.path()).await.expect("load config");
let table = loaded.as_table().expect("top-level table expected");
assert!(
table.is_empty(),
"expected empty table when configs missing"
);
}
}
#[cfg(target_os = "macos")]
#[tokio::test]
async fn managed_preferences_take_highest_precedence() {
use base64::Engine;
let managed_payload = r#"
[nested]
value = "managed"
flag = false
"#;
let encoded = base64::prelude::BASE64_STANDARD.encode(managed_payload.as_bytes());
let tmp = tempdir().expect("tempdir");
let managed_path = tmp.path().join("managed_config.toml");
std::fs::write(
tmp.path().join(CONFIG_TOML_FILE),
r#"[nested]
value = "base"
"#,
)
.expect("write base");
std::fs::write(
&managed_path,
r#"[nested]
value = "managed_config"
flag = true
"#,
)
.expect("write managed config");
let overrides = LoaderOverrides {
managed_config_path: Some(managed_path),
managed_preferences_base64: Some(encoded),
};
let loaded = load_config_as_toml_with_overrides(tmp.path(), overrides)
.await
.expect("load config");
let nested = loaded
.get("nested")
.and_then(|v| v.as_table())
.expect("nested table");
assert_eq!(
nested.get("value"),
Some(&TomlValue::String("managed".to_string()))
);
assert_eq!(nested.get("flag"), Some(&TomlValue::Boolean(false)));
}
}

View File

@@ -0,0 +1,174 @@
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::TokenUsage;
use llmx_protocol::protocol::TokenUsageInfo;
use std::ops::Deref;
use crate::context_manager::normalize;
use crate::context_manager::truncate::format_output_for_model_body;
use crate::context_manager::truncate::globally_truncate_function_output_items;
/// Transcript of conversation history
#[derive(Debug, Clone, Default)]
pub(crate) struct ContextManager {
/// The oldest items are at the beginning of the vector.
items: Vec<ResponseItem>,
token_info: Option<TokenUsageInfo>,
}
impl ContextManager {
pub(crate) fn new() -> Self {
Self {
items: Vec::new(),
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
}
}
pub(crate) fn token_info(&self) -> Option<TokenUsageInfo> {
self.token_info.clone()
}
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
match &mut self.token_info {
Some(info) => info.fill_to_context_window(context_window),
None => {
self.token_info = Some(TokenUsageInfo::full_context_window(context_window));
}
}
}
/// `items` is ordered from oldest to newest.
pub(crate) fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
{
for item in items {
let item_ref = item.deref();
let is_ghost_snapshot = matches!(item_ref, ResponseItem::GhostSnapshot { .. });
if !is_api_message(item_ref) && !is_ghost_snapshot {
continue;
}
let processed = Self::process_item(&item);
self.items.push(processed);
}
}
pub(crate) fn get_history(&mut self) -> Vec<ResponseItem> {
self.normalize_history();
self.contents()
}
// Returns the history prepared for sending to the model.
// With extra response items filtered out and GhostCommits removed.
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
let mut history = self.get_history();
Self::remove_ghost_snapshots(&mut history);
history
}
pub(crate) fn remove_first_item(&mut self) {
if !self.items.is_empty() {
// Remove the oldest item (front of the list). Items are ordered from
// oldest → newest, so index 0 is the first entry recorded.
let removed = self.items.remove(0);
// If the removed item participates in a call/output pair, also remove
// its corresponding counterpart to keep the invariants intact without
// running a full normalization pass.
normalize::remove_corresponding_for(&mut self.items, &removed);
}
}
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
self.items = items;
}
pub(crate) fn update_token_info(
&mut self,
usage: &TokenUsage,
model_context_window: Option<i64>,
) {
self.token_info = TokenUsageInfo::new_or_append(
&self.token_info,
&Some(usage.clone()),
model_context_window,
);
}
/// This function enforces a couple of invariants on the in-memory history:
/// 1. every call (function/custom) has a corresponding output entry
/// 2. every output has a corresponding call entry
fn normalize_history(&mut self) {
// all function/tool calls must have a corresponding output
normalize::ensure_call_outputs_present(&mut self.items);
// all outputs must have a corresponding function/tool call
normalize::remove_orphan_outputs(&mut self.items);
}
/// Returns a clone of the contents in the transcript.
fn contents(&self) -> Vec<ResponseItem> {
self.items.clone()
}
fn remove_ghost_snapshots(items: &mut Vec<ResponseItem>) {
items.retain(|item| !matches!(item, ResponseItem::GhostSnapshot { .. }));
}
fn process_item(item: &ResponseItem) -> ResponseItem {
match item {
ResponseItem::FunctionCallOutput { call_id, output } => {
let truncated = format_output_for_model_body(output.content.as_str());
let truncated_items = output
.content_items
.as_ref()
.map(|items| globally_truncate_function_output_items(items));
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: truncated,
content_items: truncated_items,
success: output.success,
},
}
}
ResponseItem::CustomToolCallOutput { call_id, output } => {
let truncated = format_output_for_model_body(output);
ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
output: truncated,
}
}
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::CustomToolCall { .. }
| ResponseItem::GhostSnapshot { .. }
| ResponseItem::Other => item.clone(),
}
}
}
/// API messages include every non-system item (user/assistant messages, reasoning,
/// tool calls, tool outputs, shell calls, and web-search calls).
fn is_api_message(message: &ResponseItem) -> bool {
match message {
ResponseItem::Message { role, .. } => role.as_str() != "system",
ResponseItem::FunctionCallOutput { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::WebSearchCall { .. } => true,
ResponseItem::GhostSnapshot { .. } => false,
ResponseItem::Other => false,
}
}
#[cfg(test)]
#[path = "history_tests.rs"]
mod tests;

View File

@@ -0,0 +1,841 @@
use super::*;
use crate::context_manager::truncate;
use llmx_git::GhostCommit;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::FunctionCallOutputContentItem;
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::LocalShellAction;
use llmx_protocol::models::LocalShellExecAction;
use llmx_protocol::models::LocalShellStatus;
use llmx_protocol::models::ReasoningItemContent;
use llmx_protocol::models::ReasoningItemReasoningSummary;
use pretty_assertions::assert_eq;
use regex_lite::Regex;
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn create_history_with_items(items: Vec<ResponseItem>) -> ContextManager {
let mut h = ContextManager::new();
h.record_items(items.iter());
h
}
fn user_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn reasoning_msg(text: &str) -> ResponseItem {
ResponseItem::Reasoning {
id: String::new(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "summary".to_string(),
}],
content: Some(vec![ReasoningItemContent::ReasoningText {
text: text.to_string(),
}]),
encrypted_content: None,
}
}
#[test]
fn filters_non_api_messages() {
let mut h = ContextManager::default();
// System message is not API messages; Other is ignored.
let system = ResponseItem::Message {
id: None,
role: "system".to_string(),
content: vec![ContentItem::OutputText {
text: "ignored".to_string(),
}],
};
let reasoning = reasoning_msg("thinking...");
h.record_items([&system, &reasoning, &ResponseItem::Other]);
// User and assistant should be retained.
let u = user_msg("hi");
let a = assistant_msg("hello");
h.record_items([&u, &a]);
let items = h.contents();
assert_eq!(
items,
vec![
ResponseItem::Reasoning {
id: String::new(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "summary".to_string(),
}],
content: Some(vec![ReasoningItemContent::ReasoningText {
text: "thinking...".to_string(),
}]),
encrypted_content: None,
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: "hi".to_string()
}]
},
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "hello".to_string()
}]
}
]
);
}
#[test]
fn get_history_for_prompt_drops_ghost_commits() {
let items = vec![ResponseItem::GhostSnapshot {
ghost_commit: GhostCommit::new("ghost-1".to_string(), None, Vec::new(), Vec::new()),
}];
let mut history = create_history_with_items(items);
let filtered = history.get_history_for_prompt();
assert_eq!(filtered, vec![]);
}
#[test]
fn remove_first_item_removes_matching_output_for_function_call() {
let items = vec![
ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
];
let mut h = create_history_with_items(items);
h.remove_first_item();
assert_eq!(h.contents(), vec![]);
}
#[test]
fn remove_first_item_removes_matching_call_for_output() {
let items = vec![
ResponseItem::FunctionCallOutput {
call_id: "call-2".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-2".to_string(),
},
];
let mut h = create_history_with_items(items);
h.remove_first_item();
assert_eq!(h.contents(), vec![]);
}
#[test]
fn remove_first_item_handles_local_shell_pair() {
let items = vec![
ResponseItem::LocalShellCall {
id: None,
call_id: Some("call-3".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string(), "hi".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
ResponseItem::FunctionCallOutput {
call_id: "call-3".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
];
let mut h = create_history_with_items(items);
h.remove_first_item();
assert_eq!(h.contents(), vec![]);
}
#[test]
fn remove_first_item_handles_custom_tool_pair() {
let items = vec![
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "tool-1".to_string(),
name: "my_tool".to_string(),
input: "{}".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "tool-1".to_string(),
output: "ok".to_string(),
},
];
let mut h = create_history_with_items(items);
h.remove_first_item();
assert_eq!(h.contents(), vec![]);
}
#[test]
fn normalization_retains_local_shell_outputs() {
let items = vec![
ResponseItem::LocalShellCall {
id: None,
call_id: Some("shell-1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string(), "hi".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
ResponseItem::FunctionCallOutput {
call_id: "shell-1".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
];
let mut history = create_history_with_items(items.clone());
let normalized = history.get_history();
assert_eq!(normalized, items);
}
#[test]
fn record_items_truncates_function_call_output_content() {
let mut history = ContextManager::new();
let long_line = "a very long line to trigger truncation\n";
let long_output = long_line.repeat(2_500);
let item = ResponseItem::FunctionCallOutput {
call_id: "call-100".to_string(),
output: FunctionCallOutputPayload {
content: long_output.clone(),
success: Some(true),
..Default::default()
},
};
history.record_items([&item]);
assert_eq!(history.items.len(), 1);
match &history.items[0] {
ResponseItem::FunctionCallOutput { output, .. } => {
assert_ne!(output.content, long_output);
assert!(
output.content.starts_with("Total output lines:"),
"expected truncated summary, got {}",
output.content
);
}
other => panic!("unexpected history item: {other:?}"),
}
}
#[test]
fn record_items_truncates_custom_tool_call_output_content() {
let mut history = ContextManager::new();
let line = "custom output that is very long\n";
let long_output = line.repeat(2_500);
let item = ResponseItem::CustomToolCallOutput {
call_id: "tool-200".to_string(),
output: long_output.clone(),
};
history.record_items([&item]);
assert_eq!(history.items.len(), 1);
match &history.items[0] {
ResponseItem::CustomToolCallOutput { output, .. } => {
assert_ne!(output, &long_output);
assert!(
output.starts_with("Total output lines:"),
"expected truncated summary, got {output}"
);
}
other => panic!("unexpected history item: {other:?}"),
}
}
fn assert_truncated_message_matches(message: &str, line: &str, total_lines: usize) {
let pattern = truncated_message_pattern(line, total_lines);
let regex = Regex::new(&pattern).unwrap_or_else(|err| {
panic!("failed to compile regex {pattern}: {err}");
});
let captures = regex
.captures(message)
.unwrap_or_else(|| panic!("message failed to match pattern {pattern}: {message}"));
let body = captures
.name("body")
.expect("missing body capture")
.as_str();
assert!(
body.len() <= truncate::MODEL_FORMAT_MAX_BYTES,
"body exceeds byte limit: {} bytes",
body.len()
);
}
fn truncated_message_pattern(line: &str, total_lines: usize) -> String {
let head_take = truncate::MODEL_FORMAT_HEAD_LINES.min(total_lines);
let tail_take = truncate::MODEL_FORMAT_TAIL_LINES.min(total_lines.saturating_sub(head_take));
let omitted = total_lines.saturating_sub(head_take + tail_take);
let escaped_line = regex_lite::escape(line);
if omitted == 0 {
return format!(
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} output truncated to fit {max_bytes} bytes \.{{3}}]\n\n.*)$",
max_bytes = truncate::MODEL_FORMAT_MAX_BYTES,
);
}
format!(
r"(?s)^Total output lines: {total_lines}\n\n(?P<body>{escaped_line}.*\n\[\.{{3}} omitted {omitted} of {total_lines} lines \.{{3}}]\n\n.*)$",
)
}
#[test]
fn format_exec_output_truncates_large_error() {
let line = "very long execution error line that should trigger truncation\n";
let large_error = line.repeat(2_500); // way beyond both byte and line limits
let truncated = truncate::format_output_for_model_body(&large_error);
let total_lines = large_error.lines().count();
assert_truncated_message_matches(&truncated, line, total_lines);
assert_ne!(truncated, large_error);
}
#[test]
fn format_exec_output_marks_byte_truncation_without_omitted_lines() {
let long_line = "a".repeat(truncate::MODEL_FORMAT_MAX_BYTES + 50);
let truncated = truncate::format_output_for_model_body(&long_line);
assert_ne!(truncated, long_line);
let marker_line = format!(
"[... output truncated to fit {} bytes ...]",
truncate::MODEL_FORMAT_MAX_BYTES
);
assert!(
truncated.contains(&marker_line),
"missing byte truncation marker: {truncated}"
);
assert!(
!truncated.contains("omitted"),
"line omission marker should not appear when no lines were dropped: {truncated}"
);
}
#[test]
fn format_exec_output_returns_original_when_within_limits() {
let content = "example output\n".repeat(10);
assert_eq!(truncate::format_output_for_model_body(&content), content);
}
#[test]
fn format_exec_output_reports_omitted_lines_and_keeps_head_and_tail() {
let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 100;
let content: String = (0..total_lines)
.map(|idx| format!("line-{idx}\n"))
.collect();
let truncated = truncate::format_output_for_model_body(&content);
let omitted = total_lines - truncate::MODEL_FORMAT_MAX_LINES;
let expected_marker = format!("[... omitted {omitted} of {total_lines} lines ...]");
assert!(
truncated.contains(&expected_marker),
"missing omitted marker: {truncated}"
);
assert!(
truncated.contains("line-0\n"),
"expected head line to remain: {truncated}"
);
let last_line = format!("line-{}\n", total_lines - 1);
assert!(
truncated.contains(&last_line),
"expected tail line to remain: {truncated}"
);
}
#[test]
fn format_exec_output_prefers_line_marker_when_both_limits_exceeded() {
let total_lines = truncate::MODEL_FORMAT_MAX_LINES + 42;
let long_line = "x".repeat(256);
let content: String = (0..total_lines)
.map(|idx| format!("line-{idx}-{long_line}\n"))
.collect();
let truncated = truncate::format_output_for_model_body(&content);
assert!(
truncated.contains("[... omitted 42 of 298 lines ...]"),
"expected omitted marker when line count exceeds limit: {truncated}"
);
assert!(
!truncated.contains("output truncated to fit"),
"line omission marker should take precedence over byte marker: {truncated}"
);
}
#[test]
fn truncates_across_multiple_under_limit_texts_and_reports_omitted() {
// Arrange: several text items, none exceeding per-item limit, but total exceeds budget.
let budget = truncate::MODEL_FORMAT_MAX_BYTES;
let t1_len = (budget / 2).saturating_sub(10);
let t2_len = (budget / 2).saturating_sub(10);
let remaining_after_t1_t2 = budget.saturating_sub(t1_len + t2_len);
let t3_len = 50; // gets truncated to remaining_after_t1_t2
let t4_len = 5; // omitted
let t5_len = 7; // omitted
let t1 = "a".repeat(t1_len);
let t2 = "b".repeat(t2_len);
let t3 = "c".repeat(t3_len);
let t4 = "d".repeat(t4_len);
let t5 = "e".repeat(t5_len);
let item = ResponseItem::FunctionCallOutput {
call_id: "call-omit".to_string(),
output: FunctionCallOutputPayload {
content: "irrelevant".to_string(),
content_items: Some(vec![
FunctionCallOutputContentItem::InputText { text: t1 },
FunctionCallOutputContentItem::InputText { text: t2 },
FunctionCallOutputContentItem::InputImage {
image_url: "img:mid".to_string(),
},
FunctionCallOutputContentItem::InputText { text: t3 },
FunctionCallOutputContentItem::InputText { text: t4 },
FunctionCallOutputContentItem::InputText { text: t5 },
]),
success: Some(true),
},
};
let mut history = ContextManager::new();
history.record_items([&item]);
assert_eq!(history.items.len(), 1);
let json = serde_json::to_value(&history.items[0]).expect("serialize to json");
let output = json
.get("output")
.expect("output field")
.as_array()
.expect("array output");
// Expect: t1 (full), t2 (full), image, t3 (truncated), summary mentioning 2 omitted.
assert_eq!(output.len(), 5);
let first = output[0].as_object().expect("first obj");
assert_eq!(first.get("type").unwrap(), "input_text");
let first_text = first.get("text").unwrap().as_str().unwrap();
assert_eq!(first_text.len(), t1_len);
let second = output[1].as_object().expect("second obj");
assert_eq!(second.get("type").unwrap(), "input_text");
let second_text = second.get("text").unwrap().as_str().unwrap();
assert_eq!(second_text.len(), t2_len);
assert_eq!(
output[2],
serde_json::json!({"type": "input_image", "image_url": "img:mid"})
);
let fourth = output[3].as_object().expect("fourth obj");
assert_eq!(fourth.get("type").unwrap(), "input_text");
let fourth_text = fourth.get("text").unwrap().as_str().unwrap();
assert_eq!(fourth_text.len(), remaining_after_t1_t2);
let summary = output[4].as_object().expect("summary obj");
assert_eq!(summary.get("type").unwrap(), "input_text");
let summary_text = summary.get("text").unwrap().as_str().unwrap();
assert!(summary_text.contains("omitted 2 text items"));
}
//TODO(aibrahim): run CI in release mode.
#[cfg(not(debug_assertions))]
#[test]
fn normalize_adds_missing_output_for_function_call() {
let items = vec![ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-x".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(
h.contents(),
vec![
ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-x".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "call-x".to_string(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
]
);
}
#[cfg(not(debug_assertions))]
#[test]
fn normalize_adds_missing_output_for_custom_tool_call() {
let items = vec![ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "tool-x".to_string(),
name: "custom".to_string(),
input: "{}".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(
h.contents(),
vec![
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "tool-x".to_string(),
name: "custom".to_string(),
input: "{}".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "tool-x".to_string(),
output: "aborted".to_string(),
},
]
);
}
#[cfg(not(debug_assertions))]
#[test]
fn normalize_adds_missing_output_for_local_shell_call_with_id() {
let items = vec![ResponseItem::LocalShellCall {
id: None,
call_id: Some("shell-1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string(), "hi".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(
h.contents(),
vec![
ResponseItem::LocalShellCall {
id: None,
call_id: Some("shell-1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string(), "hi".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
ResponseItem::FunctionCallOutput {
call_id: "shell-1".to_string(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
]
);
}
#[cfg(not(debug_assertions))]
#[test]
fn normalize_removes_orphan_function_call_output() {
let items = vec![ResponseItem::FunctionCallOutput {
call_id: "orphan-1".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
}];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(h.contents(), vec![]);
}
#[cfg(not(debug_assertions))]
#[test]
fn normalize_removes_orphan_custom_tool_call_output() {
let items = vec![ResponseItem::CustomToolCallOutput {
call_id: "orphan-2".to_string(),
output: "ok".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(h.contents(), vec![]);
}
#[cfg(not(debug_assertions))]
#[test]
fn normalize_mixed_inserts_and_removals() {
let items = vec![
// Will get an inserted output
ResponseItem::FunctionCall {
id: None,
name: "f1".to_string(),
arguments: "{}".to_string(),
call_id: "c1".to_string(),
},
// Orphan output that should be removed
ResponseItem::FunctionCallOutput {
call_id: "c2".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
// Will get an inserted custom tool output
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "t1".to_string(),
name: "tool".to_string(),
input: "{}".to_string(),
},
// Local shell call also gets an inserted function call output
ResponseItem::LocalShellCall {
id: None,
call_id: Some("s1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
];
let mut h = create_history_with_items(items);
h.normalize_history();
assert_eq!(
h.contents(),
vec![
ResponseItem::FunctionCall {
id: None,
name: "f1".to_string(),
arguments: "{}".to_string(),
call_id: "c1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "c1".to_string(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "t1".to_string(),
name: "tool".to_string(),
input: "{}".to_string(),
},
ResponseItem::CustomToolCallOutput {
call_id: "t1".to_string(),
output: "aborted".to_string(),
},
ResponseItem::LocalShellCall {
id: None,
call_id: Some("s1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
ResponseItem::FunctionCallOutput {
call_id: "s1".to_string(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
]
);
}
// In debug builds we panic on normalization errors instead of silently fixing them.
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_adds_missing_output_for_function_call_panics_in_debug() {
let items = vec![ResponseItem::FunctionCall {
id: None,
name: "do_it".to_string(),
arguments: "{}".to_string(),
call_id: "call-x".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_adds_missing_output_for_custom_tool_call_panics_in_debug() {
let items = vec![ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "tool-x".to_string(),
name: "custom".to_string(),
input: "{}".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_adds_missing_output_for_local_shell_call_with_id_panics_in_debug() {
let items = vec![ResponseItem::LocalShellCall {
id: None,
call_id: Some("shell-1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string(), "hi".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_removes_orphan_function_call_output_panics_in_debug() {
let items = vec![ResponseItem::FunctionCallOutput {
call_id: "orphan-1".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
}];
let mut h = create_history_with_items(items);
h.normalize_history();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_removes_orphan_custom_tool_call_output_panics_in_debug() {
let items = vec![ResponseItem::CustomToolCallOutput {
call_id: "orphan-2".to_string(),
output: "ok".to_string(),
}];
let mut h = create_history_with_items(items);
h.normalize_history();
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn normalize_mixed_inserts_and_removals_panics_in_debug() {
let items = vec![
ResponseItem::FunctionCall {
id: None,
name: "f1".to_string(),
arguments: "{}".to_string(),
call_id: "c1".to_string(),
},
ResponseItem::FunctionCallOutput {
call_id: "c2".to_string(),
output: FunctionCallOutputPayload {
content: "ok".to_string(),
..Default::default()
},
},
ResponseItem::CustomToolCall {
id: None,
status: None,
call_id: "t1".to_string(),
name: "tool".to_string(),
input: "{}".to_string(),
},
ResponseItem::LocalShellCall {
id: None,
call_id: Some("s1".to_string()),
status: LocalShellStatus::Completed,
action: LocalShellAction::Exec(LocalShellExecAction {
command: vec!["echo".to_string()],
timeout_ms: None,
working_directory: None,
env: None,
user: None,
}),
},
];
let mut h = create_history_with_items(items);
h.normalize_history();
}

View File

@@ -0,0 +1,6 @@
mod history;
mod normalize;
mod truncate;
pub(crate) use history::ContextManager;
pub(crate) use truncate::format_output_for_model_body;

View File

@@ -0,0 +1,213 @@
use std::collections::HashSet;
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::ResponseItem;
use crate::util::error_or_panic;
pub(crate) fn ensure_call_outputs_present(items: &mut Vec<ResponseItem>) {
// Collect synthetic outputs to insert immediately after their calls.
// Store the insertion position (index of call) alongside the item so
// we can insert in reverse order and avoid index shifting.
let mut missing_outputs_to_insert: Vec<(usize, ResponseItem)> = Vec::new();
for (idx, item) in items.iter().enumerate() {
match item {
ResponseItem::FunctionCall { call_id, .. } => {
let has_output = items.iter().any(|i| match i {
ResponseItem::FunctionCallOutput {
call_id: existing, ..
} => existing == call_id,
_ => false,
});
if !has_output {
error_or_panic(format!(
"Function call output is missing for call id: {call_id}"
));
missing_outputs_to_insert.push((
idx,
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
));
}
}
ResponseItem::CustomToolCall { call_id, .. } => {
let has_output = items.iter().any(|i| match i {
ResponseItem::CustomToolCallOutput {
call_id: existing, ..
} => existing == call_id,
_ => false,
});
if !has_output {
error_or_panic(format!(
"Custom tool call output is missing for call id: {call_id}"
));
missing_outputs_to_insert.push((
idx,
ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
output: "aborted".to_string(),
},
));
}
}
// LocalShellCall is represented in upstream streams by a FunctionCallOutput
ResponseItem::LocalShellCall { call_id, .. } => {
if let Some(call_id) = call_id.as_ref() {
let has_output = items.iter().any(|i| match i {
ResponseItem::FunctionCallOutput {
call_id: existing, ..
} => existing == call_id,
_ => false,
});
if !has_output {
error_or_panic(format!(
"Local shell call output is missing for call id: {call_id}"
));
missing_outputs_to_insert.push((
idx,
ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: "aborted".to_string(),
..Default::default()
},
},
));
}
}
}
_ => {}
}
}
// Insert synthetic outputs in reverse index order to avoid re-indexing.
for (idx, output_item) in missing_outputs_to_insert.into_iter().rev() {
items.insert(idx + 1, output_item);
}
}
pub(crate) fn remove_orphan_outputs(items: &mut Vec<ResponseItem>) {
let function_call_ids: HashSet<String> = items
.iter()
.filter_map(|i| match i {
ResponseItem::FunctionCall { call_id, .. } => Some(call_id.clone()),
_ => None,
})
.collect();
let local_shell_call_ids: HashSet<String> = items
.iter()
.filter_map(|i| match i {
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => Some(call_id.clone()),
_ => None,
})
.collect();
let custom_tool_call_ids: HashSet<String> = items
.iter()
.filter_map(|i| match i {
ResponseItem::CustomToolCall { call_id, .. } => Some(call_id.clone()),
_ => None,
})
.collect();
items.retain(|item| match item {
ResponseItem::FunctionCallOutput { call_id, .. } => {
let has_match =
function_call_ids.contains(call_id) || local_shell_call_ids.contains(call_id);
if !has_match {
error_or_panic(format!(
"Orphan function call output for call id: {call_id}"
));
}
has_match
}
ResponseItem::CustomToolCallOutput { call_id, .. } => {
let has_match = custom_tool_call_ids.contains(call_id);
if !has_match {
error_or_panic(format!(
"Orphan custom tool call output for call id: {call_id}"
));
}
has_match
}
_ => true,
});
}
pub(crate) fn remove_corresponding_for(items: &mut Vec<ResponseItem>, item: &ResponseItem) {
match item {
ResponseItem::FunctionCall { call_id, .. } => {
remove_first_matching(items, |i| {
matches!(
i,
ResponseItem::FunctionCallOutput {
call_id: existing, ..
} if existing == call_id
)
});
}
ResponseItem::FunctionCallOutput { call_id, .. } => {
if let Some(pos) = items.iter().position(|i| {
matches!(i, ResponseItem::FunctionCall { call_id: existing, .. } if existing == call_id)
}) {
items.remove(pos);
} else if let Some(pos) = items.iter().position(|i| {
matches!(i, ResponseItem::LocalShellCall { call_id: Some(existing), .. } if existing == call_id)
}) {
items.remove(pos);
}
}
ResponseItem::CustomToolCall { call_id, .. } => {
remove_first_matching(items, |i| {
matches!(
i,
ResponseItem::CustomToolCallOutput {
call_id: existing, ..
} if existing == call_id
)
});
}
ResponseItem::CustomToolCallOutput { call_id, .. } => {
remove_first_matching(
items,
|i| matches!(i, ResponseItem::CustomToolCall { call_id: existing, .. } if existing == call_id),
);
}
ResponseItem::LocalShellCall {
call_id: Some(call_id),
..
} => {
remove_first_matching(items, |i| {
matches!(
i,
ResponseItem::FunctionCallOutput {
call_id: existing, ..
} if existing == call_id
)
});
}
_ => {}
}
}
fn remove_first_matching<F>(items: &mut Vec<ResponseItem>, predicate: F)
where
F: Fn(&ResponseItem) -> bool,
{
if let Some(pos) = items.iter().position(predicate) {
items.remove(pos);
}
}

View File

@@ -0,0 +1,128 @@
use llmx_protocol::models::FunctionCallOutputContentItem;
use llmx_utils_string::take_bytes_at_char_boundary;
use llmx_utils_string::take_last_bytes_at_char_boundary;
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
pub(crate) const MODEL_FORMAT_MAX_BYTES: usize = 10 * 1024; // 10 KiB
pub(crate) const MODEL_FORMAT_MAX_LINES: usize = 256; // lines
pub(crate) const MODEL_FORMAT_HEAD_LINES: usize = MODEL_FORMAT_MAX_LINES / 2;
pub(crate) const MODEL_FORMAT_TAIL_LINES: usize = MODEL_FORMAT_MAX_LINES - MODEL_FORMAT_HEAD_LINES; // 128
pub(crate) const MODEL_FORMAT_HEAD_BYTES: usize = MODEL_FORMAT_MAX_BYTES / 2;
pub(crate) fn globally_truncate_function_output_items(
items: &[FunctionCallOutputContentItem],
) -> Vec<FunctionCallOutputContentItem> {
let mut out: Vec<FunctionCallOutputContentItem> = Vec::with_capacity(items.len());
let mut remaining = MODEL_FORMAT_MAX_BYTES;
let mut omitted_text_items = 0usize;
for it in items {
match it {
FunctionCallOutputContentItem::InputText { text } => {
if remaining == 0 {
omitted_text_items += 1;
continue;
}
let len = text.len();
if len <= remaining {
out.push(FunctionCallOutputContentItem::InputText { text: text.clone() });
remaining -= len;
} else {
let slice = take_bytes_at_char_boundary(text, remaining);
if !slice.is_empty() {
out.push(FunctionCallOutputContentItem::InputText {
text: slice.to_string(),
});
}
remaining = 0;
}
}
// todo(aibrahim): handle input images; resize
FunctionCallOutputContentItem::InputImage { image_url } => {
out.push(FunctionCallOutputContentItem::InputImage {
image_url: image_url.clone(),
});
}
}
}
if omitted_text_items > 0 {
out.push(FunctionCallOutputContentItem::InputText {
text: format!("[omitted {omitted_text_items} text items ...]"),
});
}
out
}
pub(crate) fn format_output_for_model_body(content: &str) -> String {
// Head+tail truncation for the model: show the beginning and end with an elision.
// Clients still receive full streams; only this formatted summary is capped.
let total_lines = content.lines().count();
if content.len() <= MODEL_FORMAT_MAX_BYTES && total_lines <= MODEL_FORMAT_MAX_LINES {
return content.to_string();
}
let output = truncate_formatted_exec_output(content, total_lines);
format!("Total output lines: {total_lines}\n\n{output}")
}
fn truncate_formatted_exec_output(content: &str, total_lines: usize) -> String {
let segments: Vec<&str> = content.split_inclusive('\n').collect();
let head_take = MODEL_FORMAT_HEAD_LINES.min(segments.len());
let tail_take = MODEL_FORMAT_TAIL_LINES.min(segments.len().saturating_sub(head_take));
let omitted = segments.len().saturating_sub(head_take + tail_take);
let head_slice_end: usize = segments
.iter()
.take(head_take)
.map(|segment| segment.len())
.sum();
let tail_slice_start: usize = if tail_take == 0 {
content.len()
} else {
content.len()
- segments
.iter()
.rev()
.take(tail_take)
.map(|segment| segment.len())
.sum::<usize>()
};
let head_slice = &content[..head_slice_end];
let tail_slice = &content[tail_slice_start..];
let truncated_by_bytes = content.len() > MODEL_FORMAT_MAX_BYTES;
// this is a bit wrong. We are counting metadata lines and not just shell output lines.
let marker = if omitted > 0 {
Some(format!(
"\n[... omitted {omitted} of {total_lines} lines ...]\n\n"
))
} else if truncated_by_bytes {
Some(format!(
"\n[... output truncated to fit {MODEL_FORMAT_MAX_BYTES} bytes ...]\n\n"
))
} else {
None
};
let marker_len = marker.as_ref().map_or(0, String::len);
let base_head_budget = MODEL_FORMAT_HEAD_BYTES.min(MODEL_FORMAT_MAX_BYTES);
let head_budget = base_head_budget.min(MODEL_FORMAT_MAX_BYTES.saturating_sub(marker_len));
let head_part = take_bytes_at_char_boundary(head_slice, head_budget);
let mut result = String::with_capacity(MODEL_FORMAT_MAX_BYTES.min(content.len()));
result.push_str(head_part);
if let Some(marker_text) = marker.as_ref() {
result.push_str(marker_text);
}
let remaining = MODEL_FORMAT_MAX_BYTES.saturating_sub(result.len());
if remaining == 0 {
return result;
}
let tail_part = take_last_bytes_at_char_boundary(tail_slice, remaining);
result.push_str(tail_part);
result
}

View File

@@ -0,0 +1,339 @@
use crate::AuthManager;
use crate::LlmxAuth;
use crate::config::Config;
use crate::error::LlmxErr;
use crate::error::Result as LlmxResult;
use crate::llmx::INITIAL_SUBMIT_ID;
use crate::llmx::Llmx;
use crate::llmx::LlmxSpawnOk;
use crate::llmx_conversation::LlmxConversation;
use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::SessionConfiguredEvent;
use crate::rollout::RolloutRecorder;
use llmx_protocol::ConversationId;
use llmx_protocol::items::TurnItem;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::InitialHistory;
use llmx_protocol::protocol::RolloutItem;
use llmx_protocol::protocol::SessionSource;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
/// Represents a newly created Llmx conversation, including the first event
/// (which is [`EventMsg::SessionConfigured`]).
pub struct NewConversation {
pub conversation_id: ConversationId,
pub conversation: Arc<LlmxConversation>,
pub session_configured: SessionConfiguredEvent,
}
/// [`ConversationManager`] is responsible for creating conversations and
/// maintaining them in memory.
pub struct ConversationManager {
conversations: Arc<RwLock<HashMap<ConversationId, Arc<LlmxConversation>>>>,
auth_manager: Arc<AuthManager>,
session_source: SessionSource,
}
impl ConversationManager {
pub fn new(auth_manager: Arc<AuthManager>, session_source: SessionSource) -> Self {
Self {
conversations: Arc::new(RwLock::new(HashMap::new())),
auth_manager,
session_source,
}
}
/// Construct with a dummy AuthManager containing the provided LlmxAuth.
/// Used for integration tests: should not be used by ordinary business logic.
pub fn with_auth(auth: LlmxAuth) -> Self {
Self::new(
crate::AuthManager::from_auth_for_testing(auth),
SessionSource::Exec,
)
}
pub async fn new_conversation(&self, config: Config) -> LlmxResult<NewConversation> {
self.spawn_conversation(config, self.auth_manager.clone())
.await
}
async fn spawn_conversation(
&self,
config: Config,
auth_manager: Arc<AuthManager>,
) -> LlmxResult<NewConversation> {
let LlmxSpawnOk {
llmx,
conversation_id,
} = Llmx::spawn(
config,
auth_manager,
InitialHistory::New,
self.session_source.clone(),
)
.await?;
self.finalize_spawn(llmx, conversation_id).await
}
async fn finalize_spawn(
&self,
llmx: Llmx,
conversation_id: ConversationId,
) -> LlmxResult<NewConversation> {
// The first event must be `SessionInitialized`. Validate and forward it
// to the caller so that they can display it in the conversation
// history.
let event = llmx.next_event().await?;
let session_configured = match event {
Event {
id,
msg: EventMsg::SessionConfigured(session_configured),
} if id == INITIAL_SUBMIT_ID => session_configured,
_ => {
return Err(LlmxErr::SessionConfiguredNotFirstEvent);
}
};
let conversation = Arc::new(LlmxConversation::new(
llmx,
session_configured.rollout_path.clone(),
));
self.conversations
.write()
.await
.insert(conversation_id, conversation.clone());
Ok(NewConversation {
conversation_id,
conversation,
session_configured,
})
}
pub async fn get_conversation(
&self,
conversation_id: ConversationId,
) -> LlmxResult<Arc<LlmxConversation>> {
let conversations = self.conversations.read().await;
conversations
.get(&conversation_id)
.cloned()
.ok_or_else(|| LlmxErr::ConversationNotFound(conversation_id))
}
pub async fn resume_conversation_from_rollout(
&self,
config: Config,
rollout_path: PathBuf,
auth_manager: Arc<AuthManager>,
) -> LlmxResult<NewConversation> {
let initial_history = RolloutRecorder::get_rollout_history(&rollout_path).await?;
self.resume_conversation_with_history(config, initial_history, auth_manager)
.await
}
pub async fn resume_conversation_with_history(
&self,
config: Config,
initial_history: InitialHistory,
auth_manager: Arc<AuthManager>,
) -> LlmxResult<NewConversation> {
let LlmxSpawnOk {
llmx,
conversation_id,
} = Llmx::spawn(
config,
auth_manager,
initial_history,
self.session_source.clone(),
)
.await?;
self.finalize_spawn(llmx, conversation_id).await
}
/// Removes the conversation from the manager's internal map, though the
/// conversation is stored as `Arc<LlmxConversation>`, it is possible that
/// other references to it exist elsewhere. Returns the conversation if the
/// conversation was found and removed.
pub async fn remove_conversation(
&self,
conversation_id: &ConversationId,
) -> Option<Arc<LlmxConversation>> {
self.conversations.write().await.remove(conversation_id)
}
/// Fork an existing conversation by taking messages up to the given position
/// (not including the message at the given position) and starting a new
/// conversation with identical configuration (unless overridden by the
/// caller's `config`). The new conversation will have a fresh id.
pub async fn fork_conversation(
&self,
nth_user_message: usize,
config: Config,
path: PathBuf,
) -> LlmxResult<NewConversation> {
// Compute the prefix up to the cut point.
let history = RolloutRecorder::get_rollout_history(&path).await?;
let history = truncate_before_nth_user_message(history, nth_user_message);
// Spawn a new conversation with the computed initial history.
let auth_manager = self.auth_manager.clone();
let LlmxSpawnOk {
llmx,
conversation_id,
} = Llmx::spawn(config, auth_manager, history, self.session_source.clone()).await?;
self.finalize_spawn(llmx, conversation_id).await
}
}
/// Return a prefix of `items` obtained by cutting strictly before the nth user message
/// (0-based) and all items that follow it.
fn truncate_before_nth_user_message(history: InitialHistory, n: usize) -> InitialHistory {
// Work directly on rollout items, and cut the vector at the nth user message input.
let items: Vec<RolloutItem> = history.get_rollout_items();
// Find indices of user message inputs in rollout order.
let mut user_positions: Vec<usize> = Vec::new();
for (idx, item) in items.iter().enumerate() {
if let RolloutItem::ResponseItem(item @ ResponseItem::Message { .. }) = item
&& matches!(
crate::event_mapping::parse_turn_item(item),
Some(TurnItem::UserMessage(_))
)
{
user_positions.push(idx);
}
}
// If fewer than or equal to n user messages exist, treat as empty (out of range).
if user_positions.len() <= n {
return InitialHistory::New;
}
// Cut strictly before the nth user message (do not keep the nth itself).
let cut_idx = user_positions[n];
let rolled: Vec<RolloutItem> = items.into_iter().take(cut_idx).collect();
if rolled.is_empty() {
InitialHistory::New
} else {
InitialHistory::Forked(rolled)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llmx::make_session_and_context;
use assert_matches::assert_matches;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ReasoningItemReasoningSummary;
use llmx_protocol::models::ResponseItem;
use pretty_assertions::assert_eq;
fn user_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
#[test]
fn drops_from_last_user_only() {
let items = [
user_msg("u1"),
assistant_msg("a1"),
assistant_msg("a2"),
user_msg("u2"),
assistant_msg("a3"),
ResponseItem::Reasoning {
id: "r1".to_string(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "s".to_string(),
}],
content: None,
encrypted_content: None,
},
ResponseItem::FunctionCall {
id: None,
name: "tool".to_string(),
arguments: "{}".to_string(),
call_id: "c1".to_string(),
},
assistant_msg("a4"),
];
// Wrap as InitialHistory::Forked with response items only.
let initial: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(initial), 1);
let got_items = truncated.get_rollout_items();
let expected_items = vec![
RolloutItem::ResponseItem(items[0].clone()),
RolloutItem::ResponseItem(items[1].clone()),
RolloutItem::ResponseItem(items[2].clone()),
];
assert_eq!(
serde_json::to_value(&got_items).unwrap(),
serde_json::to_value(&expected_items).unwrap()
);
let initial2: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated2 = truncate_before_nth_user_message(InitialHistory::Forked(initial2), 2);
assert_matches!(truncated2, InitialHistory::New);
}
#[test]
fn ignores_session_prefix_messages_when_truncating() {
let (session, turn_context) = make_session_and_context();
let mut items = session.build_initial_context(&turn_context);
items.push(user_msg("feature request"));
items.push(assistant_msg("ack"));
items.push(user_msg("second question"));
items.push(assistant_msg("answer"));
let rollout_items: Vec<RolloutItem> = items
.iter()
.cloned()
.map(RolloutItem::ResponseItem)
.collect();
let truncated = truncate_before_nth_user_message(InitialHistory::Forked(rollout_items), 1);
let got_items = truncated.get_rollout_items();
let expected: Vec<RolloutItem> = vec![
RolloutItem::ResponseItem(items[0].clone()),
RolloutItem::ResponseItem(items[1].clone()),
RolloutItem::ResponseItem(items[2].clone()),
];
assert_eq!(
serde_json::to_value(&got_items).unwrap(),
serde_json::to_value(&expected).unwrap()
);
}
}

View File

@@ -0,0 +1,244 @@
use llmx_protocol::custom_prompts::CustomPrompt;
use std::collections::HashSet;
use std::path::Path;
use std::path::PathBuf;
use tokio::fs;
/// Return the default prompts directory: `$LLMX_HOME/prompts`.
/// If `LLMX_HOME` cannot be resolved, returns `None`.
pub fn default_prompts_dir() -> Option<PathBuf> {
crate::config::find_llmx_home()
.ok()
.map(|home| home.join("prompts"))
}
/// Discover prompt files in the given directory, returning entries sorted by name.
/// Non-files are ignored. If the directory does not exist or cannot be read, returns empty.
pub async fn discover_prompts_in(dir: &Path) -> Vec<CustomPrompt> {
discover_prompts_in_excluding(dir, &HashSet::new()).await
}
/// Discover prompt files in the given directory, excluding any with names in `exclude`.
/// Returns entries sorted by name. Non-files are ignored. Missing/unreadable dir yields empty.
pub async fn discover_prompts_in_excluding(
dir: &Path,
exclude: &HashSet<String>,
) -> Vec<CustomPrompt> {
let mut out: Vec<CustomPrompt> = Vec::new();
let mut entries = match fs::read_dir(dir).await {
Ok(entries) => entries,
Err(_) => return out,
};
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
let is_file_like = fs::metadata(&path)
.await
.map(|m| m.is_file())
.unwrap_or(false);
if !is_file_like {
continue;
}
// Only include Markdown files with a .md extension.
let is_md = path
.extension()
.and_then(|s| s.to_str())
.map(|ext| ext.eq_ignore_ascii_case("md"))
.unwrap_or(false);
if !is_md {
continue;
}
let Some(name) = path
.file_stem()
.and_then(|s| s.to_str())
.map(str::to_string)
else {
continue;
};
if exclude.contains(&name) {
continue;
}
let content = match fs::read_to_string(&path).await {
Ok(s) => s,
Err(_) => continue,
};
let (description, argument_hint, body) = parse_frontmatter(&content);
out.push(CustomPrompt {
name,
path,
content: body,
description,
argument_hint,
});
}
out.sort_by(|a, b| a.name.cmp(&b.name));
out
}
/// Parse optional YAML-like frontmatter at the beginning of `content`.
/// Supported keys:
/// - `description`: short description shown in the slash popup
/// - `argument-hint` or `argument_hint`: brief hint string shown after the description
/// Returns (description, argument_hint, body_without_frontmatter).
fn parse_frontmatter(content: &str) -> (Option<String>, Option<String>, String) {
let mut segments = content.split_inclusive('\n');
let Some(first_segment) = segments.next() else {
return (None, None, String::new());
};
let first_line = first_segment.trim_end_matches(['\r', '\n']);
if first_line.trim() != "---" {
return (None, None, content.to_string());
}
let mut desc: Option<String> = None;
let mut hint: Option<String> = None;
let mut frontmatter_closed = false;
let mut consumed = first_segment.len();
for segment in segments {
let line = segment.trim_end_matches(['\r', '\n']);
let trimmed = line.trim();
if trimmed == "---" {
frontmatter_closed = true;
consumed += segment.len();
break;
}
if trimmed.is_empty() || trimmed.starts_with('#') {
consumed += segment.len();
continue;
}
if let Some((k, v)) = trimmed.split_once(':') {
let key = k.trim().to_ascii_lowercase();
let mut val = v.trim().to_string();
if val.len() >= 2 {
let bytes = val.as_bytes();
let first = bytes[0];
let last = bytes[bytes.len() - 1];
if (first == b'\"' && last == b'\"') || (first == b'\'' && last == b'\'') {
val = val[1..val.len().saturating_sub(1)].to_string();
}
}
match key.as_str() {
"description" => desc = Some(val),
"argument-hint" | "argument_hint" => hint = Some(val),
_ => {}
}
}
consumed += segment.len();
}
if !frontmatter_closed {
// Unterminated frontmatter: treat input as-is.
return (None, None, content.to_string());
}
let body = if consumed >= content.len() {
String::new()
} else {
content[consumed..].to_string()
};
(desc, hint, body)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[tokio::test]
async fn empty_when_dir_missing() {
let tmp = tempdir().expect("create TempDir");
let missing = tmp.path().join("nope");
let found = discover_prompts_in(&missing).await;
assert!(found.is_empty());
}
#[tokio::test]
async fn discovers_and_sorts_files() {
let tmp = tempdir().expect("create TempDir");
let dir = tmp.path();
fs::write(dir.join("b.md"), b"b").unwrap();
fs::write(dir.join("a.md"), b"a").unwrap();
fs::create_dir(dir.join("subdir")).unwrap();
let found = discover_prompts_in(dir).await;
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
assert_eq!(names, vec!["a", "b"]);
}
#[tokio::test]
async fn excludes_builtins() {
let tmp = tempdir().expect("create TempDir");
let dir = tmp.path();
fs::write(dir.join("init.md"), b"ignored").unwrap();
fs::write(dir.join("foo.md"), b"ok").unwrap();
let mut exclude = HashSet::new();
exclude.insert("init".to_string());
let found = discover_prompts_in_excluding(dir, &exclude).await;
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
assert_eq!(names, vec!["foo"]);
}
#[tokio::test]
async fn skips_non_utf8_files() {
let tmp = tempdir().expect("create TempDir");
let dir = tmp.path();
// Valid UTF-8 file
fs::write(dir.join("good.md"), b"hello").unwrap();
// Invalid UTF-8 content in .md file (e.g., lone 0xFF byte)
fs::write(dir.join("bad.md"), vec![0xFF, 0xFE, b'\n']).unwrap();
let found = discover_prompts_in(dir).await;
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
assert_eq!(names, vec!["good"]);
}
#[tokio::test]
#[cfg(unix)]
async fn discovers_symlinked_md_files() {
let tmp = tempdir().expect("create TempDir");
let dir = tmp.path();
// Create a real file
fs::write(dir.join("real.md"), b"real content").unwrap();
// Create a symlink to the real file
std::os::unix::fs::symlink(dir.join("real.md"), dir.join("link.md")).unwrap();
let found = discover_prompts_in(dir).await;
let names: Vec<String> = found.into_iter().map(|e| e.name).collect();
// Both real and link should be discovered, sorted alphabetically
assert_eq!(names, vec!["link", "real"]);
}
#[tokio::test]
async fn parses_frontmatter_and_strips_from_body() {
let tmp = tempdir().expect("create TempDir");
let dir = tmp.path();
let file = dir.join("withmeta.md");
let text = "---\nname: ignored\ndescription: \"Quick review command\"\nargument-hint: \"[file] [priority]\"\n---\nActual body with $1 and $ARGUMENTS";
fs::write(&file, text).unwrap();
let found = discover_prompts_in(dir).await;
assert_eq!(found.len(), 1);
let p = &found[0];
assert_eq!(p.name, "withmeta");
assert_eq!(p.description.as_deref(), Some("Quick review command"));
assert_eq!(p.argument_hint.as_deref(), Some("[file] [priority]"));
// Body should not include the frontmatter delimiters.
assert_eq!(p.content, "Actual body with $1 and $ARGUMENTS");
}
#[test]
fn parse_frontmatter_preserves_body_newlines() {
let content = "---\r\ndescription: \"Line endings\"\r\nargument_hint: \"[arg]\"\r\n---\r\nFirst line\r\nSecond line\r\n";
let (desc, hint, body) = parse_frontmatter(content);
assert_eq!(desc.as_deref(), Some("Line endings"));
assert_eq!(hint.as_deref(), Some("[arg]"));
assert_eq!(body, "First line\r\nSecond line\r\n");
}
}

View File

@@ -0,0 +1,375 @@
use crate::spawn::LLMX_SANDBOX_ENV_VAR;
use http::Error as HttpError;
use reqwest::IntoUrl;
use reqwest::Method;
use reqwest::Response;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use serde::Serialize;
use std::collections::HashMap;
use std::fmt::Display;
use std::sync::LazyLock;
use std::sync::Mutex;
use std::sync::OnceLock;
/// Set this to add a suffix to the User-Agent string.
///
/// It is not ideal that we're using a global singleton for this.
/// This is primarily designed to differentiate MCP clients from each other.
/// Because there can only be one MCP server per process, it should be safe for this to be a global static.
/// However, future users of this should use this with caution as a result.
/// In addition, we want to be confident that this value is used for ALL clients and doing that requires a
/// lot of wiring and it's easy to miss code paths by doing so.
/// See https://github.com/valknar/llmx/pull/3388/files for an example of what that would look like.
/// Finally, we want to make sure this is set for ALL mcp clients without needing to know a special env var
/// or having to set data that they already specified in the mcp initialize request somewhere else.
///
/// A space is automatically added between the suffix and the rest of the User-Agent string.
/// The full user agent string is returned from the mcp initialize response.
/// Parenthesis will be added by Llmx. This should only specify what goes inside of the parenthesis.
pub static USER_AGENT_SUFFIX: LazyLock<Mutex<Option<String>>> = LazyLock::new(|| Mutex::new(None));
pub const DEFAULT_ORIGINATOR: &str = "llmx_cli_rs";
pub const LLMX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR: &str = "LLMX_INTERNAL_ORIGINATOR_OVERRIDE";
#[derive(Clone, Debug)]
pub struct LlmxHttpClient {
inner: reqwest::Client,
}
impl LlmxHttpClient {
fn new(inner: reqwest::Client) -> Self {
Self { inner }
}
pub fn get<U>(&self, url: U) -> LlmxRequestBuilder
where
U: IntoUrl,
{
self.request(Method::GET, url)
}
pub fn post<U>(&self, url: U) -> LlmxRequestBuilder
where
U: IntoUrl,
{
self.request(Method::POST, url)
}
pub fn request<U>(&self, method: Method, url: U) -> LlmxRequestBuilder
where
U: IntoUrl,
{
let url_str = url.as_str().to_string();
LlmxRequestBuilder::new(self.inner.request(method.clone(), url), method, url_str)
}
}
#[must_use = "requests are not sent unless `send` is awaited"]
#[derive(Debug)]
pub struct LlmxRequestBuilder {
builder: reqwest::RequestBuilder,
method: Method,
url: String,
}
impl LlmxRequestBuilder {
fn new(builder: reqwest::RequestBuilder, method: Method, url: String) -> Self {
Self {
builder,
method,
url,
}
}
fn map(self, f: impl FnOnce(reqwest::RequestBuilder) -> reqwest::RequestBuilder) -> Self {
Self {
builder: f(self.builder),
method: self.method,
url: self.url,
}
}
pub fn header<K, V>(self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<HttpError>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
{
self.map(|builder| builder.header(key, value))
}
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: Display,
{
self.map(|builder| builder.bearer_auth(token))
}
pub fn json<T>(self, value: &T) -> Self
where
T: ?Sized + Serialize,
{
self.map(|builder| builder.json(value))
}
pub async fn send(self) -> Result<Response, reqwest::Error> {
match self.builder.send().await {
Ok(response) => {
let request_ids = Self::extract_request_ids(&response);
tracing::debug!(
method = %self.method,
url = %self.url,
status = %response.status(),
request_ids = ?request_ids,
version = ?response.version(),
"Request completed"
);
Ok(response)
}
Err(error) => {
let status = error.status();
tracing::debug!(
method = %self.method,
url = %self.url,
status = status.map(|s| s.as_u16()),
error = %error,
"Request failed"
);
Err(error)
}
}
}
fn extract_request_ids(response: &Response) -> HashMap<String, String> {
["cf-ray", "x-request-id", "x-oai-request-id"]
.iter()
.filter_map(|&name| {
let header_name = HeaderName::from_static(name);
let value = response.headers().get(header_name)?;
let value = value.to_str().ok()?.to_owned();
Some((name.to_owned(), value))
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct Originator {
pub value: String,
pub header_value: HeaderValue,
}
static ORIGINATOR: OnceLock<Originator> = OnceLock::new();
#[derive(Debug)]
pub enum SetOriginatorError {
InvalidHeaderValue,
AlreadyInitialized,
}
fn get_originator_value(provided: Option<String>) -> Originator {
let value = std::env::var(LLMX_INTERNAL_ORIGINATOR_OVERRIDE_ENV_VAR)
.ok()
.or(provided)
.unwrap_or(DEFAULT_ORIGINATOR.to_string());
match HeaderValue::from_str(&value) {
Ok(header_value) => Originator {
value,
header_value,
},
Err(e) => {
tracing::error!("Unable to turn originator override {value} into header value: {e}");
Originator {
value: DEFAULT_ORIGINATOR.to_string(),
header_value: HeaderValue::from_static(DEFAULT_ORIGINATOR),
}
}
}
}
pub fn set_default_originator(value: String) -> Result<(), SetOriginatorError> {
let originator = get_originator_value(Some(value));
ORIGINATOR
.set(originator)
.map_err(|_| SetOriginatorError::AlreadyInitialized)
}
pub fn originator() -> &'static Originator {
ORIGINATOR.get_or_init(|| get_originator_value(None))
}
pub fn get_llmx_user_agent() -> String {
let build_version = env!("CARGO_PKG_VERSION");
let os_info = os_info::get();
let prefix = format!(
"{}/{build_version} ({} {}; {}) {}",
originator().value.as_str(),
os_info.os_type(),
os_info.version(),
os_info.architecture().unwrap_or("unknown"),
crate::terminal::user_agent()
);
let suffix = USER_AGENT_SUFFIX
.lock()
.ok()
.and_then(|guard| guard.clone());
let suffix = suffix
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map_or_else(String::new, |value| format!(" ({value})"));
let candidate = format!("{prefix}{suffix}");
sanitize_user_agent(candidate, &prefix)
}
/// Sanitize the user agent string.
///
/// Invalid characters are replaced with an underscore.
///
/// If the user agent fails to parse, it falls back to fallback and then to ORIGINATOR.
fn sanitize_user_agent(candidate: String, fallback: &str) -> String {
if HeaderValue::from_str(candidate.as_str()).is_ok() {
return candidate;
}
let sanitized: String = candidate
.chars()
.map(|ch| if matches!(ch, ' '..='~') { ch } else { '_' })
.collect();
if !sanitized.is_empty() && HeaderValue::from_str(sanitized.as_str()).is_ok() {
tracing::warn!(
"Sanitized LLMX user agent because provided suffix contained invalid header characters"
);
sanitized
} else if HeaderValue::from_str(fallback).is_ok() {
tracing::warn!(
"Falling back to base LLMX user agent because provided suffix could not be sanitized"
);
fallback.to_string()
} else {
tracing::warn!(
"Falling back to default LLMX originator because base user agent string is invalid"
);
originator().value.clone()
}
}
/// Create an HTTP client with default `originator` and `User-Agent` headers set.
pub fn create_client() -> LlmxHttpClient {
use reqwest::header::HeaderMap;
let mut headers = HeaderMap::new();
headers.insert("originator", originator().header_value.clone());
let ua = get_llmx_user_agent();
let mut builder = reqwest::Client::builder()
// Set UA via dedicated helper to avoid header validation pitfalls
.user_agent(ua)
.default_headers(headers);
if is_sandboxed() {
builder = builder.no_proxy();
}
let inner = builder.build().unwrap_or_else(|_| reqwest::Client::new());
LlmxHttpClient::new(inner)
}
fn is_sandboxed() -> bool {
std::env::var(LLMX_SANDBOX_ENV_VAR).as_deref() == Ok("seatbelt")
}
#[cfg(test)]
mod tests {
use super::*;
use core_test_support::skip_if_no_network;
#[test]
fn test_get_llmx_user_agent() {
let user_agent = get_llmx_user_agent();
assert!(user_agent.starts_with("llmx_cli_rs/"));
}
#[tokio::test]
async fn test_create_client_sets_default_headers() {
skip_if_no_network!();
use wiremock::Mock;
use wiremock::MockServer;
use wiremock::ResponseTemplate;
use wiremock::matchers::method;
use wiremock::matchers::path;
let client = create_client();
// Spin up a local mock server and capture a request.
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/"))
.respond_with(ResponseTemplate::new(200))
.mount(&server)
.await;
let resp = client
.get(server.uri())
.send()
.await
.expect("failed to send request");
assert!(resp.status().is_success());
let requests = server
.received_requests()
.await
.expect("failed to fetch received requests");
assert!(!requests.is_empty());
let headers = &requests[0].headers;
// originator header is set to the provided value
let originator_header = headers
.get("originator")
.expect("originator header missing");
assert_eq!(originator_header.to_str().unwrap(), "llmx_cli_rs");
// User-Agent matches the computed LLMX UA for that originator
let expected_ua = get_llmx_user_agent();
let ua_header = headers
.get("user-agent")
.expect("user-agent header missing");
assert_eq!(ua_header.to_str().unwrap(), expected_ua);
}
#[test]
fn test_invalid_suffix_is_sanitized() {
let prefix = "llmx_cli_rs/0.0.0";
let suffix = "bad\rsuffix";
assert_eq!(
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
"llmx_cli_rs/0.0.0 (bad_suffix)"
);
}
#[test]
fn test_invalid_suffix_is_sanitized2() {
let prefix = "llmx_cli_rs/0.0.0";
let suffix = "bad\0suffix";
assert_eq!(
sanitize_user_agent(format!("{prefix} ({suffix})"), prefix),
"llmx_cli_rs/0.0.0 (bad_suffix)"
);
}
#[test]
#[cfg(target_os = "macos")]
fn test_macos() {
use regex_lite::Regex;
let user_agent = get_llmx_user_agent();
let re = Regex::new(
r"^llmx_cli_rs/\d+\.\d+\.\d+ \(Mac OS \d+\.\d+\.\d+; (x86_64|arm64)\) (\S+)$",
)
.unwrap();
assert!(re.is_match(&user_agent));
}
}

View File

@@ -0,0 +1,347 @@
use serde::Deserialize;
use serde::Serialize;
use strum_macros::Display as DeriveDisplay;
use crate::llmx::TurnContext;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
use crate::shell::Shell;
use llmx_protocol::config_types::SandboxMode;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::ENVIRONMENT_CONTEXT_CLOSE_TAG;
use llmx_protocol::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, DeriveDisplay)]
#[serde(rename_all = "kebab-case")]
#[strum(serialize_all = "kebab-case")]
pub enum NetworkAccess {
Restricted,
Enabled,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename = "environment_context", rename_all = "snake_case")]
pub(crate) struct EnvironmentContext {
pub cwd: Option<PathBuf>,
pub approval_policy: Option<AskForApproval>,
pub sandbox_mode: Option<SandboxMode>,
pub network_access: Option<NetworkAccess>,
pub writable_roots: Option<Vec<PathBuf>>,
pub shell: Option<Shell>,
}
impl EnvironmentContext {
pub fn new(
cwd: Option<PathBuf>,
approval_policy: Option<AskForApproval>,
sandbox_policy: Option<SandboxPolicy>,
shell: Option<Shell>,
) -> Self {
Self {
cwd,
approval_policy,
sandbox_mode: match sandbox_policy {
Some(SandboxPolicy::DangerFullAccess) => Some(SandboxMode::DangerFullAccess),
Some(SandboxPolicy::ReadOnly) => Some(SandboxMode::ReadOnly),
Some(SandboxPolicy::WorkspaceWrite { .. }) => Some(SandboxMode::WorkspaceWrite),
None => None,
},
network_access: match sandbox_policy {
Some(SandboxPolicy::DangerFullAccess) => Some(NetworkAccess::Enabled),
Some(SandboxPolicy::ReadOnly) => Some(NetworkAccess::Restricted),
Some(SandboxPolicy::WorkspaceWrite { network_access, .. }) => {
if network_access {
Some(NetworkAccess::Enabled)
} else {
Some(NetworkAccess::Restricted)
}
}
None => None,
},
writable_roots: match sandbox_policy {
Some(SandboxPolicy::WorkspaceWrite { writable_roots, .. }) => {
if writable_roots.is_empty() {
None
} else {
Some(writable_roots)
}
}
_ => None,
},
shell,
}
}
/// Compares two environment contexts, ignoring the shell. Useful when
/// comparing turn to turn, since the initial environment_context will
/// include the shell, and then it is not configurable from turn to turn.
pub fn equals_except_shell(&self, other: &EnvironmentContext) -> bool {
let EnvironmentContext {
cwd,
approval_policy,
sandbox_mode,
network_access,
writable_roots,
// should compare all fields except shell
shell: _,
} = other;
self.cwd == *cwd
&& self.approval_policy == *approval_policy
&& self.sandbox_mode == *sandbox_mode
&& self.network_access == *network_access
&& self.writable_roots == *writable_roots
}
pub fn diff(before: &TurnContext, after: &TurnContext) -> Self {
let cwd = if before.cwd != after.cwd {
Some(after.cwd.clone())
} else {
None
};
let approval_policy = if before.approval_policy != after.approval_policy {
Some(after.approval_policy)
} else {
None
};
let sandbox_policy = if before.sandbox_policy != after.sandbox_policy {
Some(after.sandbox_policy.clone())
} else {
None
};
EnvironmentContext::new(cwd, approval_policy, sandbox_policy, None)
}
}
impl From<&TurnContext> for EnvironmentContext {
fn from(turn_context: &TurnContext) -> Self {
Self::new(
Some(turn_context.cwd.clone()),
Some(turn_context.approval_policy),
Some(turn_context.sandbox_policy.clone()),
// Shell is not configurable from turn to turn
None,
)
}
}
impl EnvironmentContext {
/// Serializes the environment context to XML. Libraries like `quick-xml`
/// require custom macros to handle Enums with newtypes, so we just do it
/// manually, to keep things simple. Output looks like:
///
/// ```xml
/// <environment_context>
/// <cwd>...</cwd>
/// <approval_policy>...</approval_policy>
/// <sandbox_mode>...</sandbox_mode>
/// <writable_roots>...</writable_roots>
/// <network_access>...</network_access>
/// <shell>...</shell>
/// </environment_context>
/// ```
pub fn serialize_to_xml(self) -> String {
let mut lines = vec![ENVIRONMENT_CONTEXT_OPEN_TAG.to_string()];
if let Some(cwd) = self.cwd {
lines.push(format!(" <cwd>{}</cwd>", cwd.to_string_lossy()));
}
if let Some(approval_policy) = self.approval_policy {
lines.push(format!(
" <approval_policy>{approval_policy}</approval_policy>"
));
}
if let Some(sandbox_mode) = self.sandbox_mode {
lines.push(format!(" <sandbox_mode>{sandbox_mode}</sandbox_mode>"));
}
if let Some(network_access) = self.network_access {
lines.push(format!(
" <network_access>{network_access}</network_access>"
));
}
if let Some(writable_roots) = self.writable_roots {
lines.push(" <writable_roots>".to_string());
for writable_root in writable_roots {
lines.push(format!(
" <root>{}</root>",
writable_root.to_string_lossy()
));
}
lines.push(" </writable_roots>".to_string());
}
if let Some(shell) = self.shell
&& let Some(shell_name) = shell.name()
{
lines.push(format!(" <shell>{shell_name}</shell>"));
}
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
lines.join("\n")
}
}
impl From<EnvironmentContext> for ResponseItem {
fn from(ec: EnvironmentContext) -> Self {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: ec.serialize_to_xml(),
}],
}
}
}
#[cfg(test)]
mod tests {
use crate::shell::BashShell;
use crate::shell::ZshShell;
use super::*;
use pretty_assertions::assert_eq;
fn workspace_write_policy(writable_roots: Vec<&str>, network_access: bool) -> SandboxPolicy {
SandboxPolicy::WorkspaceWrite {
writable_roots: writable_roots.into_iter().map(PathBuf::from).collect(),
network_access,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
}
}
#[test]
fn serialize_workspace_write_environment_context() {
let context = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp"], false)),
None,
);
let expected = r#"<environment_context>
<cwd>/repo</cwd>
<approval_policy>on-request</approval_policy>
<sandbox_mode>workspace-write</sandbox_mode>
<network_access>restricted</network_access>
<writable_roots>
<root>/repo</root>
<root>/tmp</root>
</writable_roots>
</environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected);
}
#[test]
fn serialize_read_only_environment_context() {
let context = EnvironmentContext::new(
None,
Some(AskForApproval::Never),
Some(SandboxPolicy::ReadOnly),
None,
);
let expected = r#"<environment_context>
<approval_policy>never</approval_policy>
<sandbox_mode>read-only</sandbox_mode>
<network_access>restricted</network_access>
</environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected);
}
#[test]
fn serialize_full_access_environment_context() {
let context = EnvironmentContext::new(
None,
Some(AskForApproval::OnFailure),
Some(SandboxPolicy::DangerFullAccess),
None,
);
let expected = r#"<environment_context>
<approval_policy>on-failure</approval_policy>
<sandbox_mode>danger-full-access</sandbox_mode>
<network_access>enabled</network_access>
</environment_context>"#;
assert_eq!(context.serialize_to_xml(), expected);
}
#[test]
fn equals_except_shell_compares_approval_policy() {
// Approval policy
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::Never),
Some(workspace_write_policy(vec!["/repo"], true)),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_compares_sandbox_policy() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_read_only_policy()),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(SandboxPolicy::new_workspace_write_policy()),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_compares_workspace_write_policy() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp", "/var"], false)),
None,
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo", "/tmp"], true)),
None,
);
assert!(!context1.equals_except_shell(&context2));
}
#[test]
fn equals_except_shell_ignores_shell() {
let context1 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Bash(BashShell {
shell_path: "/bin/bash".into(),
bashrc_path: "/home/user/.bashrc".into(),
})),
);
let context2 = EnvironmentContext::new(
Some(PathBuf::from("/repo")),
Some(AskForApproval::OnRequest),
Some(workspace_write_policy(vec!["/repo"], false)),
Some(Shell::Zsh(ZshShell {
shell_path: "/bin/zsh".into(),
zshrc_path: "/home/user/.zshrc".into(),
})),
);
assert!(context1.equals_except_shell(&context2));
}
}

773
llmx-rs/core/src/error.rs Normal file
View File

@@ -0,0 +1,773 @@
use crate::exec::ExecToolCallOutput;
use crate::llmx::ProcessedResponseItem;
use crate::token_data::KnownPlan;
use crate::token_data::PlanType;
use crate::truncate::truncate_middle;
use chrono::DateTime;
use chrono::Datelike;
use chrono::Local;
use chrono::Utc;
use llmx_async_utils::CancelErr;
use llmx_protocol::ConversationId;
use llmx_protocol::protocol::RateLimitSnapshot;
use reqwest::StatusCode;
use serde_json;
use std::io;
use std::time::Duration;
use thiserror::Error;
use tokio::task::JoinError;
pub type Result<T> = std::result::Result<T, LlmxErr>;
/// Limit UI error messages to a reasonable size while keeping useful context.
const ERROR_MESSAGE_UI_MAX_BYTES: usize = 2 * 1024; // 4 KiB
#[derive(Error, Debug)]
pub enum SandboxErr {
/// Error from sandbox execution
#[error(
"sandbox denied exec error, exit code: {}, stdout: {}, stderr: {}",
.output.exit_code, .output.stdout.text, .output.stderr.text
)]
Denied { output: Box<ExecToolCallOutput> },
/// Error from linux seccomp filter setup
#[cfg(target_os = "linux")]
#[error("seccomp setup error")]
SeccompInstall(#[from] seccompiler::Error),
/// Error from linux seccomp backend
#[cfg(target_os = "linux")]
#[error("seccomp backend error")]
SeccompBackend(#[from] seccompiler::BackendError),
/// Command timed out
#[error("command timed out")]
Timeout { output: Box<ExecToolCallOutput> },
/// Command was killed by a signal
#[error("command was killed by a signal")]
Signal(i32),
/// Error from linux landlock
#[error("Landlock was not able to fully enforce all sandbox rules")]
LandlockRestrict,
}
#[derive(Error, Debug)]
pub enum LlmxErr {
// todo(aibrahim): git rid of this error carrying the dangling artifacts
#[error("turn aborted. Something went wrong? Hit `/feedback` to report the issue.")]
TurnAborted {
dangling_artifacts: Vec<ProcessedResponseItem>,
},
/// Returned by ResponsesClient when the SSE stream disconnects or errors out **after** the HTTP
/// handshake has succeeded but **before** it finished emitting `response.completed`.
///
/// The Session loop treats this as a transient error and will automatically retry the turn.
///
/// Optionally includes the requested delay before retrying the turn.
#[error("stream disconnected before completion: {0}")]
Stream(String, Option<Duration>),
#[error(
"LLMX ran out of room in the model's context window. Start a new conversation or clear earlier history before retrying."
)]
ContextWindowExceeded,
#[error("no conversation with id: {0}")]
ConversationNotFound(ConversationId),
#[error("session configured event was not the first event in the stream")]
SessionConfiguredNotFirstEvent,
/// Returned by run_command_stream when the spawned child process timed out (10s).
#[error("timeout waiting for child process to exit")]
Timeout,
/// Returned by run_command_stream when the child could not be spawned (its stdout/stderr pipes
/// could not be captured). Analogous to the previous `LlmxError::Spawn` variant.
#[error("spawn failed: child stdout/stderr not captured")]
Spawn,
/// Returned by run_command_stream when the user pressed CtrlC (SIGINT). Session uses this to
/// surface a polite FunctionCallOutput back to the model instead of crashing the CLI.
#[error("interrupted (Ctrl-C). Something went wrong? Hit `/feedback` to report the issue.")]
Interrupted,
/// Unexpected HTTP status code.
#[error("{0}")]
UnexpectedStatus(UnexpectedResponseError),
#[error("{0}")]
UsageLimitReached(UsageLimitReachedError),
#[error("{0}")]
ResponseStreamFailed(ResponseStreamFailed),
#[error("{0}")]
ConnectionFailed(ConnectionFailedError),
#[error("Quota exceeded. Check your plan and billing details.")]
QuotaExceeded,
#[error(
"To use LLMX with your ChatGPT plan, upgrade to Plus: https://openai.com/chatgpt/pricing."
)]
UsageNotIncluded,
#[error("We're currently experiencing high demand, which may cause temporary errors.")]
InternalServerError,
/// Retry limit exceeded.
#[error("{0}")]
RetryLimit(RetryLimitReachedError),
/// Agent loop died unexpectedly
#[error("internal error; agent loop died unexpectedly")]
InternalAgentDied,
/// Sandbox error
#[error("sandbox error: {0}")]
Sandbox(#[from] SandboxErr),
#[error("llmx-linux-sandbox was required but not provided")]
LandlockSandboxExecutableNotProvided,
#[error("unsupported operation: {0}")]
UnsupportedOperation(String),
#[error("{0}")]
RefreshTokenFailed(RefreshTokenFailedError),
#[error("Fatal error: {0}")]
Fatal(String),
// -----------------------------------------------------------------
// Automatic conversions for common external error types
// -----------------------------------------------------------------
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
Json(#[from] serde_json::Error),
#[cfg(target_os = "linux")]
#[error(transparent)]
LandlockRuleset(#[from] landlock::RulesetError),
#[cfg(target_os = "linux")]
#[error(transparent)]
LandlockPathFd(#[from] landlock::PathFdError),
#[error(transparent)]
TokioJoin(#[from] JoinError),
#[error("{0}")]
EnvVar(EnvVarError),
}
impl From<CancelErr> for LlmxErr {
fn from(_: CancelErr) -> Self {
LlmxErr::TurnAborted {
dangling_artifacts: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct ConnectionFailedError {
pub source: reqwest::Error,
}
impl std::fmt::Display for ConnectionFailedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Connection failed: {}", self.source)
}
}
#[derive(Debug)]
pub struct ResponseStreamFailed {
pub source: reqwest::Error,
pub request_id: Option<String>,
}
impl std::fmt::Display for ResponseStreamFailed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Error while reading the server response: {}{}",
self.source,
self.request_id
.as_ref()
.map(|id| format!(", request id: {id}"))
.unwrap_or_default()
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[error("{message}")]
pub struct RefreshTokenFailedError {
pub reason: RefreshTokenFailedReason,
pub message: String,
}
impl RefreshTokenFailedError {
pub fn new(reason: RefreshTokenFailedReason, message: impl Into<String>) -> Self {
Self {
reason,
message: message.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RefreshTokenFailedReason {
Expired,
Exhausted,
Revoked,
Other,
}
#[derive(Debug)]
pub struct UnexpectedResponseError {
pub status: StatusCode,
pub body: String,
pub request_id: Option<String>,
}
const CLOUDFLARE_BLOCKED_MESSAGE: &str =
"Access blocked by Cloudflare. This usually happens when connecting from a restricted region";
impl UnexpectedResponseError {
fn friendly_message(&self) -> Option<String> {
if self.status != StatusCode::FORBIDDEN {
return None;
}
if !self.body.contains("Cloudflare") || !self.body.contains("blocked") {
return None;
}
let mut message = format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {})", self.status);
if let Some(id) = &self.request_id {
message.push_str(&format!(", request id: {id}"));
}
Some(message)
}
}
impl std::fmt::Display for UnexpectedResponseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(friendly) = self.friendly_message() {
write!(f, "{friendly}")
} else {
write!(
f,
"unexpected status {}: {}{}",
self.status,
self.body,
self.request_id
.as_ref()
.map(|id| format!(", request id: {id}"))
.unwrap_or_default()
)
}
}
}
impl std::error::Error for UnexpectedResponseError {}
#[derive(Debug)]
pub struct RetryLimitReachedError {
pub status: StatusCode,
pub request_id: Option<String>,
}
impl std::fmt::Display for RetryLimitReachedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"exceeded retry limit, last status: {}{}",
self.status,
self.request_id
.as_ref()
.map(|id| format!(", request id: {id}"))
.unwrap_or_default()
)
}
}
#[derive(Debug)]
pub struct UsageLimitReachedError {
pub(crate) plan_type: Option<PlanType>,
pub(crate) resets_at: Option<DateTime<Utc>>,
pub(crate) rate_limits: Option<RateLimitSnapshot>,
}
impl std::fmt::Display for UsageLimitReachedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let message = match self.plan_type.as_ref() {
Some(PlanType::Known(KnownPlan::Plus)) => format!(
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/llmx/settings/usage to purchase more credits{}",
retry_suffix_after_or(self.resets_at.as_ref())
),
Some(PlanType::Known(KnownPlan::Team)) | Some(PlanType::Known(KnownPlan::Business)) => {
format!(
"You've hit your usage limit. To get more access now, send a request to your admin{}",
retry_suffix_after_or(self.resets_at.as_ref())
)
}
Some(PlanType::Known(KnownPlan::Free)) => {
"You've hit your usage limit. Upgrade to Plus to continue using LLMX (https://openai.com/chatgpt/pricing)."
.to_string()
}
Some(PlanType::Known(KnownPlan::Pro)) => format!(
"You've hit your usage limit. Visit https://chatgpt.com/llmx/settings/usage to purchase more credits{}",
retry_suffix_after_or(self.resets_at.as_ref())
),
Some(PlanType::Known(KnownPlan::Enterprise))
| Some(PlanType::Known(KnownPlan::Edu)) => format!(
"You've hit your usage limit.{}",
retry_suffix(self.resets_at.as_ref())
),
Some(PlanType::Unknown(_)) | None => format!(
"You've hit your usage limit.{}",
retry_suffix(self.resets_at.as_ref())
),
};
write!(f, "{message}")
}
}
fn retry_suffix(resets_at: Option<&DateTime<Utc>>) -> String {
if let Some(resets_at) = resets_at {
let formatted = format_retry_timestamp(resets_at);
format!(" Try again at {formatted}.")
} else {
" Try again later.".to_string()
}
}
fn retry_suffix_after_or(resets_at: Option<&DateTime<Utc>>) -> String {
if let Some(resets_at) = resets_at {
let formatted = format_retry_timestamp(resets_at);
format!(" or try again at {formatted}.")
} else {
" or try again later.".to_string()
}
}
fn format_retry_timestamp(resets_at: &DateTime<Utc>) -> String {
let local_reset = resets_at.with_timezone(&Local);
let local_now = now_for_retry().with_timezone(&Local);
if local_reset.date_naive() == local_now.date_naive() {
local_reset.format("%-I:%M %p").to_string()
} else {
let suffix = day_suffix(local_reset.day());
local_reset
.format(&format!("%b %-d{suffix}, %Y %-I:%M %p"))
.to_string()
}
}
fn day_suffix(day: u32) -> &'static str {
match day {
11..=13 => "th",
_ => match day % 10 {
1 => "st",
2 => "nd", // codespell:ignore
3 => "rd",
_ => "th",
},
}
}
#[cfg(test)]
thread_local! {
static NOW_OVERRIDE: std::cell::RefCell<Option<DateTime<Utc>>> =
const { std::cell::RefCell::new(None) };
}
fn now_for_retry() -> DateTime<Utc> {
#[cfg(test)]
{
if let Some(now) = NOW_OVERRIDE.with(|cell| *cell.borrow()) {
return now;
}
}
Utc::now()
}
#[derive(Debug)]
pub struct EnvVarError {
/// Name of the environment variable that is missing.
pub var: String,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub instructions: Option<String>,
}
impl std::fmt::Display for EnvVarError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Missing environment variable: `{}`.", self.var)?;
if let Some(instructions) = &self.instructions {
write!(f, " {instructions}")?;
}
Ok(())
}
}
impl LlmxErr {
/// Minimal shim so that existing `e.downcast_ref::<LlmxErr>()` checks continue to compile
/// after replacing `anyhow::Error` in the return signature. This mirrors the behavior of
/// `anyhow::Error::downcast_ref` but works directly on our concrete enum.
pub fn downcast_ref<T: std::any::Any>(&self) -> Option<&T> {
(self as &dyn std::any::Any).downcast_ref::<T>()
}
}
pub fn get_error_message_ui(e: &LlmxErr) -> String {
let message = match e {
LlmxErr::Sandbox(SandboxErr::Denied { output }) => {
let aggregated = output.aggregated_output.text.trim();
if !aggregated.is_empty() {
output.aggregated_output.text.clone()
} else {
let stderr = output.stderr.text.trim();
let stdout = output.stdout.text.trim();
match (stderr.is_empty(), stdout.is_empty()) {
(false, false) => format!("{stderr}\n{stdout}"),
(false, true) => output.stderr.text.clone(),
(true, false) => output.stdout.text.clone(),
(true, true) => format!(
"command failed inside sandbox with exit code {}",
output.exit_code
),
}
}
}
// Timeouts are not sandbox errors from a UX perspective; present them plainly
LlmxErr::Sandbox(SandboxErr::Timeout { output }) => {
format!(
"error: command timed out after {} ms",
output.duration.as_millis()
)
}
_ => e.to_string(),
};
truncate_middle(&message, ERROR_MESSAGE_UI_MAX_BYTES).0
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exec::StreamOutput;
use chrono::DateTime;
use chrono::Duration as ChronoDuration;
use chrono::TimeZone;
use chrono::Utc;
use llmx_protocol::protocol::RateLimitWindow;
use pretty_assertions::assert_eq;
fn rate_limit_snapshot() -> RateLimitSnapshot {
let primary_reset_at = Utc
.with_ymd_and_hms(2024, 1, 1, 1, 0, 0)
.unwrap()
.timestamp();
let secondary_reset_at = Utc
.with_ymd_and_hms(2024, 1, 1, 2, 0, 0)
.unwrap()
.timestamp();
RateLimitSnapshot {
primary: Some(RateLimitWindow {
used_percent: 50.0,
window_minutes: Some(60),
resets_at: Some(primary_reset_at),
}),
secondary: Some(RateLimitWindow {
used_percent: 30.0,
window_minutes: Some(120),
resets_at: Some(secondary_reset_at),
}),
}
}
fn with_now_override<T>(now: DateTime<Utc>, f: impl FnOnce() -> T) -> T {
NOW_OVERRIDE.with(|cell| {
*cell.borrow_mut() = Some(now);
let result = f();
*cell.borrow_mut() = None;
result
})
}
#[test]
fn usage_limit_reached_error_formats_plus_plan() {
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
resets_at: None,
rate_limits: Some(rate_limit_snapshot()),
};
assert_eq!(
err.to_string(),
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/llmx/settings/usage to purchase more credits or try again later."
);
}
#[test]
fn sandbox_denied_uses_aggregated_output_when_stderr_empty() {
let output = ExecToolCallOutput {
exit_code: 77,
stdout: StreamOutput::new(String::new()),
stderr: StreamOutput::new(String::new()),
aggregated_output: StreamOutput::new("aggregate detail".to_string()),
duration: Duration::from_millis(10),
timed_out: false,
};
let err = LlmxErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
});
assert_eq!(get_error_message_ui(&err), "aggregate detail");
}
#[test]
fn sandbox_denied_reports_both_streams_when_available() {
let output = ExecToolCallOutput {
exit_code: 9,
stdout: StreamOutput::new("stdout detail".to_string()),
stderr: StreamOutput::new("stderr detail".to_string()),
aggregated_output: StreamOutput::new(String::new()),
duration: Duration::from_millis(10),
timed_out: false,
};
let err = LlmxErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
});
assert_eq!(get_error_message_ui(&err), "stderr detail\nstdout detail");
}
#[test]
fn sandbox_denied_reports_stdout_when_no_stderr() {
let output = ExecToolCallOutput {
exit_code: 11,
stdout: StreamOutput::new("stdout only".to_string()),
stderr: StreamOutput::new(String::new()),
aggregated_output: StreamOutput::new(String::new()),
duration: Duration::from_millis(8),
timed_out: false,
};
let err = LlmxErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
});
assert_eq!(get_error_message_ui(&err), "stdout only");
}
#[test]
fn sandbox_denied_reports_exit_code_when_no_output_available() {
let output = ExecToolCallOutput {
exit_code: 13,
stdout: StreamOutput::new(String::new()),
stderr: StreamOutput::new(String::new()),
aggregated_output: StreamOutput::new(String::new()),
duration: Duration::from_millis(5),
timed_out: false,
};
let err = LlmxErr::Sandbox(SandboxErr::Denied {
output: Box::new(output),
});
assert_eq!(
get_error_message_ui(&err),
"command failed inside sandbox with exit code 13"
);
}
#[test]
fn usage_limit_reached_error_formats_free_plan() {
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Free)),
resets_at: None,
rate_limits: Some(rate_limit_snapshot()),
};
assert_eq!(
err.to_string(),
"You've hit your usage limit. Upgrade to Plus to continue using LLMX (https://openai.com/chatgpt/pricing)."
);
}
#[test]
fn usage_limit_reached_error_formats_default_when_none() {
let err = UsageLimitReachedError {
plan_type: None,
resets_at: None,
rate_limits: Some(rate_limit_snapshot()),
};
assert_eq!(
err.to_string(),
"You've hit your usage limit. Try again later."
);
}
#[test]
fn usage_limit_reached_error_formats_team_plan() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at = base + ChronoDuration::hours(1);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Team)),
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!(
"You've hit your usage limit. To get more access now, send a request to your admin or try again at {expected_time}."
);
assert_eq!(err.to_string(), expected);
});
}
#[test]
fn usage_limit_reached_error_formats_business_plan_without_reset() {
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Business)),
resets_at: None,
rate_limits: Some(rate_limit_snapshot()),
};
assert_eq!(
err.to_string(),
"You've hit your usage limit. To get more access now, send a request to your admin or try again later."
);
}
#[test]
fn usage_limit_reached_error_formats_default_for_other_plans() {
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Enterprise)),
resets_at: None,
rate_limits: Some(rate_limit_snapshot()),
};
assert_eq!(
err.to_string(),
"You've hit your usage limit. Try again later."
);
}
#[test]
fn usage_limit_reached_error_formats_pro_plan_with_reset() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at = base + ChronoDuration::hours(1);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Pro)),
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!(
"You've hit your usage limit. Visit https://chatgpt.com/llmx/settings/usage to purchase more credits or try again at {expected_time}."
);
assert_eq!(err.to_string(), expected);
});
}
#[test]
fn usage_limit_reached_includes_minutes_when_available() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at = base + ChronoDuration::minutes(5);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: None,
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
assert_eq!(err.to_string(), expected);
});
}
#[test]
fn unexpected_status_cloudflare_html_is_simplified() {
let err = UnexpectedResponseError {
status: StatusCode::FORBIDDEN,
body: "<html><body>Cloudflare error: Sorry, you have been blocked</body></html>"
.to_string(),
request_id: Some("ray-id".to_string()),
};
let status = StatusCode::FORBIDDEN.to_string();
assert_eq!(
err.to_string(),
format!("{CLOUDFLARE_BLOCKED_MESSAGE} (status {status}), request id: ray-id")
);
}
#[test]
fn unexpected_status_non_html_is_unchanged() {
let err = UnexpectedResponseError {
status: StatusCode::FORBIDDEN,
body: "plain text error".to_string(),
request_id: None,
};
let status = StatusCode::FORBIDDEN.to_string();
assert_eq!(
err.to_string(),
format!("unexpected status {status}: plain text error")
);
}
#[test]
fn usage_limit_reached_includes_hours_and_minutes() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at = base + ChronoDuration::hours(3) + ChronoDuration::minutes(32);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: Some(PlanType::Known(KnownPlan::Plus)),
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!(
"You've hit your usage limit. Upgrade to Pro (https://openai.com/chatgpt/pricing), visit https://chatgpt.com/llmx/settings/usage to purchase more credits or try again at {expected_time}."
);
assert_eq!(err.to_string(), expected);
});
}
#[test]
fn usage_limit_reached_includes_days_hours_minutes() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at =
base + ChronoDuration::days(2) + ChronoDuration::hours(3) + ChronoDuration::minutes(5);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: None,
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
assert_eq!(err.to_string(), expected);
});
}
#[test]
fn usage_limit_reached_less_than_minute() {
let base = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
let resets_at = base + ChronoDuration::seconds(30);
with_now_override(base, move || {
let expected_time = format_retry_timestamp(&resets_at);
let err = UsageLimitReachedError {
plan_type: None,
resets_at: Some(resets_at),
rate_limits: Some(rate_limit_snapshot()),
};
let expected = format!("You've hit your usage limit. Try again at {expected_time}.");
assert_eq!(err.to_string(), expected);
});
}
}

View File

@@ -0,0 +1,323 @@
use llmx_protocol::items::AgentMessageContent;
use llmx_protocol::items::AgentMessageItem;
use llmx_protocol::items::ReasoningItem;
use llmx_protocol::items::TurnItem;
use llmx_protocol::items::UserMessageItem;
use llmx_protocol::items::WebSearchItem;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ReasoningItemContent;
use llmx_protocol::models::ReasoningItemReasoningSummary;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::models::WebSearchAction;
use llmx_protocol::user_input::UserInput;
use tracing::warn;
use uuid::Uuid;
use crate::user_instructions::UserInstructions;
use crate::user_shell_command::is_user_shell_command_text;
fn is_session_prefix(text: &str) -> bool {
let trimmed = text.trim_start();
let lowered = trimmed.to_ascii_lowercase();
lowered.starts_with("<environment_context>")
}
fn parse_user_message(message: &[ContentItem]) -> Option<UserMessageItem> {
if UserInstructions::is_user_instructions(message) {
return None;
}
let mut content: Vec<UserInput> = Vec::new();
for content_item in message.iter() {
match content_item {
ContentItem::InputText { text } => {
if is_session_prefix(text) || is_user_shell_command_text(text) {
return None;
}
content.push(UserInput::Text { text: text.clone() });
}
ContentItem::InputImage { image_url } => {
content.push(UserInput::Image {
image_url: image_url.clone(),
});
}
ContentItem::OutputText { text } => {
if is_session_prefix(text) {
return None;
}
warn!("Output text in user message: {}", text);
}
}
}
Some(UserMessageItem::new(&content))
}
fn parse_agent_message(id: Option<&String>, message: &[ContentItem]) -> AgentMessageItem {
let mut content: Vec<AgentMessageContent> = Vec::new();
for content_item in message.iter() {
match content_item {
ContentItem::OutputText { text } => {
content.push(AgentMessageContent::Text { text: text.clone() });
}
_ => {
warn!(
"Unexpected content item in agent message: {:?}",
content_item
);
}
}
}
let id = id.cloned().unwrap_or_else(|| Uuid::new_v4().to_string());
AgentMessageItem { id, content }
}
pub fn parse_turn_item(item: &ResponseItem) -> Option<TurnItem> {
match item {
ResponseItem::Message { role, content, id } => match role.as_str() {
"user" => parse_user_message(content).map(TurnItem::UserMessage),
"assistant" => Some(TurnItem::AgentMessage(parse_agent_message(
id.as_ref(),
content,
))),
"system" => None,
_ => None,
},
ResponseItem::Reasoning {
id,
summary,
content,
..
} => {
let summary_text = summary
.iter()
.map(|entry| match entry {
ReasoningItemReasoningSummary::SummaryText { text } => text.clone(),
})
.collect();
let raw_content = content
.clone()
.unwrap_or_default()
.into_iter()
.map(|entry| match entry {
ReasoningItemContent::ReasoningText { text }
| ReasoningItemContent::Text { text } => text,
})
.collect();
Some(TurnItem::Reasoning(ReasoningItem {
id: id.clone(),
summary_text,
raw_content,
}))
}
ResponseItem::WebSearchCall {
id,
action: WebSearchAction::Search { query },
..
} => Some(TurnItem::WebSearch(WebSearchItem {
id: id.clone().unwrap_or_default(),
query: query.clone(),
})),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::parse_turn_item;
use llmx_protocol::items::AgentMessageContent;
use llmx_protocol::items::TurnItem;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ReasoningItemContent;
use llmx_protocol::models::ReasoningItemReasoningSummary;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::models::WebSearchAction;
use llmx_protocol::user_input::UserInput;
use pretty_assertions::assert_eq;
#[test]
fn parses_user_message_with_text_and_two_images() {
let img1 = "https://example.com/one.png".to_string();
let img2 = "https://example.com/two.jpg".to_string();
let item = ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![
ContentItem::InputText {
text: "Hello world".to_string(),
},
ContentItem::InputImage {
image_url: img1.clone(),
},
ContentItem::InputImage {
image_url: img2.clone(),
},
],
};
let turn_item = parse_turn_item(&item).expect("expected user message turn item");
match turn_item {
TurnItem::UserMessage(user) => {
let expected_content = vec![
UserInput::Text {
text: "Hello world".to_string(),
},
UserInput::Image { image_url: img1 },
UserInput::Image { image_url: img2 },
];
assert_eq!(user.content, expected_content);
}
other => panic!("expected TurnItem::UserMessage, got {other:?}"),
}
}
#[test]
fn skips_user_instructions_and_env() {
let items = vec![
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<user_instructions>test_text</user_instructions>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<environment_context>test_text</environment_context>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>".to_string(),
}],
},
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText {
text: "<user_shell_command>echo 42</user_shell_command>".to_string(),
}],
},
];
for item in items {
let turn_item = parse_turn_item(&item);
assert!(turn_item.is_none(), "expected none, got {turn_item:?}");
}
}
#[test]
fn parses_agent_message() {
let item = ResponseItem::Message {
id: Some("msg-1".to_string()),
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: "Hello from LLMX".to_string(),
}],
};
let turn_item = parse_turn_item(&item).expect("expected agent message turn item");
match turn_item {
TurnItem::AgentMessage(message) => {
let Some(AgentMessageContent::Text { text }) = message.content.first() else {
panic!("expected agent message text content");
};
assert_eq!(text, "Hello from LLMX");
}
other => panic!("expected TurnItem::AgentMessage, got {other:?}"),
}
}
#[test]
fn parses_reasoning_summary_and_raw_content() {
let item = ResponseItem::Reasoning {
id: "reasoning_1".to_string(),
summary: vec![
ReasoningItemReasoningSummary::SummaryText {
text: "Step 1".to_string(),
},
ReasoningItemReasoningSummary::SummaryText {
text: "Step 2".to_string(),
},
],
content: Some(vec![ReasoningItemContent::ReasoningText {
text: "raw details".to_string(),
}]),
encrypted_content: None,
};
let turn_item = parse_turn_item(&item).expect("expected reasoning turn item");
match turn_item {
TurnItem::Reasoning(reasoning) => {
assert_eq!(
reasoning.summary_text,
vec!["Step 1".to_string(), "Step 2".to_string()]
);
assert_eq!(reasoning.raw_content, vec!["raw details".to_string()]);
}
other => panic!("expected TurnItem::Reasoning, got {other:?}"),
}
}
#[test]
fn parses_reasoning_including_raw_content() {
let item = ResponseItem::Reasoning {
id: "reasoning_2".to_string(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "Summarized step".to_string(),
}],
content: Some(vec![
ReasoningItemContent::ReasoningText {
text: "raw step".to_string(),
},
ReasoningItemContent::Text {
text: "final thought".to_string(),
},
]),
encrypted_content: None,
};
let turn_item = parse_turn_item(&item).expect("expected reasoning turn item");
match turn_item {
TurnItem::Reasoning(reasoning) => {
assert_eq!(reasoning.summary_text, vec!["Summarized step".to_string()]);
assert_eq!(
reasoning.raw_content,
vec!["raw step".to_string(), "final thought".to_string()]
);
}
other => panic!("expected TurnItem::Reasoning, got {other:?}"),
}
}
#[test]
fn parses_web_search_call() {
let item = ResponseItem::WebSearchCall {
id: Some("ws_1".to_string()),
status: Some("completed".to_string()),
action: WebSearchAction::Search {
query: "weather".to_string(),
},
};
let turn_item = parse_turn_item(&item).expect("expected web search turn item");
match turn_item {
TurnItem::WebSearch(search) => {
assert_eq!(search.id, "ws_1");
assert_eq!(search.query, "weather");
}
other => panic!("expected TurnItem::WebSearch, got {other:?}"),
}
}
}

777
llmx-rs/core/src/exec.rs Normal file
View File

@@ -0,0 +1,777 @@
#[cfg(unix)]
use std::os::unix::process::ExitStatusExt;
use std::collections::HashMap;
use std::io;
use std::path::Path;
use std::path::PathBuf;
use std::process::ExitStatus;
use std::time::Duration;
use std::time::Instant;
use async_channel::Sender;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
use tokio::io::BufReader;
use tokio::process::Child;
use crate::error::LlmxErr;
use crate::error::Result;
use crate::error::SandboxErr;
use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandOutputDeltaEvent;
use crate::protocol::ExecOutputStream;
use crate::protocol::SandboxPolicy;
use crate::sandboxing::CommandSpec;
use crate::sandboxing::ExecEnv;
use crate::sandboxing::SandboxManager;
use crate::spawn::StdioPolicy;
use crate::spawn::spawn_child_async;
const DEFAULT_TIMEOUT_MS: u64 = 10_000;
// Hardcode these since it does not seem worth including the libc crate just
// for these.
const SIGKILL_CODE: i32 = 9;
const TIMEOUT_CODE: i32 = 64;
const EXIT_CODE_SIGNAL_BASE: i32 = 128; // conventional shell: 128 + signal
const EXEC_TIMEOUT_EXIT_CODE: i32 = 124; // conventional timeout exit code
// I/O buffer sizing
const READ_CHUNK_SIZE: usize = 8192; // bytes per read
const AGGREGATE_BUFFER_INITIAL_CAPACITY: usize = 8 * 1024; // 8 KiB
/// Limit the number of ExecCommandOutputDelta events emitted per exec call.
/// Aggregation still collects full output; only the live event stream is capped.
pub(crate) const MAX_EXEC_OUTPUT_DELTAS_PER_CALL: usize = 10_000;
#[derive(Clone, Debug)]
pub struct ExecParams {
pub command: Vec<String>,
pub cwd: PathBuf,
pub timeout_ms: Option<u64>,
pub env: HashMap<String, String>,
pub with_escalated_permissions: Option<bool>,
pub justification: Option<String>,
pub arg0: Option<String>,
}
impl ExecParams {
pub fn timeout_duration(&self) -> Duration {
Duration::from_millis(self.timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS))
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SandboxType {
None,
/// Only available on macOS.
MacosSeatbelt,
/// Only available on Linux.
LinuxSeccomp,
/// Only available on Windows.
WindowsRestrictedToken,
}
#[derive(Clone)]
pub struct StdoutStream {
pub sub_id: String,
pub call_id: String,
pub tx_event: Sender<Event>,
}
pub async fn process_exec_tool_call(
params: ExecParams,
sandbox_type: SandboxType,
sandbox_policy: &SandboxPolicy,
sandbox_cwd: &Path,
llmx_linux_sandbox_exe: &Option<PathBuf>,
stdout_stream: Option<StdoutStream>,
) -> Result<ExecToolCallOutput> {
let ExecParams {
command,
cwd,
timeout_ms,
env,
with_escalated_permissions,
justification,
arg0: _,
} = params;
let (program, args) = command.split_first().ok_or_else(|| {
LlmxErr::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"command args are empty",
))
})?;
let spec = CommandSpec {
program: program.clone(),
args: args.to_vec(),
cwd,
env,
timeout_ms,
with_escalated_permissions,
justification,
};
let manager = SandboxManager::new();
let exec_env = manager
.transform(
&spec,
sandbox_policy,
sandbox_type,
sandbox_cwd,
llmx_linux_sandbox_exe.as_ref(),
)
.map_err(LlmxErr::from)?;
// Route through the sandboxing module for a single, unified execution path.
crate::sandboxing::execute_env(&exec_env, sandbox_policy, stdout_stream).await
}
pub(crate) async fn execute_exec_env(
env: ExecEnv,
sandbox_policy: &SandboxPolicy,
stdout_stream: Option<StdoutStream>,
) -> Result<ExecToolCallOutput> {
let ExecEnv {
command,
cwd,
env,
timeout_ms,
sandbox,
with_escalated_permissions,
justification,
arg0,
} = env;
let params = ExecParams {
command,
cwd,
timeout_ms,
env,
with_escalated_permissions,
justification,
arg0,
};
let start = Instant::now();
let raw_output_result = exec(params, sandbox, sandbox_policy, stdout_stream).await;
let duration = start.elapsed();
finalize_exec_result(raw_output_result, sandbox, duration)
}
#[cfg(target_os = "windows")]
async fn exec_windows_sandbox(
params: ExecParams,
sandbox_policy: &SandboxPolicy,
) -> Result<RawExecToolCallOutput> {
use crate::config::find_llmx_home;
use llmx_windows_sandbox::run_windows_sandbox_capture;
let ExecParams {
command,
cwd,
env,
timeout_ms,
..
} = params;
let policy_str = match sandbox_policy {
SandboxPolicy::DangerFullAccess => "workspace-write",
SandboxPolicy::ReadOnly => "read-only",
SandboxPolicy::WorkspaceWrite { .. } => "workspace-write",
};
let sandbox_cwd = cwd.clone();
let logs_base_dir = find_llmx_home().ok();
let spawn_res = tokio::task::spawn_blocking(move || {
run_windows_sandbox_capture(
policy_str,
&sandbox_cwd,
command,
&cwd,
env,
timeout_ms,
logs_base_dir.as_deref(),
)
})
.await;
let capture = match spawn_res {
Ok(Ok(v)) => v,
Ok(Err(err)) => {
return Err(LlmxErr::Io(io::Error::other(format!(
"windows sandbox: {err}"
))));
}
Err(join_err) => {
return Err(LlmxErr::Io(io::Error::other(format!(
"windows sandbox join error: {join_err}"
))));
}
};
let exit_status = synthetic_exit_status(capture.exit_code);
let stdout = StreamOutput {
text: capture.stdout,
truncated_after_lines: None,
};
let stderr = StreamOutput {
text: capture.stderr,
truncated_after_lines: None,
};
// Best-effort aggregate: stdout then stderr
let mut aggregated = Vec::with_capacity(stdout.text.len() + stderr.text.len());
append_all(&mut aggregated, &stdout.text);
append_all(&mut aggregated, &stderr.text);
let aggregated_output = StreamOutput {
text: aggregated,
truncated_after_lines: None,
};
Ok(RawExecToolCallOutput {
exit_status,
stdout,
stderr,
aggregated_output,
timed_out: capture.timed_out,
})
}
fn finalize_exec_result(
raw_output_result: std::result::Result<RawExecToolCallOutput, LlmxErr>,
sandbox_type: SandboxType,
duration: Duration,
) -> Result<ExecToolCallOutput> {
match raw_output_result {
Ok(raw_output) => {
#[allow(unused_mut)]
let mut timed_out = raw_output.timed_out;
#[cfg(target_family = "unix")]
{
if let Some(signal) = raw_output.exit_status.signal() {
if signal == TIMEOUT_CODE {
timed_out = true;
} else {
return Err(LlmxErr::Sandbox(SandboxErr::Signal(signal)));
}
}
}
let mut exit_code = raw_output.exit_status.code().unwrap_or(-1);
if timed_out {
exit_code = EXEC_TIMEOUT_EXIT_CODE;
}
let stdout = raw_output.stdout.from_utf8_lossy();
let stderr = raw_output.stderr.from_utf8_lossy();
let aggregated_output = raw_output.aggregated_output.from_utf8_lossy();
let exec_output = ExecToolCallOutput {
exit_code,
stdout,
stderr,
aggregated_output,
duration,
timed_out,
};
if timed_out {
return Err(LlmxErr::Sandbox(SandboxErr::Timeout {
output: Box::new(exec_output),
}));
}
if is_likely_sandbox_denied(sandbox_type, &exec_output) {
return Err(LlmxErr::Sandbox(SandboxErr::Denied {
output: Box::new(exec_output),
}));
}
Ok(exec_output)
}
Err(err) => {
tracing::error!("exec error: {err}");
Err(err)
}
}
}
pub(crate) mod errors {
use super::LlmxErr;
use crate::sandboxing::SandboxTransformError;
impl From<SandboxTransformError> for LlmxErr {
fn from(err: SandboxTransformError) -> Self {
match err {
SandboxTransformError::MissingLinuxSandboxExecutable => {
LlmxErr::LandlockSandboxExecutableNotProvided
}
#[cfg(not(target_os = "macos"))]
SandboxTransformError::SeatbeltUnavailable => LlmxErr::UnsupportedOperation(
"seatbelt sandbox is only available on macOS".to_string(),
),
}
}
}
}
/// We don't have a fully deterministic way to tell if our command failed
/// because of the sandbox - a command in the user's zshrc file might hit an
/// error, but the command itself might fail or succeed for other reasons.
/// For now, we conservatively check for well known command failure exit codes and
/// also look for common sandbox denial keywords in the command output.
pub(crate) fn is_likely_sandbox_denied(
sandbox_type: SandboxType,
exec_output: &ExecToolCallOutput,
) -> bool {
if sandbox_type == SandboxType::None || exec_output.exit_code == 0 {
return false;
}
// Quick rejects: well-known non-sandbox shell exit codes
// 2: misuse of shell builtins
// 126: permission denied
// 127: command not found
const SANDBOX_DENIED_KEYWORDS: [&str; 7] = [
"operation not permitted",
"permission denied",
"read-only file system",
"seccomp",
"sandbox",
"landlock",
"failed to write file",
];
let has_sandbox_keyword = [
&exec_output.stderr.text,
&exec_output.stdout.text,
&exec_output.aggregated_output.text,
]
.into_iter()
.any(|section| {
let lower = section.to_lowercase();
SANDBOX_DENIED_KEYWORDS
.iter()
.any(|needle| lower.contains(needle))
});
if has_sandbox_keyword {
return true;
}
const QUICK_REJECT_EXIT_CODES: [i32; 3] = [2, 126, 127];
if QUICK_REJECT_EXIT_CODES.contains(&exec_output.exit_code) {
return false;
}
#[cfg(unix)]
{
const SIGSYS_CODE: i32 = libc::SIGSYS;
if sandbox_type == SandboxType::LinuxSeccomp
&& exec_output.exit_code == EXIT_CODE_SIGNAL_BASE + SIGSYS_CODE
{
return true;
}
}
false
}
#[derive(Debug, Clone)]
pub struct StreamOutput<T: Clone> {
pub text: T,
pub truncated_after_lines: Option<u32>,
}
#[derive(Debug)]
struct RawExecToolCallOutput {
pub exit_status: ExitStatus,
pub stdout: StreamOutput<Vec<u8>>,
pub stderr: StreamOutput<Vec<u8>>,
pub aggregated_output: StreamOutput<Vec<u8>>,
pub timed_out: bool,
}
impl StreamOutput<String> {
pub fn new(text: String) -> Self {
Self {
text,
truncated_after_lines: None,
}
}
}
impl StreamOutput<Vec<u8>> {
pub fn from_utf8_lossy(&self) -> StreamOutput<String> {
StreamOutput {
text: String::from_utf8_lossy(&self.text).to_string(),
truncated_after_lines: self.truncated_after_lines,
}
}
}
#[inline]
fn append_all(dst: &mut Vec<u8>, src: &[u8]) {
dst.extend_from_slice(src);
}
#[derive(Clone, Debug)]
pub struct ExecToolCallOutput {
pub exit_code: i32,
pub stdout: StreamOutput<String>,
pub stderr: StreamOutput<String>,
pub aggregated_output: StreamOutput<String>,
pub duration: Duration,
pub timed_out: bool,
}
#[cfg_attr(not(target_os = "windows"), allow(unused_variables))]
async fn exec(
params: ExecParams,
sandbox: SandboxType,
sandbox_policy: &SandboxPolicy,
stdout_stream: Option<StdoutStream>,
) -> Result<RawExecToolCallOutput> {
#[cfg(target_os = "windows")]
if sandbox == SandboxType::WindowsRestrictedToken {
return exec_windows_sandbox(params, sandbox_policy).await;
}
let timeout = params.timeout_duration();
let ExecParams {
command,
cwd,
env,
arg0,
..
} = params;
let (program, args) = command.split_first().ok_or_else(|| {
LlmxErr::Io(io::Error::new(
io::ErrorKind::InvalidInput,
"command args are empty",
))
})?;
let arg0_ref = arg0.as_deref();
let child = spawn_child_async(
PathBuf::from(program),
args.into(),
arg0_ref,
cwd,
sandbox_policy,
StdioPolicy::RedirectForShellTool,
env,
)
.await?;
consume_truncated_output(child, timeout, stdout_stream).await
}
/// Consumes the output of a child process, truncating it so it is suitable for
/// use as the output of a `shell` tool call. Also enforces specified timeout.
async fn consume_truncated_output(
mut child: Child,
timeout: Duration,
stdout_stream: Option<StdoutStream>,
) -> Result<RawExecToolCallOutput> {
// Both stdout and stderr were configured with `Stdio::piped()`
// above, therefore `take()` should normally return `Some`. If it doesn't
// we treat it as an exceptional I/O error
let stdout_reader = child.stdout.take().ok_or_else(|| {
LlmxErr::Io(io::Error::other(
"stdout pipe was unexpectedly not available",
))
})?;
let stderr_reader = child.stderr.take().ok_or_else(|| {
LlmxErr::Io(io::Error::other(
"stderr pipe was unexpectedly not available",
))
})?;
let (agg_tx, agg_rx) = async_channel::unbounded::<Vec<u8>>();
let stdout_handle = tokio::spawn(read_capped(
BufReader::new(stdout_reader),
stdout_stream.clone(),
false,
Some(agg_tx.clone()),
));
let stderr_handle = tokio::spawn(read_capped(
BufReader::new(stderr_reader),
stdout_stream.clone(),
true,
Some(agg_tx.clone()),
));
let (exit_status, timed_out) = tokio::select! {
result = tokio::time::timeout(timeout, child.wait()) => {
match result {
Ok(status_result) => {
let exit_status = status_result?;
(exit_status, false)
}
Err(_) => {
// timeout
kill_child_process_group(&mut child)?;
child.start_kill()?;
// Debatable whether `child.wait().await` should be called here.
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + TIMEOUT_CODE), true)
}
}
}
_ = tokio::signal::ctrl_c() => {
kill_child_process_group(&mut child)?;
child.start_kill()?;
(synthetic_exit_status(EXIT_CODE_SIGNAL_BASE + SIGKILL_CODE), false)
}
};
let stdout = stdout_handle.await??;
let stderr = stderr_handle.await??;
drop(agg_tx);
let mut combined_buf = Vec::with_capacity(AGGREGATE_BUFFER_INITIAL_CAPACITY);
while let Ok(chunk) = agg_rx.recv().await {
append_all(&mut combined_buf, &chunk);
}
let aggregated_output = StreamOutput {
text: combined_buf,
truncated_after_lines: None,
};
Ok(RawExecToolCallOutput {
exit_status,
stdout,
stderr,
aggregated_output,
timed_out,
})
}
async fn read_capped<R: AsyncRead + Unpin + Send + 'static>(
mut reader: R,
stream: Option<StdoutStream>,
is_stderr: bool,
aggregate_tx: Option<Sender<Vec<u8>>>,
) -> io::Result<StreamOutput<Vec<u8>>> {
let mut buf = Vec::with_capacity(AGGREGATE_BUFFER_INITIAL_CAPACITY);
let mut tmp = [0u8; READ_CHUNK_SIZE];
let mut emitted_deltas: usize = 0;
// No caps: append all bytes
loop {
let n = reader.read(&mut tmp).await?;
if n == 0 {
break;
}
if let Some(stream) = &stream
&& emitted_deltas < MAX_EXEC_OUTPUT_DELTAS_PER_CALL
{
let chunk = tmp[..n].to_vec();
let msg = EventMsg::ExecCommandOutputDelta(ExecCommandOutputDeltaEvent {
call_id: stream.call_id.clone(),
stream: if is_stderr {
ExecOutputStream::Stderr
} else {
ExecOutputStream::Stdout
},
chunk,
});
let event = Event {
id: stream.sub_id.clone(),
msg,
};
#[allow(clippy::let_unit_value)]
let _ = stream.tx_event.send(event).await;
emitted_deltas += 1;
}
if let Some(tx) = &aggregate_tx {
let _ = tx.send(tmp[..n].to_vec()).await;
}
append_all(&mut buf, &tmp[..n]);
// Continue reading to EOF to avoid back-pressure
}
Ok(StreamOutput {
text: buf,
truncated_after_lines: None,
})
}
#[cfg(unix)]
fn synthetic_exit_status(code: i32) -> ExitStatus {
use std::os::unix::process::ExitStatusExt;
std::process::ExitStatus::from_raw(code)
}
#[cfg(windows)]
fn synthetic_exit_status(code: i32) -> ExitStatus {
use std::os::windows::process::ExitStatusExt;
// On Windows the raw status is a u32. Use a direct cast to avoid
// panicking on negative i32 values produced by prior narrowing casts.
std::process::ExitStatus::from_raw(code as u32)
}
#[cfg(unix)]
fn kill_child_process_group(child: &mut Child) -> io::Result<()> {
use std::io::ErrorKind;
if let Some(pid) = child.id() {
let pid = pid as libc::pid_t;
let pgid = unsafe { libc::getpgid(pid) };
if pgid == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
return Ok(());
}
let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
if result == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
}
}
Ok(())
}
#[cfg(not(unix))]
fn kill_child_process_group(_: &mut Child) -> io::Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn make_exec_output(
exit_code: i32,
stdout: &str,
stderr: &str,
aggregated: &str,
) -> ExecToolCallOutput {
ExecToolCallOutput {
exit_code,
stdout: StreamOutput::new(stdout.to_string()),
stderr: StreamOutput::new(stderr.to_string()),
aggregated_output: StreamOutput::new(aggregated.to_string()),
duration: Duration::from_millis(1),
timed_out: false,
}
}
#[test]
fn sandbox_detection_requires_keywords() {
let output = make_exec_output(1, "", "", "");
assert!(!is_likely_sandbox_denied(
SandboxType::LinuxSeccomp,
&output
));
}
#[test]
fn sandbox_detection_identifies_keyword_in_stderr() {
let output = make_exec_output(1, "", "Operation not permitted", "");
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
}
#[test]
fn sandbox_detection_respects_quick_reject_exit_codes() {
let output = make_exec_output(127, "", "command not found", "");
assert!(!is_likely_sandbox_denied(
SandboxType::LinuxSeccomp,
&output
));
}
#[test]
fn sandbox_detection_ignores_non_sandbox_mode() {
let output = make_exec_output(1, "", "Operation not permitted", "");
assert!(!is_likely_sandbox_denied(SandboxType::None, &output));
}
#[test]
fn sandbox_detection_uses_aggregated_output() {
let output = make_exec_output(
101,
"",
"",
"cargo failed: Read-only file system when writing target",
);
assert!(is_likely_sandbox_denied(
SandboxType::MacosSeatbelt,
&output
));
}
#[cfg(unix)]
#[test]
fn sandbox_detection_flags_sigsys_exit_code() {
let exit_code = EXIT_CODE_SIGNAL_BASE + libc::SIGSYS;
let output = make_exec_output(exit_code, "", "", "");
assert!(is_likely_sandbox_denied(SandboxType::LinuxSeccomp, &output));
}
#[cfg(unix)]
#[tokio::test]
async fn kill_child_process_group_kills_grandchildren_on_timeout() -> Result<()> {
let command = vec![
"/bin/bash".to_string(),
"-c".to_string(),
"sleep 60 & echo $!; sleep 60".to_string(),
];
let env: HashMap<String, String> = std::env::vars().collect();
let params = ExecParams {
command,
cwd: std::env::current_dir()?,
timeout_ms: Some(500),
env,
with_escalated_permissions: None,
justification: None,
arg0: None,
};
let output = exec(params, SandboxType::None, &SandboxPolicy::ReadOnly, None).await?;
assert!(output.timed_out);
let stdout = output.stdout.from_utf8_lossy().text;
let pid_line = stdout.lines().next().unwrap_or("").trim();
let pid: i32 = pid_line.parse().map_err(|error| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Failed to parse pid from stdout '{pid_line}': {error}"),
)
})?;
let mut killed = false;
for _ in 0..20 {
// Use kill(pid, 0) to check if the process is alive.
if unsafe { libc::kill(pid, 0) } == -1
&& let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error()
{
killed = true;
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
assert!(killed, "grandchild process with pid {pid} is still alive");
Ok(())
}
}

View File

@@ -0,0 +1,194 @@
use crate::config::types::EnvironmentVariablePattern;
use crate::config::types::ShellEnvironmentPolicy;
use crate::config::types::ShellEnvironmentPolicyInherit;
use std::collections::HashMap;
use std::collections::HashSet;
/// Construct an environment map based on the rules in the specified policy. The
/// resulting map can be passed directly to `Command::envs()` after calling
/// `env_clear()` to ensure no unintended variables are leaked to the spawned
/// process.
///
/// The derivation follows the algorithm documented in the struct-level comment
/// for [`ShellEnvironmentPolicy`].
pub fn create_env(policy: &ShellEnvironmentPolicy) -> HashMap<String, String> {
populate_env(std::env::vars(), policy)
}
fn populate_env<I>(vars: I, policy: &ShellEnvironmentPolicy) -> HashMap<String, String>
where
I: IntoIterator<Item = (String, String)>,
{
// Step 1 determine the starting set of variables based on the
// `inherit` strategy.
let mut env_map: HashMap<String, String> = match policy.inherit {
ShellEnvironmentPolicyInherit::All => vars.into_iter().collect(),
ShellEnvironmentPolicyInherit::None => HashMap::new(),
ShellEnvironmentPolicyInherit::Core => {
const CORE_VARS: &[&str] = &[
"HOME", "LOGNAME", "PATH", "SHELL", "USER", "USERNAME", "TMPDIR", "TEMP", "TMP",
];
let allow: HashSet<&str> = CORE_VARS.iter().copied().collect();
vars.into_iter()
.filter(|(k, _)| allow.contains(k.as_str()))
.collect()
}
};
// Internal helper does `name` match **any** pattern in `patterns`?
let matches_any = |name: &str, patterns: &[EnvironmentVariablePattern]| -> bool {
patterns.iter().any(|pattern| pattern.matches(name))
};
// Step 2 Apply the default exclude if not disabled.
if !policy.ignore_default_excludes {
let default_excludes = vec![
EnvironmentVariablePattern::new_case_insensitive("*KEY*"),
EnvironmentVariablePattern::new_case_insensitive("*SECRET*"),
EnvironmentVariablePattern::new_case_insensitive("*TOKEN*"),
];
env_map.retain(|k, _| !matches_any(k, &default_excludes));
}
// Step 3 Apply custom excludes.
if !policy.exclude.is_empty() {
env_map.retain(|k, _| !matches_any(k, &policy.exclude));
}
// Step 4 Apply user-provided overrides.
for (key, val) in &policy.r#set {
env_map.insert(key.clone(), val.clone());
}
// Step 5 If include_only is non-empty, keep *only* the matching vars.
if !policy.include_only.is_empty() {
env_map.retain(|k, _| matches_any(k, &policy.include_only));
}
env_map
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::types::ShellEnvironmentPolicyInherit;
use maplit::hashmap;
fn make_vars(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect()
}
#[test]
fn test_core_inherit_and_default_excludes() {
let vars = make_vars(&[
("PATH", "/usr/bin"),
("HOME", "/home/user"),
("API_KEY", "secret"),
("SECRET_TOKEN", "t"),
]);
let policy = ShellEnvironmentPolicy::default(); // inherit Core, default excludes on
let result = populate_env(vars, &policy);
let expected: HashMap<String, String> = hashmap! {
"PATH".to_string() => "/usr/bin".to_string(),
"HOME".to_string() => "/home/user".to_string(),
};
assert_eq!(result, expected);
}
#[test]
fn test_include_only() {
let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]);
let policy = ShellEnvironmentPolicy {
// skip default excludes so nothing is removed prematurely
ignore_default_excludes: true,
include_only: vec![EnvironmentVariablePattern::new_case_insensitive("*PATH")],
..Default::default()
};
let result = populate_env(vars, &policy);
let expected: HashMap<String, String> = hashmap! {
"PATH".to_string() => "/usr/bin".to_string(),
};
assert_eq!(result, expected);
}
#[test]
fn test_set_overrides() {
let vars = make_vars(&[("PATH", "/usr/bin")]);
let mut policy = ShellEnvironmentPolicy {
ignore_default_excludes: true,
..Default::default()
};
policy.r#set.insert("NEW_VAR".to_string(), "42".to_string());
let result = populate_env(vars, &policy);
let expected: HashMap<String, String> = hashmap! {
"PATH".to_string() => "/usr/bin".to_string(),
"NEW_VAR".to_string() => "42".to_string(),
};
assert_eq!(result, expected);
}
#[test]
fn test_inherit_all() {
let vars = make_vars(&[("PATH", "/usr/bin"), ("FOO", "bar")]);
let policy = ShellEnvironmentPolicy {
inherit: ShellEnvironmentPolicyInherit::All,
ignore_default_excludes: true, // keep everything
..Default::default()
};
let result = populate_env(vars.clone(), &policy);
let expected: HashMap<String, String> = vars.into_iter().collect();
assert_eq!(result, expected);
}
#[test]
fn test_inherit_all_with_default_excludes() {
let vars = make_vars(&[("PATH", "/usr/bin"), ("API_KEY", "secret")]);
let policy = ShellEnvironmentPolicy {
inherit: ShellEnvironmentPolicyInherit::All,
..Default::default()
};
let result = populate_env(vars, &policy);
let expected: HashMap<String, String> = hashmap! {
"PATH".to_string() => "/usr/bin".to_string(),
};
assert_eq!(result, expected);
}
#[test]
fn test_inherit_none() {
let vars = make_vars(&[("PATH", "/usr/bin"), ("HOME", "/home")]);
let mut policy = ShellEnvironmentPolicy {
inherit: ShellEnvironmentPolicyInherit::None,
ignore_default_excludes: true,
..Default::default()
};
policy
.r#set
.insert("ONLY_VAR".to_string(), "yes".to_string());
let result = populate_env(vars, &policy);
let expected: HashMap<String, String> = hashmap! {
"ONLY_VAR".to_string() => "yes".to_string(),
};
assert_eq!(result, expected);
}
}

View File

@@ -0,0 +1,295 @@
//! Centralized feature flags and metadata.
//!
//! This module defines a small set of toggles that gate experimental and
//! optional behavior across the codebase. Instead of wiring individual
//! booleans through multiple types, call sites consult a single `Features`
//! container attached to `Config`.
use crate::config::ConfigToml;
use crate::config::profile::ConfigProfile;
use serde::Deserialize;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
mod legacy;
pub(crate) use legacy::LegacyFeatureToggles;
/// High-level lifecycle stage for a feature.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Stage {
Experimental,
Beta,
Stable,
Deprecated,
Removed,
}
/// Unique features toggled via configuration.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum Feature {
/// Use the single unified PTY-backed exec tool.
UnifiedExec,
/// Enable experimental RMCP features such as OAuth login.
RmcpClient,
/// Include the freeform apply_patch tool.
ApplyPatchFreeform,
/// Include the view_image tool.
ViewImageTool,
/// Allow the model to request web searches.
WebSearchRequest,
/// Enable the model-based risk assessments for sandboxed commands.
SandboxCommandAssessment,
/// Create a ghost commit at each turn.
GhostCommit,
/// Enable Windows sandbox (restricted token) on Windows.
WindowsSandbox,
}
impl Feature {
pub fn key(self) -> &'static str {
self.info().key
}
pub fn stage(self) -> Stage {
self.info().stage
}
pub fn default_enabled(self) -> bool {
self.info().default_enabled
}
fn info(self) -> &'static FeatureSpec {
FEATURES
.iter()
.find(|spec| spec.id == self)
.unwrap_or_else(|| unreachable!("missing FeatureSpec for {:?}", self))
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct LegacyFeatureUsage {
pub alias: String,
pub feature: Feature,
}
/// Holds the effective set of enabled features.
#[derive(Debug, Clone, Default, PartialEq)]
pub struct Features {
enabled: BTreeSet<Feature>,
legacy_usages: BTreeSet<LegacyFeatureUsage>,
}
#[derive(Debug, Clone, Default)]
pub struct FeatureOverrides {
pub include_apply_patch_tool: Option<bool>,
pub web_search_request: Option<bool>,
pub experimental_sandbox_command_assessment: Option<bool>,
}
impl FeatureOverrides {
fn apply(self, features: &mut Features) {
LegacyFeatureToggles {
include_apply_patch_tool: self.include_apply_patch_tool,
tools_web_search: self.web_search_request,
..Default::default()
}
.apply(features);
}
}
impl Features {
/// Starts with built-in defaults.
pub fn with_defaults() -> Self {
let mut set = BTreeSet::new();
for spec in FEATURES {
if spec.default_enabled {
set.insert(spec.id);
}
}
Self {
enabled: set,
legacy_usages: BTreeSet::new(),
}
}
pub fn enabled(&self, f: Feature) -> bool {
self.enabled.contains(&f)
}
pub fn enable(&mut self, f: Feature) -> &mut Self {
self.enabled.insert(f);
self
}
pub fn disable(&mut self, f: Feature) -> &mut Self {
self.enabled.remove(&f);
self
}
pub fn record_legacy_usage_force(&mut self, alias: &str, feature: Feature) {
self.legacy_usages.insert(LegacyFeatureUsage {
alias: alias.to_string(),
feature,
});
}
pub fn record_legacy_usage(&mut self, alias: &str, feature: Feature) {
if alias == feature.key() {
return;
}
self.record_legacy_usage_force(alias, feature);
}
pub fn legacy_feature_usages(&self) -> impl Iterator<Item = (&str, Feature)> + '_ {
self.legacy_usages
.iter()
.map(|usage| (usage.alias.as_str(), usage.feature))
}
/// Apply a table of key -> bool toggles (e.g. from TOML).
pub fn apply_map(&mut self, m: &BTreeMap<String, bool>) {
for (k, v) in m {
match feature_for_key(k) {
Some(feat) => {
if k != feat.key() {
self.record_legacy_usage(k.as_str(), feat);
}
if *v {
self.enable(feat);
} else {
self.disable(feat);
}
}
None => {
tracing::warn!("unknown feature key in config: {k}");
}
}
}
}
pub fn from_config(
cfg: &ConfigToml,
config_profile: &ConfigProfile,
overrides: FeatureOverrides,
) -> Self {
let mut features = Features::with_defaults();
let base_legacy = LegacyFeatureToggles {
experimental_sandbox_command_assessment: cfg.experimental_sandbox_command_assessment,
experimental_use_freeform_apply_patch: cfg.experimental_use_freeform_apply_patch,
experimental_use_unified_exec_tool: cfg.experimental_use_unified_exec_tool,
experimental_use_rmcp_client: cfg.experimental_use_rmcp_client,
tools_web_search: cfg.tools.as_ref().and_then(|t| t.web_search),
tools_view_image: cfg.tools.as_ref().and_then(|t| t.view_image),
..Default::default()
};
base_legacy.apply(&mut features);
if let Some(base_features) = cfg.features.as_ref() {
features.apply_map(&base_features.entries);
}
let profile_legacy = LegacyFeatureToggles {
include_apply_patch_tool: config_profile.include_apply_patch_tool,
experimental_sandbox_command_assessment: config_profile
.experimental_sandbox_command_assessment,
experimental_use_freeform_apply_patch: config_profile
.experimental_use_freeform_apply_patch,
experimental_use_unified_exec_tool: config_profile.experimental_use_unified_exec_tool,
experimental_use_rmcp_client: config_profile.experimental_use_rmcp_client,
tools_web_search: config_profile.tools_web_search,
tools_view_image: config_profile.tools_view_image,
};
profile_legacy.apply(&mut features);
if let Some(profile_features) = config_profile.features.as_ref() {
features.apply_map(&profile_features.entries);
}
overrides.apply(&mut features);
features
}
}
/// Keys accepted in `[features]` tables.
fn feature_for_key(key: &str) -> Option<Feature> {
for spec in FEATURES {
if spec.key == key {
return Some(spec.id);
}
}
legacy::feature_for_key(key)
}
/// Returns `true` if the provided string matches a known feature toggle key.
pub fn is_known_feature_key(key: &str) -> bool {
feature_for_key(key).is_some()
}
/// Deserializable features table for TOML.
#[derive(Deserialize, Debug, Clone, Default, PartialEq)]
pub struct FeaturesToml {
#[serde(flatten)]
pub entries: BTreeMap<String, bool>,
}
/// Single, easy-to-read registry of all feature definitions.
#[derive(Debug, Clone, Copy)]
pub struct FeatureSpec {
pub id: Feature,
pub key: &'static str,
pub stage: Stage,
pub default_enabled: bool,
}
pub const FEATURES: &[FeatureSpec] = &[
FeatureSpec {
id: Feature::UnifiedExec,
key: "unified_exec",
stage: Stage::Experimental,
default_enabled: false,
},
FeatureSpec {
id: Feature::RmcpClient,
key: "rmcp_client",
stage: Stage::Experimental,
default_enabled: false,
},
FeatureSpec {
id: Feature::ApplyPatchFreeform,
key: "apply_patch_freeform",
stage: Stage::Beta,
default_enabled: false,
},
FeatureSpec {
id: Feature::ViewImageTool,
key: "view_image_tool",
stage: Stage::Stable,
default_enabled: true,
},
FeatureSpec {
id: Feature::WebSearchRequest,
key: "web_search_request",
stage: Stage::Stable,
default_enabled: false,
},
FeatureSpec {
id: Feature::SandboxCommandAssessment,
key: "experimental_sandbox_command_assessment",
stage: Stage::Experimental,
default_enabled: false,
},
FeatureSpec {
id: Feature::GhostCommit,
key: "ghost_commit",
stage: Stage::Experimental,
default_enabled: true,
},
FeatureSpec {
id: Feature::WindowsSandbox,
key: "enable_experimental_windows_sandbox",
stage: Stage::Experimental,
default_enabled: false,
},
];

View File

@@ -0,0 +1,137 @@
use super::Feature;
use super::Features;
use tracing::info;
#[derive(Clone, Copy)]
struct Alias {
legacy_key: &'static str,
feature: Feature,
}
const ALIASES: &[Alias] = &[
Alias {
legacy_key: "experimental_sandbox_command_assessment",
feature: Feature::SandboxCommandAssessment,
},
Alias {
legacy_key: "experimental_use_unified_exec_tool",
feature: Feature::UnifiedExec,
},
Alias {
legacy_key: "experimental_use_rmcp_client",
feature: Feature::RmcpClient,
},
Alias {
legacy_key: "experimental_use_freeform_apply_patch",
feature: Feature::ApplyPatchFreeform,
},
Alias {
legacy_key: "include_apply_patch_tool",
feature: Feature::ApplyPatchFreeform,
},
Alias {
legacy_key: "web_search",
feature: Feature::WebSearchRequest,
},
];
pub(crate) fn feature_for_key(key: &str) -> Option<Feature> {
ALIASES
.iter()
.find(|alias| alias.legacy_key == key)
.map(|alias| {
log_alias(alias.legacy_key, alias.feature);
alias.feature
})
}
#[derive(Debug, Default)]
pub struct LegacyFeatureToggles {
pub include_apply_patch_tool: Option<bool>,
pub experimental_sandbox_command_assessment: Option<bool>,
pub experimental_use_freeform_apply_patch: Option<bool>,
pub experimental_use_unified_exec_tool: Option<bool>,
pub experimental_use_rmcp_client: Option<bool>,
pub tools_web_search: Option<bool>,
pub tools_view_image: Option<bool>,
}
impl LegacyFeatureToggles {
pub fn apply(self, features: &mut Features) {
set_if_some(
features,
Feature::ApplyPatchFreeform,
self.include_apply_patch_tool,
"include_apply_patch_tool",
);
set_if_some(
features,
Feature::SandboxCommandAssessment,
self.experimental_sandbox_command_assessment,
"experimental_sandbox_command_assessment",
);
set_if_some(
features,
Feature::ApplyPatchFreeform,
self.experimental_use_freeform_apply_patch,
"experimental_use_freeform_apply_patch",
);
set_if_some(
features,
Feature::UnifiedExec,
self.experimental_use_unified_exec_tool,
"experimental_use_unified_exec_tool",
);
set_if_some(
features,
Feature::RmcpClient,
self.experimental_use_rmcp_client,
"experimental_use_rmcp_client",
);
set_if_some(
features,
Feature::WebSearchRequest,
self.tools_web_search,
"tools.web_search",
);
set_if_some(
features,
Feature::ViewImageTool,
self.tools_view_image,
"tools.view_image",
);
}
}
fn set_if_some(
features: &mut Features,
feature: Feature,
maybe_value: Option<bool>,
alias_key: &'static str,
) {
if let Some(enabled) = maybe_value {
set_feature(features, feature, enabled);
log_alias(alias_key, feature);
features.record_legacy_usage(alias_key, feature);
}
}
fn set_feature(features: &mut Features, feature: Feature, enabled: bool) {
if enabled {
features.enable(feature);
} else {
features.disable(feature);
}
}
fn log_alias(alias: &str, feature: Feature) {
let canonical = feature.key();
if alias == canonical {
return;
}
info!(
%alias,
canonical,
"legacy feature toggle detected; prefer `[features].{canonical}`"
);
}

View File

@@ -0,0 +1,6 @@
use env_flags::env_flags;
env_flags! {
/// Fixture path for offline tests (see client.rs).
pub LLMX_RS_SSE_FIXTURE: Option<&str> = None;
}

View File

@@ -0,0 +1,14 @@
use thiserror::Error;
#[derive(Debug, Error, PartialEq)]
pub enum FunctionCallError {
#[error("{0}")]
RespondToModel(String),
#[error("{0}")]
#[allow(dead_code)] // TODO(jif) fix in a follow-up PR
Denied(String),
#[error("LocalShellCall without call_id or id")]
MissingLocalShellCallId,
#[error("Fatal error: {0}")]
Fatal(String),
}

1124
llmx-rs/core/src/git_info.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,72 @@
use crate::protocol::SandboxPolicy;
use crate::spawn::StdioPolicy;
use crate::spawn::spawn_child_async;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use tokio::process::Child;
/// Spawn a shell tool command under the Linux Landlock+seccomp sandbox helper
/// (llmx-linux-sandbox).
///
/// Unlike macOS Seatbelt where we directly embed the policy text, the Linux
/// helper accepts a list of `--sandbox-permission`/`-s` flags mirroring the
/// public CLI. We convert the internal [`SandboxPolicy`] representation into
/// the equivalent CLI options.
pub async fn spawn_command_under_linux_sandbox<P>(
llmx_linux_sandbox_exe: P,
command: Vec<String>,
command_cwd: PathBuf,
sandbox_policy: &SandboxPolicy,
sandbox_policy_cwd: &Path,
stdio_policy: StdioPolicy,
env: HashMap<String, String>,
) -> std::io::Result<Child>
where
P: AsRef<Path>,
{
let args = create_linux_sandbox_command_args(command, sandbox_policy, sandbox_policy_cwd);
let arg0 = Some("llmx-linux-sandbox");
spawn_child_async(
llmx_linux_sandbox_exe.as_ref().to_path_buf(),
args,
arg0,
command_cwd,
sandbox_policy,
stdio_policy,
env,
)
.await
}
/// Converts the sandbox policy into the CLI invocation for `llmx-linux-sandbox`.
pub(crate) fn create_linux_sandbox_command_args(
command: Vec<String>,
sandbox_policy: &SandboxPolicy,
sandbox_policy_cwd: &Path,
) -> Vec<String> {
#[expect(clippy::expect_used)]
let sandbox_policy_cwd = sandbox_policy_cwd
.to_str()
.expect("cwd must be valid UTF-8")
.to_string();
#[expect(clippy::expect_used)]
let sandbox_policy_json =
serde_json::to_string(sandbox_policy).expect("Failed to serialize SandboxPolicy to JSON");
let mut linux_cmd: Vec<String> = vec![
"--sandbox-policy-cwd".to_string(),
sandbox_policy_cwd,
"--sandbox-policy".to_string(),
sandbox_policy_json,
// Separator so that command arguments starting with `-` are not parsed as
// options of the helper itself.
"--".to_string(),
];
// Append the original tool command.
linux_cmd.extend(command);
linux_cmd
}

111
llmx-rs/core/src/lib.rs Normal file
View File

@@ -0,0 +1,111 @@
//! Root of the `llmx-core` library.
// Prevent accidental direct writes to stdout/stderr in library code. All
// user-visible output must go through the appropriate abstraction (e.g.,
// the TUI or the tracing stack).
#![deny(clippy::print_stdout, clippy::print_stderr)]
mod apply_patch;
pub mod auth;
pub mod bash;
mod chat_completions;
mod client;
mod client_common;
pub mod llmx;
mod llmx_conversation;
pub use llmx_conversation::LlmxConversation;
mod command_safety;
pub mod config;
pub mod config_loader;
mod context_manager;
pub mod custom_prompts;
mod environment_context;
pub mod error;
pub mod exec;
pub mod exec_env;
pub mod features;
mod flags;
pub mod git_info;
pub mod landlock;
mod llmx_delegate;
pub mod mcp;
mod mcp_connection_manager;
mod mcp_tool_call;
mod message_history;
mod model_provider_info;
pub mod parse_command;
mod response_processing;
pub mod sandboxing;
pub mod token_data;
mod truncate;
mod unified_exec;
mod user_instructions;
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
mod conversation_manager;
mod event_mapping;
pub mod review_format;
pub use conversation_manager::ConversationManager;
pub use conversation_manager::NewConversation;
pub use llmx_protocol::protocol::InitialHistory;
// Re-export common auth types for workspace consumers
pub use auth::AuthManager;
pub use auth::LlmxAuth;
pub mod default_client;
pub mod model_family;
mod openai_model_info;
pub mod project_doc;
mod rollout;
pub(crate) mod safety;
pub mod seatbelt;
pub mod shell;
pub mod spawn;
pub mod terminal;
mod tools;
pub mod turn_diff_tracker;
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
pub use rollout::INTERACTIVE_SESSION_SOURCES;
pub use rollout::RolloutRecorder;
pub use rollout::SESSIONS_SUBDIR;
pub use rollout::SessionMeta;
pub use rollout::find_conversation_path_by_id_str;
pub use rollout::list::ConversationItem;
pub use rollout::list::ConversationsPage;
pub use rollout::list::Cursor;
pub use rollout::list::parse_cursor;
pub use rollout::list::read_head_for_summary;
mod function_tool;
mod state;
mod tasks;
mod user_notification;
mod user_shell_command;
pub mod util;
pub use apply_patch::LLMX_APPLY_PATCH_ARG1;
pub use command_safety::is_safe_command;
pub use safety::get_platform_sandbox;
pub use safety::set_windows_sandbox_enabled;
// Re-export the protocol types from the standalone `llmx-protocol` crate so existing
// `llmx_core::protocol::...` references continue to work across the workspace.
pub use llmx_protocol::protocol;
// Re-export protocol config enums to ensure call sites can use the same types
// as those in the protocol crate when constructing protocol messages.
pub use llmx_protocol::config_types as protocol_config_types;
pub use client::ModelClient;
pub use client_common::Prompt;
pub use client_common::REVIEW_PROMPT;
pub use client_common::ResponseEvent;
pub use client_common::ResponseStream;
pub use compact::content_items_to_text;
pub use event_mapping::parse_turn_item;
pub use llmx_protocol::models::ContentItem;
pub use llmx_protocol::models::LocalShellAction;
pub use llmx_protocol::models::LocalShellExecAction;
pub use llmx_protocol::models::LocalShellStatus;
pub use llmx_protocol::models::ResponseItem;
pub mod compact;
pub mod otel_init;

3150
llmx-rs/core/src/llmx.rs Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,36 @@
use crate::error::Result as LlmxResult;
use crate::llmx::Llmx;
use crate::protocol::Event;
use crate::protocol::Op;
use crate::protocol::Submission;
use std::path::PathBuf;
pub struct LlmxConversation {
llmx: Llmx,
rollout_path: PathBuf,
}
/// Conduit for the bidirectional stream of messages that compose a conversation
/// in Llmx.
impl LlmxConversation {
pub(crate) fn new(llmx: Llmx, rollout_path: PathBuf) -> Self {
Self { llmx, rollout_path }
}
pub async fn submit(&self, op: Op) -> LlmxResult<String> {
self.llmx.submit(op).await
}
/// Use sparingly: this is intended to be removed soon.
pub async fn submit_with_id(&self, sub: Submission) -> LlmxResult<()> {
self.llmx.submit_with_id(sub).await
}
pub async fn next_event(&self) -> LlmxResult<Event> {
self.llmx.next_event().await
}
pub fn rollout_path(&self) -> PathBuf {
self.rollout_path.clone()
}
}

View File

@@ -0,0 +1,300 @@
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use async_channel::Receiver;
use async_channel::Sender;
use llmx_async_utils::OrCancelExt;
use llmx_protocol::protocol::ApplyPatchApprovalRequestEvent;
use llmx_protocol::protocol::Event;
use llmx_protocol::protocol::EventMsg;
use llmx_protocol::protocol::ExecApprovalRequestEvent;
use llmx_protocol::protocol::Op;
use llmx_protocol::protocol::SessionSource;
use llmx_protocol::protocol::SubAgentSource;
use llmx_protocol::protocol::Submission;
use llmx_protocol::user_input::UserInput;
use tokio_util::sync::CancellationToken;
use crate::AuthManager;
use crate::config::Config;
use crate::error::LlmxErr;
use crate::llmx::Llmx;
use crate::llmx::LlmxSpawnOk;
use crate::llmx::SUBMISSION_CHANNEL_CAPACITY;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use llmx_protocol::protocol::InitialHistory;
/// Start an interactive sub-Llmx conversation and return IO channels.
///
/// The returned `events_rx` yields non-approval events emitted by the sub-agent.
/// Approval requests are handled via `parent_session` and are not surfaced.
/// The returned `ops_tx` allows the caller to submit additional `Op`s to the sub-agent.
pub(crate) async fn run_llmx_conversation_interactive(
config: Config,
auth_manager: Arc<AuthManager>,
parent_session: Arc<Session>,
parent_ctx: Arc<TurnContext>,
cancel_token: CancellationToken,
initial_history: Option<InitialHistory>,
) -> Result<Llmx, LlmxErr> {
let (tx_sub, rx_sub) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let (tx_ops, rx_ops) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let LlmxSpawnOk { llmx, .. } = Llmx::spawn(
config,
auth_manager,
initial_history.unwrap_or(InitialHistory::New),
SessionSource::SubAgent(SubAgentSource::Review),
)
.await?;
let llmx = Arc::new(llmx);
// Use a child token so parent cancel cascades but we can scope it to this task
let cancel_token_events = cancel_token.child_token();
let cancel_token_ops = cancel_token.child_token();
// Forward events from the sub-agent to the consumer, filtering approvals and
// routing them to the parent session for decisions.
let parent_session_clone = Arc::clone(&parent_session);
let parent_ctx_clone = Arc::clone(&parent_ctx);
let llmx_for_events = Arc::clone(&llmx);
tokio::spawn(async move {
let _ = forward_events(
llmx_for_events,
tx_sub,
parent_session_clone,
parent_ctx_clone,
cancel_token_events.clone(),
)
.or_cancel(&cancel_token_events)
.await;
});
// Forward ops from the caller to the sub-agent.
let llmx_for_ops = Arc::clone(&llmx);
tokio::spawn(async move {
forward_ops(llmx_for_ops, rx_ops, cancel_token_ops).await;
});
Ok(Llmx {
next_id: AtomicU64::new(0),
tx_sub: tx_ops,
rx_event: rx_sub,
})
}
/// Convenience wrapper for one-time use with an initial prompt.
///
/// Internally calls the interactive variant, then immediately submits the provided input.
pub(crate) async fn run_llmx_conversation_one_shot(
config: Config,
auth_manager: Arc<AuthManager>,
input: Vec<UserInput>,
parent_session: Arc<Session>,
parent_ctx: Arc<TurnContext>,
cancel_token: CancellationToken,
initial_history: Option<InitialHistory>,
) -> Result<Llmx, LlmxErr> {
// Use a child token so we can stop the delegate after completion without
// requiring the caller to cancel the parent token.
let child_cancel = cancel_token.child_token();
let io = run_llmx_conversation_interactive(
config,
auth_manager,
parent_session,
parent_ctx,
child_cancel.clone(),
initial_history,
)
.await?;
// Send the initial input to kick off the one-shot turn.
io.submit(Op::UserInput { items: input }).await?;
// Bridge events so we can observe completion and shut down automatically.
let (tx_bridge, rx_bridge) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
let ops_tx = io.tx_sub.clone();
let io_for_bridge = io;
tokio::spawn(async move {
while let Ok(event) = io_for_bridge.next_event().await {
let should_shutdown = matches!(
event.msg,
EventMsg::TaskComplete(_) | EventMsg::TurnAborted(_)
);
let _ = tx_bridge.send(event).await;
if should_shutdown {
let _ = ops_tx
.send(Submission {
id: "shutdown".to_string(),
op: Op::Shutdown {},
})
.await;
child_cancel.cancel();
break;
}
}
});
// For one-shot usage, return a closed `tx_sub` so callers cannot submit
// additional ops after the initial request. Create a channel and drop the
// receiver to close it immediately.
let (tx_closed, rx_closed) = async_channel::bounded(SUBMISSION_CHANNEL_CAPACITY);
drop(rx_closed);
Ok(Llmx {
next_id: AtomicU64::new(0),
rx_event: rx_bridge,
tx_sub: tx_closed,
})
}
async fn forward_events(
llmx: Arc<Llmx>,
tx_sub: Sender<Event>,
parent_session: Arc<Session>,
parent_ctx: Arc<TurnContext>,
cancel_token: CancellationToken,
) {
while let Ok(event) = llmx.next_event().await {
match event {
// ignore all legacy delta events
Event {
id: _,
msg: EventMsg::AgentMessageDelta(_) | EventMsg::AgentReasoningDelta(_),
} => continue,
Event {
id: _,
msg: EventMsg::SessionConfigured(_),
} => continue,
Event {
id,
msg: EventMsg::ExecApprovalRequest(event),
} => {
// Initiate approval via parent session; do not surface to consumer.
handle_exec_approval(
&llmx,
id,
&parent_session,
&parent_ctx,
event,
&cancel_token,
)
.await;
}
Event {
id,
msg: EventMsg::ApplyPatchApprovalRequest(event),
} => {
handle_patch_approval(
&llmx,
id,
&parent_session,
&parent_ctx,
event,
&cancel_token,
)
.await;
}
other => {
let _ = tx_sub.send(other).await;
}
}
}
}
/// Forward ops from a caller to a sub-agent, respecting cancellation.
async fn forward_ops(
llmx: Arc<Llmx>,
rx_ops: Receiver<Submission>,
cancel_token_ops: CancellationToken,
) {
loop {
let op: Op = match rx_ops.recv().or_cancel(&cancel_token_ops).await {
Ok(Ok(Submission { id: _, op })) => op,
Ok(Err(_)) | Err(_) => break,
};
let _ = llmx.submit(op).await;
}
}
/// Handle an ExecApprovalRequest by consulting the parent session and replying.
async fn handle_exec_approval(
llmx: &Llmx,
id: String,
parent_session: &Session,
parent_ctx: &TurnContext,
event: ExecApprovalRequestEvent,
cancel_token: &CancellationToken,
) {
// Race approval with cancellation and timeout to avoid hangs.
let approval_fut = parent_session.request_command_approval(
parent_ctx,
parent_ctx.sub_id.clone(),
event.command,
event.cwd,
event.reason,
event.risk,
);
let decision = await_approval_with_cancel(
approval_fut,
parent_session,
&parent_ctx.sub_id,
cancel_token,
)
.await;
let _ = llmx.submit(Op::ExecApproval { id, decision }).await;
}
/// Handle an ApplyPatchApprovalRequest by consulting the parent session and replying.
async fn handle_patch_approval(
llmx: &Llmx,
id: String,
parent_session: &Session,
parent_ctx: &TurnContext,
event: ApplyPatchApprovalRequestEvent,
cancel_token: &CancellationToken,
) {
let decision_rx = parent_session
.request_patch_approval(
parent_ctx,
parent_ctx.sub_id.clone(),
event.changes,
event.reason,
event.grant_root,
)
.await;
let decision = await_approval_with_cancel(
async move { decision_rx.await.unwrap_or_default() },
parent_session,
&parent_ctx.sub_id,
cancel_token,
)
.await;
let _ = llmx.submit(Op::PatchApproval { id, decision }).await;
}
/// Await an approval decision, aborting on cancellation.
async fn await_approval_with_cancel<F>(
fut: F,
parent_session: &Session,
sub_id: &str,
cancel_token: &CancellationToken,
) -> llmx_protocol::protocol::ReviewDecision
where
F: core::future::Future<Output = llmx_protocol::protocol::ReviewDecision>,
{
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
parent_session
.notify_approval(sub_id, llmx_protocol::protocol::ReviewDecision::Abort)
.await;
llmx_protocol::protocol::ReviewDecision::Abort
}
decision = fut => {
decision
}
}
}

View File

@@ -0,0 +1,72 @@
use std::collections::HashMap;
use anyhow::Result;
use futures::future::join_all;
use llmx_protocol::protocol::McpAuthStatus;
use llmx_rmcp_client::OAuthCredentialsStoreMode;
use llmx_rmcp_client::determine_streamable_http_auth_status;
use tracing::warn;
use crate::config::types::McpServerConfig;
use crate::config::types::McpServerTransportConfig;
#[derive(Debug, Clone)]
pub struct McpAuthStatusEntry {
pub config: McpServerConfig,
pub auth_status: McpAuthStatus,
}
pub async fn compute_auth_statuses<'a, I>(
servers: I,
store_mode: OAuthCredentialsStoreMode,
) -> HashMap<String, McpAuthStatusEntry>
where
I: IntoIterator<Item = (&'a String, &'a McpServerConfig)>,
{
let futures = servers.into_iter().map(|(name, config)| {
let name = name.clone();
let config = config.clone();
async move {
let auth_status = match compute_auth_status(&name, &config, store_mode).await {
Ok(status) => status,
Err(error) => {
warn!("failed to determine auth status for MCP server `{name}`: {error:?}");
McpAuthStatus::Unsupported
}
};
let entry = McpAuthStatusEntry {
config,
auth_status,
};
(name, entry)
}
});
join_all(futures).await.into_iter().collect()
}
async fn compute_auth_status(
server_name: &str,
config: &McpServerConfig,
store_mode: OAuthCredentialsStoreMode,
) -> Result<McpAuthStatus> {
match &config.transport {
McpServerTransportConfig::Stdio { .. } => Ok(McpAuthStatus::Unsupported),
McpServerTransportConfig::StreamableHttp {
url,
bearer_token_env_var,
http_headers,
env_http_headers,
} => {
determine_streamable_http_auth_status(
server_name,
url,
bearer_token_env_var.as_deref(),
http_headers.clone(),
env_http_headers.clone(),
store_mode,
)
.await
}
}
}

View File

@@ -0,0 +1 @@
pub mod auth;

View File

@@ -0,0 +1,822 @@
//! Connection manager for Model Context Protocol (MCP) servers.
//!
//! The [`McpConnectionManager`] owns one [`llmx_rmcp_client::RmcpClient`] per
//! configured server (keyed by the *server name*). It offers convenience
//! helpers to query the available tools across *all* servers and returns them
//! in a single aggregated map using the fully-qualified tool name
//! `"<server><MCP_TOOL_NAME_DELIMITER><tool>"` as the key.
use std::collections::HashMap;
use std::collections::HashSet;
use std::env;
use std::ffi::OsString;
use std::sync::Arc;
use std::time::Duration;
use anyhow::Context;
use anyhow::Result;
use anyhow::anyhow;
use llmx_rmcp_client::OAuthCredentialsStoreMode;
use llmx_rmcp_client::RmcpClient;
use mcp_types::ClientCapabilities;
use mcp_types::Implementation;
use mcp_types::ListResourceTemplatesRequestParams;
use mcp_types::ListResourceTemplatesResult;
use mcp_types::ListResourcesRequestParams;
use mcp_types::ListResourcesResult;
use mcp_types::ReadResourceRequestParams;
use mcp_types::ReadResourceResult;
use mcp_types::Resource;
use mcp_types::ResourceTemplate;
use mcp_types::Tool;
use serde_json::json;
use sha1::Digest;
use sha1::Sha1;
use tokio::task::JoinSet;
use tracing::info;
use tracing::warn;
use crate::config::types::McpServerConfig;
use crate::config::types::McpServerTransportConfig;
/// Delimiter used to separate the server name from the tool name in a fully
/// qualified tool name.
///
/// OpenAI requires tool names to conform to `^[a-zA-Z0-9_-]+$`, so we must
/// choose a delimiter from this character set.
const MCP_TOOL_NAME_DELIMITER: &str = "__";
const MAX_TOOL_NAME_LENGTH: usize = 64;
/// Default timeout for initializing MCP server & initially listing tools.
pub const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);
/// Default timeout for individual tool calls.
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);
/// Map that holds a startup error for every MCP server that could **not** be
/// spawned successfully.
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
fn qualify_tools(tools: Vec<ToolInfo>) -> HashMap<String, ToolInfo> {
let mut used_names = HashSet::new();
let mut qualified_tools = HashMap::new();
for tool in tools {
let mut qualified_name = format!(
"mcp{}{}{}{}",
MCP_TOOL_NAME_DELIMITER, tool.server_name, MCP_TOOL_NAME_DELIMITER, tool.tool_name
);
if qualified_name.len() > MAX_TOOL_NAME_LENGTH {
let mut hasher = Sha1::new();
hasher.update(qualified_name.as_bytes());
let sha1 = hasher.finalize();
let sha1_str = format!("{sha1:x}");
// Truncate to make room for the hash suffix
let prefix_len = MAX_TOOL_NAME_LENGTH - sha1_str.len();
qualified_name = format!("{}{}", &qualified_name[..prefix_len], sha1_str);
}
if used_names.contains(&qualified_name) {
warn!("skipping duplicated tool {}", qualified_name);
continue;
}
used_names.insert(qualified_name.clone());
qualified_tools.insert(qualified_name, tool);
}
qualified_tools
}
struct ToolInfo {
server_name: String,
tool_name: String,
tool: Tool,
}
struct ManagedClient {
client: Arc<RmcpClient>,
startup_timeout: Duration,
tool_timeout: Option<Duration>,
}
/// A thin wrapper around a set of running [`RmcpClient`] instances.
#[derive(Default)]
pub(crate) struct McpConnectionManager {
/// Server-name -> client instance.
///
/// The server name originates from the keys of the `mcp_servers` map in
/// the user configuration.
clients: HashMap<String, ManagedClient>,
/// Fully qualified tool name -> tool instance.
tools: HashMap<String, ToolInfo>,
/// Server-name -> configured tool filters.
tool_filters: HashMap<String, ToolFilter>,
}
impl McpConnectionManager {
/// Spawn a [`RmcpClient`] for each configured server.
///
/// * `mcp_servers` Map loaded from the user configuration where *keys*
/// are human-readable server identifiers and *values* are the spawn
/// instructions.
///
/// Servers that fail to start are reported in `ClientStartErrors`: the
/// user should be informed about these errors.
pub async fn new(
mcp_servers: HashMap<String, McpServerConfig>,
store_mode: OAuthCredentialsStoreMode,
) -> Result<(Self, ClientStartErrors)> {
// Early exit if no servers are configured.
if mcp_servers.is_empty() {
return Ok((Self::default(), ClientStartErrors::default()));
}
// Launch all configured servers concurrently.
let mut join_set = JoinSet::new();
let mut errors = ClientStartErrors::new();
let mut tool_filters: HashMap<String, ToolFilter> = HashMap::new();
for (server_name, cfg) in mcp_servers {
// Validate server name before spawning
if !is_valid_mcp_server_name(&server_name) {
let error = anyhow::anyhow!(
"invalid server name '{server_name}': must match pattern ^[a-zA-Z0-9_-]+$"
);
errors.insert(server_name, error);
continue;
}
if !cfg.enabled {
tool_filters.insert(server_name, ToolFilter::from_config(&cfg));
continue;
}
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);
tool_filters.insert(server_name.clone(), ToolFilter::from_config(&cfg));
let resolved_bearer_token = match &cfg.transport {
McpServerTransportConfig::StreamableHttp {
bearer_token_env_var,
..
} => resolve_bearer_token(&server_name, bearer_token_env_var.as_deref()),
_ => Ok(None),
};
join_set.spawn(async move {
let McpServerConfig { transport, .. } = cfg;
let params = mcp_types::InitializeRequestParams {
capabilities: ClientCapabilities {
experimental: None,
roots: None,
sampling: None,
// https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities
// indicates this should be an empty object.
elicitation: Some(json!({})),
},
client_info: Implementation {
name: "llmx-mcp-client".to_owned(),
version: env!("CARGO_PKG_VERSION").to_owned(),
title: Some("LLMX".into()),
// This field is used by LLMX when it is an MCP
// server: it should not be used when LLMX is
// an MCP client.
user_agent: None,
},
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
};
let resolved_bearer_token = resolved_bearer_token.unwrap_or_default();
let client_result = match transport {
McpServerTransportConfig::Stdio {
command,
args,
env,
env_vars,
cwd,
} => {
let command_os: OsString = command.into();
let args_os: Vec<OsString> = args.into_iter().map(Into::into).collect();
match RmcpClient::new_stdio_client(command_os, args_os, env, &env_vars, cwd)
.await
{
Ok(client) => {
let client = Arc::new(client);
client
.initialize(params.clone(), Some(startup_timeout))
.await
.map(|_| client)
}
Err(err) => Err(err.into()),
}
}
McpServerTransportConfig::StreamableHttp {
url,
http_headers,
env_http_headers,
..
} => {
match RmcpClient::new_streamable_http_client(
&server_name,
&url,
resolved_bearer_token.clone(),
http_headers,
env_http_headers,
store_mode,
)
.await
{
Ok(client) => {
let client = Arc::new(client);
client
.initialize(params.clone(), Some(startup_timeout))
.await
.map(|_| client)
}
Err(err) => Err(err),
}
}
};
(
(server_name, tool_timeout),
client_result.map(|client| (client, startup_timeout)),
)
});
}
let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());
while let Some(res) = join_set.join_next().await {
let ((server_name, tool_timeout), client_res) = match res {
Ok(result) => result,
Err(e) => {
warn!("Task panic when starting MCP server: {e:#}");
continue;
}
};
match client_res {
Ok((client, startup_timeout)) => {
clients.insert(
server_name,
ManagedClient {
client,
startup_timeout,
tool_timeout: Some(tool_timeout),
},
);
}
Err(e) => {
errors.insert(server_name, e);
}
}
}
let all_tools = match list_all_tools(&clients).await {
Ok(tools) => tools,
Err(e) => {
warn!("Failed to list tools from some MCP servers: {e:#}");
Vec::new()
}
};
let filtered_tools = filter_tools(all_tools, &tool_filters);
let tools = qualify_tools(filtered_tools);
Ok((
Self {
clients,
tools,
tool_filters,
},
errors,
))
}
/// Returns a single map that contains all tools. Each key is the
/// fully-qualified name for the tool.
pub fn list_all_tools(&self) -> HashMap<String, Tool> {
self.tools
.iter()
.map(|(name, tool)| (name.clone(), tool.tool.clone()))
.collect()
}
/// Returns a single map that contains all resources. Each key is the
/// server name and the value is a vector of resources.
pub async fn list_all_resources(&self) -> HashMap<String, Vec<Resource>> {
let mut join_set = JoinSet::new();
for (server_name, managed_client) in &self.clients {
let server_name_cloned = server_name.clone();
let client_clone = managed_client.client.clone();
let timeout = managed_client.tool_timeout;
join_set.spawn(async move {
let mut collected: Vec<Resource> = Vec::new();
let mut cursor: Option<String> = None;
loop {
let params = cursor.as_ref().map(|next| ListResourcesRequestParams {
cursor: Some(next.clone()),
});
let response = match client_clone.list_resources(params, timeout).await {
Ok(result) => result,
Err(err) => return (server_name_cloned, Err(err)),
};
collected.extend(response.resources);
match response.next_cursor {
Some(next) => {
if cursor.as_ref() == Some(&next) {
return (
server_name_cloned,
Err(anyhow!("resources/list returned duplicate cursor")),
);
}
cursor = Some(next);
}
None => return (server_name_cloned, Ok(collected)),
}
}
});
}
let mut aggregated: HashMap<String, Vec<Resource>> = HashMap::new();
while let Some(join_res) = join_set.join_next().await {
match join_res {
Ok((server_name, Ok(resources))) => {
aggregated.insert(server_name, resources);
}
Ok((server_name, Err(err))) => {
warn!("Failed to list resources for MCP server '{server_name}': {err:#}");
}
Err(err) => {
warn!("Task panic when listing resources for MCP server: {err:#}");
}
}
}
aggregated
}
/// Returns a single map that contains all resource templates. Each key is the
/// server name and the value is a vector of resource templates.
pub async fn list_all_resource_templates(&self) -> HashMap<String, Vec<ResourceTemplate>> {
let mut join_set = JoinSet::new();
for (server_name, managed_client) in &self.clients {
let server_name_cloned = server_name.clone();
let client_clone = managed_client.client.clone();
let timeout = managed_client.tool_timeout;
join_set.spawn(async move {
let mut collected: Vec<ResourceTemplate> = Vec::new();
let mut cursor: Option<String> = None;
loop {
let params = cursor
.as_ref()
.map(|next| ListResourceTemplatesRequestParams {
cursor: Some(next.clone()),
});
let response = match client_clone.list_resource_templates(params, timeout).await
{
Ok(result) => result,
Err(err) => return (server_name_cloned, Err(err)),
};
collected.extend(response.resource_templates);
match response.next_cursor {
Some(next) => {
if cursor.as_ref() == Some(&next) {
return (
server_name_cloned,
Err(anyhow!(
"resources/templates/list returned duplicate cursor"
)),
);
}
cursor = Some(next);
}
None => return (server_name_cloned, Ok(collected)),
}
}
});
}
let mut aggregated: HashMap<String, Vec<ResourceTemplate>> = HashMap::new();
while let Some(join_res) = join_set.join_next().await {
match join_res {
Ok((server_name, Ok(templates))) => {
aggregated.insert(server_name, templates);
}
Ok((server_name, Err(err))) => {
warn!(
"Failed to list resource templates for MCP server '{server_name}': {err:#}"
);
}
Err(err) => {
warn!("Task panic when listing resource templates for MCP server: {err:#}");
}
}
}
aggregated
}
/// Invoke the tool indicated by the (server, tool) pair.
pub async fn call_tool(
&self,
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
) -> Result<mcp_types::CallToolResult> {
if let Some(filter) = self.tool_filters.get(server)
&& !filter.allows(tool)
{
return Err(anyhow!(
"tool '{tool}' is disabled for MCP server '{server}'"
));
}
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = &managed.client;
let timeout = managed.tool_timeout;
client
.call_tool(tool.to_string(), arguments, timeout)
.await
.with_context(|| format!("tool call failed for `{server}/{tool}`"))
}
/// List resources from the specified server.
pub async fn list_resources(
&self,
server: &str,
params: Option<ListResourcesRequestParams>,
) -> Result<ListResourcesResult> {
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
client
.list_resources(params, timeout)
.await
.with_context(|| format!("resources/list failed for `{server}`"))
}
/// List resource templates from the specified server.
pub async fn list_resource_templates(
&self,
server: &str,
params: Option<ListResourceTemplatesRequestParams>,
) -> Result<ListResourceTemplatesResult> {
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
client
.list_resource_templates(params, timeout)
.await
.with_context(|| format!("resources/templates/list failed for `{server}`"))
}
/// Read a resource from the specified server.
pub async fn read_resource(
&self,
server: &str,
params: ReadResourceRequestParams,
) -> Result<ReadResourceResult> {
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;
let uri = params.uri.clone();
client
.read_resource(params, timeout)
.await
.with_context(|| format!("resources/read failed for `{server}` ({uri})"))
}
pub fn parse_tool_name(&self, tool_name: &str) -> Option<(String, String)> {
self.tools
.get(tool_name)
.map(|tool| (tool.server_name.clone(), tool.tool_name.clone()))
}
}
/// A tool is allowed to be used if both are true:
/// 1. enabled is None (no allowlist is set) or the tool is explicitly enabled.
/// 2. The tool is not explicitly disabled.
#[derive(Default, Clone)]
struct ToolFilter {
enabled: Option<HashSet<String>>,
disabled: HashSet<String>,
}
impl ToolFilter {
fn from_config(cfg: &McpServerConfig) -> Self {
let enabled = cfg
.enabled_tools
.as_ref()
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>());
let disabled = cfg
.disabled_tools
.as_ref()
.map(|tools| tools.iter().cloned().collect::<HashSet<_>>())
.unwrap_or_default();
Self { enabled, disabled }
}
fn allows(&self, tool_name: &str) -> bool {
if let Some(enabled) = &self.enabled
&& !enabled.contains(tool_name)
{
return false;
}
!self.disabled.contains(tool_name)
}
}
fn filter_tools(tools: Vec<ToolInfo>, filters: &HashMap<String, ToolFilter>) -> Vec<ToolInfo> {
tools
.into_iter()
.filter(|tool| {
filters
.get(&tool.server_name)
.is_none_or(|filter| filter.allows(&tool.tool_name))
})
.collect()
}
fn resolve_bearer_token(
server_name: &str,
bearer_token_env_var: Option<&str>,
) -> Result<Option<String>> {
let Some(env_var) = bearer_token_env_var else {
return Ok(None);
};
match env::var(env_var) {
Ok(value) => {
if value.is_empty() {
Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' is empty"
))
} else {
Ok(Some(value))
}
}
Err(env::VarError::NotPresent) => Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' is not set"
)),
Err(env::VarError::NotUnicode(_)) => Err(anyhow!(
"Environment variable {env_var} for MCP server '{server_name}' contains invalid Unicode"
)),
}
}
/// Query every server for its available tools and return a single map that
/// contains all tools. Each key is the fully-qualified name for the tool.
async fn list_all_tools(clients: &HashMap<String, ManagedClient>) -> Result<Vec<ToolInfo>> {
let mut join_set = JoinSet::new();
// Spawn one task per server so we can query them concurrently. This
// keeps the overall latency roughly at the slowest server instead of
// the cumulative latency.
for (server_name, managed_client) in clients {
let server_name_cloned = server_name.clone();
let client_clone = managed_client.client.clone();
let startup_timeout = managed_client.startup_timeout;
join_set.spawn(async move {
let res = client_clone.list_tools(None, Some(startup_timeout)).await;
(server_name_cloned, res)
});
}
let mut aggregated: Vec<ToolInfo> = Vec::with_capacity(join_set.len());
while let Some(join_res) = join_set.join_next().await {
let (server_name, list_result) = if let Ok(result) = join_res {
result
} else {
warn!("Task panic when listing tools for MCP server: {join_res:#?}");
continue;
};
let list_result = if let Ok(result) = list_result {
result
} else {
warn!("Failed to list tools for MCP server '{server_name}': {list_result:#?}");
continue;
};
for tool in list_result.tools {
let tool_info = ToolInfo {
server_name: server_name.clone(),
tool_name: tool.name.clone(),
tool,
};
aggregated.push(tool_info);
}
}
info!(
"aggregated {} tools from {} servers",
aggregated.len(),
clients.len()
);
Ok(aggregated)
}
fn is_valid_mcp_server_name(server_name: &str) -> bool {
!server_name.is_empty()
&& server_name
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-')
}
#[cfg(test)]
mod tests {
use super::*;
use mcp_types::ToolInputSchema;
use std::collections::HashSet;
fn create_test_tool(server_name: &str, tool_name: &str) -> ToolInfo {
ToolInfo {
server_name: server_name.to_string(),
tool_name: tool_name.to_string(),
tool: Tool {
annotations: None,
description: Some(format!("Test tool: {tool_name}")),
input_schema: ToolInputSchema {
properties: None,
required: None,
r#type: "object".to_string(),
},
name: tool_name.to_string(),
output_schema: None,
title: None,
},
}
}
#[test]
fn test_qualify_tools_short_non_duplicated_names() {
let tools = vec![
create_test_tool("server1", "tool1"),
create_test_tool("server1", "tool2"),
];
let qualified_tools = qualify_tools(tools);
assert_eq!(qualified_tools.len(), 2);
assert!(qualified_tools.contains_key("mcp__server1__tool1"));
assert!(qualified_tools.contains_key("mcp__server1__tool2"));
}
#[test]
fn test_qualify_tools_duplicated_names_skipped() {
let tools = vec![
create_test_tool("server1", "duplicate_tool"),
create_test_tool("server1", "duplicate_tool"),
];
let qualified_tools = qualify_tools(tools);
// Only the first tool should remain, the second is skipped
assert_eq!(qualified_tools.len(), 1);
assert!(qualified_tools.contains_key("mcp__server1__duplicate_tool"));
}
#[test]
fn test_qualify_tools_long_names_same_server() {
let server_name = "my_server";
let tools = vec![
create_test_tool(
server_name,
"extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
),
create_test_tool(
server_name,
"yet_another_extremely_lengthy_function_name_that_absolutely_surpasses_all_reasonable_limits",
),
];
let qualified_tools = qualify_tools(tools);
assert_eq!(qualified_tools.len(), 2);
let mut keys: Vec<_> = qualified_tools.keys().cloned().collect();
keys.sort();
assert_eq!(keys[0].len(), 64);
assert_eq!(
keys[0],
"mcp__my_server__extremel119a2b97664e41363932dc84de21e2ff1b93b3e9"
);
assert_eq!(keys[1].len(), 64);
assert_eq!(
keys[1],
"mcp__my_server__yet_anot419a82a89325c1b477274a41f8c65ea5f3a7f341"
);
}
#[test]
fn tool_filter_allows_by_default() {
let filter = ToolFilter::default();
assert!(filter.allows("any"));
}
#[test]
fn tool_filter_applies_enabled_list() {
let filter = ToolFilter {
enabled: Some(HashSet::from(["allowed".to_string()])),
disabled: HashSet::new(),
};
assert!(filter.allows("allowed"));
assert!(!filter.allows("denied"));
}
#[test]
fn tool_filter_applies_disabled_list() {
let filter = ToolFilter {
enabled: None,
disabled: HashSet::from(["blocked".to_string()]),
};
assert!(!filter.allows("blocked"));
assert!(filter.allows("open"));
}
#[test]
fn tool_filter_applies_enabled_then_disabled() {
let filter = ToolFilter {
enabled: Some(HashSet::from(["keep".to_string(), "remove".to_string()])),
disabled: HashSet::from(["remove".to_string()]),
};
assert!(filter.allows("keep"));
assert!(!filter.allows("remove"));
assert!(!filter.allows("unknown"));
}
#[test]
fn filter_tools_applies_per_server_filters() {
let tools = vec![
create_test_tool("server1", "tool_a"),
create_test_tool("server1", "tool_b"),
create_test_tool("server2", "tool_a"),
];
let mut filters = HashMap::new();
filters.insert(
"server1".to_string(),
ToolFilter {
enabled: Some(HashSet::from(["tool_a".to_string(), "tool_b".to_string()])),
disabled: HashSet::from(["tool_b".to_string()]),
},
);
filters.insert(
"server2".to_string(),
ToolFilter {
enabled: None,
disabled: HashSet::from(["tool_a".to_string()]),
},
);
let filtered = filter_tools(tools, &filters);
assert_eq!(filtered.len(), 1);
assert_eq!(filtered[0].server_name, "server1");
assert_eq!(filtered[0].tool_name, "tool_a");
}
}

View File

@@ -0,0 +1,80 @@
use std::time::Instant;
use tracing::error;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::protocol::EventMsg;
use crate::protocol::McpInvocation;
use crate::protocol::McpToolCallBeginEvent;
use crate::protocol::McpToolCallEndEvent;
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::ResponseInputItem;
/// Handles the specified tool call dispatches the appropriate
/// `McpToolCallBegin` and `McpToolCallEnd` events to the `Session`.
pub(crate) async fn handle_mcp_tool_call(
sess: &Session,
turn_context: &TurnContext,
call_id: String,
server: String,
tool_name: String,
arguments: String,
) -> ResponseInputItem {
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
// is not.
let arguments_value = if arguments.trim().is_empty() {
None
} else {
match serde_json::from_str::<serde_json::Value>(&arguments) {
Ok(value) => Some(value),
Err(e) => {
error!("failed to parse tool call arguments: {e}");
return ResponseInputItem::FunctionCallOutput {
call_id: call_id.clone(),
output: FunctionCallOutputPayload {
content: format!("err: {e}"),
success: Some(false),
..Default::default()
},
};
}
}
};
let invocation = McpInvocation {
server: server.clone(),
tool: tool_name.clone(),
arguments: arguments_value.clone(),
};
let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
call_id: call_id.clone(),
invocation: invocation.clone(),
});
notify_mcp_tool_call_event(sess, turn_context, tool_call_begin_event).await;
let start = Instant::now();
// Perform the tool call.
let result = sess
.call_tool(&server, &tool_name, arguments_value.clone())
.await
.map_err(|e| format!("tool call error: {e:?}"));
if let Err(e) = &result {
tracing::warn!("MCP tool call error: {e:?}");
}
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {
call_id: call_id.clone(),
invocation,
duration: start.elapsed(),
result: result.clone(),
});
notify_mcp_tool_call_event(sess, turn_context, tool_call_end_event.clone()).await;
ResponseInputItem::McpToolCallOutput { call_id, result }
}
async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext, event: EventMsg) {
sess.send_event(turn_context, event).await;
}

View File

@@ -0,0 +1,286 @@
//! Persistence layer for the global, append-only *message history* file.
//!
//! The history is stored at `~/.llmx/history.jsonl` with **one JSON object per
//! line** so that it can be efficiently appended to and parsed with standard
//! JSON-Lines tooling. Each record has the following schema:
//!
//! ````text
//! {"conversation_id":"<uuid>","ts":<unix_seconds>,"text":"<message>"}
//! ````
//!
//! To minimise the chance of interleaved writes when multiple processes are
//! appending concurrently, callers should *prepare the full line* (record +
//! trailing `\n`) and write it with a **single `write(2)` system call** while
//! the file descriptor is opened with the `O_APPEND` flag. POSIX guarantees
//! that writes up to `PIPE_BUF` bytes are atomic in that case.
use std::fs::File;
use std::fs::OpenOptions;
use std::io::Result;
use std::io::Write;
use std::path::PathBuf;
use serde::Deserialize;
use serde::Serialize;
use std::time::Duration;
use tokio::fs;
use tokio::io::AsyncReadExt;
use crate::config::Config;
use crate::config::types::HistoryPersistence;
use llmx_protocol::ConversationId;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
use std::os::unix::fs::PermissionsExt;
/// Filename that stores the message history inside `~/.llmx`.
const HISTORY_FILENAME: &str = "history.jsonl";
const MAX_RETRIES: usize = 10;
const RETRY_SLEEP: Duration = Duration::from_millis(100);
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HistoryEntry {
pub session_id: String,
pub ts: u64,
pub text: String,
}
fn history_filepath(config: &Config) -> PathBuf {
let mut path = config.llmx_home.clone();
path.push(HISTORY_FILENAME);
path
}
/// Append a `text` entry associated with `conversation_id` to the history file. Uses
/// advisory file locking to ensure that concurrent writes do not interleave,
/// which entails a small amount of blocking I/O internally.
pub(crate) async fn append_entry(
text: &str,
conversation_id: &ConversationId,
config: &Config,
) -> Result<()> {
match config.history.persistence {
HistoryPersistence::SaveAll => {
// Save everything: proceed.
}
HistoryPersistence::None => {
// No history persistence requested.
return Ok(());
}
}
// TODO: check `text` for sensitive patterns
// Resolve `~/.llmx/history.jsonl` and ensure the parent directory exists.
let path = history_filepath(config);
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
// Compute timestamp (seconds since the Unix epoch).
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| std::io::Error::other(format!("system clock before Unix epoch: {e}")))?
.as_secs();
// Construct the JSON line first so we can write it in a single syscall.
let entry = HistoryEntry {
session_id: conversation_id.to_string(),
ts,
text: text.to_string(),
};
let mut line = serde_json::to_string(&entry)
.map_err(|e| std::io::Error::other(format!("failed to serialise history entry: {e}")))?;
line.push('\n');
// Open in append-only mode.
let mut options = OpenOptions::new();
options.append(true).read(true).create(true);
#[cfg(unix)]
{
options.mode(0o600);
}
let mut history_file = options.open(&path)?;
// Ensure permissions.
ensure_owner_only_permissions(&history_file).await?;
// Perform a blocking write under an advisory write lock using std::fs.
tokio::task::spawn_blocking(move || -> Result<()> {
// Retry a few times to avoid indefinite blocking when contended.
for _ in 0..MAX_RETRIES {
match history_file.try_lock() {
Ok(()) => {
// While holding the exclusive lock, write the full line.
history_file.write_all(line.as_bytes())?;
history_file.flush()?;
return Ok(());
}
Err(std::fs::TryLockError::WouldBlock) => {
std::thread::sleep(RETRY_SLEEP);
}
Err(e) => return Err(e.into()),
}
}
Err(std::io::Error::new(
std::io::ErrorKind::WouldBlock,
"could not acquire exclusive lock on history file after multiple attempts",
))
})
.await??;
Ok(())
}
/// Asynchronously fetch the history file's *identifier* (inode on Unix) and
/// the current number of entries by counting newline characters.
pub(crate) async fn history_metadata(config: &Config) -> (u64, usize) {
let path = history_filepath(config);
#[cfg(unix)]
let log_id = {
use std::os::unix::fs::MetadataExt;
// Obtain metadata (async) to get the identifier.
let meta = match fs::metadata(&path).await {
Ok(m) => m,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return (0, 0),
Err(_) => return (0, 0),
};
meta.ino()
};
#[cfg(not(unix))]
let log_id = 0u64;
// Open the file.
let mut file = match fs::File::open(&path).await {
Ok(f) => f,
Err(_) => return (log_id, 0),
};
// Count newline bytes.
let mut buf = [0u8; 8192];
let mut count = 0usize;
loop {
match file.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
count += buf[..n].iter().filter(|&&b| b == b'\n').count();
}
Err(_) => return (log_id, 0),
}
}
(log_id, count)
}
/// Given a `log_id` (on Unix this is the file's inode number) and a zero-based
/// `offset`, return the corresponding `HistoryEntry` if the identifier matches
/// the current history file **and** the requested offset exists. Any I/O or
/// parsing errors are logged and result in `None`.
///
/// Note this function is not async because it uses a sync advisory file
/// locking API.
#[cfg(unix)]
pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<HistoryEntry> {
use std::io::BufRead;
use std::io::BufReader;
use std::os::unix::fs::MetadataExt;
let path = history_filepath(config);
let file: File = match OpenOptions::new().read(true).open(&path) {
Ok(f) => f,
Err(e) => {
tracing::warn!(error = %e, "failed to open history file");
return None;
}
};
let metadata = match file.metadata() {
Ok(m) => m,
Err(e) => {
tracing::warn!(error = %e, "failed to stat history file");
return None;
}
};
if metadata.ino() != log_id {
return None;
}
// Open & lock file for reading using a shared lock.
// Retry a few times to avoid indefinite blocking.
for _ in 0..MAX_RETRIES {
let lock_result = file.try_lock_shared();
match lock_result {
Ok(()) => {
let reader = BufReader::new(&file);
for (idx, line_res) in reader.lines().enumerate() {
let line = match line_res {
Ok(l) => l,
Err(e) => {
tracing::warn!(error = %e, "failed to read line from history file");
return None;
}
};
if idx == offset {
match serde_json::from_str::<HistoryEntry>(&line) {
Ok(entry) => return Some(entry),
Err(e) => {
tracing::warn!(error = %e, "failed to parse history entry");
return None;
}
}
}
}
// Not found at requested offset.
return None;
}
Err(std::fs::TryLockError::WouldBlock) => {
std::thread::sleep(RETRY_SLEEP);
}
Err(e) => {
tracing::warn!(error = %e, "failed to acquire shared lock on history file");
return None;
}
}
}
None
}
/// Fallback stub for non-Unix systems: currently always returns `None`.
#[cfg(not(unix))]
pub(crate) fn lookup(log_id: u64, offset: usize, config: &Config) -> Option<HistoryEntry> {
let _ = (log_id, offset, config);
None
}
/// On Unix systems ensure the file permissions are `0o600` (rw-------). If the
/// permissions cannot be changed the error is propagated to the caller.
#[cfg(unix)]
async fn ensure_owner_only_permissions(file: &File) -> Result<()> {
let metadata = file.metadata()?;
let current_mode = metadata.permissions().mode() & 0o777;
if current_mode != 0o600 {
let mut perms = metadata.permissions();
perms.set_mode(0o600);
let perms_clone = perms.clone();
let file_clone = file.try_clone()?;
tokio::task::spawn_blocking(move || file_clone.set_permissions(perms_clone)).await??;
}
Ok(())
}
#[cfg(not(unix))]
async fn ensure_owner_only_permissions(_file: &File) -> Result<()> {
// For now, on non-Unix, simply succeed.
Ok(())
}

View File

@@ -0,0 +1,196 @@
use crate::config::types::ReasoningSummaryFormat;
use crate::tools::handlers::apply_patch::ApplyPatchToolType;
use crate::tools::spec::ConfigShellToolType;
/// The `instructions` field in the payload sent to a model should always start
/// with this content.
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
const GPT_5_LLMX_INSTRUCTIONS: &str = include_str!("../gpt_5_llmx_prompt.md");
/// A model family is a group of models that share certain characteristics.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
/// The full model slug used to derive this model family, e.g.
/// "gpt-4.1-2025-04-14".
pub slug: String,
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
/// with [`crate::openai_model_info::get_model_info`].
pub family: String,
/// True if the model needs additional instructions on how to use the
/// "virtual" `apply_patch` CLI.
pub needs_special_apply_patch_instructions: bool,
// Whether the `reasoning` field can be set when making a request to this
// model family. Note it has `effort` and `summary` subfields (though
// `summary` is optional).
pub supports_reasoning_summaries: bool,
// Define if we need a special handling of reasoning summary
pub reasoning_summary_format: ReasoningSummaryFormat,
/// Whether this model supports parallel tool calls when using the
/// Responses API.
pub supports_parallel_tool_calls: bool,
/// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
// Instructions to use for querying the model
pub base_instructions: String,
/// Names of beta tools that should be exposed to this model family.
pub experimental_supported_tools: Vec<String>,
/// Percentage of the context window considered usable for inputs, after
/// reserving headroom for system prompts, tool overhead, and model output.
/// This is applied when computing the effective context window seen by
/// consumers.
pub effective_context_window_percent: i64,
/// If the model family supports setting the verbosity level when using Responses API.
pub support_verbosity: bool,
/// Preferred shell tool type for this model family when features do not override it.
pub shell_type: ConfigShellToolType,
}
macro_rules! model_family {
(
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?
) => {{
// defaults
#[allow(unused_mut)]
let mut mf = ModelFamily {
slug: $slug.to_string(),
family: $family.to_string(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),
effective_context_window_percent: 95,
support_verbosity: false,
shell_type: ConfigShellToolType::Default,
};
// apply overrides
$(
mf.$key = $value;
)*
Some(mf)
}};
}
/// Returns a `ModelFamily` for the given model slug, or `None` if the slug
/// does not match any known model family.
pub fn find_family_for_model(slug: &str) -> Option<ModelFamily> {
if slug.starts_with("o3") {
model_family!(
slug, "o3",
supports_reasoning_summaries: true,
needs_special_apply_patch_instructions: true,
)
} else if slug.starts_with("o4-mini") {
model_family!(
slug, "o4-mini",
supports_reasoning_summaries: true,
needs_special_apply_patch_instructions: true,
)
} else if slug.starts_with("llmx-mini-latest") {
model_family!(
slug, "llmx-mini-latest",
supports_reasoning_summaries: true,
needs_special_apply_patch_instructions: true,
shell_type: ConfigShellToolType::Local,
)
} else if slug.starts_with("gpt-4.1") {
model_family!(
slug, "gpt-4.1",
needs_special_apply_patch_instructions: true,
)
} else if slug.starts_with("gpt-oss") || slug.starts_with("openai/gpt-oss") {
model_family!(slug, "gpt-oss", apply_patch_tool_type: Some(ApplyPatchToolType::Function))
} else if slug.starts_with("gpt-4o") {
model_family!(slug, "gpt-4o", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("gpt-3.5") {
model_family!(slug, "gpt-3.5", needs_special_apply_patch_instructions: true)
} else if slug.starts_with("porcupine") {
model_family!(slug, "porcupine", shell_type: ConfigShellToolType::UnifiedExec)
} else if slug.starts_with("test-gpt-5-llmx") {
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_LLMX_INSTRUCTIONS.to_string(),
experimental_supported_tools: vec![
"grep_files".to_string(),
"list_dir".to_string(),
"read_file".to_string(),
"test_sync_tool".to_string(),
],
supports_parallel_tool_calls: true,
support_verbosity: true,
)
// Internal models.
} else if slug.starts_with("llmx-exp-") {
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_LLMX_INSTRUCTIONS.to_string(),
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
experimental_supported_tools: vec![
"grep_files".to_string(),
"list_dir".to_string(),
"read_file".to_string(),
],
supports_parallel_tool_calls: true,
support_verbosity: true,
)
// Production models.
} else if slug == "anthropic/claude-sonnet-4-20250514"
|| slug.starts_with("gpt-5-llmx")
|| slug.starts_with("llmx-")
{
model_family!(
slug, slug,
supports_reasoning_summaries: true,
reasoning_summary_format: ReasoningSummaryFormat::Experimental,
base_instructions: GPT_5_LLMX_INSTRUCTIONS.to_string(),
apply_patch_tool_type: Some(ApplyPatchToolType::Freeform),
support_verbosity: false,
)
} else if slug.starts_with("gpt-5") {
model_family!(
slug, "gpt-5",
supports_reasoning_summaries: true,
needs_special_apply_patch_instructions: true,
support_verbosity: true,
)
} else {
None
}
}
pub fn derive_default_model_family(model: &str) -> ModelFamily {
ModelFamily {
slug: model.to_string(),
family: model.to_string(),
needs_special_apply_patch_instructions: false,
supports_reasoning_summaries: false,
reasoning_summary_format: ReasoningSummaryFormat::None,
supports_parallel_tool_calls: false,
apply_patch_tool_type: None,
base_instructions: BASE_INSTRUCTIONS.to_string(),
experimental_supported_tools: Vec::new(),
effective_context_window_percent: 95,
support_verbosity: false,
shell_type: ConfigShellToolType::Default,
}
}

View File

@@ -0,0 +1,554 @@
//! Registry of model providers supported by LLMX.
//!
//! Providers can be defined in two places:
//! 1. Built-in defaults compiled into the binary so LLMX works out-of-the-box.
//! 2. User-defined entries inside `~/.llmx/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use crate::LlmxAuth;
use crate::default_client::LlmxHttpClient;
use crate::default_client::LlmxRequestBuilder;
use llmx_app_server_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::time::Duration;
use crate::error::EnvVarError;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
/// Hard cap for user-configured `stream_max_retries`.
const MAX_STREAM_MAX_RETRIES: u64 = 100;
/// Hard cap for user-configured `request_max_retries`.
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// *Responses* API. The two protocols use different request/response shapes
/// and *cannot* be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: Option<String>,
/// Environment variable that stores the user's API key for this provider.
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Value to use with `Authorization: Bearer <token>` header. Use of this
/// config is discouraged in favor of `env_key` for security reasons, but
/// this may be necessary when using this programmatically.
pub experimental_bearer_token: Option<String>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
/// Optional query parameters to append to the base URL.
pub query_params: Option<HashMap<String, String>>,
/// Additional HTTP headers to include in requests to this provider where
/// the (key, value) pairs are the header name and value.
pub http_headers: Option<HashMap<String, String>>,
/// Optional HTTP headers to include in requests to this provider where the
/// (key, value) pairs are the header name and _environment variable_ whose
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<u64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<u64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<u64>,
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
/// user is presented with login screen on first run, and login preference and token/key
/// are stored in auth.json. If false (which is the default), login screen is skipped,
/// and API key (if needed) comes from the "env_key" environment variable.
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
/// Construct a `POST` RequestBuilder for the given URL using the provided
/// [`LlmxHttpClient`] applying:
/// • provider-specific headers (static + env based)
/// • Bearer auth header when an API key is available.
/// • Auth token for OAuth.
///
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
/// one produced by [`ModelProviderInfo::api_key`].
pub async fn create_request_builder<'a>(
&'a self,
client: &'a LlmxHttpClient,
auth: &Option<LlmxAuth>,
) -> crate::error::Result<LlmxRequestBuilder> {
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
Some(LlmxAuth::from_api_key(secret_key))
} else {
match self.api_key() {
Ok(Some(key)) => Some(LlmxAuth::from_api_key(&key)),
Ok(None) => auth.clone(),
Err(err) => {
if auth.is_some() {
auth.clone()
} else {
return Err(err);
}
}
}
};
let url = self.get_full_url(&effective_auth);
let mut builder = client.post(url);
if let Some(auth) = effective_auth.as_ref() {
builder = builder.bearer_auth(auth.get_token().await?);
}
Ok(self.apply_http_headers(builder))
}
fn get_query_string(&self) -> String {
self.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
pub(crate) fn get_full_url(&self, auth: &Option<LlmxAuth>) -> String {
let default_base_url = if matches!(
auth,
Some(LlmxAuth {
mode: AuthMode::ChatGPT,
..
})
) {
"https://chatgpt.com/backend-api/llmx"
} else {
"https://api.openai.com/v1"
};
let query_string = self.get_query_string();
let base_url = self
.base_url
.clone()
.unwrap_or(default_base_url.to_string());
match self.wire_api {
WireApi::Responses => format!("{base_url}/responses{query_string}"),
WireApi::Chat => format!("{base_url}/chat/completions{query_string}"),
}
}
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
if self.wire_api != WireApi::Responses {
return false;
}
if self.name.eq_ignore_ascii_case("azure") {
return true;
}
self.base_url
.as_ref()
.map(|base| matches_azure_responses_base_url(base))
.unwrap_or(false)
}
/// Apply provider-specific HTTP headers (both static and environment-based)
/// onto an existing [`LlmxRequestBuilder`] and return the updated
/// builder.
fn apply_http_headers(&self, mut builder: LlmxRequestBuilder) -> LlmxRequestBuilder {
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
builder = builder.header(k, v);
}
}
if let Some(env_headers) = &self.env_http_headers {
for (header, env_var) in env_headers {
if let Ok(val) = std::env::var(env_var)
&& !val.trim().is_empty()
{
builder = builder.header(header, val);
}
}
}
builder
}
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
match &self.env_key {
Some(env_key) => {
let env_value = std::env::var(env_key);
env_value
.and_then(|v| {
if v.trim().is_empty() {
Err(VarError::NotPresent)
} else {
Ok(Some(v))
}
})
.map_err(|_| {
crate::error::LlmxErr::EnvVar(EnvVarError {
var: env_key.clone(),
instructions: self.env_key_instructions.clone(),
})
})
}
None => Ok(None),
}
}
/// Effective maximum number of request retries for this provider.
pub fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES)
}
/// Effective maximum number of stream reconnection attempts for this provider.
pub fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES)
}
/// Effective idle timeout for streaming responses.
pub fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
}
const DEFAULT_OLLAMA_PORT: u32 = 11434;
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
/// Built-in default provider list.
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
// We do not want to be in the business of adjucating which third-party
// providers are bundled with LLMX CLI, so we include LiteLLM as the default,
// along with OpenAI and open source ("oss") providers. Users are encouraged
// to add to `model_providers` in config.toml to add their own providers.
[
(
"litellm",
P {
name: "LiteLLM".into(),
// Allow users to override the default LiteLLM endpoint
base_url: std::env::var("LLMX_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
.or_else(|| Some("http://localhost:4000/v1".into())),
env_key: Some("LLMX_API_KEY".into()),
env_key_instructions: Some("Set LLMX_API_KEY to your LiteLLM API key. See https://docs.litellm.ai/ for setup.".into()),
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
},
),
(
"openai",
P {
name: "OpenAI".into(),
// Allow users to override the default OpenAI endpoint by
// exporting `LLMX_BASE_URL`. This is useful when pointing
// LLMX at a proxy, mock server, or Azure-style deployment
// without requiring a full TOML override for the built-in
// OpenAI provider.
base_url: std::env::var("LLMX_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: Some(
[("version".to_string(), env!("CARGO_PKG_VERSION").to_string())]
.into_iter()
.collect(),
),
env_http_headers: Some(
[
(
"OpenAI-Organization".to_string(),
"OPENAI_ORGANIZATION".to_string(),
),
("OpenAI-Project".to_string(), "OPENAI_PROJECT".to_string()),
]
.into_iter()
.collect(),
),
// Use global defaults for retry/timeout unless overridden in config.toml.
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: true,
},
),
(BUILT_IN_OSS_MODEL_PROVIDER_ID, create_oss_provider()),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v))
.collect()
}
pub fn create_oss_provider() -> ModelProviderInfo {
// These LLMX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let llmx_oss_base_url = match std::env::var("LLMX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
{
Some(url) => url,
None => format!(
"http://localhost:{port}/v1",
port = std::env::var("LLMX_OSS_PORT")
.ok()
.filter(|v| !v.trim().is_empty())
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(DEFAULT_OLLAMA_PORT)
),
};
create_oss_provider_with_base_url(&llmx_oss_base_url)
}
pub fn create_oss_provider_with_base_url(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "gpt-oss".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
fn matches_azure_responses_base_url(base_url: &str) -> bool {
let base = base_url.to_ascii_lowercase();
const AZURE_MARKERS: [&str; 5] = [
"openai.azure.",
"cognitiveservices.azure.",
"aoai.azure.",
"azure-api.",
"azurefd.",
];
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
let azure_provider_toml = r#"
name = "Ollama"
base_url = "http://localhost:11434/v1"
"#;
let expected_provider = ModelProviderInfo {
name: "Ollama".into(),
base_url: Some("http://localhost:11434/v1".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_azure_model_provider_toml() {
let azure_provider_toml = r#"
name = "Azure"
base_url = "https://xxxxx.openai.azure.com/openai"
env_key = "AZURE_OPENAI_API_KEY"
query_params = { api-version = "2025-04-01-preview" }
"#;
let expected_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_example_model_provider_toml() {
let azure_provider_toml = r#"
name = "Example"
base_url = "https://example.com"
env_key = "API_KEY"
http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
"#;
let expected_provider = ModelProviderInfo {
name: "Example".into(),
base_url: Some("https://example.com".into()),
env_key: Some("API_KEY".into()),
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(maplit::hashmap! {
"X-Example-Header".to_string() => "example-value".to_string(),
}),
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn detects_azure_responses_base_urls() {
fn provider_for(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];
for base_url in positive_cases {
let provider = provider_for(base_url);
assert!(
provider.is_azure_responses_endpoint(),
"expected {base_url} to be detected as Azure"
);
}
let named_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com".into()),
env_key: None,
env_key_instructions: None,
experimental_bearer_token: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
assert!(named_provider.is_azure_responses_endpoint());
let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];
for base_url in negative_cases {
let provider = provider_for(base_url);
assert!(
!provider.is_azure_responses_endpoint(),
"expected {base_url} not to be detected as Azure"
);
}
}
}

View File

@@ -0,0 +1,87 @@
use crate::model_family::ModelFamily;
// Shared constants for commonly used window/token sizes.
pub(crate) const CONTEXT_WINDOW_272K: i64 = 272_000;
pub(crate) const MAX_OUTPUT_TOKENS_128K: i64 = 128_000;
/// Metadata about a model, particularly OpenAI models.
/// We may want to consider including details like the pricing for
/// input tokens, output tokens, etc., though users will need to be able to
/// override this in config.toml, as this information can get out of date.
/// Though this would help present more accurate pricing information in the UI.
#[derive(Debug)]
pub(crate) struct ModelInfo {
/// Size of the context window in tokens. This is the maximum size of the input context.
pub(crate) context_window: i64,
/// Maximum number of output tokens that can be generated for the model.
pub(crate) max_output_tokens: i64,
/// Token threshold where we should automatically compact conversation history. This considers
/// input tokens + output tokens of this turn.
pub(crate) auto_compact_token_limit: Option<i64>,
}
impl ModelInfo {
const fn new(context_window: i64, max_output_tokens: i64) -> Self {
Self {
context_window,
max_output_tokens,
auto_compact_token_limit: Some(Self::default_auto_compact_limit(context_window)),
}
}
const fn default_auto_compact_limit(context_window: i64) -> i64 {
(context_window * 9) / 10
}
}
pub(crate) fn get_model_info(model_family: &ModelFamily) -> Option<ModelInfo> {
let slug = model_family.slug.as_str();
match slug {
// OSS models have a 128k shared token pool.
// Arbitrarily splitting it: 3/4 input context, 1/4 output.
// https://openai.com/index/gpt-oss-model-card/
"gpt-oss-20b" => Some(ModelInfo::new(96_000, 32_000)),
"gpt-oss-120b" => Some(ModelInfo::new(96_000, 32_000)),
// https://platform.openai.com/docs/models/o3
"o3" => Some(ModelInfo::new(200_000, 100_000)),
// https://platform.openai.com/docs/models/o4-mini
"o4-mini" => Some(ModelInfo::new(200_000, 100_000)),
// https://platform.openai.com/docs/models/llmx-mini-latest
"llmx-mini-latest" => Some(ModelInfo::new(200_000, 100_000)),
// As of Jun 25, 2025, gpt-4.1 defaults to gpt-4.1-2025-04-14.
// https://platform.openai.com/docs/models/gpt-4.1
"gpt-4.1" | "gpt-4.1-2025-04-14" => Some(ModelInfo::new(1_047_576, 32_768)),
// As of Jun 25, 2025, gpt-4o defaults to gpt-4o-2024-08-06.
// https://platform.openai.com/docs/models/gpt-4o
"gpt-4o" | "gpt-4o-2024-08-06" => Some(ModelInfo::new(128_000, 16_384)),
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-05-13
"gpt-4o-2024-05-13" => Some(ModelInfo::new(128_000, 4_096)),
// https://platform.openai.com/docs/models/gpt-4o?snapshot=gpt-4o-2024-11-20
"gpt-4o-2024-11-20" => Some(ModelInfo::new(128_000, 16_384)),
// https://platform.openai.com/docs/models/gpt-3.5-turbo
"gpt-3.5-turbo" => Some(ModelInfo::new(16_385, 4_096)),
_ if slug.starts_with("gpt-5-llmx") => {
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
}
_ if slug.starts_with("gpt-5") => {
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
}
_ if slug.starts_with("llmx-") => {
Some(ModelInfo::new(CONTEXT_WINDOW_272K, MAX_OUTPUT_TOKENS_128K))
}
_ => None,
}
}

View File

@@ -0,0 +1,61 @@
use crate::config::Config;
use crate::config::types::OtelExporterKind as Kind;
use crate::config::types::OtelHttpProtocol as Protocol;
use crate::default_client::originator;
use llmx_otel::config::OtelExporter;
use llmx_otel::config::OtelHttpProtocol;
use llmx_otel::config::OtelSettings;
use llmx_otel::otel_provider::OtelProvider;
use std::error::Error;
/// Build an OpenTelemetry provider from the app Config.
///
/// Returns `None` when OTEL export is disabled.
pub fn build_provider(
config: &Config,
service_version: &str,
) -> Result<Option<OtelProvider>, Box<dyn Error>> {
let exporter = match &config.otel.exporter {
Kind::None => OtelExporter::None,
Kind::OtlpHttp {
endpoint,
headers,
protocol,
} => {
let protocol = match protocol {
Protocol::Json => OtelHttpProtocol::Json,
Protocol::Binary => OtelHttpProtocol::Binary,
};
OtelExporter::OtlpHttp {
endpoint: endpoint.clone(),
headers: headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
protocol,
}
}
Kind::OtlpGrpc { endpoint, headers } => OtelExporter::OtlpGrpc {
endpoint: endpoint.clone(),
headers: headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
},
};
OtelProvider::from(&OtelSettings {
service_name: originator().value.to_owned(),
service_version: service_version.to_string(),
llmx_home: config.llmx_home.clone(),
environment: config.otel.environment.to_string(),
exporter,
})
}
/// Filter predicate for exporting only Llmx-owned events via OTEL.
/// Keeps events that originated from llmx_otel module
pub fn llmx_export_filter(meta: &tracing::Metadata<'_>) -> bool {
meta.target().starts_with("llmx_otel")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,450 @@
//! Project-level documentation discovery.
//!
//! Project-level documentation is primarily stored in files named `AGENTS.md`.
//! Additional fallback filenames can be configured via `project_doc_fallback_filenames`.
//! We include the concatenation of all files found along the path from the
//! repository root to the current working directory as follows:
//!
//! 1. Determine the Git repository root by walking upwards from the current
//! working directory until a `.git` directory or file is found. If no Git
//! root is found, only the current working directory is considered.
//! 2. Collect every `AGENTS.md` found from the repository root down to the
//! current working directory (inclusive) and concatenate their contents in
//! that order.
//! 3. We do **not** walk past the Git root.
use crate::config::Config;
use dunce::canonicalize as normalize_path;
use std::path::PathBuf;
use tokio::io::AsyncReadExt;
use tracing::error;
/// Default filename scanned for project-level docs.
pub const DEFAULT_PROJECT_DOC_FILENAME: &str = "AGENTS.md";
/// Preferred local override for project-level docs.
pub const LOCAL_PROJECT_DOC_FILENAME: &str = "AGENTS.override.md";
/// When both `Config::instructions` and the project doc are present, they will
/// be concatenated with the following separator.
const PROJECT_DOC_SEPARATOR: &str = "\n\n--- project-doc ---\n\n";
/// Combines `Config::instructions` and `AGENTS.md` (if present) into a single
/// string of instructions.
pub(crate) async fn get_user_instructions(config: &Config) -> Option<String> {
match read_project_docs(config).await {
Ok(Some(project_doc)) => match &config.user_instructions {
Some(original_instructions) => Some(format!(
"{original_instructions}{PROJECT_DOC_SEPARATOR}{project_doc}"
)),
None => Some(project_doc),
},
Ok(None) => config.user_instructions.clone(),
Err(e) => {
error!("error trying to find project doc: {e:#}");
config.user_instructions.clone()
}
}
}
/// Attempt to locate and load the project documentation.
///
/// On success returns `Ok(Some(contents))` where `contents` is the
/// concatenation of all discovered docs. If no documentation file is found the
/// function returns `Ok(None)`. Unexpected I/O failures bubble up as `Err` so
/// callers can decide how to handle them.
pub async fn read_project_docs(config: &Config) -> std::io::Result<Option<String>> {
let max_total = config.project_doc_max_bytes;
if max_total == 0 {
return Ok(None);
}
let paths = discover_project_doc_paths(config)?;
if paths.is_empty() {
return Ok(None);
}
let mut remaining: u64 = max_total as u64;
let mut parts: Vec<String> = Vec::new();
for p in paths {
if remaining == 0 {
break;
}
let file = match tokio::fs::File::open(&p).await {
Ok(f) => f,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue,
Err(e) => return Err(e),
};
let size = file.metadata().await?.len();
let mut reader = tokio::io::BufReader::new(file).take(remaining);
let mut data: Vec<u8> = Vec::new();
reader.read_to_end(&mut data).await?;
if size > remaining {
tracing::warn!(
"Project doc `{}` exceeds remaining budget ({} bytes) - truncating.",
p.display(),
remaining,
);
}
let text = String::from_utf8_lossy(&data).to_string();
if !text.trim().is_empty() {
parts.push(text);
remaining = remaining.saturating_sub(data.len() as u64);
}
}
if parts.is_empty() {
Ok(None)
} else {
Ok(Some(parts.join("\n\n")))
}
}
/// Discover the list of AGENTS.md files using the same search rules as
/// `read_project_docs`, but return the file paths instead of concatenated
/// contents. The list is ordered from repository root to the current working
/// directory (inclusive). Symlinks are allowed. When `project_doc_max_bytes`
/// is zero, returns an empty list.
pub fn discover_project_doc_paths(config: &Config) -> std::io::Result<Vec<PathBuf>> {
let mut dir = config.cwd.clone();
if let Ok(canon) = normalize_path(&dir) {
dir = canon;
}
// Build chain from cwd upwards and detect git root.
let mut chain: Vec<PathBuf> = vec![dir.clone()];
let mut git_root: Option<PathBuf> = None;
let mut cursor = dir;
while let Some(parent) = cursor.parent() {
let git_marker = cursor.join(".git");
let git_exists = match std::fs::metadata(&git_marker) {
Ok(_) => true,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => false,
Err(e) => return Err(e),
};
if git_exists {
git_root = Some(cursor.clone());
break;
}
chain.push(parent.to_path_buf());
cursor = parent.to_path_buf();
}
let search_dirs: Vec<PathBuf> = if let Some(root) = git_root {
let mut dirs: Vec<PathBuf> = Vec::new();
let mut saw_root = false;
for p in chain.iter().rev() {
if !saw_root {
if p == &root {
saw_root = true;
} else {
continue;
}
}
dirs.push(p.clone());
}
dirs
} else {
vec![config.cwd.clone()]
};
let mut found: Vec<PathBuf> = Vec::new();
let candidate_filenames = candidate_filenames(config);
for d in search_dirs {
for name in &candidate_filenames {
let candidate = d.join(name);
match std::fs::symlink_metadata(&candidate) {
Ok(md) => {
let ft = md.file_type();
// Allow regular files and symlinks; opening will later fail for dangling links.
if ft.is_file() || ft.is_symlink() {
found.push(candidate);
break;
}
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue,
Err(e) => return Err(e),
}
}
}
Ok(found)
}
fn candidate_filenames<'a>(config: &'a Config) -> Vec<&'a str> {
let mut names: Vec<&'a str> =
Vec::with_capacity(2 + config.project_doc_fallback_filenames.len());
names.push(LOCAL_PROJECT_DOC_FILENAME);
names.push(DEFAULT_PROJECT_DOC_FILENAME);
for candidate in &config.project_doc_fallback_filenames {
let candidate = candidate.as_str();
if candidate.is_empty() {
continue;
}
if !names.contains(&candidate) {
names.push(candidate);
}
}
names
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ConfigOverrides;
use crate::config::ConfigToml;
use std::fs;
use tempfile::TempDir;
/// Helper that returns a `Config` pointing at `root` and using `limit` as
/// the maximum number of bytes to embed from AGENTS.md. The caller can
/// optionally specify a custom `instructions` string when `None` the
/// value is cleared to mimic a scenario where no system instructions have
/// been configured.
fn make_config(root: &TempDir, limit: usize, instructions: Option<&str>) -> Config {
let llmx_home = TempDir::new().unwrap();
let mut config = Config::load_from_base_config_with_overrides(
ConfigToml::default(),
ConfigOverrides::default(),
llmx_home.path().to_path_buf(),
)
.expect("defaults for test should always succeed");
config.cwd = root.path().to_path_buf();
config.project_doc_max_bytes = limit;
config.user_instructions = instructions.map(ToOwned::to_owned);
config
}
fn make_config_with_fallback(
root: &TempDir,
limit: usize,
instructions: Option<&str>,
fallbacks: &[&str],
) -> Config {
let mut config = make_config(root, limit, instructions);
config.project_doc_fallback_filenames = fallbacks
.iter()
.map(std::string::ToString::to_string)
.collect();
config
}
/// AGENTS.md missing should yield `None`.
#[tokio::test]
async fn no_doc_file_returns_none() {
let tmp = tempfile::tempdir().expect("tempdir");
let res = get_user_instructions(&make_config(&tmp, 4096, None)).await;
assert!(
res.is_none(),
"Expected None when AGENTS.md is absent and no system instructions provided"
);
assert!(res.is_none(), "Expected None when AGENTS.md is absent");
}
/// Small file within the byte-limit is returned unmodified.
#[tokio::test]
async fn doc_smaller_than_limit_is_returned() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join("AGENTS.md"), "hello world").unwrap();
let res = get_user_instructions(&make_config(&tmp, 4096, None))
.await
.expect("doc expected");
assert_eq!(
res, "hello world",
"The document should be returned verbatim when it is smaller than the limit and there are no existing instructions"
);
}
/// Oversize file is truncated to `project_doc_max_bytes`.
#[tokio::test]
async fn doc_larger_than_limit_is_truncated() {
const LIMIT: usize = 1024;
let tmp = tempfile::tempdir().expect("tempdir");
let huge = "A".repeat(LIMIT * 2); // 2 KiB
fs::write(tmp.path().join("AGENTS.md"), &huge).unwrap();
let res = get_user_instructions(&make_config(&tmp, LIMIT, None))
.await
.expect("doc expected");
assert_eq!(res.len(), LIMIT, "doc should be truncated to LIMIT bytes");
assert_eq!(res, huge[..LIMIT]);
}
/// When `cwd` is nested inside a repo, the search should locate AGENTS.md
/// placed at the repository root (identified by `.git`).
#[tokio::test]
async fn finds_doc_in_repo_root() {
let repo = tempfile::tempdir().expect("tempdir");
// Simulate a git repository. Note .git can be a file or a directory.
std::fs::write(
repo.path().join(".git"),
"gitdir: /path/to/actual/git/dir\n",
)
.unwrap();
// Put the doc at the repo root.
fs::write(repo.path().join("AGENTS.md"), "root level doc").unwrap();
// Now create a nested working directory: repo/workspace/crate_a
let nested = repo.path().join("workspace/crate_a");
std::fs::create_dir_all(&nested).unwrap();
// Build config pointing at the nested dir.
let mut cfg = make_config(&repo, 4096, None);
cfg.cwd = nested;
let res = get_user_instructions(&cfg).await.expect("doc expected");
assert_eq!(res, "root level doc");
}
/// Explicitly setting the byte-limit to zero disables project docs.
#[tokio::test]
async fn zero_byte_limit_disables_docs() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join("AGENTS.md"), "something").unwrap();
let res = get_user_instructions(&make_config(&tmp, 0, None)).await;
assert!(
res.is_none(),
"With limit 0 the function should return None"
);
}
/// When both system instructions *and* a project doc are present the two
/// should be concatenated with the separator.
#[tokio::test]
async fn merges_existing_instructions_with_project_doc() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join("AGENTS.md"), "proj doc").unwrap();
const INSTRUCTIONS: &str = "base instructions";
let res = get_user_instructions(&make_config(&tmp, 4096, Some(INSTRUCTIONS)))
.await
.expect("should produce a combined instruction string");
let expected = format!("{INSTRUCTIONS}{PROJECT_DOC_SEPARATOR}{}", "proj doc");
assert_eq!(res, expected);
}
/// If there are existing system instructions but the project doc is
/// missing we expect the original instructions to be returned unchanged.
#[tokio::test]
async fn keeps_existing_instructions_when_doc_missing() {
let tmp = tempfile::tempdir().expect("tempdir");
const INSTRUCTIONS: &str = "some instructions";
let res = get_user_instructions(&make_config(&tmp, 4096, Some(INSTRUCTIONS))).await;
assert_eq!(res, Some(INSTRUCTIONS.to_string()));
}
/// When both the repository root and the working directory contain
/// AGENTS.md files, their contents are concatenated from root to cwd.
#[tokio::test]
async fn concatenates_root_and_cwd_docs() {
let repo = tempfile::tempdir().expect("tempdir");
// Simulate a git repository.
std::fs::write(
repo.path().join(".git"),
"gitdir: /path/to/actual/git/dir\n",
)
.unwrap();
// Repo root doc.
fs::write(repo.path().join("AGENTS.md"), "root doc").unwrap();
// Nested working directory with its own doc.
let nested = repo.path().join("workspace/crate_a");
std::fs::create_dir_all(&nested).unwrap();
fs::write(nested.join("AGENTS.md"), "crate doc").unwrap();
let mut cfg = make_config(&repo, 4096, None);
cfg.cwd = nested;
let res = get_user_instructions(&cfg).await.expect("doc expected");
assert_eq!(res, "root doc\n\ncrate doc");
}
/// AGENTS.override.md is preferred over AGENTS.md when both are present.
#[tokio::test]
async fn agents_local_md_preferred() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join(DEFAULT_PROJECT_DOC_FILENAME), "versioned").unwrap();
fs::write(tmp.path().join(LOCAL_PROJECT_DOC_FILENAME), "local").unwrap();
let cfg = make_config(&tmp, 4096, None);
let res = get_user_instructions(&cfg)
.await
.expect("local doc expected");
assert_eq!(res, "local");
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
assert_eq!(discovery.len(), 1);
assert_eq!(
discovery[0].file_name().unwrap().to_string_lossy(),
LOCAL_PROJECT_DOC_FILENAME
);
}
/// When AGENTS.md is absent but a configured fallback exists, the fallback is used.
#[tokio::test]
async fn uses_configured_fallback_when_agents_missing() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join("EXAMPLE.md"), "example instructions").unwrap();
let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md"]);
let res = get_user_instructions(&cfg)
.await
.expect("fallback doc expected");
assert_eq!(res, "example instructions");
}
/// AGENTS.md remains preferred when both AGENTS.md and fallbacks are present.
#[tokio::test]
async fn agents_md_preferred_over_fallbacks() {
let tmp = tempfile::tempdir().expect("tempdir");
fs::write(tmp.path().join("AGENTS.md"), "primary").unwrap();
fs::write(tmp.path().join("EXAMPLE.md"), "secondary").unwrap();
let cfg = make_config_with_fallback(&tmp, 4096, None, &["EXAMPLE.md", ".example.md"]);
let res = get_user_instructions(&cfg)
.await
.expect("AGENTS.md should win");
assert_eq!(res, "primary");
let discovery = discover_project_doc_paths(&cfg).expect("discover paths");
assert_eq!(discovery.len(), 1);
assert!(
discovery[0]
.file_name()
.unwrap()
.to_string_lossy()
.eq(DEFAULT_PROJECT_DOC_FILENAME)
);
}
}

View File

@@ -0,0 +1,104 @@
use crate::llmx::Session;
use crate::llmx::TurnContext;
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::ResponseInputItem;
use llmx_protocol::models::ResponseItem;
use tracing::warn;
/// Process streamed `ResponseItem`s from the model into the pair of:
/// - items we should record in conversation history; and
/// - `ResponseInputItem`s to send back to the model on the next turn.
pub(crate) async fn process_items(
processed_items: Vec<crate::llmx::ProcessedResponseItem>,
sess: &Session,
turn_context: &TurnContext,
) -> (Vec<ResponseInputItem>, Vec<ResponseItem>) {
let mut items_to_record_in_conversation_history = Vec::<ResponseItem>::new();
let mut responses = Vec::<ResponseInputItem>::new();
for processed_response_item in processed_items {
let crate::llmx::ProcessedResponseItem { item, response } = processed_response_item;
match (&item, &response) {
(ResponseItem::Message { role, .. }, None) if role == "assistant" => {
// If the model returned a message, we need to record it.
items_to_record_in_conversation_history.push(item);
}
(
ResponseItem::LocalShellCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::FunctionCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::CustomToolCall { .. },
Some(ResponseInputItem::CustomToolCallOutput { call_id, output }),
) => {
items_to_record_in_conversation_history.push(item);
items_to_record_in_conversation_history.push(ResponseItem::CustomToolCallOutput {
call_id: call_id.clone(),
output: output.clone(),
});
}
(
ResponseItem::FunctionCall { .. },
Some(ResponseInputItem::McpToolCallOutput { call_id, result }),
) => {
items_to_record_in_conversation_history.push(item);
let output = match result {
Ok(call_tool_result) => FunctionCallOutputPayload::from(call_tool_result),
Err(err) => FunctionCallOutputPayload {
content: err.clone(),
success: Some(false),
..Default::default()
},
};
items_to_record_in_conversation_history.push(ResponseItem::FunctionCallOutput {
call_id: call_id.clone(),
output,
});
}
(
ResponseItem::Reasoning {
id,
summary,
content,
encrypted_content,
},
None,
) => {
items_to_record_in_conversation_history.push(ResponseItem::Reasoning {
id: id.clone(),
summary: summary.clone(),
content: content.clone(),
encrypted_content: encrypted_content.clone(),
});
}
_ => {
warn!("Unexpected response item: {item:?} with response: {response:?}");
}
};
if let Some(response) = response {
responses.push(response);
}
}
// Only attempt to take the lock if there is something to record.
if !items_to_record_in_conversation_history.is_empty() {
sess.record_conversation_items(turn_context, &items_to_record_in_conversation_history)
.await;
}
(responses, items_to_record_in_conversation_history)
}

View File

@@ -0,0 +1,55 @@
use crate::protocol::ReviewFinding;
// Note: We keep this module UI-agnostic. It returns plain strings that
// higher layers (e.g., TUI) may style as needed.
fn format_location(item: &ReviewFinding) -> String {
let path = item.code_location.absolute_file_path.display();
let start = item.code_location.line_range.start;
let end = item.code_location.line_range.end;
format!("{path}:{start}-{end}")
}
/// Format a full review findings block as plain text lines.
///
/// - When `selection` is `Some`, each item line includes a checkbox marker:
/// "[x]" for selected items and "[ ]" for unselected. Missing indices
/// default to selected.
/// - When `selection` is `None`, the marker is omitted and a simple bullet is
/// rendered ("- Title — path:start-end").
pub fn format_review_findings_block(
findings: &[ReviewFinding],
selection: Option<&[bool]>,
) -> String {
let mut lines: Vec<String> = Vec::new();
lines.push(String::new());
// Header
if findings.len() > 1 {
lines.push("Full review comments:".to_string());
} else {
lines.push("Review comment:".to_string());
}
for (idx, item) in findings.iter().enumerate() {
lines.push(String::new());
let title = &item.title;
let location = format_location(item);
if let Some(flags) = selection {
// Default to selected if index is out of bounds.
let checked = flags.get(idx).copied().unwrap_or(true);
let marker = if checked { "[x]" } else { "[ ]" };
lines.push(format!("- {marker} {title}{location}"));
} else {
lines.push(format!("- {title}{location}"));
}
for body_line in item.body.lines() {
lines.push(format!(" {body_line}"));
}
}
lines.join("\n")
}

View File

@@ -0,0 +1,591 @@
use std::cmp::Reverse;
use std::io::{self};
use std::num::NonZero;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use time::OffsetDateTime;
use time::PrimitiveDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use uuid::Uuid;
use super::SESSIONS_SUBDIR;
use crate::protocol::EventMsg;
use llmx_file_search as file_search;
use llmx_protocol::protocol::RolloutItem;
use llmx_protocol::protocol::RolloutLine;
use llmx_protocol::protocol::SessionSource;
/// Returned page of conversation summaries.
#[derive(Debug, Default, PartialEq)]
pub struct ConversationsPage {
/// Conversation summaries ordered newest first.
pub items: Vec<ConversationItem>,
/// Opaque pagination token to resume after the last item, or `None` if end.
pub next_cursor: Option<Cursor>,
/// Total number of files touched while scanning this request.
pub num_scanned_files: usize,
/// True if a hard scan cap was hit; consider resuming with `next_cursor`.
pub reached_scan_cap: bool,
}
/// Summary information for a conversation rollout file.
#[derive(Debug, PartialEq)]
pub struct ConversationItem {
/// Absolute path to the rollout file.
pub path: PathBuf,
/// First up to `HEAD_RECORD_LIMIT` JSONL records parsed as JSON (includes meta line).
pub head: Vec<serde_json::Value>,
/// Last up to `TAIL_RECORD_LIMIT` JSONL response records parsed as JSON.
pub tail: Vec<serde_json::Value>,
/// RFC3339 timestamp string for when the session was created, if available.
pub created_at: Option<String>,
/// RFC3339 timestamp string for the most recent response in the tail, if available.
pub updated_at: Option<String>,
}
#[derive(Default)]
struct HeadTailSummary {
head: Vec<serde_json::Value>,
tail: Vec<serde_json::Value>,
saw_session_meta: bool,
saw_user_event: bool,
source: Option<SessionSource>,
model_provider: Option<String>,
created_at: Option<String>,
updated_at: Option<String>,
}
/// Hard cap to bound worstcase work per request.
const MAX_SCAN_FILES: usize = 10000;
const HEAD_RECORD_LIMIT: usize = 10;
const TAIL_RECORD_LIMIT: usize = 10;
/// Pagination cursor identifying a file by timestamp and UUID.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Cursor {
ts: OffsetDateTime,
id: Uuid,
}
impl Cursor {
fn new(ts: OffsetDateTime, id: Uuid) -> Self {
Self { ts, id }
}
}
impl serde::Serialize for Cursor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let ts_str = self
.ts
.format(&format_description!(
"[year]-[month]-[day]T[hour]-[minute]-[second]"
))
.map_err(|e| serde::ser::Error::custom(format!("format error: {e}")))?;
serializer.serialize_str(&format!("{ts_str}|{}", self.id))
}
}
impl<'de> serde::Deserialize<'de> for Cursor {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
parse_cursor(&s).ok_or_else(|| serde::de::Error::custom("invalid cursor"))
}
}
/// Retrieve recorded conversation file paths with token pagination. The returned `next_cursor`
/// can be supplied on the next call to resume after the last returned item, resilient to
/// concurrent new sessions being appended. Ordering is stable by timestamp desc, then UUID desc.
pub(crate) async fn get_conversations(
llmx_home: &Path,
page_size: usize,
cursor: Option<&Cursor>,
allowed_sources: &[SessionSource],
model_providers: Option<&[String]>,
default_provider: &str,
) -> io::Result<ConversationsPage> {
let mut root = llmx_home.to_path_buf();
root.push(SESSIONS_SUBDIR);
if !root.exists() {
return Ok(ConversationsPage {
items: Vec::new(),
next_cursor: None,
num_scanned_files: 0,
reached_scan_cap: false,
});
}
let anchor = cursor.cloned();
let provider_matcher =
model_providers.and_then(|filters| ProviderMatcher::new(filters, default_provider));
let result = traverse_directories_for_paths(
root.clone(),
page_size,
anchor,
allowed_sources,
provider_matcher.as_ref(),
)
.await?;
Ok(result)
}
/// Load the full contents of a single conversation session file at `path`.
/// Returns the entire file contents as a String.
#[allow(dead_code)]
pub(crate) async fn get_conversation(path: &Path) -> io::Result<String> {
tokio::fs::read_to_string(path).await
}
/// Load conversation file paths from disk using directory traversal.
///
/// Directory layout: `~/.llmx/sessions/YYYY/MM/DD/rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl`
/// Returned newest (latest) first.
async fn traverse_directories_for_paths(
root: PathBuf,
page_size: usize,
anchor: Option<Cursor>,
allowed_sources: &[SessionSource],
provider_matcher: Option<&ProviderMatcher<'_>>,
) -> io::Result<ConversationsPage> {
let mut items: Vec<ConversationItem> = Vec::with_capacity(page_size);
let mut scanned_files = 0usize;
let mut anchor_passed = anchor.is_none();
let (anchor_ts, anchor_id) = match anchor {
Some(c) => (c.ts, c.id),
None => (OffsetDateTime::UNIX_EPOCH, Uuid::nil()),
};
let mut more_matches_available = false;
let year_dirs = collect_dirs_desc(&root, |s| s.parse::<u16>().ok()).await?;
'outer: for (_year, year_path) in year_dirs.iter() {
if scanned_files >= MAX_SCAN_FILES {
break;
}
let month_dirs = collect_dirs_desc(year_path, |s| s.parse::<u8>().ok()).await?;
for (_month, month_path) in month_dirs.iter() {
if scanned_files >= MAX_SCAN_FILES {
break 'outer;
}
let day_dirs = collect_dirs_desc(month_path, |s| s.parse::<u8>().ok()).await?;
for (_day, day_path) in day_dirs.iter() {
if scanned_files >= MAX_SCAN_FILES {
break 'outer;
}
let mut day_files = collect_files(day_path, |name_str, path| {
if !name_str.starts_with("rollout-") || !name_str.ends_with(".jsonl") {
return None;
}
parse_timestamp_uuid_from_filename(name_str)
.map(|(ts, id)| (ts, id, name_str.to_string(), path.to_path_buf()))
})
.await?;
// Stable ordering within the same second: (timestamp desc, uuid desc)
day_files.sort_by_key(|(ts, sid, _name_str, _path)| (Reverse(*ts), Reverse(*sid)));
for (ts, sid, _name_str, path) in day_files.into_iter() {
scanned_files += 1;
if scanned_files >= MAX_SCAN_FILES && items.len() >= page_size {
more_matches_available = true;
break 'outer;
}
if !anchor_passed {
if ts < anchor_ts || (ts == anchor_ts && sid < anchor_id) {
anchor_passed = true;
} else {
continue;
}
}
if items.len() == page_size {
more_matches_available = true;
break 'outer;
}
// Read head and simultaneously detect message events within the same
// first N JSONL records to avoid a second file read.
let summary = read_head_and_tail(&path, HEAD_RECORD_LIMIT, TAIL_RECORD_LIMIT)
.await
.unwrap_or_default();
if !allowed_sources.is_empty()
&& !summary
.source
.is_some_and(|source| allowed_sources.iter().any(|s| s == &source))
{
continue;
}
if let Some(matcher) = provider_matcher
&& !matcher.matches(summary.model_provider.as_deref())
{
continue;
}
// Apply filters: must have session meta and at least one user message event
if summary.saw_session_meta && summary.saw_user_event {
let HeadTailSummary {
head,
tail,
created_at,
mut updated_at,
..
} = summary;
updated_at = updated_at.or_else(|| created_at.clone());
items.push(ConversationItem {
path,
head,
tail,
created_at,
updated_at,
});
}
}
}
}
}
let reached_scan_cap = scanned_files >= MAX_SCAN_FILES;
if reached_scan_cap && !items.is_empty() {
more_matches_available = true;
}
let next = if more_matches_available {
build_next_cursor(&items)
} else {
None
};
Ok(ConversationsPage {
items,
next_cursor: next,
num_scanned_files: scanned_files,
reached_scan_cap,
})
}
/// Pagination cursor token format: "<file_ts>|<uuid>" where `file_ts` matches the
/// filename timestamp portion (YYYY-MM-DDThh-mm-ss) used in rollout filenames.
/// The cursor orders files by timestamp desc, then UUID desc.
pub fn parse_cursor(token: &str) -> Option<Cursor> {
let (file_ts, uuid_str) = token.split_once('|')?;
let Ok(uuid) = Uuid::parse_str(uuid_str) else {
return None;
};
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let ts = PrimitiveDateTime::parse(file_ts, format).ok()?.assume_utc();
Some(Cursor::new(ts, uuid))
}
fn build_next_cursor(items: &[ConversationItem]) -> Option<Cursor> {
let last = items.last()?;
let file_name = last.path.file_name()?.to_string_lossy();
let (ts, id) = parse_timestamp_uuid_from_filename(&file_name)?;
Some(Cursor::new(ts, id))
}
/// Collects immediate subdirectories of `parent`, parses their (string) names with `parse`,
/// and returns them sorted descending by the parsed key.
async fn collect_dirs_desc<T, F>(parent: &Path, parse: F) -> io::Result<Vec<(T, PathBuf)>>
where
T: Ord + Copy,
F: Fn(&str) -> Option<T>,
{
let mut dir = tokio::fs::read_dir(parent).await?;
let mut vec: Vec<(T, PathBuf)> = Vec::new();
while let Some(entry) = dir.next_entry().await? {
if entry
.file_type()
.await
.map(|ft| ft.is_dir())
.unwrap_or(false)
&& let Some(s) = entry.file_name().to_str()
&& let Some(v) = parse(s)
{
vec.push((v, entry.path()));
}
}
vec.sort_by_key(|(v, _)| Reverse(*v));
Ok(vec)
}
/// Collects files in a directory and parses them with `parse`.
async fn collect_files<T, F>(parent: &Path, parse: F) -> io::Result<Vec<T>>
where
F: Fn(&str, &Path) -> Option<T>,
{
let mut dir = tokio::fs::read_dir(parent).await?;
let mut collected: Vec<T> = Vec::new();
while let Some(entry) = dir.next_entry().await? {
if entry
.file_type()
.await
.map(|ft| ft.is_file())
.unwrap_or(false)
&& let Some(s) = entry.file_name().to_str()
&& let Some(v) = parse(s, &entry.path())
{
collected.push(v);
}
}
Ok(collected)
}
fn parse_timestamp_uuid_from_filename(name: &str) -> Option<(OffsetDateTime, Uuid)> {
// Expected: rollout-YYYY-MM-DDThh-mm-ss-<uuid>.jsonl
let core = name.strip_prefix("rollout-")?.strip_suffix(".jsonl")?;
// Scan from the right for a '-' such that the suffix parses as a UUID.
let (sep_idx, uuid) = core
.match_indices('-')
.rev()
.find_map(|(i, _)| Uuid::parse_str(&core[i + 1..]).ok().map(|u| (i, u)))?;
let ts_str = &core[..sep_idx];
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let ts = PrimitiveDateTime::parse(ts_str, format).ok()?.assume_utc();
Some((ts, uuid))
}
struct ProviderMatcher<'a> {
filters: &'a [String],
matches_default_provider: bool,
}
impl<'a> ProviderMatcher<'a> {
fn new(filters: &'a [String], default_provider: &'a str) -> Option<Self> {
if filters.is_empty() {
return None;
}
let matches_default_provider = filters.iter().any(|provider| provider == default_provider);
Some(Self {
filters,
matches_default_provider,
})
}
fn matches(&self, session_provider: Option<&str>) -> bool {
match session_provider {
Some(provider) => self.filters.iter().any(|candidate| candidate == provider),
None => self.matches_default_provider,
}
}
}
async fn read_head_and_tail(
path: &Path,
head_limit: usize,
tail_limit: usize,
) -> io::Result<HeadTailSummary> {
use tokio::io::AsyncBufReadExt;
let file = tokio::fs::File::open(path).await?;
let reader = tokio::io::BufReader::new(file);
let mut lines = reader.lines();
let mut summary = HeadTailSummary::default();
while summary.head.len() < head_limit {
let line_opt = lines.next_line().await?;
let Some(line) = line_opt else { break };
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parsed: Result<RolloutLine, _> = serde_json::from_str(trimmed);
let Ok(rollout_line) = parsed else { continue };
match rollout_line.item {
RolloutItem::SessionMeta(session_meta_line) => {
summary.source = Some(session_meta_line.meta.source.clone());
summary.model_provider = session_meta_line.meta.model_provider.clone();
summary.created_at = summary
.created_at
.clone()
.or_else(|| Some(rollout_line.timestamp.clone()));
if let Ok(val) = serde_json::to_value(session_meta_line) {
summary.head.push(val);
summary.saw_session_meta = true;
}
}
RolloutItem::ResponseItem(item) => {
summary.created_at = summary
.created_at
.clone()
.or_else(|| Some(rollout_line.timestamp.clone()));
if let Ok(val) = serde_json::to_value(item) {
summary.head.push(val);
}
}
RolloutItem::TurnContext(_) => {
// Not included in `head`; skip.
}
RolloutItem::Compacted(_) => {
// Not included in `head`; skip.
}
RolloutItem::EventMsg(ev) => {
if matches!(ev, EventMsg::UserMessage(_)) {
summary.saw_user_event = true;
}
}
}
}
if tail_limit != 0 {
let (tail, updated_at) = read_tail_records(path, tail_limit).await?;
summary.tail = tail;
summary.updated_at = updated_at;
}
Ok(summary)
}
/// Read up to `HEAD_RECORD_LIMIT` records from the start of the rollout file at `path`.
/// This should be enough to produce a summary including the session meta line.
pub async fn read_head_for_summary(path: &Path) -> io::Result<Vec<serde_json::Value>> {
let summary = read_head_and_tail(path, HEAD_RECORD_LIMIT, 0).await?;
Ok(summary.head)
}
async fn read_tail_records(
path: &Path,
max_records: usize,
) -> io::Result<(Vec<serde_json::Value>, Option<String>)> {
use std::io::SeekFrom;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncSeekExt;
if max_records == 0 {
return Ok((Vec::new(), None));
}
const CHUNK_SIZE: usize = 8192;
let mut file = tokio::fs::File::open(path).await?;
let mut pos = file.seek(SeekFrom::End(0)).await?;
if pos == 0 {
return Ok((Vec::new(), None));
}
let mut buffer: Vec<u8> = Vec::new();
let mut latest_timestamp: Option<String> = None;
loop {
let slice_start = match (pos > 0, buffer.iter().position(|&b| b == b'\n')) {
(true, Some(idx)) => idx + 1,
_ => 0,
};
let (tail, newest_ts) = collect_last_response_values(&buffer[slice_start..], max_records);
if latest_timestamp.is_none() {
latest_timestamp = newest_ts.clone();
}
if tail.len() >= max_records || pos == 0 {
return Ok((tail, latest_timestamp.or(newest_ts)));
}
let read_size = CHUNK_SIZE.min(pos as usize);
if read_size == 0 {
return Ok((tail, latest_timestamp.or(newest_ts)));
}
pos -= read_size as u64;
file.seek(SeekFrom::Start(pos)).await?;
let mut chunk = vec![0; read_size];
file.read_exact(&mut chunk).await?;
chunk.extend_from_slice(&buffer);
buffer = chunk;
}
}
fn collect_last_response_values(
buffer: &[u8],
max_records: usize,
) -> (Vec<serde_json::Value>, Option<String>) {
use std::borrow::Cow;
if buffer.is_empty() || max_records == 0 {
return (Vec::new(), None);
}
let text: Cow<'_, str> = String::from_utf8_lossy(buffer);
let mut collected_rev: Vec<serde_json::Value> = Vec::new();
let mut latest_timestamp: Option<String> = None;
for line in text.lines().rev() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let parsed: serde_json::Result<RolloutLine> = serde_json::from_str(trimmed);
let Ok(rollout_line) = parsed else { continue };
let RolloutLine { timestamp, item } = rollout_line;
if let RolloutItem::ResponseItem(item) = item
&& let Ok(val) = serde_json::to_value(&item)
{
if latest_timestamp.is_none() {
latest_timestamp = Some(timestamp.clone());
}
collected_rev.push(val);
if collected_rev.len() == max_records {
break;
}
}
}
collected_rev.reverse();
(collected_rev, latest_timestamp)
}
/// Locate a recorded conversation rollout file by its UUID string using the existing
/// paginated listing implementation. Returns `Ok(Some(path))` if found, `Ok(None)` if not present
/// or the id is invalid.
pub async fn find_conversation_path_by_id_str(
llmx_home: &Path,
id_str: &str,
) -> io::Result<Option<PathBuf>> {
// Validate UUID format early.
if Uuid::parse_str(id_str).is_err() {
return Ok(None);
}
let mut root = llmx_home.to_path_buf();
root.push(SESSIONS_SUBDIR);
if !root.exists() {
return Ok(None);
}
// This is safe because we know the values are valid.
#[allow(clippy::unwrap_used)]
let limit = NonZero::new(1).unwrap();
// This is safe because we know the values are valid.
#[allow(clippy::unwrap_used)]
let threads = NonZero::new(2).unwrap();
let cancel = Arc::new(AtomicBool::new(false));
let exclude: Vec<String> = Vec::new();
let compute_indices = false;
let results = file_search::run(
id_str,
limit,
&root,
exclude,
threads,
cancel,
compute_indices,
false,
)
.map_err(|e| io::Error::other(format!("file search failed: {e}")))?;
Ok(results
.matches
.into_iter()
.next()
.map(|m| root.join(m.path)))
}

View File

@@ -0,0 +1,20 @@
//! Rollout module: persistence and discovery of session rollout files.
use llmx_protocol::protocol::SessionSource;
pub const SESSIONS_SUBDIR: &str = "sessions";
pub const ARCHIVED_SESSIONS_SUBDIR: &str = "archived_sessions";
pub const INTERACTIVE_SESSION_SOURCES: &[SessionSource] =
&[SessionSource::Cli, SessionSource::VSCode];
pub mod list;
pub(crate) mod policy;
pub mod recorder;
pub use list::find_conversation_path_by_id_str;
pub use llmx_protocol::protocol::SessionMeta;
pub use recorder::RolloutRecorder;
pub use recorder::RolloutRecorderParams;
#[cfg(test)]
pub mod tests;

View File

@@ -0,0 +1,86 @@
use crate::protocol::EventMsg;
use crate::protocol::RolloutItem;
use llmx_protocol::models::ResponseItem;
/// Whether a rollout `item` should be persisted in rollout files.
#[inline]
pub(crate) fn is_persisted_response_item(item: &RolloutItem) -> bool {
match item {
RolloutItem::ResponseItem(item) => should_persist_response_item(item),
RolloutItem::EventMsg(ev) => should_persist_event_msg(ev),
// Persist LLMX executive markers so we can analyze flows (e.g., compaction, API turns).
RolloutItem::Compacted(_) | RolloutItem::TurnContext(_) | RolloutItem::SessionMeta(_) => {
true
}
}
}
/// Whether a `ResponseItem` should be persisted in rollout files.
#[inline]
pub(crate) fn should_persist_response_item(item: &ResponseItem) -> bool {
match item {
ResponseItem::Message { .. }
| ResponseItem::Reasoning { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. }
| ResponseItem::CustomToolCall { .. }
| ResponseItem::CustomToolCallOutput { .. }
| ResponseItem::WebSearchCall { .. }
| ResponseItem::GhostSnapshot { .. } => true,
ResponseItem::Other => false,
}
}
/// Whether an `EventMsg` should be persisted in rollout files.
#[inline]
pub(crate) fn should_persist_event_msg(ev: &EventMsg) -> bool {
match ev {
EventMsg::UserMessage(_)
| EventMsg::AgentMessage(_)
| EventMsg::AgentReasoning(_)
| EventMsg::AgentReasoningRawContent(_)
| EventMsg::TokenCount(_)
| EventMsg::EnteredReviewMode(_)
| EventMsg::ExitedReviewMode(_)
| EventMsg::UndoCompleted(_)
| EventMsg::TurnAborted(_) => true,
EventMsg::Error(_)
| EventMsg::Warning(_)
| EventMsg::TaskStarted(_)
| EventMsg::TaskComplete(_)
| EventMsg::AgentMessageDelta(_)
| EventMsg::AgentReasoningDelta(_)
| EventMsg::AgentReasoningRawContentDelta(_)
| EventMsg::AgentReasoningSectionBreak(_)
| EventMsg::RawResponseItem(_)
| EventMsg::SessionConfigured(_)
| EventMsg::McpToolCallBegin(_)
| EventMsg::McpToolCallEnd(_)
| EventMsg::WebSearchBegin(_)
| EventMsg::WebSearchEnd(_)
| EventMsg::ExecCommandBegin(_)
| EventMsg::ExecCommandOutputDelta(_)
| EventMsg::ExecCommandEnd(_)
| EventMsg::ExecApprovalRequest(_)
| EventMsg::ApplyPatchApprovalRequest(_)
| EventMsg::BackgroundEvent(_)
| EventMsg::StreamError(_)
| EventMsg::PatchApplyBegin(_)
| EventMsg::PatchApplyEnd(_)
| EventMsg::TurnDiff(_)
| EventMsg::GetHistoryEntryResponse(_)
| EventMsg::UndoStarted(_)
| EventMsg::McpListToolsResponse(_)
| EventMsg::ListCustomPromptsResponse(_)
| EventMsg::PlanUpdate(_)
| EventMsg::ShutdownComplete
| EventMsg::ViewImageToolCall(_)
| EventMsg::DeprecationNotice(_)
| EventMsg::ItemStarted(_)
| EventMsg::ItemCompleted(_)
| EventMsg::AgentMessageContentDelta(_)
| EventMsg::ReasoningContentDelta(_)
| EventMsg::ReasoningRawContentDelta(_) => false,
}
}

View File

@@ -0,0 +1,424 @@
//! Persist LLMX session rollouts (.jsonl) so sessions can be replayed or inspected later.
use std::fs::File;
use std::fs::{self};
use std::io::Error as IoError;
use std::path::Path;
use std::path::PathBuf;
use llmx_protocol::ConversationId;
use serde_json::Value;
use time::OffsetDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{self};
use tokio::sync::oneshot;
use tracing::info;
use tracing::warn;
use super::SESSIONS_SUBDIR;
use super::list::ConversationsPage;
use super::list::Cursor;
use super::list::get_conversations;
use super::policy::is_persisted_response_item;
use crate::config::Config;
use crate::default_client::originator;
use crate::git_info::collect_git_info;
use llmx_protocol::protocol::InitialHistory;
use llmx_protocol::protocol::ResumedHistory;
use llmx_protocol::protocol::RolloutItem;
use llmx_protocol::protocol::RolloutLine;
use llmx_protocol::protocol::SessionMeta;
use llmx_protocol::protocol::SessionMetaLine;
use llmx_protocol::protocol::SessionSource;
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
/// every update.
///
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
///
/// ```ignore
/// $ jq -C . ~/.llmx/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// $ fx ~/.llmx/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// ```
#[derive(Clone)]
pub struct RolloutRecorder {
tx: Sender<RolloutCmd>,
pub(crate) rollout_path: PathBuf,
}
#[derive(Clone)]
pub enum RolloutRecorderParams {
Create {
conversation_id: ConversationId,
instructions: Option<String>,
source: SessionSource,
},
Resume {
path: PathBuf,
},
}
enum RolloutCmd {
AddItems(Vec<RolloutItem>),
/// Ensure all prior writes are processed; respond when flushed.
Flush {
ack: oneshot::Sender<()>,
},
Shutdown {
ack: oneshot::Sender<()>,
},
}
impl RolloutRecorderParams {
pub fn new(
conversation_id: ConversationId,
instructions: Option<String>,
source: SessionSource,
) -> Self {
Self::Create {
conversation_id,
instructions,
source,
}
}
pub fn resume(path: PathBuf) -> Self {
Self::Resume { path }
}
}
impl RolloutRecorder {
/// List conversations (rollout files) under the provided Llmx home directory.
pub async fn list_conversations(
llmx_home: &Path,
page_size: usize,
cursor: Option<&Cursor>,
allowed_sources: &[SessionSource],
model_providers: Option<&[String]>,
default_provider: &str,
) -> std::io::Result<ConversationsPage> {
get_conversations(
llmx_home,
page_size,
cursor,
allowed_sources,
model_providers,
default_provider,
)
.await
}
/// Attempt to create a new [`RolloutRecorder`]. If the sessions directory
/// cannot be created or the rollout file cannot be opened we return the
/// error so the caller can decide whether to disable persistence.
pub async fn new(config: &Config, params: RolloutRecorderParams) -> std::io::Result<Self> {
let (file, rollout_path, meta) = match params {
RolloutRecorderParams::Create {
conversation_id,
instructions,
source,
} => {
let LogFileInfo {
file,
path,
conversation_id: session_id,
timestamp,
} = create_log_file(config, conversation_id)?;
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
let timestamp = timestamp
.to_offset(time::UtcOffset::UTC)
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
(
tokio::fs::File::from_std(file),
path,
Some(SessionMeta {
id: session_id,
timestamp,
cwd: config.cwd.clone(),
originator: originator().value.clone(),
cli_version: env!("CARGO_PKG_VERSION").to_string(),
instructions,
source,
model_provider: Some(config.model_provider_id.clone()),
}),
)
}
RolloutRecorderParams::Resume { path } => (
tokio::fs::OpenOptions::new()
.append(true)
.open(&path)
.await?,
path,
None,
),
};
// Clone the cwd for the spawned task to collect git info asynchronously
let cwd = config.cwd.clone();
// A reasonably-sized bounded channel. If the buffer fills up the send
// future will yield, which is fine we only need to ensure we do not
// perform *blocking* I/O on the caller's thread.
let (tx, rx) = mpsc::channel::<RolloutCmd>(256);
// Spawn a Tokio task that owns the file handle and performs async
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
tokio::task::spawn(rollout_writer(file, rx, meta, cwd));
Ok(Self { tx, rollout_path })
}
pub(crate) async fn record_items(&self, items: &[RolloutItem]) -> std::io::Result<()> {
let mut filtered = Vec::new();
for item in items {
// Note that function calls may look a bit strange if they are
// "fully qualified MCP tool calls," so we could consider
// reformatting them in that case.
if is_persisted_response_item(item) {
filtered.push(item.clone());
}
}
if filtered.is_empty() {
return Ok(());
}
self.tx
.send(RolloutCmd::AddItems(filtered))
.await
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
}
/// Flush all queued writes and wait until they are committed by the writer task.
pub async fn flush(&self) -> std::io::Result<()> {
let (tx, rx) = oneshot::channel();
self.tx
.send(RolloutCmd::Flush { ack: tx })
.await
.map_err(|e| IoError::other(format!("failed to queue rollout flush: {e}")))?;
rx.await
.map_err(|e| IoError::other(format!("failed waiting for rollout flush: {e}")))
}
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
info!("Resuming rollout from {path:?}");
let text = tokio::fs::read_to_string(path).await?;
if text.trim().is_empty() {
return Err(IoError::other("empty session file"));
}
let mut items: Vec<RolloutItem> = Vec::new();
let mut conversation_id: Option<ConversationId> = None;
for line in text.lines() {
if line.trim().is_empty() {
continue;
}
let v: Value = match serde_json::from_str(line) {
Ok(v) => v,
Err(e) => {
warn!("failed to parse line as JSON: {line:?}, error: {e}");
continue;
}
};
// Parse the rollout line structure
match serde_json::from_value::<RolloutLine>(v.clone()) {
Ok(rollout_line) => match rollout_line.item {
RolloutItem::SessionMeta(session_meta_line) => {
// Use the FIRST SessionMeta encountered in the file as the canonical
// conversation id and main session information. Keep all items intact.
if conversation_id.is_none() {
conversation_id = Some(session_meta_line.meta.id);
}
items.push(RolloutItem::SessionMeta(session_meta_line));
}
RolloutItem::ResponseItem(item) => {
items.push(RolloutItem::ResponseItem(item));
}
RolloutItem::Compacted(item) => {
items.push(RolloutItem::Compacted(item));
}
RolloutItem::TurnContext(item) => {
items.push(RolloutItem::TurnContext(item));
}
RolloutItem::EventMsg(_ev) => {
items.push(RolloutItem::EventMsg(_ev));
}
},
Err(e) => {
warn!("failed to parse rollout line: {v:?}, error: {e}");
}
}
}
info!(
"Resumed rollout with {} items, conversation ID: {:?}",
items.len(),
conversation_id
);
let conversation_id = conversation_id
.ok_or_else(|| IoError::other("failed to parse conversation ID from rollout file"))?;
if items.is_empty() {
return Ok(InitialHistory::New);
}
info!("Resumed rollout successfully from {path:?}");
Ok(InitialHistory::Resumed(ResumedHistory {
conversation_id,
history: items,
rollout_path: path.to_path_buf(),
}))
}
pub async fn shutdown(&self) -> std::io::Result<()> {
let (tx_done, rx_done) = oneshot::channel();
match self.tx.send(RolloutCmd::Shutdown { ack: tx_done }).await {
Ok(_) => rx_done
.await
.map_err(|e| IoError::other(format!("failed waiting for rollout shutdown: {e}"))),
Err(e) => {
warn!("failed to send rollout shutdown command: {e}");
Err(IoError::other(format!(
"failed to send rollout shutdown command: {e}"
)))
}
}
}
}
struct LogFileInfo {
/// Opened file handle to the rollout file.
file: File,
/// Full path to the rollout file.
path: PathBuf,
/// Session ID (also embedded in filename).
conversation_id: ConversationId,
/// Timestamp for the start of the session.
timestamp: OffsetDateTime,
}
fn create_log_file(
config: &Config,
conversation_id: ConversationId,
) -> std::io::Result<LogFileInfo> {
// Resolve ~/.llmx/sessions/YYYY/MM/DD and create it if missing.
let timestamp = OffsetDateTime::now_local()
.map_err(|e| IoError::other(format!("failed to get local time: {e}")))?;
let mut dir = config.llmx_home.clone();
dir.push(SESSIONS_SUBDIR);
dir.push(timestamp.year().to_string());
dir.push(format!("{:02}", u8::from(timestamp.month())));
dir.push(format!("{:02}", timestamp.day()));
fs::create_dir_all(&dir)?;
// Custom format for YYYY-MM-DDThh-mm-ss. Use `-` instead of `:` for
// compatibility with filesystems that do not allow colons in filenames.
let format: &[FormatItem] =
format_description!("[year]-[month]-[day]T[hour]-[minute]-[second]");
let date_str = timestamp
.format(format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let filename = format!("rollout-{date_str}-{conversation_id}.jsonl");
let path = dir.join(filename);
let file = std::fs::OpenOptions::new()
.append(true)
.create(true)
.open(&path)?;
Ok(LogFileInfo {
file,
path,
conversation_id,
timestamp,
})
}
async fn rollout_writer(
file: tokio::fs::File,
mut rx: mpsc::Receiver<RolloutCmd>,
mut meta: Option<SessionMeta>,
cwd: std::path::PathBuf,
) -> std::io::Result<()> {
let mut writer = JsonlWriter { file };
// If we have a meta, collect git info asynchronously and write meta first
if let Some(session_meta) = meta.take() {
let git_info = collect_git_info(&cwd).await;
let session_meta_line = SessionMetaLine {
meta: session_meta,
git: git_info,
};
// Write the SessionMeta as the first item in the file, wrapped in a rollout line
writer
.write_rollout_item(RolloutItem::SessionMeta(session_meta_line))
.await?;
}
// Process rollout commands
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => {
for item in items {
if is_persisted_response_item(&item) {
writer.write_rollout_item(item).await?;
}
}
}
RolloutCmd::Flush { ack } => {
// Ensure underlying file is flushed and then ack.
if let Err(e) = writer.file.flush().await {
let _ = ack.send(());
return Err(e);
}
let _ = ack.send(());
}
RolloutCmd::Shutdown { ack } => {
let _ = ack.send(());
}
}
}
Ok(())
}
struct JsonlWriter {
file: tokio::fs::File,
}
impl JsonlWriter {
async fn write_rollout_item(&mut self, rollout_item: RolloutItem) -> std::io::Result<()> {
let timestamp_format: &[FormatItem] = format_description!(
"[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond digits:3]Z"
);
let timestamp = OffsetDateTime::now_utc()
.format(timestamp_format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let line = RolloutLine {
timestamp,
item: rollout_item,
};
self.write_line(&line).await
}
async fn write_line(&mut self, item: &impl serde::Serialize) -> std::io::Result<()> {
let mut json = serde_json::to_string(item)?;
json.push('\n');
self.file.write_all(json.as_bytes()).await?;
self.file.flush().await?;
Ok(())
}
}

File diff suppressed because it is too large Load Diff

245
llmx-rs/core/src/safety.rs Normal file
View File

@@ -0,0 +1,245 @@
use std::path::Component;
use std::path::Path;
use std::path::PathBuf;
use llmx_apply_patch::ApplyPatchAction;
use llmx_apply_patch::ApplyPatchFileChange;
use crate::exec::SandboxType;
use crate::protocol::AskForApproval;
use crate::protocol::SandboxPolicy;
#[cfg(target_os = "windows")]
use std::sync::atomic::AtomicBool;
#[cfg(target_os = "windows")]
use std::sync::atomic::Ordering;
#[cfg(target_os = "windows")]
static WINDOWS_SANDBOX_ENABLED: AtomicBool = AtomicBool::new(false);
#[cfg(target_os = "windows")]
pub fn set_windows_sandbox_enabled(enabled: bool) {
WINDOWS_SANDBOX_ENABLED.store(enabled, Ordering::Relaxed);
}
#[cfg(not(target_os = "windows"))]
#[allow(dead_code)]
pub fn set_windows_sandbox_enabled(_enabled: bool) {}
#[derive(Debug, PartialEq)]
pub enum SafetyCheck {
AutoApprove {
sandbox_type: SandboxType,
user_explicitly_approved: bool,
},
AskUser,
Reject {
reason: String,
},
}
pub fn assess_patch_safety(
action: &ApplyPatchAction,
policy: AskForApproval,
sandbox_policy: &SandboxPolicy,
cwd: &Path,
) -> SafetyCheck {
if action.is_empty() {
return SafetyCheck::Reject {
reason: "empty patch".to_string(),
};
}
match policy {
AskForApproval::OnFailure | AskForApproval::Never | AskForApproval::OnRequest => {
// Continue to see if this can be auto-approved.
}
// TODO(ragona): I'm not sure this is actually correct? I believe in this case
// we want to continue to the writable paths check before asking the user.
AskForApproval::UnlessTrusted => {
return SafetyCheck::AskUser;
}
}
// Even though the patch appears to be constrained to writable paths, it is
// possible that paths in the patch are hard links to files outside the
// writable roots, so we should still run `apply_patch` in a sandbox in that case.
if is_write_patch_constrained_to_writable_paths(action, sandbox_policy, cwd)
|| policy == AskForApproval::OnFailure
{
if matches!(sandbox_policy, SandboxPolicy::DangerFullAccess) {
// DangerFullAccess is intended to bypass sandboxing entirely.
SafetyCheck::AutoApprove {
sandbox_type: SandboxType::None,
user_explicitly_approved: false,
}
} else {
// Only autoapprove when we can actually enforce a sandbox. Otherwise
// fall back to asking the user because the patch may touch arbitrary
// paths outside the project.
match get_platform_sandbox() {
Some(sandbox_type) => SafetyCheck::AutoApprove {
sandbox_type,
user_explicitly_approved: false,
},
None => SafetyCheck::AskUser,
}
}
} else if policy == AskForApproval::Never {
SafetyCheck::Reject {
reason: "writing outside of the project; rejected by user approval settings"
.to_string(),
}
} else {
SafetyCheck::AskUser
}
}
pub fn get_platform_sandbox() -> Option<SandboxType> {
if cfg!(target_os = "macos") {
Some(SandboxType::MacosSeatbelt)
} else if cfg!(target_os = "linux") {
Some(SandboxType::LinuxSeccomp)
} else if cfg!(target_os = "windows") {
#[cfg(target_os = "windows")]
{
if WINDOWS_SANDBOX_ENABLED.load(Ordering::Relaxed) {
return Some(SandboxType::WindowsRestrictedToken);
}
}
None
} else {
None
}
}
fn is_write_patch_constrained_to_writable_paths(
action: &ApplyPatchAction,
sandbox_policy: &SandboxPolicy,
cwd: &Path,
) -> bool {
// Earlyexit if there are no declared writable roots.
let writable_roots = match sandbox_policy {
SandboxPolicy::ReadOnly => {
return false;
}
SandboxPolicy::DangerFullAccess => {
return true;
}
SandboxPolicy::WorkspaceWrite { .. } => sandbox_policy.get_writable_roots_with_cwd(cwd),
};
// Normalize a path by removing `.` and resolving `..` without touching the
// filesystem (works even if the file does not exist).
fn normalize(path: &Path) -> Option<PathBuf> {
let mut out = PathBuf::new();
for comp in path.components() {
match comp {
Component::ParentDir => {
out.pop();
}
Component::CurDir => { /* skip */ }
other => out.push(other.as_os_str()),
}
}
Some(out)
}
// Determine whether `path` is inside **any** writable root. Both `path`
// and roots are converted to absolute, normalized forms before the
// prefix check.
let is_path_writable = |p: &PathBuf| {
let abs = if p.is_absolute() {
p.clone()
} else {
cwd.join(p)
};
let abs = match normalize(&abs) {
Some(v) => v,
None => return false,
};
writable_roots
.iter()
.any(|writable_root| writable_root.is_path_writable(&abs))
};
for (path, change) in action.changes() {
match change {
ApplyPatchFileChange::Add { .. } | ApplyPatchFileChange::Delete { .. } => {
if !is_path_writable(path) {
return false;
}
}
ApplyPatchFileChange::Update { move_path, .. } => {
if !is_path_writable(path) {
return false;
}
if let Some(dest) = move_path
&& !is_path_writable(dest)
{
return false;
}
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_writable_roots_constraint() {
// Use a temporary directory as our workspace to avoid touching
// the real current working directory.
let tmp = TempDir::new().unwrap();
let cwd = tmp.path().to_path_buf();
let parent = cwd.parent().unwrap().to_path_buf();
// Helper to build a singleentry patch that adds a file at `p`.
let make_add_change = |p: PathBuf| ApplyPatchAction::new_add_for_test(&p, "".to_string());
let add_inside = make_add_change(cwd.join("inner.txt"));
let add_outside = make_add_change(parent.join("outside.txt"));
// Policy limited to the workspace only; exclude system temp roots so
// only `cwd` is writable by default.
let policy_workspace_only = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![],
network_access: false,
exclude_tmpdir_env_var: true,
exclude_slash_tmp: true,
};
assert!(is_write_patch_constrained_to_writable_paths(
&add_inside,
&policy_workspace_only,
&cwd,
));
assert!(!is_write_patch_constrained_to_writable_paths(
&add_outside,
&policy_workspace_only,
&cwd,
));
// With the parent dir explicitly added as a writable root, the
// outside write should be permitted.
let policy_with_parent = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![parent],
network_access: false,
exclude_tmpdir_env_var: true,
exclude_slash_tmp: true,
};
assert!(is_write_patch_constrained_to_writable_paths(
&add_outside,
&policy_with_parent,
&cwd,
));
}
}

View File

@@ -0,0 +1,260 @@
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use crate::AuthManager;
use crate::ModelProviderInfo;
use crate::client::ModelClient;
use crate::client_common::Prompt;
use crate::client_common::ResponseEvent;
use crate::config::Config;
use crate::protocol::SandboxPolicy;
use askama::Template;
use futures::StreamExt;
use llmx_otel::otel_event_manager::OtelEventManager;
use llmx_protocol::ConversationId;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::SandboxCommandAssessment;
use llmx_protocol::protocol::SessionSource;
use serde_json::json;
use tokio::time::timeout;
use tracing::warn;
const SANDBOX_ASSESSMENT_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Template)]
#[template(path = "sandboxing/assessment_prompt.md", escape = "none")]
struct SandboxAssessmentPromptTemplate<'a> {
platform: &'a str,
sandbox_policy: &'a str,
filesystem_roots: Option<&'a str>,
working_directory: &'a str,
command_argv: &'a str,
command_joined: &'a str,
sandbox_failure_message: Option<&'a str>,
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn assess_command(
config: Arc<Config>,
provider: ModelProviderInfo,
auth_manager: Arc<AuthManager>,
parent_otel: &OtelEventManager,
conversation_id: ConversationId,
session_source: SessionSource,
call_id: &str,
command: &[String],
sandbox_policy: &SandboxPolicy,
cwd: &Path,
failure_message: Option<&str>,
) -> Option<SandboxCommandAssessment> {
if !config.experimental_sandbox_command_assessment || command.is_empty() {
return None;
}
let command_json = serde_json::to_string(command).unwrap_or_else(|_| "[]".to_string());
let command_joined =
shlex::try_join(command.iter().map(String::as_str)).unwrap_or_else(|_| command.join(" "));
let failure = failure_message
.map(str::trim)
.filter(|msg| !msg.is_empty())
.map(str::to_string);
let cwd_str = cwd.to_string_lossy().to_string();
let sandbox_summary = summarize_sandbox_policy(sandbox_policy);
let mut roots = sandbox_roots_for_prompt(sandbox_policy, cwd);
roots.sort();
roots.dedup();
let platform = std::env::consts::OS;
let roots_formatted = roots.iter().map(|root| root.to_string_lossy().to_string());
let filesystem_roots = match roots_formatted.collect::<Vec<_>>() {
collected if collected.is_empty() => None,
collected => Some(collected.join(", ")),
};
let prompt_template = SandboxAssessmentPromptTemplate {
platform,
sandbox_policy: sandbox_summary.as_str(),
filesystem_roots: filesystem_roots.as_deref(),
working_directory: cwd_str.as_str(),
command_argv: command_json.as_str(),
command_joined: command_joined.as_str(),
sandbox_failure_message: failure.as_deref(),
};
let rendered_prompt = match prompt_template.render() {
Ok(rendered) => rendered,
Err(err) => {
warn!("failed to render sandbox assessment prompt: {err}");
return None;
}
};
let (system_prompt_section, user_prompt_section) = match rendered_prompt.split_once("\n---\n") {
Some(split) => split,
None => {
warn!("rendered sandbox assessment prompt missing separator");
return None;
}
};
let system_prompt = system_prompt_section
.strip_prefix("System Prompt:\n")
.unwrap_or(system_prompt_section)
.trim()
.to_string();
let user_prompt = user_prompt_section
.strip_prefix("User Prompt:\n")
.unwrap_or(user_prompt_section)
.trim()
.to_string();
let prompt = Prompt {
input: vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: user_prompt }],
}],
tools: Vec::new(),
parallel_tool_calls: false,
base_instructions_override: Some(system_prompt),
output_schema: Some(sandbox_assessment_schema()),
};
let child_otel =
parent_otel.with_model(config.model.as_str(), config.model_family.slug.as_str());
let client = ModelClient::new(
Arc::clone(&config),
Some(auth_manager),
child_otel,
provider,
config.model_reasoning_effort,
config.model_reasoning_summary,
conversation_id,
session_source,
);
let start = Instant::now();
let assessment_result = timeout(SANDBOX_ASSESSMENT_TIMEOUT, async move {
let mut stream = client.stream(&prompt).await?;
let mut last_json: Option<String> = None;
while let Some(event) = stream.next().await {
match event {
Ok(ResponseEvent::OutputItemDone(item)) => {
if let Some(text) = response_item_text(&item) {
last_json = Some(text);
}
}
Ok(ResponseEvent::RateLimits(_)) => {}
Ok(ResponseEvent::Completed { .. }) => break,
Ok(_) => continue,
Err(err) => return Err(err),
}
}
Ok(last_json)
})
.await;
let duration = start.elapsed();
parent_otel.sandbox_assessment_latency(call_id, duration);
match assessment_result {
Ok(Ok(Some(raw))) => match serde_json::from_str::<SandboxCommandAssessment>(raw.trim()) {
Ok(assessment) => {
parent_otel.sandbox_assessment(
call_id,
"success",
Some(assessment.risk_level),
duration,
);
return Some(assessment);
}
Err(err) => {
warn!("failed to parse sandbox assessment JSON: {err}");
parent_otel.sandbox_assessment(call_id, "parse_error", None, duration);
}
},
Ok(Ok(None)) => {
warn!("sandbox assessment response did not include any message");
parent_otel.sandbox_assessment(call_id, "no_output", None, duration);
}
Ok(Err(err)) => {
warn!("sandbox assessment failed: {err}");
parent_otel.sandbox_assessment(call_id, "model_error", None, duration);
}
Err(_) => {
warn!("sandbox assessment timed out");
parent_otel.sandbox_assessment(call_id, "timeout", None, duration);
}
}
None
}
fn summarize_sandbox_policy(policy: &SandboxPolicy) -> String {
match policy {
SandboxPolicy::DangerFullAccess => "danger-full-access".to_string(),
SandboxPolicy::ReadOnly => "read-only".to_string(),
SandboxPolicy::WorkspaceWrite { network_access, .. } => {
let network = if *network_access {
"network"
} else {
"no-network"
};
format!("workspace-write (network_access={network})")
}
}
}
fn sandbox_roots_for_prompt(policy: &SandboxPolicy, cwd: &Path) -> Vec<PathBuf> {
let mut roots = vec![cwd.to_path_buf()];
if let SandboxPolicy::WorkspaceWrite { writable_roots, .. } = policy {
roots.extend(writable_roots.iter().cloned());
}
roots
}
fn sandbox_assessment_schema() -> serde_json::Value {
json!({
"type": "object",
"required": ["description", "risk_level"],
"properties": {
"description": {
"type": "string",
"minLength": 1,
"maxLength": 500
},
"risk_level": {
"type": "string",
"enum": ["low", "medium", "high"]
},
},
"additionalProperties": false
})
}
fn response_item_text(item: &ResponseItem) -> Option<String> {
match item {
ResponseItem::Message { content, .. } => {
let mut buffers: Vec<&str> = Vec::new();
for segment in content {
match segment {
ContentItem::InputText { text } | ContentItem::OutputText { text } => {
if !text.is_empty() {
buffers.push(text);
}
}
ContentItem::InputImage { .. } => {}
}
}
if buffers.is_empty() {
None
} else {
Some(buffers.join("\n"))
}
}
ResponseItem::FunctionCallOutput { output, .. } => Some(output.content.clone()),
_ => None,
}
}

View File

@@ -0,0 +1,178 @@
/*
Module: sandboxing
Build platform wrappers and produce ExecEnv for execution. Owns lowlevel
sandbox placement and transformation of portable CommandSpec into a
readytospawn environment.
*/
pub mod assessment;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::StdoutStream;
use crate::exec::execute_exec_env;
use crate::landlock::create_linux_sandbox_command_args;
use crate::protocol::SandboxPolicy;
#[cfg(target_os = "macos")]
use crate::seatbelt::MACOS_PATH_TO_SEATBELT_EXECUTABLE;
#[cfg(target_os = "macos")]
use crate::seatbelt::create_seatbelt_command_args;
#[cfg(target_os = "macos")]
use crate::spawn::LLMX_SANDBOX_ENV_VAR;
use crate::spawn::LLMX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
use crate::tools::sandboxing::SandboxablePreference;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
#[derive(Clone, Debug)]
pub struct CommandSpec {
pub program: String,
pub args: Vec<String>,
pub cwd: PathBuf,
pub env: HashMap<String, String>,
pub timeout_ms: Option<u64>,
pub with_escalated_permissions: Option<bool>,
pub justification: Option<String>,
}
#[derive(Clone, Debug)]
pub struct ExecEnv {
pub command: Vec<String>,
pub cwd: PathBuf,
pub env: HashMap<String, String>,
pub timeout_ms: Option<u64>,
pub sandbox: SandboxType,
pub with_escalated_permissions: Option<bool>,
pub justification: Option<String>,
pub arg0: Option<String>,
}
pub enum SandboxPreference {
Auto,
Require,
Forbid,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SandboxTransformError {
#[error("missing llmx-linux-sandbox executable path")]
MissingLinuxSandboxExecutable,
#[cfg(not(target_os = "macos"))]
#[error("seatbelt sandbox is only available on macOS")]
SeatbeltUnavailable,
}
#[derive(Default)]
pub struct SandboxManager;
impl SandboxManager {
pub fn new() -> Self {
Self
}
pub(crate) fn select_initial(
&self,
policy: &SandboxPolicy,
pref: SandboxablePreference,
) -> SandboxType {
match pref {
SandboxablePreference::Forbid => SandboxType::None,
SandboxablePreference::Require => {
// Require a platform sandbox when available; on Windows this
// respects the enable_experimental_windows_sandbox feature.
crate::safety::get_platform_sandbox().unwrap_or(SandboxType::None)
}
SandboxablePreference::Auto => match policy {
SandboxPolicy::DangerFullAccess => SandboxType::None,
_ => crate::safety::get_platform_sandbox().unwrap_or(SandboxType::None),
},
}
}
pub(crate) fn transform(
&self,
spec: &CommandSpec,
policy: &SandboxPolicy,
sandbox: SandboxType,
sandbox_policy_cwd: &Path,
llmx_linux_sandbox_exe: Option<&PathBuf>,
) -> Result<ExecEnv, SandboxTransformError> {
let mut env = spec.env.clone();
if !policy.has_full_network_access() {
env.insert(
LLMX_SANDBOX_NETWORK_DISABLED_ENV_VAR.to_string(),
"1".to_string(),
);
}
let mut command = Vec::with_capacity(1 + spec.args.len());
command.push(spec.program.clone());
command.extend(spec.args.iter().cloned());
let (command, sandbox_env, arg0_override) = match sandbox {
SandboxType::None => (command, HashMap::new(), None),
#[cfg(target_os = "macos")]
SandboxType::MacosSeatbelt => {
let mut seatbelt_env = HashMap::new();
seatbelt_env.insert(LLMX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
let mut args =
create_seatbelt_command_args(command.clone(), policy, sandbox_policy_cwd);
let mut full_command = Vec::with_capacity(1 + args.len());
full_command.push(MACOS_PATH_TO_SEATBELT_EXECUTABLE.to_string());
full_command.append(&mut args);
(full_command, seatbelt_env, None)
}
#[cfg(not(target_os = "macos"))]
SandboxType::MacosSeatbelt => return Err(SandboxTransformError::SeatbeltUnavailable),
SandboxType::LinuxSeccomp => {
let exe = llmx_linux_sandbox_exe
.ok_or(SandboxTransformError::MissingLinuxSandboxExecutable)?;
let mut args =
create_linux_sandbox_command_args(command.clone(), policy, sandbox_policy_cwd);
let mut full_command = Vec::with_capacity(1 + args.len());
full_command.push(exe.to_string_lossy().to_string());
full_command.append(&mut args);
(
full_command,
HashMap::new(),
Some("llmx-linux-sandbox".to_string()),
)
}
// On Windows, the restricted token sandbox executes in-process via the
// llmx-windows-sandbox crate. We leave the command unchanged here and
// branch during execution based on the sandbox type.
#[cfg(target_os = "windows")]
SandboxType::WindowsRestrictedToken => (command, HashMap::new(), None),
// When building for non-Windows targets, this variant is never constructed.
#[cfg(not(target_os = "windows"))]
SandboxType::WindowsRestrictedToken => (command, HashMap::new(), None),
};
env.extend(sandbox_env);
Ok(ExecEnv {
command,
cwd: spec.cwd.clone(),
env,
timeout_ms: spec.timeout_ms,
sandbox,
with_escalated_permissions: spec.with_escalated_permissions,
justification: spec.justification.clone(),
arg0: arg0_override,
})
}
pub fn denied(&self, sandbox: SandboxType, out: &ExecToolCallOutput) -> bool {
crate::exec::is_likely_sandbox_denied(sandbox, out)
}
}
pub async fn execute_env(
env: &ExecEnv,
policy: &SandboxPolicy,
stdout_stream: Option<StdoutStream>,
) -> crate::error::Result<ExecToolCallOutput> {
execute_exec_env(env.clone(), policy, stdout_stream).await
}

View File

@@ -0,0 +1,368 @@
#![cfg(target_os = "macos")]
use std::collections::HashMap;
use std::ffi::CStr;
use std::path::Path;
use std::path::PathBuf;
use tokio::process::Child;
use crate::protocol::SandboxPolicy;
use crate::spawn::LLMX_SANDBOX_ENV_VAR;
use crate::spawn::StdioPolicy;
use crate::spawn::spawn_child_async;
const MACOS_SEATBELT_BASE_POLICY: &str = include_str!("seatbelt_base_policy.sbpl");
const MACOS_SEATBELT_NETWORK_POLICY: &str = include_str!("seatbelt_network_policy.sbpl");
/// When working with `sandbox-exec`, only consider `sandbox-exec` in `/usr/bin`
/// to defend against an attacker trying to inject a malicious version on the
/// PATH. If /usr/bin/sandbox-exec has been tampered with, then the attacker
/// already has root access.
pub(crate) const MACOS_PATH_TO_SEATBELT_EXECUTABLE: &str = "/usr/bin/sandbox-exec";
pub async fn spawn_command_under_seatbelt(
command: Vec<String>,
command_cwd: PathBuf,
sandbox_policy: &SandboxPolicy,
sandbox_policy_cwd: &Path,
stdio_policy: StdioPolicy,
mut env: HashMap<String, String>,
) -> std::io::Result<Child> {
let args = create_seatbelt_command_args(command, sandbox_policy, sandbox_policy_cwd);
let arg0 = None;
env.insert(LLMX_SANDBOX_ENV_VAR.to_string(), "seatbelt".to_string());
spawn_child_async(
PathBuf::from(MACOS_PATH_TO_SEATBELT_EXECUTABLE),
args,
arg0,
command_cwd,
sandbox_policy,
stdio_policy,
env,
)
.await
}
pub(crate) fn create_seatbelt_command_args(
command: Vec<String>,
sandbox_policy: &SandboxPolicy,
sandbox_policy_cwd: &Path,
) -> Vec<String> {
let (file_write_policy, file_write_dir_params) = {
if sandbox_policy.has_full_disk_write_access() {
// Allegedly, this is more permissive than `(allow file-write*)`.
(
r#"(allow file-write* (regex #"^/"))"#.to_string(),
Vec::new(),
)
} else {
let writable_roots = sandbox_policy.get_writable_roots_with_cwd(sandbox_policy_cwd);
let mut writable_folder_policies: Vec<String> = Vec::new();
let mut file_write_params = Vec::new();
for (index, wr) in writable_roots.iter().enumerate() {
// Canonicalize to avoid mismatches like /var vs /private/var on macOS.
let canonical_root = wr.root.canonicalize().unwrap_or_else(|_| wr.root.clone());
let root_param = format!("WRITABLE_ROOT_{index}");
file_write_params.push((root_param.clone(), canonical_root));
if wr.read_only_subpaths.is_empty() {
writable_folder_policies.push(format!("(subpath (param \"{root_param}\"))"));
} else {
// Add parameters for each read-only subpath and generate
// the `(require-not ...)` clauses.
let mut require_parts: Vec<String> = Vec::new();
require_parts.push(format!("(subpath (param \"{root_param}\"))"));
for (subpath_index, ro) in wr.read_only_subpaths.iter().enumerate() {
let canonical_ro = ro.canonicalize().unwrap_or_else(|_| ro.clone());
let ro_param = format!("WRITABLE_ROOT_{index}_RO_{subpath_index}");
require_parts
.push(format!("(require-not (subpath (param \"{ro_param}\")))"));
file_write_params.push((ro_param, canonical_ro));
}
let policy_component = format!("(require-all {} )", require_parts.join(" "));
writable_folder_policies.push(policy_component);
}
}
if writable_folder_policies.is_empty() {
("".to_string(), Vec::new())
} else {
let file_write_policy = format!(
"(allow file-write*\n{}\n)",
writable_folder_policies.join(" ")
);
(file_write_policy, file_write_params)
}
}
};
let file_read_policy = if sandbox_policy.has_full_disk_read_access() {
"; allow read-only file operations\n(allow file-read*)"
} else {
""
};
// TODO(mbolin): apply_patch calls must also honor the SandboxPolicy.
let network_policy = if sandbox_policy.has_full_network_access() {
MACOS_SEATBELT_NETWORK_POLICY
} else {
""
};
let full_policy = format!(
"{MACOS_SEATBELT_BASE_POLICY}\n{file_read_policy}\n{file_write_policy}\n{network_policy}"
);
let dir_params = [file_write_dir_params, macos_dir_params()].concat();
let mut seatbelt_args: Vec<String> = vec!["-p".to_string(), full_policy];
let definition_args = dir_params
.into_iter()
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy()));
seatbelt_args.extend(definition_args);
seatbelt_args.push("--".to_string());
seatbelt_args.extend(command);
seatbelt_args
}
/// Wraps libc::confstr to return a String.
fn confstr(name: libc::c_int) -> Option<String> {
let mut buf = vec![0_i8; (libc::PATH_MAX as usize) + 1];
let len = unsafe { libc::confstr(name, buf.as_mut_ptr(), buf.len()) };
if len == 0 {
return None;
}
// confstr guarantees NUL-termination when len > 0.
let cstr = unsafe { CStr::from_ptr(buf.as_ptr()) };
cstr.to_str().ok().map(ToString::to_string)
}
/// Wraps confstr to return a canonicalized PathBuf.
fn confstr_path(name: libc::c_int) -> Option<PathBuf> {
let s = confstr(name)?;
let path = PathBuf::from(s);
path.canonicalize().ok().or(Some(path))
}
fn macos_dir_params() -> Vec<(String, PathBuf)> {
if let Some(p) = confstr_path(libc::_CS_DARWIN_USER_CACHE_DIR) {
return vec![("DARWIN_USER_CACHE_DIR".to_string(), p)];
}
vec![]
}
#[cfg(test)]
mod tests {
use super::MACOS_SEATBELT_BASE_POLICY;
use super::create_seatbelt_command_args;
use super::macos_dir_params;
use crate::protocol::SandboxPolicy;
use pretty_assertions::assert_eq;
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn create_seatbelt_args_with_read_only_git_subpath() {
// Create a temporary workspace with two writable roots: one containing
// a top-level .git directory and one without it.
let tmp = TempDir::new().expect("tempdir");
let PopulatedTmp {
root_with_git,
root_without_git,
root_with_git_canon,
root_with_git_git_canon,
root_without_git_canon,
} = populate_tmpdir(tmp.path());
let cwd = tmp.path().join("cwd");
// Build a policy that only includes the two test roots as writable and
// does not automatically include defaults TMPDIR or /tmp.
let policy = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![root_with_git, root_without_git],
network_access: false,
exclude_tmpdir_env_var: true,
exclude_slash_tmp: true,
};
let args = create_seatbelt_command_args(
vec!["/bin/echo".to_string(), "hello".to_string()],
&policy,
&cwd,
);
// Build the expected policy text using a raw string for readability.
// Note that the policy includes:
// - the base policy,
// - read-only access to the filesystem,
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
let expected_policy = format!(
r#"{MACOS_SEATBELT_BASE_POLICY}
; allow read-only file operations
(allow file-read*)
(allow file-write*
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")) (subpath (param "WRITABLE_ROOT_2"))
)
"#,
);
let mut expected_args = vec![
"-p".to_string(),
expected_policy,
format!(
"-DWRITABLE_ROOT_0={}",
root_with_git_canon.to_string_lossy()
),
format!(
"-DWRITABLE_ROOT_0_RO_0={}",
root_with_git_git_canon.to_string_lossy()
),
format!(
"-DWRITABLE_ROOT_1={}",
root_without_git_canon.to_string_lossy()
),
format!("-DWRITABLE_ROOT_2={}", cwd.to_string_lossy()),
];
expected_args.extend(
macos_dir_params()
.into_iter()
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
);
expected_args.extend(vec![
"--".to_string(),
"/bin/echo".to_string(),
"hello".to_string(),
]);
assert_eq!(expected_args, args);
}
#[test]
fn create_seatbelt_args_for_cwd_as_git_repo() {
// Create a temporary workspace with two writable roots: one containing
// a top-level .git directory and one without it.
let tmp = TempDir::new().expect("tempdir");
let PopulatedTmp {
root_with_git,
root_with_git_canon,
root_with_git_git_canon,
..
} = populate_tmpdir(tmp.path());
// Build a policy that does not specify any writable_roots, but does
// use the default ones (cwd and TMPDIR) and verifies the `.git` check
// is done properly for cwd.
let policy = SandboxPolicy::WorkspaceWrite {
writable_roots: vec![],
network_access: false,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
};
let args = create_seatbelt_command_args(
vec!["/bin/echo".to_string(), "hello".to_string()],
&policy,
root_with_git.as_path(),
);
let tmpdir_env_var = std::env::var("TMPDIR")
.ok()
.map(PathBuf::from)
.and_then(|p| p.canonicalize().ok())
.map(|p| p.to_string_lossy().to_string());
let tempdir_policy_entry = if tmpdir_env_var.is_some() {
r#" (subpath (param "WRITABLE_ROOT_2"))"#
} else {
""
};
// Build the expected policy text using a raw string for readability.
// Note that the policy includes:
// - the base policy,
// - read-only access to the filesystem,
// - write access to WRITABLE_ROOT_0 (but not its .git) and WRITABLE_ROOT_1.
let expected_policy = format!(
r#"{MACOS_SEATBELT_BASE_POLICY}
; allow read-only file operations
(allow file-read*)
(allow file-write*
(require-all (subpath (param "WRITABLE_ROOT_0")) (require-not (subpath (param "WRITABLE_ROOT_0_RO_0"))) ) (subpath (param "WRITABLE_ROOT_1")){tempdir_policy_entry}
)
"#,
);
let mut expected_args = vec![
"-p".to_string(),
expected_policy,
format!(
"-DWRITABLE_ROOT_0={}",
root_with_git_canon.to_string_lossy()
),
format!(
"-DWRITABLE_ROOT_0_RO_0={}",
root_with_git_git_canon.to_string_lossy()
),
format!(
"-DWRITABLE_ROOT_1={}",
PathBuf::from("/tmp")
.canonicalize()
.expect("canonicalize /tmp")
.to_string_lossy()
),
];
if let Some(p) = tmpdir_env_var {
expected_args.push(format!("-DWRITABLE_ROOT_2={p}"));
}
expected_args.extend(
macos_dir_params()
.into_iter()
.map(|(key, value)| format!("-D{key}={value}", value = value.to_string_lossy())),
);
expected_args.extend(vec![
"--".to_string(),
"/bin/echo".to_string(),
"hello".to_string(),
]);
assert_eq!(expected_args, args);
}
struct PopulatedTmp {
root_with_git: PathBuf,
root_without_git: PathBuf,
root_with_git_canon: PathBuf,
root_with_git_git_canon: PathBuf,
root_without_git_canon: PathBuf,
}
fn populate_tmpdir(tmp: &Path) -> PopulatedTmp {
let root_with_git = tmp.join("with_git");
let root_without_git = tmp.join("no_git");
fs::create_dir_all(&root_with_git).expect("create with_git");
fs::create_dir_all(&root_without_git).expect("create no_git");
fs::create_dir_all(root_with_git.join(".git")).expect("create .git");
// Ensure we have canonical paths for -D parameter matching.
let root_with_git_canon = root_with_git.canonicalize().expect("canonicalize with_git");
let root_with_git_git_canon = root_with_git_canon.join(".git");
let root_without_git_canon = root_without_git
.canonicalize()
.expect("canonicalize no_git");
PopulatedTmp {
root_with_git,
root_without_git,
root_with_git_canon,
root_with_git_git_canon,
root_without_git_canon,
}
}
}

View File

@@ -0,0 +1,95 @@
(version 1)
; inspired by Chrome's sandbox policy:
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/common.sb;l=273-319;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
; https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/renderer.sb;l=64;drc=7b3962fe2e5fc9e2ee58000dc8fbf3429d84d3bd
; start with closed-by-default
(deny default)
; child processes inherit the policy of their parent
(allow process-exec)
(allow process-fork)
(allow signal (target same-sandbox))
; Allow cf prefs to work.
(allow user-preference-read)
; process-info
(allow process-info* (target same-sandbox))
(allow file-write-data
(require-all
(path "/dev/null")
(vnode-type CHARACTER-DEVICE)))
; sysctls permitted.
(allow sysctl-read
(sysctl-name "hw.activecpu")
(sysctl-name "hw.busfrequency_compat")
(sysctl-name "hw.byteorder")
(sysctl-name "hw.cacheconfig")
(sysctl-name "hw.cachelinesize_compat")
(sysctl-name "hw.cpufamily")
(sysctl-name "hw.cpufrequency_compat")
(sysctl-name "hw.cputype")
(sysctl-name "hw.l1dcachesize_compat")
(sysctl-name "hw.l1icachesize_compat")
(sysctl-name "hw.l2cachesize_compat")
(sysctl-name "hw.l3cachesize_compat")
(sysctl-name "hw.logicalcpu_max")
(sysctl-name "hw.machine")
(sysctl-name "hw.memsize")
(sysctl-name "hw.ncpu")
(sysctl-name "hw.nperflevels")
; Chrome locks these CPU feature detection down a bit more tightly,
; but mostly for fingerprinting concerns which isn't an issue for codex.
(sysctl-name-prefix "hw.optional.arm.")
(sysctl-name-prefix "hw.optional.armv8_")
(sysctl-name "hw.packages")
(sysctl-name "hw.pagesize_compat")
(sysctl-name "hw.pagesize")
(sysctl-name "hw.physicalcpu")
(sysctl-name "hw.physicalcpu_max")
(sysctl-name "hw.tbfrequency_compat")
(sysctl-name "hw.vectorunit")
(sysctl-name "kern.hostname")
(sysctl-name "kern.maxfilesperproc")
(sysctl-name "kern.maxproc")
(sysctl-name "kern.osproductversion")
(sysctl-name "kern.osrelease")
(sysctl-name "kern.ostype")
(sysctl-name "kern.osvariant_status")
(sysctl-name "kern.osversion")
(sysctl-name "kern.secure_kernel")
(sysctl-name "kern.usrstack64")
(sysctl-name "kern.version")
(sysctl-name "sysctl.proc_cputype")
(sysctl-name "vm.loadavg")
(sysctl-name-prefix "hw.perflevel")
(sysctl-name-prefix "kern.proc.pgrp.")
(sysctl-name-prefix "kern.proc.pid.")
(sysctl-name-prefix "net.routetable.")
)
; Allow Java to set CPU type grade when required
(allow sysctl-write
(sysctl-name "kern.grade_cputype"))
; IOKit
(allow iokit-open
(iokit-registry-entry-class "RootDomainUserClient")
)
; needed to look up user info, see https://crbug.com/792228
(allow mach-lookup
(global-name "com.apple.system.opendirectoryd.libinfo")
)
; Added on top of Chrome profile
; Needed for python multiprocessing on MacOS for the SemLock
(allow ipc-posix-sem)
(allow mach-lookup
(global-name "com.apple.PowerManagement.control")
)

View File

@@ -0,0 +1,30 @@
; when network access is enabled, these policies are added after those in seatbelt_base_policy.sbpl
; Ref https://source.chromium.org/chromium/chromium/src/+/main:sandbox/policy/mac/network.sb;drc=f8f264d5e4e7509c913f4c60c2639d15905a07e4
(allow network-outbound)
(allow network-inbound)
(allow system-socket)
(allow mach-lookup
; Used to look up the _CS_DARWIN_USER_CACHE_DIR in the sandbox.
(global-name "com.apple.bsd.dirhelper")
(global-name "com.apple.system.opendirectoryd.membership")
; Communicate with the security server for TLS certificate information.
(global-name "com.apple.SecurityServer")
(global-name "com.apple.networkd")
(global-name "com.apple.ocspd")
(global-name "com.apple.trustd.agent")
; Read network configuration.
(global-name "com.apple.SystemConfiguration.DNSConfiguration")
(global-name "com.apple.SystemConfiguration.configd")
)
(allow sysctl-read
(sysctl-name-regex #"^net.routetable")
)
(allow file-write*
(subpath (param "DARWIN_USER_CACHE_DIR"))
)

434
llmx-rs/core/src/shell.rs Normal file
View File

@@ -0,0 +1,434 @@
use serde::Deserialize;
use serde::Serialize;
use std::path::PathBuf;
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ZshShell {
pub(crate) shell_path: String,
pub(crate) zshrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct BashShell {
pub(crate) shell_path: String,
pub(crate) bashrc_path: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct PowerShellConfig {
pub(crate) exe: String, // Executable name or path, e.g. "pwsh" or "powershell.exe".
pub(crate) bash_exe_fallback: Option<PathBuf>, // In case the model generates a bash command.
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub enum Shell {
Zsh(ZshShell),
Bash(BashShell),
PowerShell(PowerShellConfig),
Unknown,
}
impl Shell {
pub fn name(&self) -> Option<String> {
match self {
Shell::Zsh(zsh) => std::path::Path::new(&zsh.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::Bash(bash) => std::path::Path::new(&bash.shell_path)
.file_name()
.map(|s| s.to_string_lossy().to_string()),
Shell::PowerShell(ps) => Some(ps.exe.clone()),
Shell::Unknown => None,
}
}
}
#[cfg(unix)]
fn detect_default_user_shell() -> Shell {
use libc::getpwuid;
use libc::getuid;
use std::ffi::CStr;
unsafe {
let uid = getuid();
let pw = getpwuid(uid);
if !pw.is_null() {
let shell_path = CStr::from_ptr((*pw).pw_shell)
.to_string_lossy()
.into_owned();
let home_path = CStr::from_ptr((*pw).pw_dir).to_string_lossy().into_owned();
if shell_path.ends_with("/zsh") {
return Shell::Zsh(ZshShell {
shell_path,
zshrc_path: format!("{home_path}/.zshrc"),
});
}
if shell_path.ends_with("/bash") {
return Shell::Bash(BashShell {
shell_path,
bashrc_path: format!("{home_path}/.bashrc"),
});
}
}
}
Shell::Unknown
}
#[cfg(unix)]
pub async fn default_user_shell() -> Shell {
detect_default_user_shell()
}
#[cfg(target_os = "windows")]
pub async fn default_user_shell() -> Shell {
use tokio::process::Command;
// Prefer PowerShell 7+ (`pwsh`) if available, otherwise fall back to Windows PowerShell.
let has_pwsh = Command::new("pwsh")
.arg("-NoLogo")
.arg("-NoProfile")
.arg("-Command")
.arg("$PSVersionTable.PSVersion.Major")
.output()
.await
.map(|o| o.status.success())
.unwrap_or(false);
let bash_exe = if Command::new("bash.exe")
.arg("--version")
.stdin(std::process::Stdio::null())
.output()
.await
.ok()
.map(|o| o.status.success())
.unwrap_or(false)
{
which::which("bash.exe").ok()
} else {
None
};
if has_pwsh {
Shell::PowerShell(PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: bash_exe,
})
} else {
Shell::PowerShell(PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: bash_exe,
})
}
}
#[cfg(all(not(target_os = "windows"), not(unix)))]
pub async fn default_user_shell() -> Shell {
Shell::Unknown
}
#[cfg(test)]
#[cfg(unix)]
mod tests {
use super::*;
use std::path::PathBuf;
use std::process::Command;
#[tokio::test]
async fn test_current_shell_detects_zsh() {
let shell = Command::new("sh")
.arg("-c")
.arg("echo $SHELL")
.output()
.unwrap();
let home = std::env::var("HOME").unwrap();
let shell_path = String::from_utf8_lossy(&shell.stdout).trim().to_string();
if shell_path.ends_with("/zsh") {
assert_eq!(
default_user_shell().await,
Shell::Zsh(ZshShell {
shell_path: shell_path.to_string(),
zshrc_path: format!("{home}/.zshrc",),
})
);
}
}
#[tokio::test]
async fn test_run_with_profile_bash_escaping_and_execution() {
let shell_path = "/bin/bash";
let cases = vec![
(
vec!["myecho"],
vec![shell_path, "-lc", "source BASHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["bash", "-lc", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source BASHRC_PATH && (echo 'single' \"double\")",
],
Some("single double\n"),
),
];
for (input, expected_cmd, expected_output) in cases {
use std::collections::HashMap;
use crate::exec::ExecParams;
use crate::exec::SandboxType;
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
let temp_home = tempfile::tempdir().unwrap();
let bashrc_path = temp_home.path().join(".bashrc");
std::fs::write(
&bashrc_path,
r#"
set -x
function myecho {
echo 'It works!'
}
"#,
)
.unwrap();
let command = expected_cmd
.iter()
.map(|s| s.replace("BASHRC_PATH", bashrc_path.to_str().unwrap()))
.collect::<Vec<_>>();
let output = process_exec_tool_call(
ExecParams {
command: command.clone(),
cwd: PathBuf::from(temp_home.path()),
timeout_ms: None,
env: HashMap::from([(
"HOME".to_string(),
temp_home.path().to_str().unwrap().to_string(),
)]),
with_escalated_permissions: None,
justification: None,
arg0: None,
},
SandboxType::None,
&SandboxPolicy::DangerFullAccess,
temp_home.path(),
&None,
None,
)
.await
.unwrap();
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
if let Some(expected) = expected_output {
assert_eq!(
output.stdout.text, expected,
"input: {input:?} output: {output:?}"
);
}
}
}
}
#[cfg(test)]
#[cfg(target_os = "macos")]
mod macos_tests {
use std::path::PathBuf;
#[tokio::test]
async fn test_run_with_profile_escaping_and_execution() {
let shell_path = "/bin/zsh";
let cases = vec![
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["myecho"],
vec![shell_path, "-lc", "source ZSHRC_PATH && (myecho)"],
Some("It works!\n"),
),
(
vec!["bash", "-c", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (bash -c \"echo 'single' \\\"double\\\"\")",
],
Some("single double\n"),
),
(
vec!["bash", "-lc", "echo 'single' \"double\""],
vec![
shell_path,
"-lc",
"source ZSHRC_PATH && (echo 'single' \"double\")",
],
Some("single double\n"),
),
];
for (input, expected_cmd, expected_output) in cases {
use std::collections::HashMap;
use crate::exec::ExecParams;
use crate::exec::SandboxType;
use crate::exec::process_exec_tool_call;
use crate::protocol::SandboxPolicy;
let temp_home = tempfile::tempdir().unwrap();
let zshrc_path = temp_home.path().join(".zshrc");
std::fs::write(
&zshrc_path,
r#"
set -x
function myecho {
echo 'It works!'
}
"#,
)
.unwrap();
let command = expected_cmd
.iter()
.map(|s| s.replace("ZSHRC_PATH", zshrc_path.to_str().unwrap()))
.collect::<Vec<_>>();
let output = process_exec_tool_call(
ExecParams {
command: command.clone(),
cwd: PathBuf::from(temp_home.path()),
timeout_ms: None,
env: HashMap::from([(
"HOME".to_string(),
temp_home.path().to_str().unwrap().to_string(),
)]),
with_escalated_permissions: None,
justification: None,
arg0: None,
},
SandboxType::None,
&SandboxPolicy::DangerFullAccess,
temp_home.path(),
&None,
None,
)
.await
.unwrap();
assert_eq!(output.exit_code, 0, "input: {input:?} output: {output:?}");
if let Some(expected) = expected_output {
assert_eq!(
output.stdout.text, expected,
"input: {input:?} output: {output:?}"
);
}
}
}
}
#[cfg(test)]
#[cfg(target_os = "windows")]
mod tests_windows {
use super::*;
#[test]
fn test_format_default_shell_invocation_powershell() {
use std::path::PathBuf;
let cases = vec![
(
PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: None,
},
vec!["bash", "-lc", "echo hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: None,
},
vec!["bash", "-lc", "echo hello"],
vec!["powershell.exe", "-NoProfile", "-Command", "echo hello"],
),
(
PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
},
vec!["bash", "-lc", "echo hello"],
vec!["bash.exe", "-lc", "echo hello"],
),
(
PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
},
vec![
"bash",
"-lc",
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
],
vec![
"bash.exe",
"-lc",
"apply_patch <<'EOF'\n*** Begin Patch\n*** Update File: destination_file.txt\n-original content\n+modified content\n*** End Patch\nEOF",
],
),
(
PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
},
vec!["echo", "hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
PowerShellConfig {
exe: "pwsh.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
},
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
vec!["pwsh.exe", "-NoProfile", "-Command", "echo hello"],
),
(
PowerShellConfig {
exe: "powershell.exe".to_string(),
bash_exe_fallback: Some(PathBuf::from("bash.exe")),
},
vec![
"llmx-mcp-server.exe",
"--llmx-run-as-apply-patch",
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
],
vec![
"llmx-mcp-server.exe",
"--llmx-run-as-apply-patch",
"*** Begin Patch\n*** Update File: C:\\Users\\person\\destination_file.txt\n-original content\n+modified content\n*** End Patch",
],
),
];
for (config, input, expected_cmd) in cases {
let command = expected_cmd
.iter()
.map(|s| (*s).to_string())
.collect::<Vec<_>>();
// These tests assert the final command for each scenario now that the helper
// has been removed. The inputs remain to document the original coverage.
let expected = expected_cmd
.iter()
.map(|s| (*s).to_string())
.collect::<Vec<_>>();
assert_eq!(command, expected, "input: {input:?} config: {config:?}");
}
}
}

117
llmx-rs/core/src/spawn.rs Normal file
View File

@@ -0,0 +1,117 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::process::Stdio;
use tokio::process::Child;
use tokio::process::Command;
use tracing::trace;
use crate::protocol::SandboxPolicy;
/// Experimental environment variable that will be set to some non-empty value
/// if both of the following are true:
///
/// 1. The process was spawned by Llmx as part of a shell tool call.
/// 2. SandboxPolicy.has_full_network_access() was false for the tool call.
///
/// We may try to have just one environment variable for all sandboxing
/// attributes, so this may change in the future.
pub const LLMX_SANDBOX_NETWORK_DISABLED_ENV_VAR: &str = "LLMX_SANDBOX_NETWORK_DISABLED";
/// Should be set when the process is spawned under a sandbox. Currently, the
/// value is "seatbelt" for macOS, but it may change in the future to
/// accommodate sandboxing configuration and other sandboxing mechanisms.
pub const LLMX_SANDBOX_ENV_VAR: &str = "LLMX_SANDBOX";
#[derive(Debug, Clone, Copy)]
pub enum StdioPolicy {
RedirectForShellTool,
Inherit,
}
/// Spawns the appropriate child process for the ExecParams and SandboxPolicy,
/// ensuring the args and environment variables used to create the `Command`
/// (and `Child`) honor the configuration.
///
/// For now, we take `SandboxPolicy` as a parameter to spawn_child() because
/// we need to determine whether to set the
/// `LLMX_SANDBOX_NETWORK_DISABLED_ENV_VAR` environment variable.
pub(crate) async fn spawn_child_async(
program: PathBuf,
args: Vec<String>,
#[cfg_attr(not(unix), allow(unused_variables))] arg0: Option<&str>,
cwd: PathBuf,
sandbox_policy: &SandboxPolicy,
stdio_policy: StdioPolicy,
env: HashMap<String, String>,
) -> std::io::Result<Child> {
trace!(
"spawn_child_async: {program:?} {args:?} {arg0:?} {cwd:?} {sandbox_policy:?} {stdio_policy:?} {env:?}"
);
let mut cmd = Command::new(&program);
#[cfg(unix)]
cmd.arg0(arg0.map_or_else(|| program.to_string_lossy().to_string(), String::from));
cmd.args(args);
cmd.current_dir(cwd);
cmd.env_clear();
cmd.envs(env);
if !sandbox_policy.has_full_network_access() {
cmd.env(LLMX_SANDBOX_NETWORK_DISABLED_ENV_VAR, "1");
}
// If this LLMX process dies (including being killed via SIGKILL), we want
// any child processes that were spawned as part of a `"shell"` tool call
// to also be terminated.
#[cfg(unix)]
unsafe {
#[cfg(target_os = "linux")]
let parent_pid = libc::getpid();
cmd.pre_exec(move || {
if libc::setpgid(0, 0) == -1 {
return Err(std::io::Error::last_os_error());
}
// This relies on prctl(2), so it only works on Linux.
#[cfg(target_os = "linux")]
{
// This prctl call effectively requests, "deliver SIGTERM when my
// current parent dies."
if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGTERM) == -1 {
return Err(std::io::Error::last_os_error());
}
// Though if there was a race condition and this pre_exec() block is
// run _after_ the parent (i.e., the LLMX process) has already
// exited, then parent will be the closest configured "subreaper"
// ancestor process, or PID 1 (init). If the LLMX process has exited
// already, so should the child process.
if libc::getppid() != parent_pid {
libc::raise(libc::SIGTERM);
}
}
Ok(())
});
}
match stdio_policy {
StdioPolicy::RedirectForShellTool => {
// Do not create a file descriptor for stdin because otherwise some
// commands may hang forever waiting for input. For example, ripgrep has
// a heuristic where it may try to read from stdin as explained here:
// https://github.com/BurntSushi/ripgrep/blob/e2362d4d5185d02fa857bf381e7bd52e66fafc73/crates/core/flags/hiargs.rs#L1101-L1103
cmd.stdin(Stdio::null());
cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
}
StdioPolicy::Inherit => {
// Inherit stdin, stdout, and stderr from the parent process.
cmd.stdin(Stdio::inherit())
.stdout(Stdio::inherit())
.stderr(Stdio::inherit());
}
}
cmd.kill_on_drop(true).spawn()
}

View File

@@ -0,0 +1,9 @@
mod service;
mod session;
mod turn;
pub(crate) use service::SessionServices;
pub(crate) use session::SessionState;
pub(crate) use turn::ActiveTurn;
pub(crate) use turn::RunningTask;
pub(crate) use turn::TaskKind;

View File

@@ -0,0 +1,22 @@
use std::sync::Arc;
use crate::AuthManager;
use crate::RolloutRecorder;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::user_notification::UserNotifier;
use llmx_otel::otel_event_manager::OtelEventManager;
use tokio::sync::Mutex;
pub(crate) struct SessionServices {
pub(crate) mcp_connection_manager: McpConnectionManager,
pub(crate) unified_exec_manager: UnifiedExecSessionManager,
pub(crate) notifier: UserNotifier,
pub(crate) rollout: Mutex<Option<RolloutRecorder>>,
pub(crate) user_shell: crate::shell::Shell,
pub(crate) show_raw_agent_reasoning: bool,
pub(crate) auth_manager: Arc<AuthManager>,
pub(crate) otel_event_manager: OtelEventManager,
pub(crate) tool_approvals: Mutex<ApprovalStore>,
}

View File

@@ -0,0 +1,71 @@
//! Session-wide mutable state.
use llmx_protocol::models::ResponseItem;
use crate::context_manager::ContextManager;
use crate::llmx::SessionConfiguration;
use crate::protocol::RateLimitSnapshot;
use crate::protocol::TokenUsage;
use crate::protocol::TokenUsageInfo;
/// Persistent, session-scoped state previously stored directly on `Session`.
pub(crate) struct SessionState {
pub(crate) session_configuration: SessionConfiguration,
pub(crate) history: ContextManager,
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
}
impl SessionState {
/// Create a new session state mirroring previous `State::default()` semantics.
pub(crate) fn new(session_configuration: SessionConfiguration) -> Self {
Self {
session_configuration,
history: ContextManager::new(),
latest_rate_limits: None,
}
}
// History helpers
pub(crate) fn record_items<I>(&mut self, items: I)
where
I: IntoIterator,
I::Item: std::ops::Deref<Target = ResponseItem>,
{
self.history.record_items(items)
}
pub(crate) fn clone_history(&self) -> ContextManager {
self.history.clone()
}
pub(crate) fn replace_history(&mut self, items: Vec<ResponseItem>) {
self.history.replace(items);
}
// Token/rate limit helpers
pub(crate) fn update_token_info_from_usage(
&mut self,
usage: &TokenUsage,
model_context_window: Option<i64>,
) {
self.history.update_token_info(usage, model_context_window);
}
pub(crate) fn token_info(&self) -> Option<TokenUsageInfo> {
self.history.token_info()
}
pub(crate) fn set_rate_limits(&mut self, snapshot: RateLimitSnapshot) {
self.latest_rate_limits = Some(snapshot);
}
pub(crate) fn token_info_and_rate_limits(
&self,
) -> (Option<TokenUsageInfo>, Option<RateLimitSnapshot>) {
(self.token_info(), self.latest_rate_limits.clone())
}
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
self.history.set_token_usage_full(context_window);
}
}

View File

@@ -0,0 +1,115 @@
//! Turn-scoped state and active turn metadata scaffolding.
use indexmap::IndexMap;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use llmx_protocol::models::ResponseInputItem;
use tokio::sync::oneshot;
use crate::llmx::TurnContext;
use crate::protocol::ReviewDecision;
use crate::tasks::SessionTask;
/// Metadata about the currently running turn.
pub(crate) struct ActiveTurn {
pub(crate) tasks: IndexMap<String, RunningTask>,
pub(crate) turn_state: Arc<Mutex<TurnState>>,
}
impl Default for ActiveTurn {
fn default() -> Self {
Self {
tasks: IndexMap::new(),
turn_state: Arc::new(Mutex::new(TurnState::default())),
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum TaskKind {
Regular,
Review,
Compact,
}
#[derive(Clone)]
pub(crate) struct RunningTask {
pub(crate) done: Arc<Notify>,
pub(crate) kind: TaskKind,
pub(crate) task: Arc<dyn SessionTask>,
pub(crate) cancellation_token: CancellationToken,
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
pub(crate) turn_context: Arc<TurnContext>,
}
impl ActiveTurn {
pub(crate) fn add_task(&mut self, task: RunningTask) {
let sub_id = task.turn_context.sub_id.clone();
self.tasks.insert(sub_id, task);
}
pub(crate) fn remove_task(&mut self, sub_id: &str) -> bool {
self.tasks.swap_remove(sub_id);
self.tasks.is_empty()
}
pub(crate) fn drain_tasks(&mut self) -> Vec<RunningTask> {
self.tasks.drain(..).map(|(_, task)| task).collect()
}
}
/// Mutable state for a single turn.
#[derive(Default)]
pub(crate) struct TurnState {
pending_approvals: HashMap<String, oneshot::Sender<ReviewDecision>>,
pending_input: Vec<ResponseInputItem>,
}
impl TurnState {
pub(crate) fn insert_pending_approval(
&mut self,
key: String,
tx: oneshot::Sender<ReviewDecision>,
) -> Option<oneshot::Sender<ReviewDecision>> {
self.pending_approvals.insert(key, tx)
}
pub(crate) fn remove_pending_approval(
&mut self,
key: &str,
) -> Option<oneshot::Sender<ReviewDecision>> {
self.pending_approvals.remove(key)
}
pub(crate) fn clear_pending(&mut self) {
self.pending_approvals.clear();
self.pending_input.clear();
}
pub(crate) fn push_pending_input(&mut self, input: ResponseInputItem) {
self.pending_input.push(input);
}
pub(crate) fn take_pending_input(&mut self) -> Vec<ResponseInputItem> {
if self.pending_input.is_empty() {
Vec::with_capacity(0)
} else {
let mut ret = Vec::new();
std::mem::swap(&mut ret, &mut self.pending_input);
ret
}
}
}
impl ActiveTurn {
/// Clear any pending approvals and input buffered for the current turn.
pub(crate) async fn clear_pending(&self) {
let mut ts = self.turn_state.lock().await;
ts.clear_pending();
}
}

View File

@@ -0,0 +1,32 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::compact;
use crate::llmx::TurnContext;
use crate::state::TaskKind;
use llmx_protocol::user_input::UserInput;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct CompactTask;
#[async_trait]
impl SessionTask for CompactTask {
fn kind(&self) -> TaskKind {
TaskKind::Compact
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
_cancellation_token: CancellationToken,
) -> Option<String> {
compact::run_compact_task(session.clone_session(), ctx, input).await
}
}

View File

@@ -0,0 +1,110 @@
use crate::llmx::TurnContext;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use async_trait::async_trait;
use llmx_git::CreateGhostCommitOptions;
use llmx_git::GitToolingError;
use llmx_git::create_ghost_commit;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::user_input::UserInput;
use llmx_utils_readiness::Readiness;
use llmx_utils_readiness::Token;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::info;
use tracing::warn;
pub(crate) struct GhostSnapshotTask {
token: Token,
}
#[async_trait]
impl SessionTask for GhostSnapshotTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
tokio::task::spawn(async move {
let token = self.token;
let ctx_for_task = Arc::clone(&ctx);
let cancelled = tokio::select! {
_ = cancellation_token.cancelled() => true,
_ = async {
let repo_path = ctx_for_task.cwd.clone();
// Required to run in a dedicated blocking pool.
match tokio::task::spawn_blocking(move || {
let options = CreateGhostCommitOptions::new(&repo_path);
create_ghost_commit(&options)
})
.await
{
Ok(Ok(ghost_commit)) => {
info!("ghost snapshot blocking task finished");
session
.session
.record_conversation_items(&ctx, &[ResponseItem::GhostSnapshot {
ghost_commit: ghost_commit.clone(),
}])
.await;
info!("ghost commit captured: {}", ghost_commit.id());
}
Ok(Err(err)) => {
warn!(
sub_id = ctx_for_task.sub_id.as_str(),
"failed to capture ghost snapshot: {err}"
);
let message = match err {
GitToolingError::NotAGitRepository { .. } => {
"Snapshots disabled: current directory is not a Git repository."
.to_string()
}
_ => format!("Snapshots disabled after ghost snapshot error: {err}."),
};
session
.session
.notify_background_event(&ctx_for_task, message)
.await;
}
Err(err) => {
warn!(
sub_id = ctx_for_task.sub_id.as_str(),
"ghost snapshot task panicked: {err}"
);
let message =
format!("Snapshots disabled after ghost snapshot panic: {err}.");
session
.session
.notify_background_event(&ctx_for_task, message)
.await;
}
}
} => false,
};
if cancelled {
info!("ghost snapshot task cancelled");
}
match ctx.tool_call_gate.mark_ready(token).await {
Ok(true) => info!("ghost snapshot gate marked ready"),
Ok(false) => warn!("ghost snapshot gate already ready"),
Err(err) => warn!("failed to mark ghost snapshot ready: {err}"),
}
});
None
}
}
impl GhostSnapshotTask {
pub(crate) fn new(token: Token) -> Self {
Self { token }
}
}

View File

@@ -0,0 +1,225 @@
mod compact;
mod ghost_snapshot;
mod regular;
mod review;
mod undo;
mod user_shell;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::select;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
use tracing::trace;
use tracing::warn;
use crate::AuthManager;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::protocol::EventMsg;
use crate::protocol::TaskCompleteEvent;
use crate::protocol::TurnAbortReason;
use crate::protocol::TurnAbortedEvent;
use crate::state::ActiveTurn;
use crate::state::RunningTask;
use crate::state::TaskKind;
use llmx_protocol::user_input::UserInput;
pub(crate) use compact::CompactTask;
pub(crate) use ghost_snapshot::GhostSnapshotTask;
pub(crate) use regular::RegularTask;
pub(crate) use review::ReviewTask;
pub(crate) use undo::UndoTask;
pub(crate) use user_shell::UserShellCommandTask;
const GRACEFULL_INTERRUPTION_TIMEOUT_MS: u64 = 100;
/// Thin wrapper that exposes the parts of [`Session`] task runners need.
#[derive(Clone)]
pub(crate) struct SessionTaskContext {
session: Arc<Session>,
}
impl SessionTaskContext {
pub(crate) fn new(session: Arc<Session>) -> Self {
Self { session }
}
pub(crate) fn clone_session(&self) -> Arc<Session> {
Arc::clone(&self.session)
}
pub(crate) fn auth_manager(&self) -> Arc<AuthManager> {
Arc::clone(&self.session.services.auth_manager)
}
}
/// Async task that drives a [`Session`] turn.
///
/// Implementations encapsulate a specific Llmx workflow (regular chat,
/// reviews, ghost snapshots, etc.). Each task instance is owned by a
/// [`Session`] and executed on a background Tokio task. The trait is
/// intentionally small: implementers identify themselves via
/// [`SessionTask::kind`], perform their work in [`SessionTask::run`], and may
/// release resources in [`SessionTask::abort`].
#[async_trait]
pub(crate) trait SessionTask: Send + Sync + 'static {
/// Describes the type of work the task performs so the session can
/// surface it in telemetry and UI.
fn kind(&self) -> TaskKind;
/// Executes the task until completion or cancellation.
///
/// Implementations typically stream protocol events using `session` and
/// `ctx`, returning an optional final agent message when finished. The
/// provided `cancellation_token` is cancelled when the session requests an
/// abort; implementers should watch for it and terminate quickly once it
/// fires. Returning [`Some`] yields a final message that
/// [`Session::on_task_finished`] will emit to the client.
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String>;
/// Gives the task a chance to perform cleanup after an abort.
///
/// The default implementation is a no-op; override this if additional
/// teardown or notifications are required once
/// [`Session::abort_all_tasks`] cancels the task.
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
let _ = (session, ctx);
}
}
impl Session {
pub async fn spawn_task<T: SessionTask>(
self: &Arc<Self>,
turn_context: Arc<TurnContext>,
input: Vec<UserInput>,
task: T,
) {
self.abort_all_tasks(TurnAbortReason::Replaced).await;
let task: Arc<dyn SessionTask> = Arc::new(task);
let task_kind = task.kind();
let cancellation_token = CancellationToken::new();
let done = Arc::new(Notify::new());
let done_clone = Arc::clone(&done);
let handle = {
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
let ctx = Arc::clone(&turn_context);
let task_for_run = Arc::clone(&task);
let task_cancellation_token = cancellation_token.child_token();
tokio::spawn(async move {
let ctx_for_finish = Arc::clone(&ctx);
let last_agent_message = task_for_run
.run(
Arc::clone(&session_ctx),
ctx,
input,
task_cancellation_token.child_token(),
)
.await;
session_ctx.clone_session().flush_rollout().await;
if !task_cancellation_token.is_cancelled() {
// Emit completion uniformly from spawn site so all tasks share the same lifecycle.
let sess = session_ctx.clone_session();
sess.on_task_finished(ctx_for_finish, last_agent_message)
.await;
}
done_clone.notify_waiters();
})
};
let running_task = RunningTask {
done,
handle: Arc::new(AbortOnDropHandle::new(handle)),
kind: task_kind,
task,
cancellation_token,
turn_context: Arc::clone(&turn_context),
};
self.register_new_active_task(running_task).await;
}
pub async fn abort_all_tasks(self: &Arc<Self>, reason: TurnAbortReason) {
for task in self.take_all_running_tasks().await {
self.handle_task_abort(task, reason.clone()).await;
}
}
pub async fn on_task_finished(
self: &Arc<Self>,
turn_context: Arc<TurnContext>,
last_agent_message: Option<String>,
) {
let mut active = self.active_turn.lock().await;
if let Some(at) = active.as_mut()
&& at.remove_task(&turn_context.sub_id)
{
*active = None;
}
drop(active);
let event = EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message });
self.send_event(turn_context.as_ref(), event).await;
}
async fn register_new_active_task(&self, task: RunningTask) {
let mut active = self.active_turn.lock().await;
let mut turn = ActiveTurn::default();
turn.add_task(task);
*active = Some(turn);
}
async fn take_all_running_tasks(&self) -> Vec<RunningTask> {
let mut active = self.active_turn.lock().await;
match active.take() {
Some(mut at) => {
at.clear_pending().await;
at.drain_tasks()
}
None => Vec::new(),
}
}
async fn handle_task_abort(self: &Arc<Self>, task: RunningTask, reason: TurnAbortReason) {
let sub_id = task.turn_context.sub_id.clone();
if task.cancellation_token.is_cancelled() {
return;
}
trace!(task_kind = ?task.kind, sub_id, "aborting running task");
task.cancellation_token.cancel();
let session_task = task.task;
select! {
_ = task.done.notified() => {
},
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
}
}
task.handle.abort();
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
session_task
.abort(session_ctx, Arc::clone(&task.turn_context))
.await;
let event = EventMsg::TurnAborted(TurnAbortedEvent { reason });
self.send_event(task.turn_context.as_ref(), event).await;
}
}
#[cfg(test)]
mod tests {}

View File

@@ -0,0 +1,33 @@
use std::sync::Arc;
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;
use crate::llmx::TurnContext;
use crate::llmx::run_task;
use crate::state::TaskKind;
use llmx_protocol::user_input::UserInput;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct RegularTask;
#[async_trait]
impl SessionTask for RegularTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
let sess = session.clone_session();
run_task(sess, ctx, input, cancellation_token).await
}
}

View File

@@ -0,0 +1,210 @@
use std::sync::Arc;
use async_trait::async_trait;
use llmx_protocol::items::TurnItem;
use llmx_protocol::models::ContentItem;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::protocol::AgentMessageContentDeltaEvent;
use llmx_protocol::protocol::AgentMessageDeltaEvent;
use llmx_protocol::protocol::Event;
use llmx_protocol::protocol::EventMsg;
use llmx_protocol::protocol::ExitedReviewModeEvent;
use llmx_protocol::protocol::ItemCompletedEvent;
use llmx_protocol::protocol::ReviewOutputEvent;
use tokio_util::sync::CancellationToken;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::llmx_delegate::run_llmx_conversation_one_shot;
use crate::review_format::format_review_findings_block;
use crate::state::TaskKind;
use llmx_protocol::user_input::UserInput;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone, Copy, Default)]
pub(crate) struct ReviewTask;
#[async_trait]
impl SessionTask for ReviewTask {
fn kind(&self) -> TaskKind {
TaskKind::Review
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
// Start sub-llmx conversation and get the receiver for events.
let output = match start_review_conversation(
session.clone(),
ctx.clone(),
input,
cancellation_token.clone(),
)
.await
{
Some(receiver) => process_review_events(session.clone(), ctx.clone(), receiver).await,
None => None,
};
if !cancellation_token.is_cancelled() {
exit_review_mode(session.clone_session(), output.clone(), ctx.clone()).await;
}
None
}
async fn abort(&self, session: Arc<SessionTaskContext>, ctx: Arc<TurnContext>) {
exit_review_mode(session.clone_session(), None, ctx).await;
}
}
async fn start_review_conversation(
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<async_channel::Receiver<Event>> {
let config = ctx.client.config();
let mut sub_agent_config = config.as_ref().clone();
// Run with only reviewer rubric — drop outer user_instructions
sub_agent_config.user_instructions = None;
// Avoid loading project docs; reviewer only needs findings
sub_agent_config.project_doc_max_bytes = 0;
// Carry over review-only feature restrictions so the delegate cannot
// re-enable blocked tools (web search, view image).
sub_agent_config
.features
.disable(crate::features::Feature::WebSearchRequest)
.disable(crate::features::Feature::ViewImageTool);
// Set explicit review rubric for the sub-agent
sub_agent_config.base_instructions = Some(crate::REVIEW_PROMPT.to_string());
(run_llmx_conversation_one_shot(
sub_agent_config,
session.auth_manager(),
input,
session.clone_session(),
ctx.clone(),
cancellation_token,
None,
)
.await)
.ok()
.map(|io| io.rx_event)
}
async fn process_review_events(
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
receiver: async_channel::Receiver<Event>,
) -> Option<ReviewOutputEvent> {
let mut prev_agent_message: Option<Event> = None;
while let Ok(event) = receiver.recv().await {
match event.clone().msg {
EventMsg::AgentMessage(_) => {
if let Some(prev) = prev_agent_message.take() {
session
.clone_session()
.send_event(ctx.as_ref(), prev.msg)
.await;
}
prev_agent_message = Some(event);
}
// Suppress ItemCompleted only for assistant messages: forwarding it
// would trigger legacy AgentMessage via as_legacy_events(), which this
// review flow intentionally hides in favor of structured output.
EventMsg::ItemCompleted(ItemCompletedEvent {
item: TurnItem::AgentMessage(_),
..
})
| EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { .. })
| EventMsg::AgentMessageContentDelta(AgentMessageContentDeltaEvent { .. }) => {}
EventMsg::TaskComplete(task_complete) => {
// Parse review output from the last agent message (if present).
let out = task_complete
.last_agent_message
.as_deref()
.map(parse_review_output_event);
return out;
}
EventMsg::TurnAborted(_) => {
// Cancellation or abort: consumer will finalize with None.
return None;
}
other => {
session
.clone_session()
.send_event(ctx.as_ref(), other)
.await;
}
}
}
// Channel closed without TaskComplete: treat as interrupted.
None
}
/// Parse a ReviewOutputEvent from a text blob returned by the reviewer model.
/// If the text is valid JSON matching ReviewOutputEvent, deserialize it.
/// Otherwise, attempt to extract the first JSON object substring and parse it.
/// If parsing still fails, return a structured fallback carrying the plain text
/// in `overall_explanation`.
fn parse_review_output_event(text: &str) -> ReviewOutputEvent {
if let Ok(ev) = serde_json::from_str::<ReviewOutputEvent>(text) {
return ev;
}
if let (Some(start), Some(end)) = (text.find('{'), text.rfind('}'))
&& start < end
&& let Some(slice) = text.get(start..=end)
&& let Ok(ev) = serde_json::from_str::<ReviewOutputEvent>(slice)
{
return ev;
}
ReviewOutputEvent {
overall_explanation: text.to_string(),
..Default::default()
}
}
/// Emits an ExitedReviewMode Event with optional ReviewOutput,
/// and records a developer message with the review output.
pub(crate) async fn exit_review_mode(
session: Arc<Session>,
review_output: Option<ReviewOutputEvent>,
ctx: Arc<TurnContext>,
) {
let user_message = if let Some(out) = review_output.clone() {
let mut findings_str = String::new();
let text = out.overall_explanation.trim();
if !text.is_empty() {
findings_str.push_str(text);
}
if !out.findings.is_empty() {
let block = format_review_findings_block(&out.findings, None);
findings_str.push_str(&format!("\n{block}"));
}
crate::client_common::REVIEW_EXIT_SUCCESS_TMPL.replace("{results}", &findings_str)
} else {
crate::client_common::REVIEW_EXIT_INTERRUPTED_TMPL.to_string()
};
session
.record_conversation_items(
&ctx,
&[ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputText { text: user_message }],
}],
)
.await;
session
.send_event(
ctx.as_ref(),
EventMsg::ExitedReviewMode(ExitedReviewModeEvent { review_output }),
)
.await;
}

View File

@@ -0,0 +1,117 @@
use std::sync::Arc;
use crate::llmx::TurnContext;
use crate::protocol::EventMsg;
use crate::protocol::UndoCompletedEvent;
use crate::protocol::UndoStartedEvent;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use async_trait::async_trait;
use llmx_git::restore_ghost_commit;
use llmx_protocol::models::ResponseItem;
use llmx_protocol::user_input::UserInput;
use tokio_util::sync::CancellationToken;
use tracing::error;
use tracing::info;
use tracing::warn;
pub(crate) struct UndoTask;
impl UndoTask {
pub(crate) fn new() -> Self {
Self
}
}
#[async_trait]
impl SessionTask for UndoTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
let sess = session.clone_session();
sess.send_event(
ctx.as_ref(),
EventMsg::UndoStarted(UndoStartedEvent {
message: Some("Undo in progress...".to_string()),
}),
)
.await;
if cancellation_token.is_cancelled() {
sess.send_event(
ctx.as_ref(),
EventMsg::UndoCompleted(UndoCompletedEvent {
success: false,
message: Some("Undo cancelled.".to_string()),
}),
)
.await;
return None;
}
let mut history = sess.clone_history().await;
let mut items = history.get_history();
let mut completed = UndoCompletedEvent {
success: false,
message: None,
};
let Some((idx, ghost_commit)) =
items
.iter()
.enumerate()
.rev()
.find_map(|(idx, item)| match item {
ResponseItem::GhostSnapshot { ghost_commit } => {
Some((idx, ghost_commit.clone()))
}
_ => None,
})
else {
completed.message = Some("No ghost snapshot available to undo.".to_string());
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
.await;
return None;
};
let commit_id = ghost_commit.id().to_string();
let repo_path = ctx.cwd.clone();
let restore_result =
tokio::task::spawn_blocking(move || restore_ghost_commit(&repo_path, &ghost_commit))
.await;
match restore_result {
Ok(Ok(())) => {
items.remove(idx);
sess.replace_history(items).await;
let short_id: String = commit_id.chars().take(7).collect();
info!(commit_id = commit_id, "Undo restored ghost snapshot");
completed.success = true;
completed.message = Some(format!("Undo restored snapshot {short_id}."));
}
Ok(Err(err)) => {
let message = format!("Failed to restore snapshot {commit_id}: {err}");
warn!("{message}");
completed.message = Some(message);
}
Err(err) => {
let message = format!("Failed to restore snapshot {commit_id}: {err}");
error!("{message}");
completed.message = Some(message);
}
}
sess.send_event(ctx.as_ref(), EventMsg::UndoCompleted(completed))
.await;
None
}
}

View File

@@ -0,0 +1,211 @@
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use llmx_async_utils::CancelErr;
use llmx_async_utils::OrCancelExt;
use llmx_protocol::user_input::UserInput;
use tokio_util::sync::CancellationToken;
use tracing::error;
use uuid::Uuid;
use crate::exec::ExecToolCallOutput;
use crate::exec::SandboxType;
use crate::exec::StdoutStream;
use crate::exec::StreamOutput;
use crate::exec::execute_exec_env;
use crate::exec_env::create_env;
use crate::llmx::TurnContext;
use crate::parse_command::parse_command;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandBeginEvent;
use crate::protocol::ExecCommandEndEvent;
use crate::protocol::SandboxPolicy;
use crate::protocol::TaskStartedEvent;
use crate::sandboxing::ExecEnv;
use crate::state::TaskKind;
use crate::tools::format_exec_output_str;
use crate::user_shell_command::user_shell_command_record_item;
use super::SessionTask;
use super::SessionTaskContext;
#[derive(Clone)]
pub(crate) struct UserShellCommandTask {
command: String,
}
impl UserShellCommandTask {
pub(crate) fn new(command: String) -> Self {
Self { command }
}
}
#[async_trait]
impl SessionTask for UserShellCommandTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
async fn run(
self: Arc<Self>,
session: Arc<SessionTaskContext>,
turn_context: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
let event = EventMsg::TaskStarted(TaskStartedEvent {
model_context_window: turn_context.client.get_model_context_window(),
});
let session = session.clone_session();
session.send_event(turn_context.as_ref(), event).await;
// Execute the user's script under their default shell when known; this
// allows commands that use shell features (pipes, &&, redirects, etc.).
// We do not source rc files or otherwise reformat the script.
let shell_invocation = match session.user_shell() {
crate::shell::Shell::Zsh(zsh) => vec![
zsh.shell_path.clone(),
"-lc".to_string(),
self.command.clone(),
],
crate::shell::Shell::Bash(bash) => vec![
bash.shell_path.clone(),
"-lc".to_string(),
self.command.clone(),
],
crate::shell::Shell::PowerShell(ps) => vec![
ps.exe.clone(),
"-NoProfile".to_string(),
"-Command".to_string(),
self.command.clone(),
],
crate::shell::Shell::Unknown => {
shlex::split(&self.command).unwrap_or_else(|| vec![self.command.clone()])
}
};
let call_id = Uuid::new_v4().to_string();
let raw_command = self.command.clone();
let parsed_cmd = parse_command(&shell_invocation);
session
.send_event(
turn_context.as_ref(),
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
call_id: call_id.clone(),
command: shell_invocation.clone(),
cwd: turn_context.cwd.clone(),
parsed_cmd,
is_user_shell_command: true,
}),
)
.await;
let exec_env = ExecEnv {
command: shell_invocation,
cwd: turn_context.cwd.clone(),
env: create_env(&turn_context.shell_environment_policy),
timeout_ms: None,
sandbox: SandboxType::None,
with_escalated_permissions: None,
justification: None,
arg0: None,
};
let stdout_stream = Some(StdoutStream {
sub_id: turn_context.sub_id.clone(),
call_id: call_id.clone(),
tx_event: session.get_tx_event(),
});
let sandbox_policy = SandboxPolicy::DangerFullAccess;
let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream)
.or_cancel(&cancellation_token)
.await;
match exec_result {
Err(CancelErr::Cancelled) => {
let aborted_message = "command aborted by user".to_string();
let exec_output = ExecToolCallOutput {
exit_code: -1,
stdout: StreamOutput::new(String::new()),
stderr: StreamOutput::new(aborted_message.clone()),
aggregated_output: StreamOutput::new(aborted_message.clone()),
duration: Duration::ZERO,
timed_out: false,
};
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
session
.record_conversation_items(turn_context.as_ref(), &output_items)
.await;
session
.send_event(
turn_context.as_ref(),
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id,
stdout: String::new(),
stderr: aborted_message.clone(),
aggregated_output: aborted_message.clone(),
exit_code: -1,
duration: Duration::ZERO,
formatted_output: aborted_message,
}),
)
.await;
}
Ok(Ok(output)) => {
session
.send_event(
turn_context.as_ref(),
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id: call_id.clone(),
stdout: output.stdout.text.clone(),
stderr: output.stderr.text.clone(),
aggregated_output: output.aggregated_output.text.clone(),
exit_code: output.exit_code,
duration: output.duration,
formatted_output: format_exec_output_str(&output),
}),
)
.await;
let output_items = [user_shell_command_record_item(&raw_command, &output)];
session
.record_conversation_items(turn_context.as_ref(), &output_items)
.await;
}
Ok(Err(err)) => {
error!("user shell command failed: {err:?}");
let message = format!("execution error: {err:?}");
let exec_output = ExecToolCallOutput {
exit_code: -1,
stdout: StreamOutput::new(String::new()),
stderr: StreamOutput::new(message.clone()),
aggregated_output: StreamOutput::new(message.clone()),
duration: Duration::ZERO,
timed_out: false,
};
session
.send_event(
turn_context.as_ref(),
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id,
stdout: exec_output.stdout.text.clone(),
stderr: exec_output.stderr.text.clone(),
aggregated_output: exec_output.aggregated_output.text.clone(),
exit_code: exec_output.exit_code,
duration: exec_output.duration,
formatted_output: format_exec_output_str(&exec_output),
}),
)
.await;
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
session
.record_conversation_items(turn_context.as_ref(), &output_items)
.await;
}
}
None
}
}

View File

@@ -0,0 +1,72 @@
use std::sync::OnceLock;
static TERMINAL: OnceLock<String> = OnceLock::new();
pub fn user_agent() -> String {
TERMINAL.get_or_init(detect_terminal).to_string()
}
/// Sanitize a header value to be used in a User-Agent string.
///
/// This function replaces any characters that are not allowed in a User-Agent string with an underscore.
///
/// # Arguments
///
/// * `value` - The value to sanitize.
fn is_valid_header_value_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.' || c == '/'
}
fn sanitize_header_value(value: String) -> String {
value.replace(|c| !is_valid_header_value_char(c), "_")
}
fn detect_terminal() -> String {
sanitize_header_value(
if let Ok(tp) = std::env::var("TERM_PROGRAM")
&& !tp.trim().is_empty()
{
let ver = std::env::var("TERM_PROGRAM_VERSION").ok();
match ver {
Some(v) if !v.trim().is_empty() => format!("{tp}/{v}"),
_ => tp,
}
} else if let Ok(v) = std::env::var("WEZTERM_VERSION") {
if !v.trim().is_empty() {
format!("WezTerm/{v}")
} else {
"WezTerm".to_string()
}
} else if std::env::var("KITTY_WINDOW_ID").is_ok()
|| std::env::var("TERM")
.map(|t| t.contains("kitty"))
.unwrap_or(false)
{
"kitty".to_string()
} else if std::env::var("ALACRITTY_SOCKET").is_ok()
|| std::env::var("TERM")
.map(|t| t == "alacritty")
.unwrap_or(false)
{
"Alacritty".to_string()
} else if let Ok(v) = std::env::var("KONSOLE_VERSION") {
if !v.trim().is_empty() {
format!("Konsole/{v}")
} else {
"Konsole".to_string()
}
} else if std::env::var("GNOME_TERMINAL_SCREEN").is_ok() {
return "gnome-terminal".to_string();
} else if let Ok(v) = std::env::var("VTE_VERSION") {
if !v.trim().is_empty() {
format!("VTE/{v}")
} else {
"VTE".to_string()
}
} else if std::env::var("WT_SESSION").is_ok() {
return "WindowsTerminal".to_string();
} else {
std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string())
},
)
}

View File

@@ -0,0 +1,195 @@
use base64::Engine;
use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
pub struct TokenData {
/// Flat info parsed from the JWT in auth.json.
#[serde(
deserialize_with = "deserialize_id_token",
serialize_with = "serialize_id_token"
)]
pub id_token: IdTokenInfo,
/// This is a JWT.
pub access_token: String,
pub refresh_token: String,
pub account_id: Option<String>,
}
/// Flat subset of useful claims in id_token from auth.json.
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct IdTokenInfo {
pub email: Option<String>,
/// The ChatGPT subscription plan type
/// (e.g., "free", "plus", "pro", "business", "enterprise", "edu").
/// (Note: values may vary by backend.)
pub(crate) chatgpt_plan_type: Option<PlanType>,
/// Organization/workspace identifier associated with the token, if present.
pub chatgpt_account_id: Option<String>,
pub raw_jwt: String,
}
impl IdTokenInfo {
pub fn get_chatgpt_plan_type(&self) -> Option<String> {
self.chatgpt_plan_type.as_ref().map(|t| match t {
PlanType::Known(plan) => format!("{plan:?}"),
PlanType::Unknown(s) => s.clone(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub(crate) enum PlanType {
Known(KnownPlan),
Unknown(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum KnownPlan {
Free,
Plus,
Pro,
Team,
Business,
Enterprise,
Edu,
}
#[derive(Deserialize)]
struct IdClaims {
#[serde(default)]
email: Option<String>,
#[serde(rename = "https://api.openai.com/auth", default)]
auth: Option<AuthClaims>,
}
#[derive(Deserialize)]
struct AuthClaims {
#[serde(default)]
chatgpt_plan_type: Option<PlanType>,
#[serde(default)]
chatgpt_account_id: Option<String>,
}
#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
InvalidFormat,
#[error(transparent)]
Base64(#[from] base64::DecodeError),
#[error(transparent)]
Json(#[from] serde_json::Error),
}
pub fn parse_id_token(id_token: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
// JWT format: header.payload.signature
let mut parts = id_token.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
(Some(h), Some(p), Some(s)) if !h.is_empty() && !p.is_empty() && !s.is_empty() => (h, p, s),
_ => return Err(IdTokenInfoError::InvalidFormat),
};
let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
match claims.auth {
Some(auth) => Ok(IdTokenInfo {
email: claims.email,
raw_jwt: id_token.to_string(),
chatgpt_plan_type: auth.chatgpt_plan_type,
chatgpt_account_id: auth.chatgpt_account_id,
}),
None => Ok(IdTokenInfo {
email: claims.email,
raw_jwt: id_token.to_string(),
chatgpt_plan_type: None,
chatgpt_account_id: None,
}),
}
}
fn deserialize_id_token<'de, D>(deserializer: D) -> Result<IdTokenInfo, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
parse_id_token(&s).map_err(serde::de::Error::custom)
}
fn serialize_id_token<S>(id_token: &IdTokenInfo, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_str(&id_token.raw_jwt)
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Serialize;
#[test]
fn id_token_info_parses_email_and_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
assert_eq!(info.get_chatgpt_plan_type().as_deref(), Some("Pro"));
}
#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });
fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let info = parse_id_token(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}
}

View File

@@ -0,0 +1,268 @@
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::tools::TELEMETRY_PREVIEW_MAX_BYTES;
use crate::tools::TELEMETRY_PREVIEW_MAX_LINES;
use crate::tools::TELEMETRY_PREVIEW_TRUNCATION_NOTICE;
use crate::turn_diff_tracker::TurnDiffTracker;
use llmx_otel::otel_event_manager::OtelEventManager;
use llmx_protocol::models::FunctionCallOutputContentItem;
use llmx_protocol::models::FunctionCallOutputPayload;
use llmx_protocol::models::ResponseInputItem;
use llmx_protocol::models::ShellToolCallParams;
use llmx_protocol::protocol::FileChange;
use llmx_utils_string::take_bytes_at_char_boundary;
use mcp_types::CallToolResult;
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Mutex;
pub type SharedTurnDiffTracker = Arc<Mutex<TurnDiffTracker>>;
#[derive(Clone)]
pub struct ToolInvocation {
pub session: Arc<Session>,
pub turn: Arc<TurnContext>,
pub tracker: SharedTurnDiffTracker,
pub call_id: String,
pub tool_name: String,
pub payload: ToolPayload,
}
#[derive(Clone)]
pub enum ToolPayload {
Function {
arguments: String,
},
Custom {
input: String,
},
LocalShell {
params: ShellToolCallParams,
},
UnifiedExec {
arguments: String,
},
Mcp {
server: String,
tool: String,
raw_arguments: String,
},
}
impl ToolPayload {
pub fn log_payload(&self) -> Cow<'_, str> {
match self {
ToolPayload::Function { arguments } => Cow::Borrowed(arguments),
ToolPayload::Custom { input } => Cow::Borrowed(input),
ToolPayload::LocalShell { params } => Cow::Owned(params.command.join(" ")),
ToolPayload::UnifiedExec { arguments } => Cow::Borrowed(arguments),
ToolPayload::Mcp { raw_arguments, .. } => Cow::Borrowed(raw_arguments),
}
}
}
#[derive(Clone)]
pub enum ToolOutput {
Function {
// Plain text representation of the tool output.
content: String,
// Some tool calls such as MCP calls may return structured content that can get parsed into an array of polymorphic content items.
content_items: Option<Vec<FunctionCallOutputContentItem>>,
success: Option<bool>,
},
Mcp {
result: Result<CallToolResult, String>,
},
}
impl ToolOutput {
pub fn log_preview(&self) -> String {
match self {
ToolOutput::Function { content, .. } => telemetry_preview(content),
ToolOutput::Mcp { result } => format!("{result:?}"),
}
}
pub fn success_for_logging(&self) -> bool {
match self {
ToolOutput::Function { success, .. } => success.unwrap_or(true),
ToolOutput::Mcp { result } => result.is_ok(),
}
}
pub fn into_response(self, call_id: &str, payload: &ToolPayload) -> ResponseInputItem {
match self {
ToolOutput::Function {
content,
content_items,
success,
} => {
if matches!(payload, ToolPayload::Custom { .. }) {
ResponseInputItem::CustomToolCallOutput {
call_id: call_id.to_string(),
output: content,
}
} else {
ResponseInputItem::FunctionCallOutput {
call_id: call_id.to_string(),
output: FunctionCallOutputPayload {
content,
content_items,
success,
},
}
}
}
ToolOutput::Mcp { result } => ResponseInputItem::McpToolCallOutput {
call_id: call_id.to_string(),
result,
},
}
}
}
fn telemetry_preview(content: &str) -> String {
let truncated_slice = take_bytes_at_char_boundary(content, TELEMETRY_PREVIEW_MAX_BYTES);
let truncated_by_bytes = truncated_slice.len() < content.len();
let mut preview = String::new();
let mut lines_iter = truncated_slice.lines();
for idx in 0..TELEMETRY_PREVIEW_MAX_LINES {
match lines_iter.next() {
Some(line) => {
if idx > 0 {
preview.push('\n');
}
preview.push_str(line);
}
None => break,
}
}
let truncated_by_lines = lines_iter.next().is_some();
if !truncated_by_bytes && !truncated_by_lines {
return content.to_string();
}
if preview.len() < truncated_slice.len()
&& truncated_slice
.as_bytes()
.get(preview.len())
.is_some_and(|byte| *byte == b'\n')
{
preview.push('\n');
}
if !preview.is_empty() && !preview.ends_with('\n') {
preview.push('\n');
}
preview.push_str(TELEMETRY_PREVIEW_TRUNCATION_NOTICE);
preview
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn custom_tool_calls_should_roundtrip_as_custom_outputs() {
let payload = ToolPayload::Custom {
input: "patch".to_string(),
};
let response = ToolOutput::Function {
content: "patched".to_string(),
content_items: None,
success: Some(true),
}
.into_response("call-42", &payload);
match response {
ResponseInputItem::CustomToolCallOutput { call_id, output } => {
assert_eq!(call_id, "call-42");
assert_eq!(output, "patched");
}
other => panic!("expected CustomToolCallOutput, got {other:?}"),
}
}
#[test]
fn function_payloads_remain_function_outputs() {
let payload = ToolPayload::Function {
arguments: "{}".to_string(),
};
let response = ToolOutput::Function {
content: "ok".to_string(),
content_items: None,
success: Some(true),
}
.into_response("fn-1", &payload);
match response {
ResponseInputItem::FunctionCallOutput { call_id, output } => {
assert_eq!(call_id, "fn-1");
assert_eq!(output.content, "ok");
assert!(output.content_items.is_none());
assert_eq!(output.success, Some(true));
}
other => panic!("expected FunctionCallOutput, got {other:?}"),
}
}
#[test]
fn telemetry_preview_returns_original_within_limits() {
let content = "short output";
assert_eq!(telemetry_preview(content), content);
}
#[test]
fn telemetry_preview_truncates_by_bytes() {
let content = "x".repeat(TELEMETRY_PREVIEW_MAX_BYTES + 8);
let preview = telemetry_preview(&content);
assert!(preview.contains(TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
assert!(
preview.len()
<= TELEMETRY_PREVIEW_MAX_BYTES + TELEMETRY_PREVIEW_TRUNCATION_NOTICE.len() + 1
);
}
#[test]
fn telemetry_preview_truncates_by_lines() {
let content = (0..(TELEMETRY_PREVIEW_MAX_LINES + 5))
.map(|idx| format!("line {idx}"))
.collect::<Vec<_>>()
.join("\n");
let preview = telemetry_preview(&content);
let lines: Vec<&str> = preview.lines().collect();
assert!(lines.len() <= TELEMETRY_PREVIEW_MAX_LINES + 1);
assert_eq!(lines.last(), Some(&TELEMETRY_PREVIEW_TRUNCATION_NOTICE));
}
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub(crate) struct ExecCommandContext {
pub(crate) turn: Arc<TurnContext>,
pub(crate) call_id: String,
pub(crate) command_for_display: Vec<String>,
pub(crate) cwd: PathBuf,
pub(crate) apply_patch: Option<ApplyPatchCommandContext>,
pub(crate) tool_name: String,
pub(crate) otel_event_manager: OtelEventManager,
// TODO(abhisek-oai): Find a better way to track this.
// https://github.com/valknar/llmx/pull/2471/files#r2470352242
pub(crate) is_user_shell_command: bool,
}
#[derive(Clone, Debug)]
#[allow(dead_code)]
pub(crate) struct ApplyPatchCommandContext {
pub(crate) user_explicitly_approved_this_action: bool,
pub(crate) changes: HashMap<PathBuf, FileChange>,
}

View File

@@ -0,0 +1,369 @@
use crate::error::LlmxErr;
use crate::error::SandboxErr;
use crate::exec::ExecToolCallOutput;
use crate::function_tool::FunctionCallError;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::parse_command::parse_command;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandBeginEvent;
use crate::protocol::ExecCommandEndEvent;
use crate::protocol::FileChange;
use crate::protocol::PatchApplyBeginEvent;
use crate::protocol::PatchApplyEndEvent;
use crate::protocol::TurnDiffEvent;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::sandboxing::ToolError;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use super::format_exec_output_str;
#[derive(Clone, Copy)]
pub(crate) struct ToolEventCtx<'a> {
pub session: &'a Session,
pub turn: &'a TurnContext,
pub call_id: &'a str,
pub turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
}
impl<'a> ToolEventCtx<'a> {
pub fn new(
session: &'a Session,
turn: &'a TurnContext,
call_id: &'a str,
turn_diff_tracker: Option<&'a SharedTurnDiffTracker>,
) -> Self {
Self {
session,
turn,
call_id,
turn_diff_tracker,
}
}
}
pub(crate) enum ToolEventStage {
Begin,
Success(ExecToolCallOutput),
Failure(ToolEventFailure),
}
pub(crate) enum ToolEventFailure {
Output(ExecToolCallOutput),
Message(String),
}
pub(crate) async fn emit_exec_command_begin(
ctx: ToolEventCtx<'_>,
command: &[String],
cwd: &Path,
is_user_shell_command: bool,
) {
ctx.session
.send_event(
ctx.turn,
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
call_id: ctx.call_id.to_string(),
command: command.to_vec(),
cwd: cwd.to_path_buf(),
parsed_cmd: parse_command(command),
is_user_shell_command,
}),
)
.await;
}
// Concrete, allocation-free emitter: avoid trait objects and boxed futures.
pub(crate) enum ToolEmitter {
Shell {
command: Vec<String>,
cwd: PathBuf,
is_user_shell_command: bool,
},
ApplyPatch {
changes: HashMap<PathBuf, FileChange>,
auto_approved: bool,
},
UnifiedExec {
command: String,
cwd: PathBuf,
// True for `exec_command` and false for `write_stdin`.
#[allow(dead_code)]
is_startup_command: bool,
},
}
impl ToolEmitter {
pub fn shell(command: Vec<String>, cwd: PathBuf, is_user_shell_command: bool) -> Self {
Self::Shell {
command,
cwd,
is_user_shell_command,
}
}
pub fn apply_patch(changes: HashMap<PathBuf, FileChange>, auto_approved: bool) -> Self {
Self::ApplyPatch {
changes,
auto_approved,
}
}
pub fn unified_exec(command: String, cwd: PathBuf, is_startup_command: bool) -> Self {
Self::UnifiedExec {
command,
cwd,
is_startup_command,
}
}
pub async fn emit(&self, ctx: ToolEventCtx<'_>, stage: ToolEventStage) {
match (self, stage) {
(
Self::Shell {
command,
cwd,
is_user_shell_command,
},
ToolEventStage::Begin,
) => {
emit_exec_command_begin(ctx, command, cwd.as_path(), *is_user_shell_command).await;
}
(Self::Shell { .. }, ToolEventStage::Success(output)) => {
emit_exec_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.aggregated_output.text.clone(),
output.exit_code,
output.duration,
format_exec_output_str(&output),
)
.await;
}
(Self::Shell { .. }, ToolEventStage::Failure(ToolEventFailure::Output(output))) => {
emit_exec_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.aggregated_output.text.clone(),
output.exit_code,
output.duration,
format_exec_output_str(&output),
)
.await;
}
(Self::Shell { .. }, ToolEventStage::Failure(ToolEventFailure::Message(message))) => {
emit_exec_end(
ctx,
String::new(),
(*message).to_string(),
(*message).to_string(),
-1,
Duration::ZERO,
message.clone(),
)
.await;
}
(
Self::ApplyPatch {
changes,
auto_approved,
},
ToolEventStage::Begin,
) => {
if let Some(tracker) = ctx.turn_diff_tracker {
let mut guard = tracker.lock().await;
guard.on_patch_begin(changes);
}
ctx.session
.send_event(
ctx.turn,
EventMsg::PatchApplyBegin(PatchApplyBeginEvent {
call_id: ctx.call_id.to_string(),
auto_approved: *auto_approved,
changes: changes.clone(),
}),
)
.await;
}
(Self::ApplyPatch { .. }, ToolEventStage::Success(output)) => {
emit_patch_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.exit_code == 0,
)
.await;
}
(
Self::ApplyPatch { .. },
ToolEventStage::Failure(ToolEventFailure::Output(output)),
) => {
emit_patch_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.exit_code == 0,
)
.await;
}
(
Self::ApplyPatch { .. },
ToolEventStage::Failure(ToolEventFailure::Message(message)),
) => {
emit_patch_end(ctx, String::new(), (*message).to_string(), false).await;
}
(Self::UnifiedExec { command, cwd, .. }, ToolEventStage::Begin) => {
emit_exec_command_begin(ctx, &[command.to_string()], cwd.as_path(), false).await;
}
(Self::UnifiedExec { .. }, ToolEventStage::Success(output)) => {
emit_exec_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.aggregated_output.text.clone(),
output.exit_code,
output.duration,
format_exec_output_str(&output),
)
.await;
}
(
Self::UnifiedExec { .. },
ToolEventStage::Failure(ToolEventFailure::Output(output)),
) => {
emit_exec_end(
ctx,
output.stdout.text.clone(),
output.stderr.text.clone(),
output.aggregated_output.text.clone(),
output.exit_code,
output.duration,
format_exec_output_str(&output),
)
.await;
}
(
Self::UnifiedExec { .. },
ToolEventStage::Failure(ToolEventFailure::Message(message)),
) => {
emit_exec_end(
ctx,
String::new(),
(*message).to_string(),
(*message).to_string(),
-1,
Duration::ZERO,
message.clone(),
)
.await;
}
}
}
pub async fn begin(&self, ctx: ToolEventCtx<'_>) {
self.emit(ctx, ToolEventStage::Begin).await;
}
pub async fn finish(
&self,
ctx: ToolEventCtx<'_>,
out: Result<ExecToolCallOutput, ToolError>,
) -> Result<String, FunctionCallError> {
let (event, result) = match out {
Ok(output) => {
let content = super::format_exec_output_for_model(&output);
let exit_code = output.exit_code;
let event = ToolEventStage::Success(output);
let result = if exit_code == 0 {
Ok(content)
} else {
Err(FunctionCallError::RespondToModel(content))
};
(event, result)
}
Err(ToolError::Llmx(LlmxErr::Sandbox(SandboxErr::Timeout { output })))
| Err(ToolError::Llmx(LlmxErr::Sandbox(SandboxErr::Denied { output }))) => {
let response = super::format_exec_output_for_model(&output);
let event = ToolEventStage::Failure(ToolEventFailure::Output(*output));
let result = Err(FunctionCallError::RespondToModel(response));
(event, result)
}
Err(ToolError::Llmx(err)) => {
let message = format!("execution error: {err:?}");
let event = ToolEventStage::Failure(ToolEventFailure::Message(message.clone()));
let result = Err(FunctionCallError::RespondToModel(message));
(event, result)
}
Err(ToolError::Rejected(msg)) => {
// Normalize common rejection messages for exec tools so tests and
// users see a clear, consistent phrase.
let normalized = if msg == "rejected by user" {
"exec command rejected by user".to_string()
} else {
msg
};
let event = ToolEventStage::Failure(ToolEventFailure::Message(normalized.clone()));
let result = Err(FunctionCallError::RespondToModel(normalized));
(event, result)
}
};
self.emit(ctx, event).await;
result
}
}
async fn emit_exec_end(
ctx: ToolEventCtx<'_>,
stdout: String,
stderr: String,
aggregated_output: String,
exit_code: i32,
duration: Duration,
formatted_output: String,
) {
ctx.session
.send_event(
ctx.turn,
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
call_id: ctx.call_id.to_string(),
stdout,
stderr,
aggregated_output,
exit_code,
duration,
formatted_output,
}),
)
.await;
}
async fn emit_patch_end(ctx: ToolEventCtx<'_>, stdout: String, stderr: String, success: bool) {
ctx.session
.send_event(
ctx.turn,
EventMsg::PatchApplyEnd(PatchApplyEndEvent {
call_id: ctx.call_id.to_string(),
stdout,
stderr,
success,
}),
)
.await;
if let Some(tracker) = ctx.turn_diff_tracker {
let unified_diff = {
let mut guard = tracker.lock().await;
guard.get_unified_diff()
};
if let Ok(Some(unified_diff)) = unified_diff {
ctx.session
.send_event(ctx.turn, EventMsg::TurnDiff(TurnDiffEvent { unified_diff }))
.await;
}
}
}

View File

@@ -0,0 +1,265 @@
use std::collections::BTreeMap;
use crate::apply_patch;
use crate::apply_patch::InternalApplyPatchInvocation;
use crate::apply_patch::convert_apply_patch_to_protocol;
use crate::client_common::tools::FreeformTool;
use crate::client_common::tools::FreeformToolFormat;
use crate::client_common::tools::ResponsesApiTool;
use crate::client_common::tools::ToolSpec;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::orchestrator::ToolOrchestrator;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
use crate::tools::sandboxing::ToolCtx;
use crate::tools::spec::ApplyPatchToolArgs;
use crate::tools::spec::JsonSchema;
use async_trait::async_trait;
use serde::Deserialize;
use serde::Serialize;
pub struct ApplyPatchHandler;
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
#[async_trait]
impl ToolHandler for ApplyPatchHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
payload,
ToolPayload::Function { .. } | ToolPayload::Custom { .. }
)
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
tracker,
call_id,
tool_name,
payload,
} = invocation;
let patch_input = match payload {
ToolPayload::Function { arguments } => {
let args: ApplyPatchToolArgs = serde_json::from_str(&arguments).map_err(|e| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {e:?}"
))
})?;
args.input
}
ToolPayload::Custom { input } => input,
_ => {
return Err(FunctionCallError::RespondToModel(
"apply_patch handler received unsupported payload".to_string(),
));
}
};
// Re-parse and verify the patch so we can compute changes and approval.
// Avoid building temporary ExecParams/command vectors; derive directly from inputs.
let cwd = turn.cwd.clone();
let command = vec!["apply_patch".to_string(), patch_input.clone()];
match llmx_apply_patch::maybe_parse_apply_patch_verified(&command, &cwd) {
llmx_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(session.as_ref(), turn.as_ref(), &call_id, changes)
.await
{
InternalApplyPatchInvocation::Output(item) => {
let content = item?;
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
InternalApplyPatchInvocation::DelegateToExec(apply) => {
let emitter = ToolEmitter::apply_patch(
convert_apply_patch_to_protocol(&apply.action),
!apply.user_explicitly_approved_this_action,
);
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
&call_id,
Some(&tracker),
);
emitter.begin(event_ctx).await;
let req = ApplyPatchRequest {
patch: apply.action.patch.clone(),
cwd: apply.action.cwd.clone(),
timeout_ms: None,
user_explicitly_approved: apply.user_explicitly_approved_this_action,
llmx_exe: turn.llmx_linux_sandbox_exe.clone(),
};
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
call_id: call_id.clone(),
tool_name: tool_name.to_string(),
};
let out = orchestrator
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
.await;
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
&call_id,
Some(&tracker),
);
let content = emitter.finish(event_ctx, out).await?;
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
}
}
llmx_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
Err(FunctionCallError::RespondToModel(format!(
"apply_patch verification failed: {parse_error}"
)))
}
llmx_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => {
tracing::trace!("Failed to parse apply_patch input, {error:?}");
Err(FunctionCallError::RespondToModel(
"apply_patch handler received invalid patch input".to_string(),
))
}
llmx_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => {
Err(FunctionCallError::RespondToModel(
"apply_patch handler received non-apply_patch input".to_string(),
))
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ApplyPatchToolType {
Freeform,
Function,
}
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
pub(crate) fn create_apply_patch_freeform_tool() -> ToolSpec {
ToolSpec::Freeform(FreeformTool {
name: "apply_patch".to_string(),
description: "Use the `apply_patch` tool to edit files. This is a FREEFORM tool, so do not wrap the patch in JSON.".to_string(),
format: FreeformToolFormat {
r#type: "grammar".to_string(),
syntax: "lark".to_string(),
definition: APPLY_PATCH_LARK_GRAMMAR.to_string(),
},
})
}
/// Returns a json tool that can be used to edit files. Should only be used with gpt-oss models
pub(crate) fn create_apply_patch_json_tool() -> ToolSpec {
let mut properties = BTreeMap::new();
properties.insert(
"input".to_string(),
JsonSchema::String {
description: Some(r#"The entire contents of the apply_patch command"#.to_string()),
},
);
ToolSpec::Function(ResponsesApiTool {
name: "apply_patch".to_string(),
description: r#"Use the `apply_patch` tool to edit files.
Your patch language is a strippeddown, fileoriented diff format designed to be easy to parse and safe to apply. You can think of it as a highlevel envelope:
*** Begin Patch
[ one or more file sections ]
*** End Patch
Within that envelope, you get a sequence of file operations.
You MUST include a header to specify the action you are taking.
Each operation starts with one of three headers:
*** Add File: <path> - create a new file. Every following line is a + line (the initial contents).
*** Delete File: <path> - remove an existing file. Nothing follows.
*** Update File: <path> - patch an existing file in place (optionally with a rename).
May be immediately followed by *** Move to: <new path> if you want to rename the file.
Then one or more “hunks”, each introduced by @@ (optionally followed by a hunk header).
Within a hunk each line starts with:
For instructions on [context_before] and [context_after]:
- By default, show 3 lines of code immediately above and 3 lines immediately below each change. If a change is within 3 lines of a previous change, do NOT duplicate the first changes [context_after] lines in the second changes [context_before] lines.
- If 3 lines of context is insufficient to uniquely identify the snippet of code within the file, use the @@ operator to indicate the class or function to which the snippet belongs. For instance, we might have:
@@ class BaseClass
[3 lines of pre-context]
- [old_code]
+ [new_code]
[3 lines of post-context]
- If a code block is repeated so many times in a class or function such that even a single `@@` statement and 3 lines of context cannot uniquely identify the snippet of code, you can use multiple `@@` statements to jump to the right context. For instance:
@@ class BaseClass
@@ def method():
[3 lines of pre-context]
- [old_code]
+ [new_code]
[3 lines of post-context]
The full grammar definition is below:
Patch := Begin { FileOp } End
Begin := "*** Begin Patch" NEWLINE
End := "*** End Patch" NEWLINE
FileOp := AddFile | DeleteFile | UpdateFile
AddFile := "*** Add File: " path NEWLINE { "+" line NEWLINE }
DeleteFile := "*** Delete File: " path NEWLINE
UpdateFile := "*** Update File: " path NEWLINE [ MoveTo ] { Hunk }
MoveTo := "*** Move to: " newPath NEWLINE
Hunk := "@@" [ header ] NEWLINE { HunkLine } [ "*** End of File" NEWLINE ]
HunkLine := (" " | "-" | "+") text NEWLINE
A full patch can combine several operations:
*** Begin Patch
*** Add File: hello.txt
+Hello world
*** Update File: src/app.py
*** Move to: src/main.py
@@ def greet():
-print("Hi")
+print("Hello, world!")
*** Delete File: obsolete.txt
*** End Patch
It is important to remember:
- You must include a header with your intended action (Add/Delete/Update)
- You must prefix new lines with `+` even when creating a new file
- File references can only be relative, NEVER ABSOLUTE.
"#
.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["input".to_string()]),
additional_properties: Some(false.into()),
},
})
}

View File

@@ -0,0 +1,274 @@
use std::path::Path;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::process::Command;
use tokio::time::timeout;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct GrepFilesHandler;
const DEFAULT_LIMIT: usize = 100;
const MAX_LIMIT: usize = 2000;
const COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
fn default_limit() -> usize {
DEFAULT_LIMIT
}
#[derive(Deserialize)]
struct GrepFilesArgs {
pattern: String,
#[serde(default)]
include: Option<String>,
#[serde(default)]
path: Option<String>,
#[serde(default = "default_limit")]
limit: usize,
}
#[async_trait]
impl ToolHandler for GrepFilesHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, turn, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"grep_files handler received unsupported payload".to_string(),
));
}
};
let args: GrepFilesArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
let pattern = args.pattern.trim();
if pattern.is_empty() {
return Err(FunctionCallError::RespondToModel(
"pattern must not be empty".to_string(),
));
}
if args.limit == 0 {
return Err(FunctionCallError::RespondToModel(
"limit must be greater than zero".to_string(),
));
}
let limit = args.limit.min(MAX_LIMIT);
let search_path = turn.resolve_path(args.path.clone());
verify_path_exists(&search_path).await?;
let include = args.include.as_deref().map(str::trim).and_then(|val| {
if val.is_empty() {
None
} else {
Some(val.to_string())
}
});
let search_results =
run_rg_search(pattern, include.as_deref(), &search_path, limit, &turn.cwd).await?;
if search_results.is_empty() {
Ok(ToolOutput::Function {
content: "No matches found.".to_string(),
content_items: None,
success: Some(false),
})
} else {
Ok(ToolOutput::Function {
content: search_results.join("\n"),
content_items: None,
success: Some(true),
})
}
}
}
async fn verify_path_exists(path: &Path) -> Result<(), FunctionCallError> {
tokio::fs::metadata(path).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("unable to access `{}`: {err}", path.display()))
})?;
Ok(())
}
async fn run_rg_search(
pattern: &str,
include: Option<&str>,
search_path: &Path,
limit: usize,
cwd: &Path,
) -> Result<Vec<String>, FunctionCallError> {
let mut command = Command::new("rg");
command
.current_dir(cwd)
.arg("--files-with-matches")
.arg("--sortr=modified")
.arg("--regexp")
.arg(pattern)
.arg("--no-messages");
if let Some(glob) = include {
command.arg("--glob").arg(glob);
}
command.arg("--").arg(search_path);
let output = timeout(COMMAND_TIMEOUT, command.output())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("rg timed out after 30 seconds".to_string())
})?
.map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to launch rg: {err}. Ensure ripgrep is installed and on PATH."
))
})?;
match output.status.code() {
Some(0) => Ok(parse_results(&output.stdout, limit)),
Some(1) => Ok(Vec::new()),
_ => {
let stderr = String::from_utf8_lossy(&output.stderr);
Err(FunctionCallError::RespondToModel(format!(
"rg failed: {stderr}"
)))
}
}
}
fn parse_results(stdout: &[u8], limit: usize) -> Vec<String> {
let mut results = Vec::new();
for line in stdout.split(|byte| *byte == b'\n') {
if line.is_empty() {
continue;
}
if let Ok(text) = std::str::from_utf8(line) {
if text.is_empty() {
continue;
}
results.push(text.to_string());
if results.len() == limit {
break;
}
}
}
results
}
#[cfg(test)]
mod tests {
use super::*;
use std::process::Command as StdCommand;
use tempfile::tempdir;
#[test]
fn parses_basic_results() {
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n";
let parsed = parse_results(stdout, 10);
assert_eq!(
parsed,
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
);
}
#[test]
fn parse_truncates_after_limit() {
let stdout = b"/tmp/file_a.rs\n/tmp/file_b.rs\n/tmp/file_c.rs\n";
let parsed = parse_results(stdout, 2);
assert_eq!(
parsed,
vec!["/tmp/file_a.rs".to_string(), "/tmp/file_b.rs".to_string()]
);
}
#[tokio::test]
async fn run_search_returns_results() -> anyhow::Result<()> {
if !rg_available() {
return Ok(());
}
let temp = tempdir().expect("create temp dir");
let dir = temp.path();
std::fs::write(dir.join("match_one.txt"), "alpha beta gamma").unwrap();
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
std::fs::write(dir.join("other.txt"), "omega").unwrap();
let results = run_rg_search("alpha", None, dir, 10, dir).await?;
assert_eq!(results.len(), 2);
assert!(results.iter().any(|path| path.ends_with("match_one.txt")));
assert!(results.iter().any(|path| path.ends_with("match_two.txt")));
Ok(())
}
#[tokio::test]
async fn run_search_with_glob_filter() -> anyhow::Result<()> {
if !rg_available() {
return Ok(());
}
let temp = tempdir().expect("create temp dir");
let dir = temp.path();
std::fs::write(dir.join("match_one.rs"), "alpha beta gamma").unwrap();
std::fs::write(dir.join("match_two.txt"), "alpha delta").unwrap();
let results = run_rg_search("alpha", Some("*.rs"), dir, 10, dir).await?;
assert_eq!(results.len(), 1);
assert!(results.iter().all(|path| path.ends_with("match_one.rs")));
Ok(())
}
#[tokio::test]
async fn run_search_respects_limit() -> anyhow::Result<()> {
if !rg_available() {
return Ok(());
}
let temp = tempdir().expect("create temp dir");
let dir = temp.path();
std::fs::write(dir.join("one.txt"), "alpha one").unwrap();
std::fs::write(dir.join("two.txt"), "alpha two").unwrap();
std::fs::write(dir.join("three.txt"), "alpha three").unwrap();
let results = run_rg_search("alpha", None, dir, 2, dir).await?;
assert_eq!(results.len(), 2);
Ok(())
}
#[tokio::test]
async fn run_search_handles_no_matches() -> anyhow::Result<()> {
if !rg_available() {
return Ok(());
}
let temp = tempdir().expect("create temp dir");
let dir = temp.path();
std::fs::write(dir.join("one.txt"), "omega").unwrap();
let results = run_rg_search("alpha", None, dir, 5, dir).await?;
assert!(results.is_empty());
Ok(())
}
fn rg_available() -> bool {
StdCommand::new("rg")
.arg("--version")
.output()
.map(|output| output.status.success())
.unwrap_or(false)
}
}

View File

@@ -0,0 +1,477 @@
use std::collections::VecDeque;
use std::ffi::OsStr;
use std::fs::FileType;
use std::path::Path;
use std::path::PathBuf;
use async_trait::async_trait;
use llmx_utils_string::take_bytes_at_char_boundary;
use serde::Deserialize;
use tokio::fs;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct ListDirHandler;
const MAX_ENTRY_LENGTH: usize = 500;
const INDENTATION_SPACES: usize = 2;
fn default_offset() -> usize {
1
}
fn default_limit() -> usize {
25
}
fn default_depth() -> usize {
2
}
#[derive(Deserialize)]
struct ListDirArgs {
dir_path: String,
#[serde(default = "default_offset")]
offset: usize,
#[serde(default = "default_limit")]
limit: usize,
#[serde(default = "default_depth")]
depth: usize,
}
#[async_trait]
impl ToolHandler for ListDirHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"list_dir handler received unsupported payload".to_string(),
));
}
};
let args: ListDirArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
let ListDirArgs {
dir_path,
offset,
limit,
depth,
} = args;
if offset == 0 {
return Err(FunctionCallError::RespondToModel(
"offset must be a 1-indexed entry number".to_string(),
));
}
if limit == 0 {
return Err(FunctionCallError::RespondToModel(
"limit must be greater than zero".to_string(),
));
}
if depth == 0 {
return Err(FunctionCallError::RespondToModel(
"depth must be greater than zero".to_string(),
));
}
let path = PathBuf::from(&dir_path);
if !path.is_absolute() {
return Err(FunctionCallError::RespondToModel(
"dir_path must be an absolute path".to_string(),
));
}
let entries = list_dir_slice(&path, offset, limit, depth).await?;
let mut output = Vec::with_capacity(entries.len() + 1);
output.push(format!("Absolute path: {}", path.display()));
output.extend(entries);
Ok(ToolOutput::Function {
content: output.join("\n"),
content_items: None,
success: Some(true),
})
}
}
async fn list_dir_slice(
path: &Path,
offset: usize,
limit: usize,
depth: usize,
) -> Result<Vec<String>, FunctionCallError> {
let mut entries = Vec::new();
collect_entries(path, Path::new(""), depth, &mut entries).await?;
if entries.is_empty() {
return Ok(Vec::new());
}
let start_index = offset - 1;
if start_index >= entries.len() {
return Err(FunctionCallError::RespondToModel(
"offset exceeds directory entry count".to_string(),
));
}
let remaining_entries = entries.len() - start_index;
let capped_limit = limit.min(remaining_entries);
let end_index = start_index + capped_limit;
let mut selected_entries = entries[start_index..end_index].to_vec();
selected_entries.sort_unstable_by(|a, b| a.name.cmp(&b.name));
let mut formatted = Vec::with_capacity(selected_entries.len());
for entry in &selected_entries {
formatted.push(format_entry_line(entry));
}
if end_index < entries.len() {
formatted.push(format!("More than {capped_limit} entries found"));
}
Ok(formatted)
}
async fn collect_entries(
dir_path: &Path,
relative_prefix: &Path,
depth: usize,
entries: &mut Vec<DirEntry>,
) -> Result<(), FunctionCallError> {
let mut queue = VecDeque::new();
queue.push_back((dir_path.to_path_buf(), relative_prefix.to_path_buf(), depth));
while let Some((current_dir, prefix, remaining_depth)) = queue.pop_front() {
let mut read_dir = fs::read_dir(&current_dir).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
})?;
let mut dir_entries = Vec::new();
while let Some(entry) = read_dir.next_entry().await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read directory: {err}"))
})? {
let file_type = entry.file_type().await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to inspect entry: {err}"))
})?;
let file_name = entry.file_name();
let relative_path = if prefix.as_os_str().is_empty() {
PathBuf::from(&file_name)
} else {
prefix.join(&file_name)
};
let display_name = format_entry_component(&file_name);
let display_depth = prefix.components().count();
let sort_key = format_entry_name(&relative_path);
let kind = DirEntryKind::from(&file_type);
dir_entries.push((
entry.path(),
relative_path,
kind,
DirEntry {
name: sort_key,
display_name,
depth: display_depth,
kind,
},
));
}
dir_entries.sort_unstable_by(|a, b| a.3.name.cmp(&b.3.name));
for (entry_path, relative_path, kind, dir_entry) in dir_entries {
if kind == DirEntryKind::Directory && remaining_depth > 1 {
queue.push_back((entry_path, relative_path, remaining_depth - 1));
}
entries.push(dir_entry);
}
}
Ok(())
}
fn format_entry_name(path: &Path) -> String {
let normalized = path.to_string_lossy().replace("\\", "/");
if normalized.len() > MAX_ENTRY_LENGTH {
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
} else {
normalized
}
}
fn format_entry_component(name: &OsStr) -> String {
let normalized = name.to_string_lossy();
if normalized.len() > MAX_ENTRY_LENGTH {
take_bytes_at_char_boundary(&normalized, MAX_ENTRY_LENGTH).to_string()
} else {
normalized.to_string()
}
}
fn format_entry_line(entry: &DirEntry) -> String {
let indent = " ".repeat(entry.depth * INDENTATION_SPACES);
let mut name = entry.display_name.clone();
match entry.kind {
DirEntryKind::Directory => name.push('/'),
DirEntryKind::Symlink => name.push('@'),
DirEntryKind::Other => name.push('?'),
DirEntryKind::File => {}
}
format!("{indent}{name}")
}
#[derive(Clone)]
struct DirEntry {
name: String,
display_name: String,
depth: usize,
kind: DirEntryKind,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum DirEntryKind {
Directory,
File,
Symlink,
Other,
}
impl From<&FileType> for DirEntryKind {
fn from(file_type: &FileType) -> Self {
if file_type.is_symlink() {
DirEntryKind::Symlink
} else if file_type.is_dir() {
DirEntryKind::Directory
} else if file_type.is_file() {
DirEntryKind::File
} else {
DirEntryKind::Other
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[tokio::test]
async fn lists_directory_entries() {
let temp = tempdir().expect("create tempdir");
let dir_path = temp.path();
let sub_dir = dir_path.join("nested");
tokio::fs::create_dir(&sub_dir)
.await
.expect("create sub dir");
let deeper_dir = sub_dir.join("deeper");
tokio::fs::create_dir(&deeper_dir)
.await
.expect("create deeper dir");
tokio::fs::write(dir_path.join("entry.txt"), b"content")
.await
.expect("write file");
tokio::fs::write(sub_dir.join("child.txt"), b"child")
.await
.expect("write child");
tokio::fs::write(deeper_dir.join("grandchild.txt"), b"grandchild")
.await
.expect("write grandchild");
#[cfg(unix)]
{
use std::os::unix::fs::symlink;
let link_path = dir_path.join("link");
symlink(dir_path.join("entry.txt"), &link_path).expect("create symlink");
}
let entries = list_dir_slice(dir_path, 1, 20, 3)
.await
.expect("list directory");
#[cfg(unix)]
let expected = vec![
"entry.txt".to_string(),
"link@".to_string(),
"nested/".to_string(),
" child.txt".to_string(),
" deeper/".to_string(),
" grandchild.txt".to_string(),
];
#[cfg(not(unix))]
let expected = vec![
"entry.txt".to_string(),
"nested/".to_string(),
" child.txt".to_string(),
" deeper/".to_string(),
" grandchild.txt".to_string(),
];
assert_eq!(entries, expected);
}
#[tokio::test]
async fn errors_when_offset_exceeds_entries() {
let temp = tempdir().expect("create tempdir");
let dir_path = temp.path();
tokio::fs::create_dir(dir_path.join("nested"))
.await
.expect("create sub dir");
let err = list_dir_slice(dir_path, 10, 1, 2)
.await
.expect_err("offset exceeds entries");
assert_eq!(
err,
FunctionCallError::RespondToModel("offset exceeds directory entry count".to_string())
);
}
#[tokio::test]
async fn respects_depth_parameter() {
let temp = tempdir().expect("create tempdir");
let dir_path = temp.path();
let nested = dir_path.join("nested");
let deeper = nested.join("deeper");
tokio::fs::create_dir(&nested).await.expect("create nested");
tokio::fs::create_dir(&deeper).await.expect("create deeper");
tokio::fs::write(dir_path.join("root.txt"), b"root")
.await
.expect("write root");
tokio::fs::write(nested.join("child.txt"), b"child")
.await
.expect("write nested");
tokio::fs::write(deeper.join("grandchild.txt"), b"deep")
.await
.expect("write deeper");
let entries_depth_one = list_dir_slice(dir_path, 1, 10, 1)
.await
.expect("list depth 1");
assert_eq!(
entries_depth_one,
vec!["nested/".to_string(), "root.txt".to_string(),]
);
let entries_depth_two = list_dir_slice(dir_path, 1, 20, 2)
.await
.expect("list depth 2");
assert_eq!(
entries_depth_two,
vec![
"nested/".to_string(),
" child.txt".to_string(),
" deeper/".to_string(),
"root.txt".to_string(),
]
);
let entries_depth_three = list_dir_slice(dir_path, 1, 30, 3)
.await
.expect("list depth 3");
assert_eq!(
entries_depth_three,
vec![
"nested/".to_string(),
" child.txt".to_string(),
" deeper/".to_string(),
" grandchild.txt".to_string(),
"root.txt".to_string(),
]
);
}
#[tokio::test]
async fn handles_large_limit_without_overflow() {
let temp = tempdir().expect("create tempdir");
let dir_path = temp.path();
tokio::fs::write(dir_path.join("alpha.txt"), b"alpha")
.await
.expect("write alpha");
tokio::fs::write(dir_path.join("beta.txt"), b"beta")
.await
.expect("write beta");
tokio::fs::write(dir_path.join("gamma.txt"), b"gamma")
.await
.expect("write gamma");
let entries = list_dir_slice(dir_path, 2, usize::MAX, 1)
.await
.expect("list without overflow");
assert_eq!(
entries,
vec!["beta.txt".to_string(), "gamma.txt".to_string(),]
);
}
#[tokio::test]
async fn indicates_truncated_results() {
let temp = tempdir().expect("create tempdir");
let dir_path = temp.path();
for idx in 0..40 {
let file = dir_path.join(format!("file_{idx:02}.txt"));
tokio::fs::write(file, b"content")
.await
.expect("write file");
}
let entries = list_dir_slice(dir_path, 1, 25, 1)
.await
.expect("list directory");
assert_eq!(entries.len(), 26);
assert_eq!(
entries.last(),
Some(&"More than 25 entries found".to_string())
);
}
#[tokio::test]
async fn bfs_truncation() -> anyhow::Result<()> {
let temp = tempdir()?;
let dir_path = temp.path();
let nested = dir_path.join("nested");
let deeper = nested.join("deeper");
tokio::fs::create_dir(&nested).await?;
tokio::fs::create_dir(&deeper).await?;
tokio::fs::write(dir_path.join("root.txt"), b"root").await?;
tokio::fs::write(nested.join("child.txt"), b"child").await?;
tokio::fs::write(deeper.join("grandchild.txt"), b"deep").await?;
let entries_depth_three = list_dir_slice(dir_path, 1, 3, 3).await?;
assert_eq!(
entries_depth_three,
vec![
"nested/".to_string(),
" child.txt".to_string(),
"root.txt".to_string(),
"More than 3 entries found".to_string()
]
);
Ok(())
}
}

View File

@@ -0,0 +1,75 @@
use async_trait::async_trait;
use crate::function_tool::FunctionCallError;
use crate::mcp_tool_call::handle_mcp_tool_call;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct McpHandler;
#[async_trait]
impl ToolHandler for McpHandler {
fn kind(&self) -> ToolKind {
ToolKind::Mcp
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
call_id,
payload,
..
} = invocation;
let payload = match payload {
ToolPayload::Mcp {
server,
tool,
raw_arguments,
} => (server, tool, raw_arguments),
_ => {
return Err(FunctionCallError::RespondToModel(
"mcp handler received unsupported payload".to_string(),
));
}
};
let (server, tool, raw_arguments) = payload;
let arguments_str = raw_arguments;
let response = handle_mcp_tool_call(
session.as_ref(),
turn.as_ref(),
call_id.clone(),
server,
tool,
arguments_str,
)
.await;
match response {
llmx_protocol::models::ResponseInputItem::McpToolCallOutput { result, .. } => {
Ok(ToolOutput::Mcp { result })
}
llmx_protocol::models::ResponseInputItem::FunctionCallOutput { output, .. } => {
let llmx_protocol::models::FunctionCallOutputPayload {
content,
content_items,
success,
} = output;
Ok(ToolOutput::Function {
content,
content_items,
success,
})
}
_ => Err(FunctionCallError::RespondToModel(
"mcp handler received unexpected response variant".to_string(),
)),
}
}
}

View File

@@ -0,0 +1,789 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use async_trait::async_trait;
use mcp_types::CallToolResult;
use mcp_types::ContentBlock;
use mcp_types::ListResourceTemplatesRequestParams;
use mcp_types::ListResourceTemplatesResult;
use mcp_types::ListResourcesRequestParams;
use mcp_types::ListResourcesResult;
use mcp_types::ReadResourceRequestParams;
use mcp_types::ReadResourceResult;
use mcp_types::Resource;
use mcp_types::ResourceTemplate;
use mcp_types::TextContent;
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use crate::function_tool::FunctionCallError;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::protocol::EventMsg;
use crate::protocol::McpInvocation;
use crate::protocol::McpToolCallBeginEvent;
use crate::protocol::McpToolCallEndEvent;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct McpResourceHandler;
#[derive(Debug, Deserialize, Default)]
struct ListResourcesArgs {
/// Lists all resources from all servers if not specified.
#[serde(default)]
server: Option<String>,
#[serde(default)]
cursor: Option<String>,
}
#[derive(Debug, Deserialize, Default)]
struct ListResourceTemplatesArgs {
/// Lists all resource templates from all servers if not specified.
#[serde(default)]
server: Option<String>,
#[serde(default)]
cursor: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ReadResourceArgs {
server: String,
uri: String,
}
#[derive(Debug, Serialize)]
struct ResourceWithServer {
server: String,
#[serde(flatten)]
resource: Resource,
}
impl ResourceWithServer {
fn new(server: String, resource: Resource) -> Self {
Self { server, resource }
}
}
#[derive(Debug, Serialize)]
struct ResourceTemplateWithServer {
server: String,
#[serde(flatten)]
template: ResourceTemplate,
}
impl ResourceTemplateWithServer {
fn new(server: String, template: ResourceTemplate) -> Self {
Self { server, template }
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ListResourcesPayload {
#[serde(skip_serializing_if = "Option::is_none")]
server: Option<String>,
resources: Vec<ResourceWithServer>,
#[serde(skip_serializing_if = "Option::is_none")]
next_cursor: Option<String>,
}
impl ListResourcesPayload {
fn from_single_server(server: String, result: ListResourcesResult) -> Self {
let resources = result
.resources
.into_iter()
.map(|resource| ResourceWithServer::new(server.clone(), resource))
.collect();
Self {
server: Some(server),
resources,
next_cursor: result.next_cursor,
}
}
fn from_all_servers(resources_by_server: HashMap<String, Vec<Resource>>) -> Self {
let mut entries: Vec<(String, Vec<Resource>)> = resources_by_server.into_iter().collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
let mut resources = Vec::new();
for (server, server_resources) in entries {
for resource in server_resources {
resources.push(ResourceWithServer::new(server.clone(), resource));
}
}
Self {
server: None,
resources,
next_cursor: None,
}
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct ListResourceTemplatesPayload {
#[serde(skip_serializing_if = "Option::is_none")]
server: Option<String>,
resource_templates: Vec<ResourceTemplateWithServer>,
#[serde(skip_serializing_if = "Option::is_none")]
next_cursor: Option<String>,
}
impl ListResourceTemplatesPayload {
fn from_single_server(server: String, result: ListResourceTemplatesResult) -> Self {
let resource_templates = result
.resource_templates
.into_iter()
.map(|template| ResourceTemplateWithServer::new(server.clone(), template))
.collect();
Self {
server: Some(server),
resource_templates,
next_cursor: result.next_cursor,
}
}
fn from_all_servers(templates_by_server: HashMap<String, Vec<ResourceTemplate>>) -> Self {
let mut entries: Vec<(String, Vec<ResourceTemplate>)> =
templates_by_server.into_iter().collect();
entries.sort_by(|a, b| a.0.cmp(&b.0));
let mut resource_templates = Vec::new();
for (server, server_templates) in entries {
for template in server_templates {
resource_templates.push(ResourceTemplateWithServer::new(server.clone(), template));
}
}
Self {
server: None,
resource_templates,
next_cursor: None,
}
}
}
#[derive(Debug, Serialize)]
struct ReadResourcePayload {
server: String,
uri: String,
#[serde(flatten)]
result: ReadResourceResult,
}
#[async_trait]
impl ToolHandler for McpResourceHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"mcp_resource handler received unsupported payload".to_string(),
));
}
};
let arguments_value = parse_arguments(arguments.as_str())?;
match tool_name.as_str() {
"list_mcp_resources" => {
handle_list_resources(
Arc::clone(&session),
Arc::clone(&turn),
call_id.clone(),
arguments_value.clone(),
)
.await
}
"list_mcp_resource_templates" => {
handle_list_resource_templates(
Arc::clone(&session),
Arc::clone(&turn),
call_id.clone(),
arguments_value.clone(),
)
.await
}
"read_mcp_resource" => {
handle_read_resource(
Arc::clone(&session),
Arc::clone(&turn),
call_id,
arguments_value,
)
.await
}
other => Err(FunctionCallError::RespondToModel(format!(
"unsupported MCP resource tool: {other}"
))),
}
}
}
async fn handle_list_resources(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: Option<Value>,
) -> Result<ToolOutput, FunctionCallError> {
let args: ListResourcesArgs = parse_args_with_default(arguments.clone())?;
let ListResourcesArgs { server, cursor } = args;
let server = normalize_optional_string(server);
let cursor = normalize_optional_string(cursor);
let invocation = McpInvocation {
server: server.clone().unwrap_or_else(|| "llmx".to_string()),
tool: "list_mcp_resources".to_string(),
arguments: arguments.clone(),
};
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
let start = Instant::now();
let payload_result: Result<ListResourcesPayload, FunctionCallError> = async {
if let Some(server_name) = server.clone() {
let params = cursor.clone().map(|value| ListResourcesRequestParams {
cursor: Some(value),
});
let result = session
.list_resources(&server_name, params)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("resources/list failed: {err:#}"))
})?;
Ok(ListResourcesPayload::from_single_server(
server_name,
result,
))
} else {
if cursor.is_some() {
return Err(FunctionCallError::RespondToModel(
"cursor can only be used when a server is specified".to_string(),
));
}
let resources = session
.services
.mcp_connection_manager
.list_all_resources()
.await;
Ok(ListResourcesPayload::from_all_servers(resources))
}
}
.await;
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function {
content, success, ..
} = &output
else {
unreachable!("MCP resource handler should return function output");
};
let duration = start.elapsed();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Ok(call_tool_result_from_content(content, *success)),
)
.await;
Ok(output)
}
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
},
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
}
}
async fn handle_list_resource_templates(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: Option<Value>,
) -> Result<ToolOutput, FunctionCallError> {
let args: ListResourceTemplatesArgs = parse_args_with_default(arguments.clone())?;
let ListResourceTemplatesArgs { server, cursor } = args;
let server = normalize_optional_string(server);
let cursor = normalize_optional_string(cursor);
let invocation = McpInvocation {
server: server.clone().unwrap_or_else(|| "llmx".to_string()),
tool: "list_mcp_resource_templates".to_string(),
arguments: arguments.clone(),
};
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
let start = Instant::now();
let payload_result: Result<ListResourceTemplatesPayload, FunctionCallError> = async {
if let Some(server_name) = server.clone() {
let params = cursor
.clone()
.map(|value| ListResourceTemplatesRequestParams {
cursor: Some(value),
});
let result = session
.list_resource_templates(&server_name, params)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!(
"resources/templates/list failed: {err:#}"
))
})?;
Ok(ListResourceTemplatesPayload::from_single_server(
server_name,
result,
))
} else {
if cursor.is_some() {
return Err(FunctionCallError::RespondToModel(
"cursor can only be used when a server is specified".to_string(),
));
}
let templates = session
.services
.mcp_connection_manager
.list_all_resource_templates()
.await;
Ok(ListResourceTemplatesPayload::from_all_servers(templates))
}
}
.await;
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function {
content, success, ..
} = &output
else {
unreachable!("MCP resource handler should return function output");
};
let duration = start.elapsed();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Ok(call_tool_result_from_content(content, *success)),
)
.await;
Ok(output)
}
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
},
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
}
}
async fn handle_read_resource(
session: Arc<Session>,
turn: Arc<TurnContext>,
call_id: String,
arguments: Option<Value>,
) -> Result<ToolOutput, FunctionCallError> {
let args: ReadResourceArgs = parse_args(arguments.clone())?;
let ReadResourceArgs { server, uri } = args;
let server = normalize_required_string("server", server)?;
let uri = normalize_required_string("uri", uri)?;
let invocation = McpInvocation {
server: server.clone(),
tool: "read_mcp_resource".to_string(),
arguments: arguments.clone(),
};
emit_tool_call_begin(&session, turn.as_ref(), &call_id, invocation.clone()).await;
let start = Instant::now();
let payload_result: Result<ReadResourcePayload, FunctionCallError> = async {
let result = session
.read_resource(&server, ReadResourceRequestParams { uri: uri.clone() })
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("resources/read failed: {err:#}"))
})?;
Ok(ReadResourcePayload {
server,
uri,
result,
})
}
.await;
match payload_result {
Ok(payload) => match serialize_function_output(payload) {
Ok(output) => {
let ToolOutput::Function {
content, success, ..
} = &output
else {
unreachable!("MCP resource handler should return function output");
};
let duration = start.elapsed();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Ok(call_tool_result_from_content(content, *success)),
)
.await;
Ok(output)
}
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
},
Err(err) => {
let duration = start.elapsed();
let message = err.to_string();
emit_tool_call_end(
&session,
turn.as_ref(),
&call_id,
invocation,
duration,
Err(message.clone()),
)
.await;
Err(err)
}
}
}
fn call_tool_result_from_content(content: &str, success: Option<bool>) -> CallToolResult {
CallToolResult {
content: vec![ContentBlock::TextContent(TextContent {
annotations: None,
text: content.to_string(),
r#type: "text".to_string(),
})],
is_error: success.map(|value| !value),
structured_content: None,
}
}
async fn emit_tool_call_begin(
session: &Arc<Session>,
turn: &TurnContext,
call_id: &str,
invocation: McpInvocation,
) {
session
.send_event(
turn,
EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
call_id: call_id.to_string(),
invocation,
}),
)
.await;
}
async fn emit_tool_call_end(
session: &Arc<Session>,
turn: &TurnContext,
call_id: &str,
invocation: McpInvocation,
duration: Duration,
result: Result<CallToolResult, String>,
) {
session
.send_event(
turn,
EventMsg::McpToolCallEnd(McpToolCallEndEvent {
call_id: call_id.to_string(),
invocation,
duration,
result,
}),
)
.await;
}
fn normalize_optional_string(input: Option<String>) -> Option<String> {
input.and_then(|value| {
let trimmed = value.trim().to_string();
if trimmed.is_empty() {
None
} else {
Some(trimmed)
}
})
}
fn normalize_required_string(field: &str, value: String) -> Result<String, FunctionCallError> {
match normalize_optional_string(Some(value)) {
Some(normalized) => Ok(normalized),
None => Err(FunctionCallError::RespondToModel(format!(
"{field} must be provided"
))),
}
}
fn serialize_function_output<T>(payload: T) -> Result<ToolOutput, FunctionCallError>
where
T: Serialize,
{
let content = serde_json::to_string(&payload).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to serialize MCP resource response: {err}"
))
})?;
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
fn parse_arguments(raw_args: &str) -> Result<Option<Value>, FunctionCallError> {
if raw_args.trim().is_empty() {
Ok(None)
} else {
serde_json::from_str(raw_args).map(Some).map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
})
}
}
fn parse_args<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
where
T: DeserializeOwned,
{
match arguments {
Some(value) => serde_json::from_value(value).map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {err}"))
}),
None => Err(FunctionCallError::RespondToModel(
"failed to parse function arguments: expected value".to_string(),
)),
}
}
fn parse_args_with_default<T>(arguments: Option<Value>) -> Result<T, FunctionCallError>
where
T: DeserializeOwned + Default,
{
match arguments {
Some(value) => parse_args(Some(value)),
None => Ok(T::default()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use mcp_types::ListResourcesResult;
use mcp_types::ResourceTemplate;
use pretty_assertions::assert_eq;
use serde_json::json;
fn resource(uri: &str, name: &str) -> Resource {
Resource {
annotations: None,
description: None,
mime_type: None,
name: name.to_string(),
size: None,
title: None,
uri: uri.to_string(),
}
}
fn template(uri_template: &str, name: &str) -> ResourceTemplate {
ResourceTemplate {
annotations: None,
description: None,
mime_type: None,
name: name.to_string(),
title: None,
uri_template: uri_template.to_string(),
}
}
#[test]
fn resource_with_server_serializes_server_field() {
let entry = ResourceWithServer::new("test".to_string(), resource("memo://id", "memo"));
let value = serde_json::to_value(&entry).expect("serialize resource");
assert_eq!(value["server"], json!("test"));
assert_eq!(value["uri"], json!("memo://id"));
assert_eq!(value["name"], json!("memo"));
}
#[test]
fn list_resources_payload_from_single_server_copies_next_cursor() {
let result = ListResourcesResult {
next_cursor: Some("cursor-1".to_string()),
resources: vec![resource("memo://id", "memo")],
};
let payload = ListResourcesPayload::from_single_server("srv".to_string(), result);
let value = serde_json::to_value(&payload).expect("serialize payload");
assert_eq!(value["server"], json!("srv"));
assert_eq!(value["nextCursor"], json!("cursor-1"));
let resources = value["resources"].as_array().expect("resources array");
assert_eq!(resources.len(), 1);
assert_eq!(resources[0]["server"], json!("srv"));
}
#[test]
fn list_resources_payload_from_all_servers_is_sorted() {
let mut map = HashMap::new();
map.insert("beta".to_string(), vec![resource("memo://b-1", "b-1")]);
map.insert(
"alpha".to_string(),
vec![resource("memo://a-1", "a-1"), resource("memo://a-2", "a-2")],
);
let payload = ListResourcesPayload::from_all_servers(map);
let value = serde_json::to_value(&payload).expect("serialize payload");
let uris: Vec<String> = value["resources"]
.as_array()
.expect("resources array")
.iter()
.map(|entry| entry["uri"].as_str().unwrap().to_string())
.collect();
assert_eq!(
uris,
vec![
"memo://a-1".to_string(),
"memo://a-2".to_string(),
"memo://b-1".to_string()
]
);
}
#[test]
fn call_tool_result_from_content_marks_success() {
let result = call_tool_result_from_content("{}", Some(true));
assert_eq!(result.is_error, Some(false));
assert_eq!(result.content.len(), 1);
}
#[test]
fn parse_arguments_handles_empty_and_json() {
assert!(
parse_arguments(" \n\t").unwrap().is_none(),
"expected None for empty arguments"
);
let value = parse_arguments(r#"{"server":"figma"}"#)
.expect("parse json")
.expect("value present");
assert_eq!(value["server"], json!("figma"));
}
#[test]
fn template_with_server_serializes_server_field() {
let entry =
ResourceTemplateWithServer::new("srv".to_string(), template("memo://{id}", "memo"));
let value = serde_json::to_value(&entry).expect("serialize template");
assert_eq!(
value,
json!({
"server": "srv",
"uriTemplate": "memo://{id}",
"name": "memo"
})
);
}
}

View File

@@ -0,0 +1,25 @@
pub mod apply_patch;
mod grep_files;
mod list_dir;
mod mcp;
mod mcp_resource;
mod plan;
mod read_file;
mod shell;
mod test_sync;
mod unified_exec;
mod view_image;
pub use plan::PLAN_TOOL;
pub use apply_patch::ApplyPatchHandler;
pub use grep_files::GrepFilesHandler;
pub use list_dir::ListDirHandler;
pub use mcp::McpHandler;
pub use mcp_resource::McpResourceHandler;
pub use plan::PlanHandler;
pub use read_file::ReadFileHandler;
pub use shell::ShellHandler;
pub use test_sync::TestSyncHandler;
pub use unified_exec::UnifiedExecHandler;
pub use view_image::ViewImageHandler;

View File

@@ -0,0 +1,117 @@
use crate::client_common::tools::ResponsesApiTool;
use crate::client_common::tools::ToolSpec;
use crate::function_tool::FunctionCallError;
use crate::llmx::Session;
use crate::llmx::TurnContext;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::tools::spec::JsonSchema;
use async_trait::async_trait;
use llmx_protocol::plan_tool::UpdatePlanArgs;
use llmx_protocol::protocol::EventMsg;
use std::collections::BTreeMap;
use std::sync::LazyLock;
pub struct PlanHandler;
pub static PLAN_TOOL: LazyLock<ToolSpec> = LazyLock::new(|| {
let mut plan_item_props = BTreeMap::new();
plan_item_props.insert("step".to_string(), JsonSchema::String { description: None });
plan_item_props.insert(
"status".to_string(),
JsonSchema::String {
description: Some("One of: pending, in_progress, completed".to_string()),
},
);
let plan_items_schema = JsonSchema::Array {
description: Some("The list of steps".to_string()),
items: Box::new(JsonSchema::Object {
properties: plan_item_props,
required: Some(vec!["step".to_string(), "status".to_string()]),
additional_properties: Some(false.into()),
}),
};
let mut properties = BTreeMap::new();
properties.insert(
"explanation".to_string(),
JsonSchema::String { description: None },
);
properties.insert("plan".to_string(), plan_items_schema);
ToolSpec::Function(ResponsesApiTool {
name: "update_plan".to_string(),
description: r#"Updates the task plan.
Provide an optional explanation and a list of plan items, each with a step and status.
At most one step can be in_progress at a time.
"#
.to_string(),
strict: false,
parameters: JsonSchema::Object {
properties,
required: Some(vec!["plan".to_string()]),
additional_properties: Some(false.into()),
},
})
});
#[async_trait]
impl ToolHandler for PlanHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
call_id,
payload,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"update_plan handler received unsupported payload".to_string(),
));
}
};
let content =
handle_update_plan(session.as_ref(), turn.as_ref(), arguments, call_id).await?;
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
}
/// This function doesn't do anything useful. However, it gives the model a structured way to record its plan that clients can read and render.
/// So it's the _inputs_ to this function that are useful to clients, not the outputs and neither are actually useful for the model other
/// than forcing it to come up and document a plan (TBD how that affects performance).
pub(crate) async fn handle_update_plan(
session: &Session,
turn_context: &TurnContext,
arguments: String,
_call_id: String,
) -> Result<String, FunctionCallError> {
let args = parse_update_plan_arguments(&arguments)?;
session
.send_event(turn_context, EventMsg::PlanUpdate(args))
.await;
Ok("Plan updated".to_string())
}
fn parse_update_plan_arguments(arguments: &str) -> Result<UpdatePlanArgs, FunctionCallError> {
serde_json::from_str::<UpdatePlanArgs>(arguments).map_err(|e| {
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e}"))
})
}

View File

@@ -0,0 +1,999 @@
use std::collections::VecDeque;
use std::path::PathBuf;
use async_trait::async_trait;
use llmx_utils_string::take_bytes_at_char_boundary;
use serde::Deserialize;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct ReadFileHandler;
const MAX_LINE_LENGTH: usize = 500;
const TAB_WIDTH: usize = 4;
// TODO(jif) add support for block comments
const COMMENT_PREFIXES: &[&str] = &["#", "//", "--"];
/// JSON arguments accepted by the `read_file` tool handler.
#[derive(Deserialize)]
struct ReadFileArgs {
/// Absolute path to the file that will be read.
file_path: String,
/// 1-indexed line number to start reading from; defaults to 1.
#[serde(default = "defaults::offset")]
offset: usize,
/// Maximum number of lines to return; defaults to 2000.
#[serde(default = "defaults::limit")]
limit: usize,
/// Determines whether the handler reads a simple slice or indentation-aware block.
#[serde(default)]
mode: ReadMode,
/// Optional indentation configuration used when `mode` is `Indentation`.
#[serde(default)]
indentation: Option<IndentationArgs>,
}
#[derive(Deserialize)]
#[serde(rename_all = "snake_case")]
enum ReadMode {
Slice,
Indentation,
}
/// Additional configuration for indentation-aware reads.
#[derive(Deserialize, Clone)]
struct IndentationArgs {
/// Optional explicit anchor line; defaults to `offset` when omitted.
#[serde(default)]
anchor_line: Option<usize>,
/// Maximum indentation depth to collect; `0` means unlimited.
#[serde(default = "defaults::max_levels")]
max_levels: usize,
/// Whether to include sibling blocks at the same indentation level.
#[serde(default = "defaults::include_siblings")]
include_siblings: bool,
/// Whether to include header lines above the anchor block. This made on a best effort basis.
#[serde(default = "defaults::include_header")]
include_header: bool,
/// Optional hard cap on returned lines; defaults to the global `limit`.
#[serde(default)]
max_lines: Option<usize>,
}
#[derive(Clone, Debug)]
struct LineRecord {
number: usize,
raw: String,
display: String,
indent: usize,
}
impl LineRecord {
fn trimmed(&self) -> &str {
self.raw.trim_start()
}
fn is_blank(&self) -> bool {
self.trimmed().is_empty()
}
fn is_comment(&self) -> bool {
COMMENT_PREFIXES
.iter()
.any(|prefix| self.raw.trim().starts_with(prefix))
}
}
#[async_trait]
impl ToolHandler for ReadFileHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"read_file handler received unsupported payload".to_string(),
));
}
};
let args: ReadFileArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
let ReadFileArgs {
file_path,
offset,
limit,
mode,
indentation,
} = args;
if offset == 0 {
return Err(FunctionCallError::RespondToModel(
"offset must be a 1-indexed line number".to_string(),
));
}
if limit == 0 {
return Err(FunctionCallError::RespondToModel(
"limit must be greater than zero".to_string(),
));
}
let path = PathBuf::from(&file_path);
if !path.is_absolute() {
return Err(FunctionCallError::RespondToModel(
"file_path must be an absolute path".to_string(),
));
}
let collected = match mode {
ReadMode::Slice => slice::read(&path, offset, limit).await?,
ReadMode::Indentation => {
let indentation = indentation.unwrap_or_default();
indentation::read_block(&path, offset, limit, indentation).await?
}
};
Ok(ToolOutput::Function {
content: collected.join("\n"),
content_items: None,
success: Some(true),
})
}
}
mod slice {
use crate::function_tool::FunctionCallError;
use crate::tools::handlers::read_file::format_line;
use std::path::Path;
use tokio::fs::File;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
pub async fn read(
path: &Path,
offset: usize,
limit: usize,
) -> Result<Vec<String>, FunctionCallError> {
let file = File::open(path).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
})?;
let mut reader = BufReader::new(file);
let mut collected = Vec::new();
let mut seen = 0usize;
let mut buffer = Vec::new();
loop {
buffer.clear();
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
})?;
if bytes_read == 0 {
break;
}
if buffer.last() == Some(&b'\n') {
buffer.pop();
if buffer.last() == Some(&b'\r') {
buffer.pop();
}
}
seen += 1;
if seen < offset {
continue;
}
if collected.len() == limit {
break;
}
let formatted = format_line(&buffer);
collected.push(format!("L{seen}: {formatted}"));
if collected.len() == limit {
break;
}
}
if seen < offset {
return Err(FunctionCallError::RespondToModel(
"offset exceeds file length".to_string(),
));
}
Ok(collected)
}
}
mod indentation {
use crate::function_tool::FunctionCallError;
use crate::tools::handlers::read_file::IndentationArgs;
use crate::tools::handlers::read_file::LineRecord;
use crate::tools::handlers::read_file::TAB_WIDTH;
use crate::tools::handlers::read_file::format_line;
use crate::tools::handlers::read_file::trim_empty_lines;
use std::collections::VecDeque;
use std::path::Path;
use tokio::fs::File;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
pub async fn read_block(
path: &Path,
offset: usize,
limit: usize,
options: IndentationArgs,
) -> Result<Vec<String>, FunctionCallError> {
let anchor_line = options.anchor_line.unwrap_or(offset);
if anchor_line == 0 {
return Err(FunctionCallError::RespondToModel(
"anchor_line must be a 1-indexed line number".to_string(),
));
}
let guard_limit = options.max_lines.unwrap_or(limit);
if guard_limit == 0 {
return Err(FunctionCallError::RespondToModel(
"max_lines must be greater than zero".to_string(),
));
}
let collected = collect_file_lines(path).await?;
if collected.is_empty() || anchor_line > collected.len() {
return Err(FunctionCallError::RespondToModel(
"anchor_line exceeds file length".to_string(),
));
}
let anchor_index = anchor_line - 1;
let effective_indents = compute_effective_indents(&collected);
let anchor_indent = effective_indents[anchor_index];
// Compute the min indent
let min_indent = if options.max_levels == 0 {
0
} else {
anchor_indent.saturating_sub(options.max_levels * TAB_WIDTH)
};
// Cap requested lines by guard_limit and file length
let final_limit = limit.min(guard_limit).min(collected.len());
if final_limit == 1 {
return Ok(vec![format!(
"L{}: {}",
collected[anchor_index].number, collected[anchor_index].display
)]);
}
// Cursors
let mut i: isize = anchor_index as isize - 1; // up (inclusive)
let mut j: usize = anchor_index + 1; // down (inclusive)
let mut i_counter_min_indent = 0;
let mut j_counter_min_indent = 0;
let mut out = VecDeque::with_capacity(limit);
out.push_back(&collected[anchor_index]);
while out.len() < final_limit {
let mut progressed = 0;
// Up.
if i >= 0 {
let iu = i as usize;
if effective_indents[iu] >= min_indent {
out.push_front(&collected[iu]);
progressed += 1;
i -= 1;
// We do not include the siblings (not applied to comments).
if effective_indents[iu] == min_indent && !options.include_siblings {
let allow_header_comment =
options.include_header && collected[iu].is_comment();
let can_take_line = allow_header_comment || i_counter_min_indent == 0;
if can_take_line {
i_counter_min_indent += 1;
} else {
// This line shouldn't have been taken.
out.pop_front();
progressed -= 1;
i = -1; // consider using Option<usize> or a control flag instead of a sentinel
}
}
// Short-cut.
if out.len() >= final_limit {
break;
}
} else {
// Stop moving up.
i = -1;
}
}
// Down.
if j < collected.len() {
let ju = j;
if effective_indents[ju] >= min_indent {
out.push_back(&collected[ju]);
progressed += 1;
j += 1;
// We do not include the siblings (applied to comments).
if effective_indents[ju] == min_indent && !options.include_siblings {
if j_counter_min_indent > 0 {
// This line shouldn't have been taken.
out.pop_back();
progressed -= 1;
j = collected.len();
}
j_counter_min_indent += 1;
}
} else {
// Stop moving down.
j = collected.len();
}
}
if progressed == 0 {
break;
}
}
// Trim empty lines
trim_empty_lines(&mut out);
Ok(out
.into_iter()
.map(|record| format!("L{}: {}", record.number, record.display))
.collect())
}
async fn collect_file_lines(path: &Path) -> Result<Vec<LineRecord>, FunctionCallError> {
let file = File::open(path).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
})?;
let mut reader = BufReader::new(file);
let mut buffer = Vec::new();
let mut lines = Vec::new();
let mut number = 0usize;
loop {
buffer.clear();
let bytes_read = reader.read_until(b'\n', &mut buffer).await.map_err(|err| {
FunctionCallError::RespondToModel(format!("failed to read file: {err}"))
})?;
if bytes_read == 0 {
break;
}
if buffer.last() == Some(&b'\n') {
buffer.pop();
if buffer.last() == Some(&b'\r') {
buffer.pop();
}
}
number += 1;
let raw = String::from_utf8_lossy(&buffer).into_owned();
let indent = measure_indent(&raw);
let display = format_line(&buffer);
lines.push(LineRecord {
number,
raw,
display,
indent,
});
}
Ok(lines)
}
fn compute_effective_indents(records: &[LineRecord]) -> Vec<usize> {
let mut effective = Vec::with_capacity(records.len());
let mut previous_indent = 0usize;
for record in records {
if record.is_blank() {
effective.push(previous_indent);
} else {
previous_indent = record.indent;
effective.push(previous_indent);
}
}
effective
}
fn measure_indent(line: &str) -> usize {
line.chars()
.take_while(|c| matches!(c, ' ' | '\t'))
.map(|c| if c == '\t' { TAB_WIDTH } else { 1 })
.sum()
}
}
fn format_line(bytes: &[u8]) -> String {
let decoded = String::from_utf8_lossy(bytes);
if decoded.len() > MAX_LINE_LENGTH {
take_bytes_at_char_boundary(&decoded, MAX_LINE_LENGTH).to_string()
} else {
decoded.into_owned()
}
}
fn trim_empty_lines(out: &mut VecDeque<&LineRecord>) {
while matches!(out.front(), Some(line) if line.raw.trim().is_empty()) {
out.pop_front();
}
while matches!(out.back(), Some(line) if line.raw.trim().is_empty()) {
out.pop_back();
}
}
mod defaults {
use super::*;
impl Default for IndentationArgs {
fn default() -> Self {
Self {
anchor_line: None,
max_levels: max_levels(),
include_siblings: include_siblings(),
include_header: include_header(),
max_lines: None,
}
}
}
impl Default for ReadMode {
fn default() -> Self {
Self::Slice
}
}
pub fn offset() -> usize {
1
}
pub fn limit() -> usize {
2000
}
pub fn max_levels() -> usize {
0
}
pub fn include_siblings() -> bool {
false
}
pub fn include_header() -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::indentation::read_block;
use super::slice::read;
use super::*;
use pretty_assertions::assert_eq;
use tempfile::NamedTempFile;
#[tokio::test]
async fn reads_requested_range() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"alpha
beta
gamma
"
)?;
let lines = read(temp.path(), 2, 2).await?;
assert_eq!(lines, vec!["L2: beta".to_string(), "L3: gamma".to_string()]);
Ok(())
}
#[tokio::test]
async fn errors_when_offset_exceeds_length() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
writeln!(temp, "only")?;
let err = read(temp.path(), 3, 1)
.await
.expect_err("offset exceeds length");
assert_eq!(
err,
FunctionCallError::RespondToModel("offset exceeds file length".to_string())
);
Ok(())
}
#[tokio::test]
async fn reads_non_utf8_lines() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
temp.as_file_mut().write_all(b"\xff\xfe\nplain\n")?;
let lines = read(temp.path(), 1, 2).await?;
let expected_first = format!("L1: {}{}", '\u{FFFD}', '\u{FFFD}');
assert_eq!(lines, vec![expected_first, "L2: plain".to_string()]);
Ok(())
}
#[tokio::test]
async fn trims_crlf_endings() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(temp, "one\r\ntwo\r\n")?;
let lines = read(temp.path(), 1, 2).await?;
assert_eq!(lines, vec!["L1: one".to_string(), "L2: two".to_string()]);
Ok(())
}
#[tokio::test]
async fn respects_limit_even_with_more_lines() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"first
second
third
"
)?;
let lines = read(temp.path(), 1, 2).await?;
assert_eq!(
lines,
vec!["L1: first".to_string(), "L2: second".to_string()]
);
Ok(())
}
#[tokio::test]
async fn truncates_lines_longer_than_max_length() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
let long_line = "x".repeat(MAX_LINE_LENGTH + 50);
writeln!(temp, "{long_line}")?;
let lines = read(temp.path(), 1, 1).await?;
let expected = "x".repeat(MAX_LINE_LENGTH);
assert_eq!(lines, vec![format!("L1: {expected}")]);
Ok(())
}
#[tokio::test]
async fn indentation_mode_captures_block() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"fn outer() {{
if cond {{
inner();
}}
tail();
}}
"
)?;
let options = IndentationArgs {
anchor_line: Some(3),
include_siblings: false,
max_levels: 1,
..Default::default()
};
let lines = read_block(temp.path(), 3, 10, options).await?;
assert_eq!(
lines,
vec![
"L2: if cond {".to_string(),
"L3: inner();".to_string(),
"L4: }".to_string()
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_expands_parents() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"mod root {{
fn outer() {{
if cond {{
inner();
}}
}}
}}
"
)?;
let mut options = IndentationArgs {
anchor_line: Some(4),
max_levels: 2,
..Default::default()
};
let lines = read_block(temp.path(), 4, 50, options.clone()).await?;
assert_eq!(
lines,
vec![
"L2: fn outer() {".to_string(),
"L3: if cond {".to_string(),
"L4: inner();".to_string(),
"L5: }".to_string(),
"L6: }".to_string(),
]
);
options.max_levels = 3;
let expanded = read_block(temp.path(), 4, 50, options).await?;
assert_eq!(
expanded,
vec![
"L1: mod root {".to_string(),
"L2: fn outer() {".to_string(),
"L3: if cond {".to_string(),
"L4: inner();".to_string(),
"L5: }".to_string(),
"L6: }".to_string(),
"L7: }".to_string(),
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_respects_sibling_flag() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"fn wrapper() {{
if first {{
do_first();
}}
if second {{
do_second();
}}
}}
"
)?;
let mut options = IndentationArgs {
anchor_line: Some(3),
include_siblings: false,
max_levels: 1,
..Default::default()
};
let lines = read_block(temp.path(), 3, 50, options.clone()).await?;
assert_eq!(
lines,
vec![
"L2: if first {".to_string(),
"L3: do_first();".to_string(),
"L4: }".to_string(),
]
);
options.include_siblings = true;
let with_siblings = read_block(temp.path(), 3, 50, options).await?;
assert_eq!(
with_siblings,
vec![
"L2: if first {".to_string(),
"L3: do_first();".to_string(),
"L4: }".to_string(),
"L5: if second {".to_string(),
"L6: do_second();".to_string(),
"L7: }".to_string(),
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_handles_python_sample() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"class Foo:
def __init__(self, size):
self.size = size
def double(self, value):
if value is None:
return 0
result = value * self.size
return result
class Bar:
def compute(self):
helper = Foo(2)
return helper.double(5)
"
)?;
let options = IndentationArgs {
anchor_line: Some(7),
include_siblings: true,
max_levels: 1,
..Default::default()
};
let lines = read_block(temp.path(), 1, 200, options).await?;
assert_eq!(
lines,
vec![
"L2: def __init__(self, size):".to_string(),
"L3: self.size = size".to_string(),
"L4: def double(self, value):".to_string(),
"L5: if value is None:".to_string(),
"L6: return 0".to_string(),
"L7: result = value * self.size".to_string(),
"L8: return result".to_string(),
]
);
Ok(())
}
#[tokio::test]
#[ignore]
async fn indentation_mode_handles_javascript_sample() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"export function makeThing() {{
const cache = new Map();
function ensure(key) {{
if (!cache.has(key)) {{
cache.set(key, []);
}}
return cache.get(key);
}}
const handlers = {{
init() {{
console.log(\"init\");
}},
run() {{
if (Math.random() > 0.5) {{
return \"heads\";
}}
return \"tails\";
}},
}};
return {{ cache, handlers }};
}}
export function other() {{
return makeThing();
}}
"
)?;
let options = IndentationArgs {
anchor_line: Some(15),
max_levels: 1,
..Default::default()
};
let lines = read_block(temp.path(), 15, 200, options).await?;
assert_eq!(
lines,
vec![
"L10: init() {".to_string(),
"L11: console.log(\"init\");".to_string(),
"L12: },".to_string(),
"L13: run() {".to_string(),
"L14: if (Math.random() > 0.5) {".to_string(),
"L15: return \"heads\";".to_string(),
"L16: }".to_string(),
"L17: return \"tails\";".to_string(),
"L18: },".to_string(),
]
);
Ok(())
}
fn write_cpp_sample() -> anyhow::Result<NamedTempFile> {
let mut temp = NamedTempFile::new()?;
use std::io::Write as _;
write!(
temp,
"#include <vector>
#include <string>
namespace sample {{
class Runner {{
public:
void setup() {{
if (enabled_) {{
init();
}}
}}
// Run the code
int run() const {{
switch (mode_) {{
case Mode::Fast:
return fast();
case Mode::Slow:
return slow();
default:
return fallback();
}}
}}
private:
bool enabled_ = false;
Mode mode_ = Mode::Fast;
int fast() const {{
return 1;
}}
}};
}} // namespace sample
"
)?;
Ok(temp)
}
#[tokio::test]
async fn indentation_mode_handles_cpp_sample_shallow() -> anyhow::Result<()> {
let temp = write_cpp_sample()?;
let options = IndentationArgs {
include_siblings: false,
anchor_line: Some(18),
max_levels: 1,
..Default::default()
};
let lines = read_block(temp.path(), 18, 200, options).await?;
assert_eq!(
lines,
vec![
"L15: switch (mode_) {".to_string(),
"L16: case Mode::Fast:".to_string(),
"L17: return fast();".to_string(),
"L18: case Mode::Slow:".to_string(),
"L19: return slow();".to_string(),
"L20: default:".to_string(),
"L21: return fallback();".to_string(),
"L22: }".to_string(),
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_handles_cpp_sample() -> anyhow::Result<()> {
let temp = write_cpp_sample()?;
let options = IndentationArgs {
include_siblings: false,
anchor_line: Some(18),
max_levels: 2,
..Default::default()
};
let lines = read_block(temp.path(), 18, 200, options).await?;
assert_eq!(
lines,
vec![
"L13: // Run the code".to_string(),
"L14: int run() const {".to_string(),
"L15: switch (mode_) {".to_string(),
"L16: case Mode::Fast:".to_string(),
"L17: return fast();".to_string(),
"L18: case Mode::Slow:".to_string(),
"L19: return slow();".to_string(),
"L20: default:".to_string(),
"L21: return fallback();".to_string(),
"L22: }".to_string(),
"L23: }".to_string(),
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_handles_cpp_sample_no_headers() -> anyhow::Result<()> {
let temp = write_cpp_sample()?;
let options = IndentationArgs {
include_siblings: false,
include_header: false,
anchor_line: Some(18),
max_levels: 2,
..Default::default()
};
let lines = read_block(temp.path(), 18, 200, options).await?;
assert_eq!(
lines,
vec![
"L14: int run() const {".to_string(),
"L15: switch (mode_) {".to_string(),
"L16: case Mode::Fast:".to_string(),
"L17: return fast();".to_string(),
"L18: case Mode::Slow:".to_string(),
"L19: return slow();".to_string(),
"L20: default:".to_string(),
"L21: return fallback();".to_string(),
"L22: }".to_string(),
"L23: }".to_string(),
]
);
Ok(())
}
#[tokio::test]
async fn indentation_mode_handles_cpp_sample_siblings() -> anyhow::Result<()> {
let temp = write_cpp_sample()?;
let options = IndentationArgs {
include_siblings: true,
include_header: false,
anchor_line: Some(18),
max_levels: 2,
..Default::default()
};
let lines = read_block(temp.path(), 18, 200, options).await?;
assert_eq!(
lines,
vec![
"L7: void setup() {".to_string(),
"L8: if (enabled_) {".to_string(),
"L9: init();".to_string(),
"L10: }".to_string(),
"L11: }".to_string(),
"L12: ".to_string(),
"L13: // Run the code".to_string(),
"L14: int run() const {".to_string(),
"L15: switch (mode_) {".to_string(),
"L16: case Mode::Fast:".to_string(),
"L17: return fast();".to_string(),
"L18: case Mode::Slow:".to_string(),
"L19: return slow();".to_string(),
"L20: default:".to_string(),
"L21: return fallback();".to_string(),
"L22: }".to_string(),
"L23: }".to_string(),
]
);
Ok(())
}
}

View File

@@ -0,0 +1,242 @@
use async_trait::async_trait;
use llmx_protocol::models::ShellToolCallParams;
use std::sync::Arc;
use crate::apply_patch;
use crate::apply_patch::InternalApplyPatchInvocation;
use crate::apply_patch::convert_apply_patch_to_protocol;
use crate::exec::ExecParams;
use crate::exec_env::create_env;
use crate::function_tool::FunctionCallError;
use crate::llmx::TurnContext;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::orchestrator::ToolOrchestrator;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
use crate::tools::runtimes::shell::ShellRequest;
use crate::tools::runtimes::shell::ShellRuntime;
use crate::tools::sandboxing::ToolCtx;
pub struct ShellHandler;
impl ShellHandler {
fn to_exec_params(params: ShellToolCallParams, turn_context: &TurnContext) -> ExecParams {
ExecParams {
command: params.command,
cwd: turn_context.resolve_path(params.workdir.clone()),
timeout_ms: params.timeout_ms,
env: create_env(&turn_context.shell_environment_policy),
with_escalated_permissions: params.with_escalated_permissions,
justification: params.justification,
arg0: None,
}
}
}
#[async_trait]
impl ToolHandler for ShellHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
payload,
ToolPayload::Function { .. } | ToolPayload::LocalShell { .. }
)
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
tracker,
call_id,
tool_name,
payload,
} = invocation;
match payload {
ToolPayload::Function { arguments } => {
let params: ShellToolCallParams =
serde_json::from_str(&arguments).map_err(|e| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {e:?}"
))
})?;
let exec_params = Self::to_exec_params(params, turn.as_ref());
Self::run_exec_like(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
call_id,
false,
)
.await
}
ToolPayload::LocalShell { params } => {
let exec_params = Self::to_exec_params(params, turn.as_ref());
Self::run_exec_like(
tool_name.as_str(),
exec_params,
session,
turn,
tracker,
call_id,
true,
)
.await
}
_ => Err(FunctionCallError::RespondToModel(format!(
"unsupported payload for shell handler: {tool_name}"
))),
}
}
}
impl ShellHandler {
async fn run_exec_like(
tool_name: &str,
exec_params: ExecParams,
session: Arc<crate::llmx::Session>,
turn: Arc<TurnContext>,
tracker: crate::tools::context::SharedTurnDiffTracker,
call_id: String,
is_user_shell_command: bool,
) -> Result<ToolOutput, FunctionCallError> {
// Approval policy guard for explicit escalation in non-OnRequest modes.
if exec_params.with_escalated_permissions.unwrap_or(false)
&& !matches!(
turn.approval_policy,
llmx_protocol::protocol::AskForApproval::OnRequest
)
{
return Err(FunctionCallError::RespondToModel(format!(
"approval policy is {policy:?}; reject command — you should not ask for escalated permissions if the approval policy is {policy:?}",
policy = turn.approval_policy
)));
}
// Intercept apply_patch if present.
match llmx_apply_patch::maybe_parse_apply_patch_verified(
&exec_params.command,
&exec_params.cwd,
) {
llmx_apply_patch::MaybeApplyPatchVerified::Body(changes) => {
match apply_patch::apply_patch(session.as_ref(), turn.as_ref(), &call_id, changes)
.await
{
InternalApplyPatchInvocation::Output(item) => {
// Programmatic apply_patch path; return its result.
let content = item?;
return Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
});
}
InternalApplyPatchInvocation::DelegateToExec(apply) => {
let emitter = ToolEmitter::apply_patch(
convert_apply_patch_to_protocol(&apply.action),
!apply.user_explicitly_approved_this_action,
);
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
&call_id,
Some(&tracker),
);
emitter.begin(event_ctx).await;
let req = ApplyPatchRequest {
patch: apply.action.patch.clone(),
cwd: apply.action.cwd.clone(),
timeout_ms: exec_params.timeout_ms,
user_explicitly_approved: apply.user_explicitly_approved_this_action,
llmx_exe: turn.llmx_linux_sandbox_exe.clone(),
};
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ApplyPatchRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
call_id: call_id.clone(),
tool_name: tool_name.to_string(),
};
let out = orchestrator
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
.await;
let event_ctx = ToolEventCtx::new(
session.as_ref(),
turn.as_ref(),
&call_id,
Some(&tracker),
);
let content = emitter.finish(event_ctx, out).await?;
return Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
});
}
}
}
llmx_apply_patch::MaybeApplyPatchVerified::CorrectnessError(parse_error) => {
return Err(FunctionCallError::RespondToModel(format!(
"apply_patch verification failed: {parse_error}"
)));
}
llmx_apply_patch::MaybeApplyPatchVerified::ShellParseError(error) => {
tracing::trace!("Failed to parse shell command, {error:?}");
// Fall through to regular shell execution.
}
llmx_apply_patch::MaybeApplyPatchVerified::NotApplyPatch => {
// Fall through to regular shell execution.
}
}
// Regular shell execution path.
let emitter = ToolEmitter::shell(
exec_params.command.clone(),
exec_params.cwd.clone(),
is_user_shell_command,
);
let event_ctx = ToolEventCtx::new(session.as_ref(), turn.as_ref(), &call_id, None);
emitter.begin(event_ctx).await;
let req = ShellRequest {
command: exec_params.command.clone(),
cwd: exec_params.cwd.clone(),
timeout_ms: exec_params.timeout_ms,
env: exec_params.env.clone(),
with_escalated_permissions: exec_params.with_escalated_permissions,
justification: exec_params.justification.clone(),
};
let mut orchestrator = ToolOrchestrator::new();
let mut runtime = ShellRuntime::new();
let tool_ctx = ToolCtx {
session: session.as_ref(),
turn: turn.as_ref(),
call_id: call_id.clone(),
tool_name: tool_name.to_string(),
};
let out = orchestrator
.run(&mut runtime, &req, &tool_ctx, &turn, turn.approval_policy)
.await;
let event_ctx = ToolEventCtx::new(session.as_ref(), turn.as_ref(), &call_id, None);
let content = emitter.finish(event_ctx, out).await?;
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
}

View File

@@ -0,0 +1,159 @@
use std::collections::HashMap;
use std::collections::hash_map::Entry;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Duration;
use async_trait::async_trait;
use serde::Deserialize;
use tokio::sync::Barrier;
use tokio::time::sleep;
use crate::function_tool::FunctionCallError;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
pub struct TestSyncHandler;
const DEFAULT_TIMEOUT_MS: u64 = 1_000;
static BARRIERS: OnceLock<tokio::sync::Mutex<HashMap<String, BarrierState>>> = OnceLock::new();
struct BarrierState {
barrier: Arc<Barrier>,
participants: usize,
}
#[derive(Debug, Deserialize)]
struct BarrierArgs {
id: String,
participants: usize,
#[serde(default = "default_timeout_ms")]
timeout_ms: u64,
}
#[derive(Debug, Deserialize)]
struct TestSyncArgs {
#[serde(default)]
sleep_before_ms: Option<u64>,
#[serde(default)]
sleep_after_ms: Option<u64>,
#[serde(default)]
barrier: Option<BarrierArgs>,
}
fn default_timeout_ms() -> u64 {
DEFAULT_TIMEOUT_MS
}
fn barrier_map() -> &'static tokio::sync::Mutex<HashMap<String, BarrierState>> {
BARRIERS.get_or_init(|| tokio::sync::Mutex::new(HashMap::new()))
}
#[async_trait]
impl ToolHandler for TestSyncHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation { payload, .. } = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"test_sync_tool handler received unsupported payload".to_string(),
));
}
};
let args: TestSyncArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse function arguments: {err:?}"
))
})?;
if let Some(delay) = args.sleep_before_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
if let Some(barrier) = args.barrier {
wait_on_barrier(barrier).await?;
}
if let Some(delay) = args.sleep_after_ms
&& delay > 0
{
sleep(Duration::from_millis(delay)).await;
}
Ok(ToolOutput::Function {
content: "ok".to_string(),
content_items: None,
success: Some(true),
})
}
}
async fn wait_on_barrier(args: BarrierArgs) -> Result<(), FunctionCallError> {
if args.participants == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier participants must be greater than zero".to_string(),
));
}
if args.timeout_ms == 0 {
return Err(FunctionCallError::RespondToModel(
"barrier timeout must be greater than zero".to_string(),
));
}
let barrier_id = args.id.clone();
let barrier = {
let mut map = barrier_map().lock().await;
match map.entry(barrier_id.clone()) {
Entry::Occupied(entry) => {
let state = entry.get();
if state.participants != args.participants {
let existing = state.participants;
return Err(FunctionCallError::RespondToModel(format!(
"barrier {barrier_id} already registered with {existing} participants"
)));
}
state.barrier.clone()
}
Entry::Vacant(entry) => {
let barrier = Arc::new(Barrier::new(args.participants));
entry.insert(BarrierState {
barrier: barrier.clone(),
participants: args.participants,
});
barrier
}
}
};
let timeout = Duration::from_millis(args.timeout_ms);
let wait_result = tokio::time::timeout(timeout, barrier.wait())
.await
.map_err(|_| {
FunctionCallError::RespondToModel("test_sync_tool barrier wait timed out".to_string())
})?;
if wait_result.is_leader() {
let mut map = barrier_map().lock().await;
if let Some(state) = map.get(&barrier_id)
&& Arc::ptr_eq(&state.barrier, &barrier)
{
map.remove(&barrier_id);
}
}
Ok(())
}

View File

@@ -0,0 +1,19 @@
start: begin_patch hunk+ end_patch
begin_patch: "*** Begin Patch" LF
end_patch: "*** End Patch" LF?
hunk: add_hunk | delete_hunk | update_hunk
add_hunk: "*** Add File: " filename LF add_line+
delete_hunk: "*** Delete File: " filename LF
update_hunk: "*** Update File: " filename LF change_move? change?
filename: /(.+)/
add_line: "+" /(.*)/ LF -> line
change_move: "*** Move to: " filename LF
change: (change_context | change_line)+ eof_line?
change_context: ("@@" | "@@ " /(.+)/) LF
change_line: ("+" | "-" | " ") /(.*)/ LF
eof_line: "*** End of File" LF
%import common.LF

View File

@@ -0,0 +1,209 @@
use std::path::PathBuf;
use async_trait::async_trait;
use serde::Deserialize;
use crate::function_tool::FunctionCallError;
use crate::protocol::EventMsg;
use crate::protocol::ExecCommandOutputDeltaEvent;
use crate::protocol::ExecOutputStream;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::events::ToolEmitter;
use crate::tools::events::ToolEventCtx;
use crate::tools::events::ToolEventStage;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use crate::unified_exec::ExecCommandRequest;
use crate::unified_exec::UnifiedExecContext;
use crate::unified_exec::UnifiedExecResponse;
use crate::unified_exec::UnifiedExecSessionManager;
use crate::unified_exec::WriteStdinRequest;
pub struct UnifiedExecHandler;
#[derive(Debug, Deserialize)]
struct ExecCommandArgs {
cmd: String,
#[serde(default)]
workdir: Option<String>,
#[serde(default = "default_shell")]
shell: String,
#[serde(default = "default_login")]
login: bool,
#[serde(default)]
yield_time_ms: Option<u64>,
#[serde(default)]
max_output_tokens: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct WriteStdinArgs {
session_id: i32,
#[serde(default)]
chars: String,
#[serde(default)]
yield_time_ms: Option<u64>,
#[serde(default)]
max_output_tokens: Option<usize>,
}
fn default_shell() -> String {
"/bin/bash".to_string()
}
fn default_login() -> bool {
true
}
#[async_trait]
impl ToolHandler for UnifiedExecHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
fn matches_kind(&self, payload: &ToolPayload) -> bool {
matches!(
payload,
ToolPayload::Function { .. } | ToolPayload::UnifiedExec { .. }
)
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
call_id,
tool_name,
payload,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
ToolPayload::UnifiedExec { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"unified_exec handler received unsupported payload".to_string(),
));
}
};
let manager: &UnifiedExecSessionManager = &session.services.unified_exec_manager;
let context = UnifiedExecContext::new(session.clone(), turn.clone(), call_id.clone());
let response = match tool_name.as_str() {
"exec_command" => {
let args: ExecCommandArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse exec_command arguments: {err:?}"
))
})?;
let workdir = args
.workdir
.as_deref()
.filter(|value| !value.is_empty())
.map(PathBuf::from);
let cwd = workdir.clone().unwrap_or_else(|| context.turn.cwd.clone());
let event_ctx = ToolEventCtx::new(
context.session.as_ref(),
context.turn.as_ref(),
&context.call_id,
None,
);
let emitter = ToolEmitter::unified_exec(args.cmd.clone(), cwd.clone(), true);
emitter.emit(event_ctx, ToolEventStage::Begin).await;
manager
.exec_command(
ExecCommandRequest {
command: &args.cmd,
shell: &args.shell,
login: args.login,
yield_time_ms: args.yield_time_ms,
max_output_tokens: args.max_output_tokens,
workdir,
},
&context,
)
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("exec_command failed: {err:?}"))
})?
}
"write_stdin" => {
let args: WriteStdinArgs = serde_json::from_str(&arguments).map_err(|err| {
FunctionCallError::RespondToModel(format!(
"failed to parse write_stdin arguments: {err:?}"
))
})?;
manager
.write_stdin(WriteStdinRequest {
session_id: args.session_id,
input: &args.chars,
yield_time_ms: args.yield_time_ms,
max_output_tokens: args.max_output_tokens,
})
.await
.map_err(|err| {
FunctionCallError::RespondToModel(format!("write_stdin failed: {err:?}"))
})?
}
other => {
return Err(FunctionCallError::RespondToModel(format!(
"unsupported unified exec function {other}"
)));
}
};
// Emit a delta event with the chunk of output we just produced, if any.
if !response.output.is_empty() {
let delta = ExecCommandOutputDeltaEvent {
call_id: response.event_call_id.clone(),
stream: ExecOutputStream::Stdout,
chunk: response.output.as_bytes().to_vec(),
};
session
.send_event(turn.as_ref(), EventMsg::ExecCommandOutputDelta(delta))
.await;
}
let content = format_response(&response);
Ok(ToolOutput::Function {
content,
content_items: None,
success: Some(true),
})
}
}
fn format_response(response: &UnifiedExecResponse) -> String {
let mut sections = Vec::new();
if !response.chunk_id.is_empty() {
sections.push(format!("Chunk ID: {}", response.chunk_id));
}
let wall_time_seconds = response.wall_time.as_secs_f64();
sections.push(format!("Wall time: {wall_time_seconds:.4} seconds"));
if let Some(exit_code) = response.exit_code {
sections.push(format!("Process exited with code {exit_code}"));
}
if let Some(session_id) = response.session_id {
sections.push(format!("Process running with session ID {session_id}"));
}
if let Some(original_token_count) = response.original_token_count {
sections.push(format!("Original token count: {original_token_count}"));
}
sections.push("Output:".to_string());
sections.push(response.output.clone());
sections.join("\n")
}

View File

@@ -0,0 +1,92 @@
use async_trait::async_trait;
use serde::Deserialize;
use tokio::fs;
use crate::function_tool::FunctionCallError;
use crate::protocol::EventMsg;
use crate::protocol::ViewImageToolCallEvent;
use crate::tools::context::ToolInvocation;
use crate::tools::context::ToolOutput;
use crate::tools::context::ToolPayload;
use crate::tools::registry::ToolHandler;
use crate::tools::registry::ToolKind;
use llmx_protocol::user_input::UserInput;
pub struct ViewImageHandler;
#[derive(Deserialize)]
struct ViewImageArgs {
path: String,
}
#[async_trait]
impl ToolHandler for ViewImageHandler {
fn kind(&self) -> ToolKind {
ToolKind::Function
}
async fn handle(&self, invocation: ToolInvocation) -> Result<ToolOutput, FunctionCallError> {
let ToolInvocation {
session,
turn,
payload,
call_id,
..
} = invocation;
let arguments = match payload {
ToolPayload::Function { arguments } => arguments,
_ => {
return Err(FunctionCallError::RespondToModel(
"view_image handler received unsupported payload".to_string(),
));
}
};
let args: ViewImageArgs = serde_json::from_str(&arguments).map_err(|e| {
FunctionCallError::RespondToModel(format!("failed to parse function arguments: {e:?}"))
})?;
let abs_path = turn.resolve_path(Some(args.path));
let metadata = fs::metadata(&abs_path).await.map_err(|error| {
FunctionCallError::RespondToModel(format!(
"unable to locate image at `{}`: {error}",
abs_path.display()
))
})?;
if !metadata.is_file() {
return Err(FunctionCallError::RespondToModel(format!(
"image path `{}` is not a file",
abs_path.display()
)));
}
let event_path = abs_path.clone();
session
.inject_input(vec![UserInput::LocalImage { path: abs_path }])
.await
.map_err(|_| {
FunctionCallError::RespondToModel(
"unable to attach image (no active task)".to_string(),
)
})?;
session
.send_event(
turn.as_ref(),
EventMsg::ViewImageToolCall(ViewImageToolCallEvent {
call_id,
path: event_path,
}),
)
.await;
Ok(ToolOutput::Function {
content: "attached local image path".to_string(),
content_items: None,
success: Some(true),
})
}
}

Some files were not shown because too many files have changed in this diff Show More