diff --git a/Cargo.lock b/Cargo.lock index 6a1c4f3d..bcce6097 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -427,7 +427,7 @@ dependencies = [ [[package]] name = "distant" -version = "0.15.0-alpha.14" +version = "0.15.0-alpha.15" dependencies = [ "assert_cmd", "assert_fs", @@ -451,7 +451,7 @@ dependencies = [ [[package]] name = "distant-core" -version = "0.15.0-alpha.14" +version = "0.15.0-alpha.15" dependencies = [ "assert_fs", "bytes", @@ -476,7 +476,7 @@ dependencies = [ [[package]] name = "distant-lua" -version = "0.15.0-alpha.14" +version = "0.15.0-alpha.15" dependencies = [ "distant-core", "distant-ssh2", @@ -510,7 +510,7 @@ dependencies = [ [[package]] name = "distant-ssh2" -version = "0.15.0-alpha.14" +version = "0.15.0-alpha.15" dependencies = [ "assert_cmd", "assert_fs", diff --git a/distant-core/src/client/process.rs b/distant-core/src/client/process.rs index fbb83b13..7752b91b 100644 --- a/distant-core/src/client/process.rs +++ b/distant-core/src/client/process.rs @@ -6,11 +6,15 @@ use crate::{ }; use derive_more::{Display, Error, From}; use log::*; +use std::sync::Arc; use tokio::{ io, - sync::mpsc::{ - self, - error::{TryRecvError, TrySendError}, + sync::{ + mpsc::{ + self, + error::{TryRecvError, TrySendError}, + }, + RwLock, }, task::{JoinError, JoinHandle}, }; @@ -40,12 +44,11 @@ pub struct RemoteProcess { /// Id used to map back to mailbox pub(crate) origin_id: usize, - /// Task that forwards stdin to the remote process by bundling it as stdin requests - req_task: JoinHandle>, + // Sender to abort req task + abort_req_task_tx: mpsc::Sender<()>, - /// Task that reads in new responses, which returns the success and optional - /// exit code once the process has completed - res_task: JoinHandle), RemoteProcessError>>, + // Sender to abort res task + abort_res_task_tx: mpsc::Sender<()>, /// Sender for stdin pub stdin: Option, @@ -58,6 +61,12 @@ pub struct RemoteProcess { /// Sender for kill events kill: mpsc::Sender<()>, + + /// Task that waits for the process to complete + wait_task: JoinHandle<()>, + + /// Handles the success and exit code for a completed process + status: Arc), RemoteProcessError>>>>, } impl RemoteProcess { @@ -125,28 +134,56 @@ impl RemoteProcess { // Used to terminate request task, either explicitly by the process or internally // by the response task when it terminates let (kill_tx, kill_rx) = mpsc::channel(1); + let kill_tx_2 = kill_tx.clone(); // Now we spawn a task to handle future responses that are async // such as ProcStdout, ProcStderr, and ProcDone - let kill_tx_2 = kill_tx.clone(); + let (abort_res_task_tx, mut abort_res_task_rx) = mpsc::channel::<()>(1); let res_task = tokio::spawn(async move { - process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2).await + tokio::select! { + _ = abort_res_task_rx.recv() => { + panic!("killed"); + } + res = process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2) => { + res + } + } }); // Spawn a task that takes stdin from our channel and forwards it to the remote process + let (abort_req_task_tx, mut abort_req_task_rx) = mpsc::channel::<()>(1); let req_task = tokio::spawn(async move { - process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx).await + tokio::select! { + _ = abort_req_task_rx.recv() => { + panic!("killed"); + } + res = process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx) => { + res + } + } + }); + + let status = Arc::new(RwLock::new(None)); + let status_2 = Arc::clone(&status); + let wait_task = tokio::spawn(async move { + let res = match tokio::try_join!(req_task, res_task) { + Ok((_, res)) => res, + Err(x) => Err(RemoteProcessError::from(x)), + }; + status_2.write().await.replace(res); }); Ok(Self { id, origin_id, - req_task, - res_task, + abort_req_task_tx, + abort_res_task_tx, stdin: Some(RemoteStdin(stdin_tx)), stdout: Some(RemoteStdout(stdout_rx)), stderr: Some(RemoteStderr(stderr_rx)), kill: kill_tx, + wait_task, + status, }) } @@ -155,20 +192,36 @@ impl RemoteProcess { self.id } + /// Checks if the process has completed, returning the exit status if it has, without + /// consuming the process itself. Note that this does not include join errors that can + /// occur when aborting and instead converts any error to a status of false. To acquire + /// the actual error, you must call `wait` + pub async fn status(&self) -> Option<(bool, Option)> { + self.status.read().await.as_ref().map(|x| match x { + Ok((success, exit_code)) => (*success, *exit_code), + Err(_) => (false, None), + }) + } + /// Waits for the process to terminate, returning the success status and an optional exit code pub async fn wait(self) -> Result<(bool, Option), RemoteProcessError> { - match tokio::try_join!(self.req_task, self.res_task) { - Ok((_, res)) => res, - Err(x) => Err(RemoteProcessError::from(x)), - } + // Wait for the process to complete before we try to get the status + let _ = self.wait_task.await; + + // NOTE: If we haven't received an exit status, this lines up with the UnexpectedEof error + self.status + .write() + .await + .take() + .unwrap_or_else(|| Err(RemoteProcessError::UnexpectedEof)) } /// Aborts the process by forcing its response task to shutdown, which means that a call /// to `wait` will return an error. Note that this does **not** send a kill request, so if /// you want to be nice you should send the request before aborting. pub fn abort(&self) { - self.req_task.abort(); - self.res_task.abort(); + let _ = self.abort_req_task_tx.try_send(()); + let _ = self.abort_res_task_tx.try_send(()); } /// Submits a kill request for the running process @@ -352,6 +405,7 @@ mod tests { data::{Error, ErrorKind, Response}, net::{InmemoryStream, PlainCodec, Transport}, }; + use std::time::Duration; fn make_session() -> (Transport, Session) { let (t1, t2) = Transport::make_pair(); @@ -702,6 +756,145 @@ mod tests { assert_eq!(out, "some err"); } + #[tokio::test] + async fn status_should_return_none_if_not_done() { + let (mut transport, session) = make_session(); + + // Create a task for process spawning as we need to handle the request and a response + // in a separate async block + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + session.clone_channel(), + String::from("cmd"), + vec![String::from("arg")], + false, + ) + .await + }); + + // Wait until we get the request from the session + let req = transport.receive::().await.unwrap().unwrap(); + + // Send back a response through the session + let id = 12345; + transport + .send(Response::new( + "test-tenant", + req.id, + vec![ResponseData::ProcStart { id }], + )) + .await + .unwrap(); + + // Receive the process and then check its status + let proc = spawn_task.await.unwrap().unwrap(); + + let result = proc.status().await; + assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result); + } + + #[tokio::test] + async fn status_should_return_false_for_success_if_internal_tasks_fail() { + let (mut transport, session) = make_session(); + + // Create a task for process spawning as we need to handle the request and a response + // in a separate async block + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + session.clone_channel(), + String::from("cmd"), + vec![String::from("arg")], + false, + ) + .await + }); + + // Wait until we get the request from the session + let req = transport.receive::().await.unwrap().unwrap(); + + // Send back a response through the session + let id = 12345; + transport + .send(Response::new( + "test-tenant", + req.id, + vec![ResponseData::ProcStart { id }], + )) + .await + .unwrap(); + + // Receive the process and then abort it to make internal tasks fail + let proc = spawn_task.await.unwrap().unwrap(); + proc.abort(); + + // Wait a bit to ensure the other tasks abort + tokio::time::sleep(Duration::from_millis(100)).await; + + // Peek at the status to confirm the result + let result = proc.status().await; + match result { + Some((false, None)) => {} + x => panic!("Unexpected result: {:?}", x), + } + } + + #[tokio::test] + async fn status_should_return_process_status_when_done() { + let (mut transport, session) = make_session(); + + // Create a task for process spawning as we need to handle the request and a response + // in a separate async block + let spawn_task = tokio::spawn(async move { + RemoteProcess::spawn( + String::from("test-tenant"), + session.clone_channel(), + String::from("cmd"), + vec![String::from("arg")], + false, + ) + .await + }); + + // Wait until we get the request from the session + let req = transport.receive::().await.unwrap().unwrap(); + + // Send back a response through the session + let id = 12345; + transport + .send(Response::new( + "test-tenant", + req.id, + vec![ResponseData::ProcStart { id }], + )) + .await + .unwrap(); + + // Receive the process and then spawn a task for it to complete + let proc = spawn_task.await.unwrap().unwrap(); + + // Send a process completion response to pass along exit status and conclude wait + transport + .send(Response::new( + "test-tenant", + req.id, + vec![ResponseData::ProcDone { + id, + success: true, + code: Some(123), + }], + )) + .await + .unwrap(); + + // Wait a bit to ensure the status gets transmitted + tokio::time::sleep(Duration::from_millis(100)).await; + + // Finally, verify that we complete and get the expected results + assert_eq!(proc.status().await, Some((true, Some(123)))); + } + #[tokio::test] async fn wait_should_return_error_if_internal_tasks_fail() { let (mut transport, session) = make_session(); diff --git a/distant-lua/src/session/proc.rs b/distant-lua/src/session/proc.rs index 0d6a8883..328efb3a 100644 --- a/distant-lua/src/session/proc.rs +++ b/distant-lua/src/session/proc.rs @@ -155,6 +155,19 @@ macro_rules! impl_process { }) } + fn status(id: usize) -> LuaResult> { + runtime::block_on(Self::status_async(id)) + } + + async fn status_async(id: usize) -> LuaResult> { + with_proc_async!($map_name, id, proc -> { + Ok(proc.status().await.map(|(success, exit_code)| Status { + success, + exit_code, + })) + }) + } + fn wait(id: usize) -> LuaResult<(bool, Option)> { runtime::block_on(Self::wait_async(id)) } @@ -238,6 +251,10 @@ macro_rules! impl_process { methods.add_async_method("read_stderr_async", |_, this, ()| { runtime::spawn(Self::read_stderr_async(this.id)) }); + methods.add_method("status", |_, this, ()| Self::status(this.id)); + methods.add_async_method("status_async", |_, this, ()| { + runtime::spawn(Self::status_async(this.id)) + }); methods.add_method("wait", |_, this, ()| Self::wait(this.id)); methods.add_async_method("wait_async", |_, this, ()| { runtime::spawn(Self::wait_async(this.id)) @@ -256,6 +273,29 @@ macro_rules! impl_process { }; } +/// Represents process status +#[derive(Clone, Debug)] +pub struct Status { + pub success: bool, + pub exit_code: Option, +} + +impl UserData for Status { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("success", |_, this| Ok(this.success)); + fields.add_field_method_get("exit_code", |_, this| Ok(this.exit_code)); + } + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("to_tbl", |lua, this, ()| { + let tbl = lua.create_table()?; + tbl.set("success", this.success)?; + tbl.set("exit_code", this.exit_code)?; + Ok(tbl) + }); + } +} + /// Represents process output #[derive(Clone, Debug)] pub struct Output {