Skip to content

Commit

Permalink
Add status method to RemoteProcess and lua module equivalent
Browse files Browse the repository at this point in the history
  • Loading branch information
chipsenkbeil committed Oct 14, 2021
1 parent f021869 commit a8b6f3e
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 23 deletions.
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

231 changes: 212 additions & 19 deletions distant-core/src/client/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ use crate::{
};
use derive_more::{Display, Error, From};
use log::*;
use std::sync::Arc;
use tokio::{
io,
sync::mpsc::{
self,
error::{TryRecvError, TrySendError},
sync::{
mpsc::{
self,
error::{TryRecvError, TrySendError},
},
RwLock,
},
task::{JoinError, JoinHandle},
};
Expand Down Expand Up @@ -40,12 +44,11 @@ pub struct RemoteProcess {
/// Id used to map back to mailbox
pub(crate) origin_id: usize,

/// Task that forwards stdin to the remote process by bundling it as stdin requests
req_task: JoinHandle<Result<(), RemoteProcessError>>,
// Sender to abort req task
abort_req_task_tx: mpsc::Sender<()>,

/// Task that reads in new responses, which returns the success and optional
/// exit code once the process has completed
res_task: JoinHandle<Result<(bool, Option<i32>), RemoteProcessError>>,
// Sender to abort res task
abort_res_task_tx: mpsc::Sender<()>,

/// Sender for stdin
pub stdin: Option<RemoteStdin>,
Expand All @@ -58,6 +61,12 @@ pub struct RemoteProcess {

/// Sender for kill events
kill: mpsc::Sender<()>,

/// Task that waits for the process to complete
wait_task: JoinHandle<()>,

/// Handles the success and exit code for a completed process
status: Arc<RwLock<Option<Result<(bool, Option<i32>), RemoteProcessError>>>>,
}

impl RemoteProcess {
Expand Down Expand Up @@ -125,28 +134,56 @@ impl RemoteProcess {
// Used to terminate request task, either explicitly by the process or internally
// by the response task when it terminates
let (kill_tx, kill_rx) = mpsc::channel(1);
let kill_tx_2 = kill_tx.clone();

// Now we spawn a task to handle future responses that are async
// such as ProcStdout, ProcStderr, and ProcDone
let kill_tx_2 = kill_tx.clone();
let (abort_res_task_tx, mut abort_res_task_rx) = mpsc::channel::<()>(1);
let res_task = tokio::spawn(async move {
process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2).await
tokio::select! {
_ = abort_res_task_rx.recv() => {
panic!("killed");
}
res = process_incoming_responses(id, mailbox, stdout_tx, stderr_tx, kill_tx_2) => {
res
}
}
});

// Spawn a task that takes stdin from our channel and forwards it to the remote process
let (abort_req_task_tx, mut abort_req_task_rx) = mpsc::channel::<()>(1);
let req_task = tokio::spawn(async move {
process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx).await
tokio::select! {
_ = abort_req_task_rx.recv() => {
panic!("killed");
}
res = process_outgoing_requests(tenant, id, channel, stdin_rx, kill_rx) => {
res
}
}
});

let status = Arc::new(RwLock::new(None));
let status_2 = Arc::clone(&status);
let wait_task = tokio::spawn(async move {
let res = match tokio::try_join!(req_task, res_task) {
Ok((_, res)) => res,
Err(x) => Err(RemoteProcessError::from(x)),
};
status_2.write().await.replace(res);
});

Ok(Self {
id,
origin_id,
req_task,
res_task,
abort_req_task_tx,
abort_res_task_tx,
stdin: Some(RemoteStdin(stdin_tx)),
stdout: Some(RemoteStdout(stdout_rx)),
stderr: Some(RemoteStderr(stderr_rx)),
kill: kill_tx,
wait_task,
status,
})
}

Expand All @@ -155,20 +192,36 @@ impl RemoteProcess {
self.id
}

/// Checks if the process has completed, returning the exit status if it has, without
/// consuming the process itself. Note that this does not include join errors that can
/// occur when aborting and instead converts any error to a status of false. To acquire
/// the actual error, you must call `wait`
pub async fn status(&self) -> Option<(bool, Option<i32>)> {
self.status.read().await.as_ref().map(|x| match x {
Ok((success, exit_code)) => (*success, *exit_code),
Err(_) => (false, None),
})
}

/// Waits for the process to terminate, returning the success status and an optional exit code
pub async fn wait(self) -> Result<(bool, Option<i32>), RemoteProcessError> {
match tokio::try_join!(self.req_task, self.res_task) {
Ok((_, res)) => res,
Err(x) => Err(RemoteProcessError::from(x)),
}
// Wait for the process to complete before we try to get the status
let _ = self.wait_task.await;

// NOTE: If we haven't received an exit status, this lines up with the UnexpectedEof error
self.status
.write()
.await
.take()
.unwrap_or_else(|| Err(RemoteProcessError::UnexpectedEof))
}

/// Aborts the process by forcing its response task to shutdown, which means that a call
/// to `wait` will return an error. Note that this does **not** send a kill request, so if
/// you want to be nice you should send the request before aborting.
pub fn abort(&self) {
self.req_task.abort();
self.res_task.abort();
let _ = self.abort_req_task_tx.try_send(());
let _ = self.abort_res_task_tx.try_send(());
}

/// Submits a kill request for the running process
Expand Down Expand Up @@ -352,6 +405,7 @@ mod tests {
data::{Error, ErrorKind, Response},
net::{InmemoryStream, PlainCodec, Transport},
};
use std::time::Duration;

fn make_session() -> (Transport<InmemoryStream, PlainCodec>, Session) {
let (t1, t2) = Transport::make_pair();
Expand Down Expand Up @@ -702,6 +756,145 @@ mod tests {
assert_eq!(out, "some err");
}

#[tokio::test]
async fn status_should_return_none_if_not_done() {
let (mut transport, session) = make_session();

// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});

// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();

// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();

// Receive the process and then check its status
let proc = spawn_task.await.unwrap().unwrap();

let result = proc.status().await;
assert_eq!(result, None, "Unexpectedly got proc status: {:?}", result);
}

#[tokio::test]
async fn status_should_return_false_for_success_if_internal_tasks_fail() {
let (mut transport, session) = make_session();

// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});

// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();

// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();

// Receive the process and then abort it to make internal tasks fail
let proc = spawn_task.await.unwrap().unwrap();
proc.abort();

// Wait a bit to ensure the other tasks abort
tokio::time::sleep(Duration::from_millis(100)).await;

// Peek at the status to confirm the result
let result = proc.status().await;
match result {
Some((false, None)) => {}
x => panic!("Unexpected result: {:?}", x),
}
}

#[tokio::test]
async fn status_should_return_process_status_when_done() {
let (mut transport, session) = make_session();

// Create a task for process spawning as we need to handle the request and a response
// in a separate async block
let spawn_task = tokio::spawn(async move {
RemoteProcess::spawn(
String::from("test-tenant"),
session.clone_channel(),
String::from("cmd"),
vec![String::from("arg")],
false,
)
.await
});

// Wait until we get the request from the session
let req = transport.receive::<Request>().await.unwrap().unwrap();

// Send back a response through the session
let id = 12345;
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcStart { id }],
))
.await
.unwrap();

// Receive the process and then spawn a task for it to complete
let proc = spawn_task.await.unwrap().unwrap();

// Send a process completion response to pass along exit status and conclude wait
transport
.send(Response::new(
"test-tenant",
req.id,
vec![ResponseData::ProcDone {
id,
success: true,
code: Some(123),
}],
))
.await
.unwrap();

// Wait a bit to ensure the status gets transmitted
tokio::time::sleep(Duration::from_millis(100)).await;

// Finally, verify that we complete and get the expected results
assert_eq!(proc.status().await, Some((true, Some(123))));
}

#[tokio::test]
async fn wait_should_return_error_if_internal_tasks_fail() {
let (mut transport, session) = make_session();
Expand Down
Loading

0 comments on commit a8b6f3e

Please sign in to comment.