diff --git a/src/main.rs b/src/main.rs index e1df3bc..6f4b4a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ mod utils; use clap::{App, Arg}; use utils::split_and_copy_binary_file; +use std::env; +use std::process; fn main() { let matches = App::new("Zap — Fast single file copy") @@ -14,30 +16,17 @@ fn main() { .help("The input file path") .required(true) .index(1)) + .arg(Arg::new("user_host_path") + .help("Specifies user@host:remote_path") + .required(true) + .value_name("user@host:remote_path") + .index(2)) .arg(Arg::new("streams") .short('c') .long("streams") .help("The number of parallel streams") .default_value("20") .takes_value(true)) - .arg(Arg::new("user") - .short('u') - .long("user") - .help("The username for the remote server") - .takes_value(true) - .required(true)) - .arg(Arg::new("server") - .short('s') - .long("server") - .help("The hostname of the remote server") - .takes_value(true) - .required(true)) - .arg(Arg::new("remote_path") - .short('p') - .long("remote-path") - .help("The remote path where streams will be stored") - .takes_value(true) - .required(true)) .arg(Arg::new("ssh_key_path") .short('i') .long("ssh-key-path") @@ -52,11 +41,36 @@ fn main() { .get_matches(); let input_file_path = matches.value_of("input_file").unwrap(); + + let user_host_path = matches.value_of("user_host_path").unwrap(); + let parts: Vec<&str> = user_host_path.splitn(2, ':').collect(); + let user_host = parts[0]; + + // Use remote ssh CWD "." if no path provided after user@host: + let remote_path = if let Some(path) = parts.get(1) { + if path.is_empty() { "." } else { path } + } else { + "." + }; + + let host_parts: Vec<&str> = user_host.split('@').collect(); + let (remote_user, remote_host) = match host_parts.as_slice() { + [user, host] => (user.to_string(), host.to_string()), + [host] => { + let user_env = env::var("USER").unwrap_or_else(|_| { + eprintln!("$USER environment variable is not set"); + process::exit(1); + }); + (user_env, host.to_string()) + }, + _ => { + eprintln!("Invalid format for user@hostname"); + process::exit(1); + } + }; + let num_streams: usize = matches.value_of("streams").unwrap().parse() .expect("num_streams must be an integer"); - let remote_user = matches.value_of("user").unwrap(); - let remote_host = matches.value_of("server").unwrap(); - let remote_path = matches.value_of("remote_path").unwrap(); let ssh_key_path = matches.value_of("ssh_key_path"); let retries: u32 = matches.value_of("retries").unwrap().parse() .expect("retries must be an integer"); @@ -65,11 +79,12 @@ fn main() { split_and_copy_binary_file( input_file_path, num_streams, - remote_user, - remote_host, - remote_path, + &remote_user, + &remote_host, + remote_path, ssh_key_path.as_deref(), max_threads, retries ); } + diff --git a/src/ssh_comm.rs b/src/ssh_comm.rs index 4aa013c..2452833 100644 --- a/src/ssh_comm.rs +++ b/src/ssh_comm.rs @@ -17,13 +17,12 @@ pub fn stream_stream_to_remote( remote_host: &str, remote_path: &str, ssh_key_path: Option<&str>, - retries: u32, // Added retries as a parameter + retries: u32, ) -> Result<(), String> { let mut attempt = 0; while attempt <= retries { let user_host = format!("{}@{}", remote_user, remote_host); let stream_command = format!("cat > {}/stream_{}.bin", remote_path, stream_num); - let mut ssh_args = vec![ "-o", "StrictHostKeyChecking=no", &user_host,