Skip to content

Commit

Permalink
Fix processes lingering for unix socket proxy when a connection closes
Browse files Browse the repository at this point in the history
  • Loading branch information
chipsenkbeil committed Aug 5, 2021
1 parent 638638f commit a15a707
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 85 deletions.
107 changes: 83 additions & 24 deletions src/cli/subcommand/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
cli::opt::{CommonOpt, LaunchSubcommand, Mode, SessionOutput},
core::{
constants::CLIENT_BROADCAST_CHANNEL_CAPACITY,
data::{Request, Response},
data::{Request, RequestPayload, Response, ResponsePayload},
net::{Client, Transport, TransportReadHalf, TransportWriteHalf},
session::{Session, SessionFile},
utils,
Expand All @@ -13,11 +13,11 @@ use fork::{daemon, Fork};
use hex::FromHexError;
use log::*;
use orion::errors::UnknownCryptoError;
use std::{marker::Unpin, path::Path, string::FromUtf8Error};
use std::{marker::Unpin, path::Path, string::FromUtf8Error, sync::Arc};
use tokio::{
io::{self, AsyncRead, AsyncWrite},
process::Command,
sync::{broadcast, mpsc, oneshot},
sync::{broadcast, mpsc, oneshot, Mutex},
};

#[derive(Debug, Display, Error, From)]
Expand All @@ -32,6 +32,12 @@ pub enum Error {
Utf8Error(FromUtf8Error),
}

/// Represents state associated with a connection
#[derive(Default)]
struct ConnState {
processes: Vec<usize>,
}

pub fn run(cmd: LaunchSubcommand, opt: CommonOpt) -> Result<(), Error> {
let rt = tokio::runtime::Runtime::new()?;
let session_output = cmd.session;
Expand Down Expand Up @@ -147,12 +153,16 @@ async fn socket_loop(socket_path: impl AsRef<Path>, session: Session) -> io::Res
debug!("Binding to unix socket: {:?}", socket_path.as_ref());
let listener = tokio::net::UnixListener::bind(socket_path)?;

while let Ok((conn, addr)) = listener.accept().await {
while let Ok((conn, _)) = listener.accept().await {
// Create a unique id to associate with the connection since its address
// is not guaranteed to have an identifiable string
let conn_id: usize = rand::random();

// Establish a proper connection via a handshake, discarding the connection otherwise
let transport = match Transport::from_handshake(conn, None).await {
Ok(transport) => transport,
Err(x) => {
error!("<Client @ {:?}> Failed handshake: {}", addr, x);
error!("<Client @ {:?}> Failed handshake: {}", conn_id, x);
continue;
}
};
Expand All @@ -162,19 +172,24 @@ async fn socket_loop(socket_path: impl AsRef<Path>, session: Session) -> io::Res
// based on the first
let (tenant_tx, tenant_rx) = oneshot::channel();

// Create a state we use to keep track of connection-specific data
debug!("<Client @ {}> Initializing internal state", conn_id);
let state = Arc::new(Mutex::new(ConnState::default()));

// Spawn task to continually receive responses from the client that
// may or may not be relevant to the connection, which will filter
// by tenant and then along any response that matches
let res_rx = broadcaster.subscribe();
let state_2 = Arc::clone(&state);
tokio::spawn(async move {
handle_conn_outgoing(addr, t_write, tenant_rx, res_rx).await;
handle_conn_outgoing(conn_id, state_2, t_write, tenant_rx, res_rx).await;
});

// Spawn task to continually read requests from connection and forward
// them along to be sent via the client
let req_tx = req_tx.clone();
tokio::spawn(async move {
handle_conn_incoming(t_read, tenant_tx, req_tx).await;
handle_conn_incoming(conn_id, state, t_read, tenant_tx, req_tx).await;
});
}

Expand All @@ -183,14 +198,16 @@ async fn socket_loop(socket_path: impl AsRef<Path>, session: Session) -> io::Res

/// Conn::Request -> Client::Fire
async fn handle_conn_incoming<T>(
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut reader: TransportReadHalf<T>,
tenant_tx: oneshot::Sender<String>,
req_tx: mpsc::Sender<Request>,
) where
T: AsyncRead + Unpin,
{
macro_rules! process_req {
($on_success:expr) => {
($on_success:expr; $done:expr) => {
match reader.receive::<Request>().await {
Ok(Some(req)) => {
$on_success(&req);
Expand All @@ -199,34 +216,65 @@ async fn handle_conn_incoming<T>(
"Failed to pass along request received on unix socket: {:?}",
x
);
return;
$done;
}
}
Ok(None) => return,
Ok(None) => $done,
Err(x) => {
error!("Failed to receive request from unix stream: {:?}", x);
return;
$done;
}
}
};
}

let mut tenant = None;

// NOTE: Have to acquire our first request outside our loop since the oneshot
// sender of the tenant's name is consuming
process_req!(|req: &Request| {
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
error!("Failed to send along acquired tenant name: {:?}", x);
return;
}
});
process_req!(
|req: &Request| {
tenant = Some(req.tenant.clone());
if let Err(x) = tenant_tx.send(req.tenant.clone()) {
error!("Failed to send along acquired tenant name: {:?}", x);
return;
}
};
return
);

// Loop and process all additional requests
loop {
process_req!(|_| {});
process_req!(|_| {}; break);
}

// At this point, we have processed at least one request successfully
// and should have the tenant populated. If we had a failure at the
// beginning, we exit the function early via return.
let tenant = tenant.unwrap();

// Perform cleanup if done
for id in state.lock().await.processes.as_slice() {
debug!("Cleaning conn {} :: killing process {}", conn_id, id);
if let Err(x) = req_tx
.send(Request::new(
tenant.clone(),
RequestPayload::ProcKill { id: *id },
))
.await
{
error!(
"<Client @ {}> Failed to send kill signal for process {}: {}",
conn_id, id, x
);
break;
}
}
}

async fn handle_conn_outgoing<T>(
addr: tokio::net::unix::SocketAddr,
conn_id: usize,
state: Arc<Mutex<ConnState>>,
mut writer: TransportWriteHalf<T>,
tenant_rx: oneshot::Receiver<String>,
mut res_rx: broadcast::Receiver<Response>,
Expand All @@ -238,16 +286,27 @@ async fn handle_conn_outgoing<T>(
// to implement and yields the same result as we would be dropping
// all responses before we know the tenant
if let Ok(tenant) = tenant_rx.await {
debug!("Associated tenant {} with conn {:?}", tenant, addr);
debug!("Associated tenant {} with conn {}", tenant, conn_id);
loop {
match res_rx.recv().await {
// Forward along responses that are for our connection
Ok(res) if res.tenant == tenant => {
debug!(
"Conn {:?} being sent response of type {}",
addr,
"Conn {} being sent response of type {}",
conn_id,
res.payload.as_ref()
);

// If a new process was started, we want to capture the id and
// associate it with the connection
match &res.payload {
ResponsePayload::ProcStart { id } => {
debug!("Tracking proc {} for conn {}", id, conn_id);
state.lock().await.processes.push(*id);
}
_ => {}
}

if let Err(x) = writer.send(res).await {
error!("Failed to send response through unix connection: {}", x);
break;
Expand All @@ -257,8 +316,8 @@ async fn handle_conn_outgoing<T>(
Ok(_) => {}
Err(x) => {
error!(
"Conn {:?} failed to receive broadcast response: {}",
addr, x
"Conn {} failed to receive broadcast response: {}",
conn_id, x
);
break;
}
Expand Down
22 changes: 8 additions & 14 deletions src/cli/subcommand/listen/handler.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::{Process, State};
use crate::core::data::{
self, DirEntry, FileType, Metadata, Request, RequestPayload, Response, ResponsePayload,
RunningProcess,
use crate::core::{
data::{
self, DirEntry, FileType, Metadata, Request, RequestPayload, Response, ResponsePayload,
RunningProcess,
},
state::{Process, ServerState},
};
use log::*;
use std::{
Expand All @@ -20,7 +22,7 @@ use tokio::{
use walkdir::WalkDir;

pub type Reply = mpsc::Sender<Response>;
type HState = Arc<Mutex<State>>;
type HState = Arc<Mutex<ServerState<SocketAddr>>>;

/// Processes the provided request, sending replies using the given sender
pub(super) async fn process(
Expand Down Expand Up @@ -472,15 +474,7 @@ async fn proc_run(
stdin_tx,
kill_tx,
};
state.lock().await.processes.insert(id, process);

state
.lock()
.await
.client_processes
.entry(addr)
.or_insert(Vec::new())
.push(id);
state.lock().await.push_process(addr, process);

Ok(ResponsePayload::ProcStart { id })
}
Expand Down
52 changes: 5 additions & 47 deletions src/cli/subcommand/listen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@ use crate::{
data::{Request, Response},
net::{Transport, TransportReadHalf, TransportWriteHalf},
session::Session,
state::ServerState,
},
};
use derive_more::{Display, Error, From};
use fork::{daemon, Fork};
use log::*;
use orion::aead::SecretKey;
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
io,
net::{tcp, TcpListener},
sync::{mpsc, oneshot, Mutex},
sync::{mpsc, Mutex},
};

mod handler;
Expand All @@ -26,49 +27,6 @@ pub enum Error {
IoError(io::Error),
}

/// Holds state relevant to the server
#[derive(Default)]
struct State {
/// Map of all processes running on the server
processes: HashMap<usize, Process>,

/// List of processes that will be killed when a client drops
client_processes: HashMap<SocketAddr, Vec<usize>>,
}

impl State {
/// Cleans up state associated with a particular client
pub async fn cleanup_client(&mut self, addr: SocketAddr) {
debug!("<Client @ {}> Cleaning up state", addr);
if let Some(ids) = self.client_processes.remove(&addr) {
for id in ids {
if let Some(process) = self.processes.remove(&id) {
trace!(
"<Client @ {}> Requesting proc {} be killed",
addr,
process.id
);
if let Err(_) = process.kill_tx.send(()) {
error!(
"Client {} failed to send process {} kill signal",
id, process.id
);
}
}
}
}
}
}

/// Represents an actively-running process maintained by the server
struct Process {
pub id: usize,
pub cmd: String,
pub args: Vec<String>,
pub stdin_tx: mpsc::Sender<String>,
pub kill_tx: oneshot::Sender<()>,
}

pub fn run(cmd: ListenSubcommand, opt: CommonOpt) -> Result<(), Error> {
if cmd.daemon {
// NOTE: We keep the stdin, stdout, stderr open so we can print out the pid with the parent
Expand Down Expand Up @@ -130,7 +88,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
}

// Build our state for the server
let state = Arc::new(Mutex::new(State::default()));
let state: Arc<Mutex<ServerState<SocketAddr>>> = Arc::new(Mutex::new(ServerState::default()));

// Wait for a client connection, then spawn a new task to handle
// receiving data from the client
Expand Down Expand Up @@ -171,7 +129,7 @@ async fn run_async(cmd: ListenSubcommand, _opt: CommonOpt, is_forked: bool) -> R
/// response loop
async fn request_loop(
addr: SocketAddr,
state: Arc<Mutex<State>>,
state: Arc<Mutex<ServerState<SocketAddr>>>,
mut transport: TransportReadHalf<tcp::OwnedReadHalf>,
tx: mpsc::Sender<Response>,
) {
Expand Down
1 change: 1 addition & 0 deletions src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ pub mod constants;
pub mod data;
pub mod net;
pub mod session;
pub mod state;
pub mod utils;
Loading

0 comments on commit a15a707

Please sign in to comment.