diff --git a/codex-rs/login/src/server.rs b/codex-rs/login/src/server.rs index bca3b164..cfd1db8b 100644 --- a/codex-rs/login/src/server.rs +++ b/codex-rs/login/src/server.rs @@ -1,9 +1,14 @@ use std::io::Cursor; +use std::io::Read; +use std::io::Write; use std::io::{self}; +use std::net::SocketAddr; +use std::net::TcpStream; use std::path::Path; use std::path::PathBuf; use std::sync::Arc; use std::thread; +use std::time::Duration; use crate::pkce::PkceCodes; use crate::pkce::generate_pkce; @@ -85,7 +90,7 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result { let pkce = generate_pkce(); let state = opts.force_state.clone().unwrap_or_else(generate_state); - let server = Server::http(format!("127.0.0.1:{}", opts.port)).map_err(io::Error::other)?; + let server = bind_server(opts.port)?; let actual_port = match server.server_addr().to_ip() { Some(addr) => addr.port(), None => { @@ -145,19 +150,24 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result { let response = process_request(&url_raw, &opts, &redirect_uri, &pkce, actual_port, &state).await; - let is_login_complete = matches!(response, HandledRequest::ResponseAndExit(_)); - match response { - HandledRequest::Response(r) | HandledRequest::ResponseAndExit(r) => { - let _ = tokio::task::spawn_blocking(move || req.respond(r)).await; + let exit_result = match response { + HandledRequest::Response(response) => { + let _ = tokio::task::spawn_blocking(move || req.respond(response)).await; + None + } + HandledRequest::ResponseAndExit { response, result } => { + let _ = tokio::task::spawn_blocking(move || req.respond(response)).await; + Some(result) } HandledRequest::RedirectWithHeader(header) => { let redirect = Response::empty(302).with_header(header); let _ = tokio::task::spawn_blocking(move || req.respond(redirect)).await; + None } - } + }; - if is_login_complete { - break Ok(()); + if let Some(result) = exit_result { + break result; } } } @@ -181,7 +191,10 @@ pub fn run_login_server(opts: ServerOptions) -> io::Result { enum HandledRequest { Response(Response>>), RedirectWithHeader(Header), - ResponseAndExit(Response>>), + ResponseAndExit { + response: Response>>, + result: io::Result<()>, + }, } async fn process_request( @@ -276,8 +289,18 @@ async fn process_request( ) { resp.add_header(h); } - HandledRequest::ResponseAndExit(resp) + HandledRequest::ResponseAndExit { + response: resp, + result: Ok(()), + } } + "/cancel" => HandledRequest::ResponseAndExit { + response: Response::from_string("Login cancelled"), + result: Err(io::Error::new( + io::ErrorKind::Interrupted, + "Login cancelled", + )), + }, _ => HandledRequest::Response(Response::from_string("Not Found").with_status_code(404)), } } @@ -316,6 +339,68 @@ fn generate_state() -> String { base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) } +fn send_cancel_request(port: u16) -> io::Result<()> { + let addr: SocketAddr = format!("127.0.0.1:{port}") + .parse() + .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; + let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?; + stream.set_read_timeout(Some(Duration::from_secs(2)))?; + stream.set_write_timeout(Some(Duration::from_secs(2)))?; + + stream.write_all(b"GET /cancel HTTP/1.1\r\n")?; + stream.write_all(format!("Host: 127.0.0.1:{port}\r\n").as_bytes())?; + stream.write_all(b"Connection: close\r\n\r\n")?; + + let mut buf = [0u8; 64]; + let _ = stream.read(&mut buf); + Ok(()) +} + +fn bind_server(port: u16) -> io::Result { + let bind_address = format!("127.0.0.1:{port}"); + let mut cancel_attempted = false; + let mut attempts = 0; + const MAX_ATTEMPTS: u32 = 10; + const RETRY_DELAY: Duration = Duration::from_millis(200); + + loop { + match Server::http(&bind_address) { + Ok(server) => return Ok(server), + Err(err) => { + attempts += 1; + let is_addr_in_use = err + .downcast_ref::() + .map(|io_err| io_err.kind() == io::ErrorKind::AddrInUse) + .unwrap_or(false); + + // If the address is in use, there is probably another instance of the login server + // running. Attempt to cancel it and retry. + if is_addr_in_use { + if !cancel_attempted { + cancel_attempted = true; + if let Err(cancel_err) = send_cancel_request(port) { + eprintln!("Failed to cancel previous login server: {cancel_err}"); + } + } + + thread::sleep(RETRY_DELAY); + + if attempts >= MAX_ATTEMPTS { + return Err(io::Error::new( + io::ErrorKind::AddrInUse, + format!("Port {bind_address} is already in use"), + )); + } + + continue; + } + + return Err(io::Error::other(err)); + } + } + } +} + struct ExchangedTokens { id_token: String, access_token: String, diff --git a/codex-rs/login/tests/suite/login_server_e2e.rs b/codex-rs/login/tests/suite/login_server_e2e.rs index ef6b80fb..5a600a4b 100644 --- a/codex-rs/login/tests/suite/login_server_e2e.rs +++ b/codex-rs/login/tests/suite/login_server_e2e.rs @@ -1,7 +1,9 @@ #![allow(clippy::unwrap_used)] +use std::io; use std::net::SocketAddr; use std::net::TcpListener; use std::thread; +use std::time::Duration; use base64::Engine; use codex_login::ServerOptions; @@ -177,3 +179,67 @@ async fn creates_missing_codex_home_dir() { "auth.json should be created even if parent dir was missing" ); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn cancels_previous_login_server_when_port_is_in_use() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + println!( + "Skipping test because it cannot execute when network is disabled in a Codex sandbox." + ); + return; + } + + let (issuer_addr, _issuer_handle) = start_mock_issuer(); + let issuer = format!("http://{}:{}", issuer_addr.ip(), issuer_addr.port()); + + let first_tmp = tempdir().unwrap(); + let first_codex_home = first_tmp.path().to_path_buf(); + + let first_opts = ServerOptions { + codex_home: first_codex_home, + client_id: codex_login::CLIENT_ID.to_string(), + issuer: issuer.clone(), + port: 0, + open_browser: false, + force_state: Some("cancel_state".to_string()), + originator: "test_originator".to_string(), + }; + + let first_server = run_login_server(first_opts).unwrap(); + let login_port = first_server.actual_port; + let first_server_task = tokio::spawn(async move { first_server.block_until_done().await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let second_tmp = tempdir().unwrap(); + let second_codex_home = second_tmp.path().to_path_buf(); + + let second_opts = ServerOptions { + codex_home: second_codex_home, + client_id: codex_login::CLIENT_ID.to_string(), + issuer, + port: login_port, + open_browser: false, + force_state: Some("cancel_state_2".to_string()), + originator: "test_originator".to_string(), + }; + + let second_server = run_login_server(second_opts).unwrap(); + assert_eq!(second_server.actual_port, login_port); + + let cancel_result = first_server_task + .await + .expect("first login server task panicked") + .expect_err("login server should report cancellation"); + assert_eq!(cancel_result.kind(), io::ErrorKind::Interrupted); + + let client = reqwest::Client::new(); + let cancel_url = format!("http://127.0.0.1:{login_port}/cancel"); + let resp = client.get(cancel_url).send().await.unwrap(); + assert!(resp.status().is_success()); + + second_server + .block_until_done() + .await + .expect_err("second login server should report cancellation"); +}