From a15a707f1d845dce29a55d25a6ce49055eacb009 Mon Sep 17 00:00:00 2001 From: Chip Senkbeil Date: Wed, 4 Aug 2021 22:47:20 -0500 Subject: [PATCH] Fix processes lingering for unix socket proxy when a connection closes --- src/cli/subcommand/launch.rs | 107 +++++++++++++++++++++------ src/cli/subcommand/listen/handler.rs | 22 ++---- src/cli/subcommand/listen/mod.rs | 52 ++----------- src/core/mod.rs | 1 + src/core/state.rs | 82 ++++++++++++++++++++ 5 files changed, 179 insertions(+), 85 deletions(-) create mode 100644 src/core/state.rs diff --git a/src/cli/subcommand/launch.rs b/src/cli/subcommand/launch.rs index b4a7c6fd..70bbd7f5 100644 --- a/src/cli/subcommand/launch.rs +++ b/src/cli/subcommand/launch.rs @@ -2,7 +2,7 @@ use crate::{ cli::opt::{CommonOpt, LaunchSubcommand, Mode, SessionOutput}, core::{ constants::CLIENT_BROADCAST_CHANNEL_CAPACITY, - data::{Request, Response}, + data::{Request, RequestPayload, Response, ResponsePayload}, net::{Client, Transport, TransportReadHalf, TransportWriteHalf}, session::{Session, SessionFile}, utils, @@ -13,11 +13,11 @@ use fork::{daemon, Fork}; use hex::FromHexError; use log::*; use orion::errors::UnknownCryptoError; -use std::{marker::Unpin, path::Path, string::FromUtf8Error}; +use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc}; use tokio::{ io::{self, AsyncRead, AsyncWrite}, process::Command, - sync::{broadcast, mpsc, oneshot}, + sync::{broadcast, mpsc, oneshot, Mutex}, }; #[derive(Debug, Display, Error, From)] @@ -32,6 +32,12 @@ pub enum Error { Utf8Error(FromUtf8Error), } +/// Represents state associated with a connection +#[derive(Default)] +struct ConnState { + processes: Vec, +} + pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> { let rt = tokio::runtime::Runtime::new()?; let session_output = cmd.session; @@ -147,12 +153,16 @@ async fn socket_loop(socket_path: impl AsRef, session: Session) -> io::Res debug!("Binding to unix socket: {:?}", socket_path.as_ref()); let listener = tokio::net::UnixListener::bind(socket_path)?; - while let Ok((conn, addr)) = listener.accept().await { + while let Ok((conn, _)) = listener.accept().await { + // Create a unique id to associate with the connection since its address + // is not guaranteed to have an identifiable string + let conn_id: usize = rand::random(); + // Establish a proper connection via a handshake, discarding the connection otherwise let transport = match Transport::from_handshake(conn, None).await { Ok(transport) => transport, Err(x) => { - error!(" Failed handshake: {}", addr, x); + error!(" Failed handshake: {}", conn_id, x); continue; } }; @@ -162,19 +172,24 @@ async fn socket_loop(socket_path: impl AsRef, session: Session) -> io::Res // based on the first let (tenant_tx, tenant_rx) = oneshot::channel(); + // Create a state we use to keep track of connection-specific data + debug!(" Initializing internal state", conn_id); + let state = Arc::new(Mutex::new(ConnState::default())); + // Spawn task to continually receive responses from the client that // may or may not be relevant to the connection, which will filter // by tenant and then along any response that matches let res_rx = broadcaster.subscribe(); + let state_2 = Arc::clone(&state); tokio::spawn(async move { - handle_conn_outgoing(addr, t_write, tenant_rx, res_rx).await; + handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await; }); // Spawn task to continually read requests from connection and forward // them along to be sent via the client let req_tx = req_tx.clone(); tokio::spawn(async move { - handle_conn_incoming(t_read, tenant_tx, req_tx).await; + handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await; }); } @@ -183,6 +198,8 @@ async fn socket_loop(socket_path: impl AsRef, session: Session) -> io::Res /// Conn::Request -> Client::Fire async fn handle_conn_incoming( + conn_id: usize, + state: Arc>, mut reader: TransportReadHalf, tenant_tx: oneshot::Sender, req_tx: mpsc::Sender, @@ -190,7 +207,7 @@ async fn handle_conn_incoming( T: AsyncRead + Unpin, { macro_rules! process_req { - ($on_success:expr) => { + ($on_success:expr; $done:expr) => { match reader.receive::().await { Ok(Some(req)) => { $on_success(&req); @@ -199,34 +216,65 @@ async fn handle_conn_incoming( "Failed to pass along request received on unix socket: {:?}", x ); - return; + $done; } } - Ok(None) => return, + Ok(None) => $done, Err(x) => { error!("Failed to receive request from unix stream: {:?}", x); - return; + $done; } } }; } + let mut tenant = None; + // NOTE: Have to acquire our first request outside our loop since the oneshot // sender of the tenant's name is consuming - process_req!(|req: &Request| { - if let Err(x) = tenant_tx.send(req.tenant.clone()) { - error!("Failed to send along acquired tenant name: {:?}", x); - return; - } - }); + process_req!( + |req: &Request| { + tenant = Some(req.tenant.clone()); + if let Err(x) = tenant_tx.send(req.tenant.clone()) { + error!("Failed to send along acquired tenant name: {:?}", x); + return; + } + }; + return + ); + // Loop and process all additional requests loop { - process_req!(|_| {}); + process_req!(|_| {}; break); + } + + // At this point, we have processed at least one request successfully + // and should have the tenant populated. If we had a failure at the + // beginning, we exit the function early via return. + let tenant = tenant.unwrap(); + + // Perform cleanup if done + for id in state.lock().await.processes.as_slice() { + debug!("Cleaning conn {} :: killing process {}", conn_id, id); + if let Err(x) = req_tx + .send(Request::new( + tenant.clone(), + RequestPayload::ProcKill { id: *id }, + )) + .await + { + error!( + " Failed to send kill signal for process {}: {}", + conn_id, id, x + ); + break; + } } } async fn handle_conn_outgoing( - addr: tokio::net::unix::SocketAddr, + conn_id: usize, + state: Arc>, mut writer: TransportWriteHalf, tenant_rx: oneshot::Receiver, mut res_rx: broadcast::Receiver, @@ -238,16 +286,27 @@ async fn handle_conn_outgoing( // to implement and yields the same result as we would be dropping // all responses before we know the tenant if let Ok(tenant) = tenant_rx.await { - debug!("Associated tenant {} with conn {:?}", tenant, addr); + debug!("Associated tenant {} with conn {}", tenant, conn_id); loop { match res_rx.recv().await { // Forward along responses that are for our connection Ok(res) if res.tenant == tenant => { debug!( - "Conn {:?} being sent response of type {}", - addr, + "Conn {} being sent response of type {}", + conn_id, res.payload.as_ref() ); + + // If a new process was started, we want to capture the id and + // associate it with the connection + match &res.payload { + ResponsePayload::ProcStart { id } => { + debug!("Tracking proc {} for conn {}", id, conn_id); + state.lock().await.processes.push(*id); + } + _ => {} + } + if let Err(x) = writer.send(res).await { error!("Failed to send response through unix connection: {}", x); break; @@ -257,8 +316,8 @@ async fn handle_conn_outgoing( Ok(_) => {} Err(x) => { error!( - "Conn {:?} failed to receive broadcast response: {}", - addr, x + "Conn {} failed to receive broadcast response: {}", + conn_id, x ); break; } diff --git a/src/cli/subcommand/listen/handler.rs b/src/cli/subcommand/listen/handler.rs index 55b8eff9..5a87d18d 100644 --- a/src/cli/subcommand/listen/handler.rs +++ b/src/cli/subcommand/listen/handler.rs @@ -1,7 +1,9 @@ -use super::{Process, State}; -use crate::core::data::{ - self, DirEntry, FileType, Metadata, Request, RequestPayload, Response, ResponsePayload, - RunningProcess, +use crate::core::{ + data::{ + self, DirEntry, FileType, Metadata, Request, RequestPayload, Response, ResponsePayload, + RunningProcess, + }, + state::{Process, ServerState}, }; use log::*; use std::{ @@ -20,7 +22,7 @@ use tokio::{ use walkdir::WalkDir; pub type Reply = mpsc::Sender; -type HState = Arc>; +type HState = Arc>>; /// Processes the provided request, sending replies using the given sender pub(super) async fn process( @@ -472,15 +474,7 @@ async fn proc_run( stdin_tx, kill_tx, }; - state.lock().await.processes.insert(id, process); - - state - .lock() - .await - .client_processes - .entry(addr) - .or_insert(Vec::new()) - .push(id); + state.lock().await.push_process(addr, process); Ok(ResponsePayload::ProcStart { id }) } diff --git a/src/cli/subcommand/listen/mod.rs b/src/cli/subcommand/listen/mod.rs index bfb28816..2c80ec95 100644 --- a/src/cli/subcommand/listen/mod.rs +++ b/src/cli/subcommand/listen/mod.rs @@ -4,17 +4,18 @@ use crate::{ data::{Request, Response}, net::{Transport, TransportReadHalf, TransportWriteHalf}, session::Session, + state::ServerState, }, }; use derive_more::{Display, Error, From}; use fork::{daemon, Fork}; use log::*; use orion::aead::SecretKey; -use std::{collections::HashMap, net::SocketAddr, sync::Arc}; +use std::{net::SocketAddr, sync::Arc}; use tokio::{ io, net::{tcp, TcpListener}, - sync::{mpsc, oneshot, Mutex}, + sync::{mpsc, Mutex}, }; mod handler; @@ -26,49 +27,6 @@ pub enum Error { IoError(io::Error), } -/// Holds state relevant to the server -#[derive(Default)] -struct State { - /// Map of all processes running on the server - processes: HashMap, - - /// List of processes that will be killed when a client drops - client_processes: HashMap>, -} - -impl State { - /// Cleans up state associated with a particular client - pub async fn cleanup_client(&mut self, addr: SocketAddr) { - debug!(" Cleaning up state", addr); - if let Some(ids) = self.client_processes.remove(&addr) { - for id in ids { - if let Some(process) = self.processes.remove(&id) { - trace!( - " Requesting proc {} be killed", - addr, - process.id - ); - if let Err(_) = process.kill_tx.send(()) { - error!( - "Client {} failed to send process {} kill signal", - id, process.id - ); - } - } - } - } - } -} - -/// Represents an actively-running process maintained by the server -struct Process { - pub id: usize, - pub cmd: String, - pub args: Vec, - pub stdin_tx: mpsc::Sender, - pub kill_tx: oneshot::Sender<()>, -} - pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> { if cmd.daemon { // NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent @@ -130,7 +88,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R } // Build our state for the server - let state = Arc::new(Mutex::new(State::default())); + let state: Arc>> = Arc::new(Mutex::new(ServerState::default())); // Wait for a client connection, then spawn a new task to handle // receiving data from the client @@ -171,7 +129,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R /// response loop async fn request_loop( addr: SocketAddr, - state: Arc>, + state: Arc>>, mut transport: TransportReadHalf, tx: mpsc::Sender, ) { diff --git a/src/core/mod.rs b/src/core/mod.rs index 3f40a62d..b1de7b9e 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -2,4 +2,5 @@ pub mod constants; pub mod data; pub mod net; pub mod session; +pub mod state; pub mod utils; diff --git a/src/core/state.rs b/src/core/state.rs new file mode 100644 index 00000000..9f9504ab --- /dev/null +++ b/src/core/state.rs @@ -0,0 +1,82 @@ +use log::*; +use std::{collections::HashMap, fmt::Debug, hash::Hash}; +use tokio::sync::{mpsc, oneshot}; + +/// Holds state related to multiple clients managed by a server +pub struct ServerState +where + ClientId: Debug + Hash + PartialEq + Eq, +{ + /// Map of all processes running on the server + pub processes: HashMap, + + /// List of processes that will be killed when a client drops + client_processes: HashMap>, +} + +impl ServerState +where + ClientId: Debug + Hash + PartialEq + Eq, +{ + /// Pushes a new process associated with a client + pub fn push_process(&mut self, client_id: ClientId, process: Process) { + self.client_processes + .entry(client_id) + .or_insert(Vec::new()) + .push(process.id); + self.processes.insert(process.id, process); + } + + /// Cleans up state associated with a particular client + pub async fn cleanup_client(&mut self, client_id: ClientId) { + debug!(" Cleaning up state", client_id); + if let Some(ids) = self.client_processes.remove(&client_id) { + for id in ids { + if let Some(process) = self.processes.remove(&id) { + trace!( + " Requesting proc {} be killed", + client_id, + process.id + ); + if let Err(_) = process.kill_tx.send(()) { + error!( + "Client {} failed to send process {} kill signal", + id, process.id + ); + } + } + } + } + } +} + +impl Default for ServerState +where + ClientId: Debug + Hash + PartialEq + Eq, +{ + fn default() -> Self { + Self { + processes: HashMap::new(), + client_processes: HashMap::new(), + } + } +} + +/// Represents an actively-running process +pub struct Process { + /// Id of the process + pub id: usize, + + /// Command used to start the process + pub cmd: String, + + /// Arguments associated with the process + pub args: Vec, + + /// Transport channel to send new input to the stdin of the process, + /// one line at a time + pub stdin_tx: mpsc::Sender, + + /// Transport channel to report that the process should be killed + pub kill_tx: oneshot::Sender<()>, +}