Skip to content

Commit

Permalink
Merge pull request #5 from b123400/rust-rewrite
Browse files Browse the repository at this point in the history
Options to connect with certs and key
  • Loading branch information
akiroz authored Jan 3, 2024
2 parents 1cc28c4 + 704e202 commit ef47dcd
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 48 deletions.
79 changes: 73 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ ipnetwork = "0.20.0"
log = "0.4.20"
lru = "0.12.0"
rand = "0.8.5"
rumqttc = "0.23.0"
rumqttc = { version = "0.23.0", features = ["use-rustls"] }
rustls = "0.21.8"
rustls-pemfile = "2.0.0"
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
tokio = { version = "1.34.0", features = ["full"] }
tokio-rustls = "0.25.0"
tokio-util = "0.7.10"
tun = { version = "0.6.1", features = ["async"] }
57 changes: 50 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use core::time::Duration;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use rustls::{Certificate, PrivateKey, RootCertStore};
use serde::Deserialize;
use std::fs;
use std::fs::{read_to_string, File};
use std::io::BufReader;
use std::net::Ipv4Addr;

#[derive(Deserialize, Clone, Debug)]
Expand All @@ -13,8 +15,6 @@ pub struct MqttOptions {
pub password: Option<String>,

pub ca_file: Option<String>,

pub tls_insecure: Option<bool>,
pub key_file: Option<String>,
pub cert_file: Option<String>,

Expand Down Expand Up @@ -76,7 +76,7 @@ pub enum ConfigError {
}

pub fn read_from_default_location() -> Result<Config, ConfigError> {
let text = fs::read_to_string("zika_config.json").map_err(ConfigError::IOError)?;
let text = read_to_string("zika_config.json").map_err(ConfigError::IOError)?;
let deserialized: Config = serde_json::from_str(&text).map_err(ConfigError::DecodeError)?;
return Ok(deserialized);
}
Expand All @@ -92,8 +92,6 @@ impl MqttOptions {
password: another_option.password.or(self.password.clone()),

ca_file: another_option.ca_file.or(self.ca_file.clone()),

tls_insecure: another_option.tls_insecure.or(self.tls_insecure),
key_file: another_option.key_file.or(self.key_file.clone()),
cert_file: another_option.cert_file.or(self.cert_file.clone()),

Expand All @@ -119,14 +117,59 @@ impl MqttBroker {
(None, o) => o,
(o, None) => o.clone(),
};
// TODO: more options

if let Some(opts) = mqtt_options {
options.set_keep_alive(Duration::new(opts.keepalive_interval.unwrap_or(60), 0));
options.set_topic_alias_max(opts.topic_alias_max);

if let (Some(u), Some(p)) = (&opts.username, &opts.password) {
options.set_credentials(u, p);
};

if opts.ca_file.is_some() || (opts.cert_file.is_some() && opts.key_file.is_some()) {
let mut root_cert_store = RootCertStore::empty();
if let Some(ca_file_path) = opts.ca_file {
let ca_file = File::open(ca_file_path).unwrap();
let mut ca_buf_read = BufReader::new(ca_file);
let cas = rustls_pemfile::certs(&mut ca_buf_read);
for ca in cas.into_iter() {
root_cert_store
.add(&Certificate(ca.unwrap().to_vec()))
.unwrap();
}
}

let tls_config = match (opts.cert_file, opts.key_file) {
(Some(cert_file_path), Some(key_file_path)) => {
let cert_file = File::open(cert_file_path).unwrap();
let key_file = File::open(key_file_path).unwrap();
let mut cert_buf_read = BufReader::new(cert_file);
let mut key_buf_read = BufReader::new(key_file);

let certs = rustls_pemfile::certs(&mut cert_buf_read);
let certs_vec = certs
.into_iter()
.map(|c| Certificate(c.unwrap().to_vec()))
.collect();

let mut keys = rustls_pemfile::ec_private_keys(&mut key_buf_read);
let first_key = keys.next().unwrap();
let key = PrivateKey(first_key.unwrap().secret_sec1_der().to_vec());

rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_client_auth_cert(certs_vec, key)
.unwrap()
}
_ => rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_cert_store)
.with_no_client_auth(),
};

options.set_transport(rumqttc::Transport::tls_with_config(tls_config.into()));
}
};
return options;
}
Expand Down
32 changes: 1 addition & 31 deletions src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct RemoteIncomingContext {
nth: usize,
sender: mpsc::Sender<(String, Bytes)>,
subs: Arc<Mutex<Vec<String>>>,
alias_pool: Option<LookupPool<Bytes, u16, Range<u16>>>, // alias they created
}

struct RemoteClient {
Expand Down Expand Up @@ -63,7 +62,6 @@ impl Remote {
nth: idx,
sender: sender.clone(),
subs: subs.clone(),
alias_pool: None, // TODO: Don't send with topic alias by default, until getting topic_alias_max from remote
};
task::spawn(async move {
loop {
Expand All @@ -90,19 +88,9 @@ impl Remote {
match pkt {
Packet::ConnAck(ConnAck {
code: Success,
properties: Some(prop),
properties: _,
session_present,
}) => {
if let Some(alias_max) = prop.topic_alias_max.filter(|n| *n > 0) {
let range = 1..alias_max;
if let Some(ref mut alias_pool) = context.alias_pool {
alias_pool.resize(range);
} else {
context.alias_pool = Some(LookupPool::new(range));
}
} else {
context.alias_pool = None
}
if !session_present {
log::info!("broker[{}] !session_present", context.nth);
let subs_v = context.subs.lock().await;
Expand Down Expand Up @@ -131,26 +119,8 @@ impl Remote {
let topic_str = String::from_utf8(topic.to_vec())
.ok()
.filter(|n| n.len() > 0);
if let (Some(alias), Some(_)) = (topic_alias, &topic_str) {
if let Some(ref mut pool) = context.alias_pool {
pool.insert_forward(&topic, alias);
}
}
if let Some(topic) = topic_str {
_ = context.sender.send((topic, payload)).await; // What if it's not ok?
} else if let Some(alias) = topic_alias {
// No topic but we have alias
if let Some(ref mut pool) = context.alias_pool {
if let Some(t) = pool
.get_reverse(&alias)
.and_then(|t| String::from_utf8(t.to_vec()).ok())
{
log::debug!("Received message, Alias: {:?}, topic: {:?}", alias, t);
_ = context.sender.send((t, payload)).await;
} else {
log::error!("Cannot find topic for alias {:?}", alias);
}
}
} else {
log::debug!("drop packet, non utf8 topic: {:?}", topic);
}
Expand Down
5 changes: 2 additions & 3 deletions zika_config.sample.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
"ca_file": "ca.pem",
"key_file": "key.pem",
"cert_file": "cert.pem",
"tls_insecure": false,

"// Unit: seconds (int), Optonal: defaults shown": "",
"keepalive_interval": 60,

"// Setting to 0 disable topic alias": "",
"topic_alias_max": 5,
"topic_alias_max": 5
},

"// Config can be overwritten on broker level": "",
Expand All @@ -32,7 +31,7 @@
"// In bytes (Optional, default shown, must match client)": "",
"id_length": 4,
"topic": "zika/OjFcZWEAGy2E3Vkh",
"bind_cidr": "172.20.0.0/24",
"bind_cidr": "172.20.0.0/24"
},

"client": {
Expand Down

0 comments on commit ef47dcd

Please sign in to comment.