Skip to content

Commit

Permalink
feat(noise): Add WebTransport certhashes extension
Browse files Browse the repository at this point in the history
  • Loading branch information
oblique committed May 26, 2023
1 parent d4c4078 commit 7965a1e
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 22 deletions.
8 changes: 6 additions & 2 deletions transports/noise/src/generated/payload.proto
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
syntax = "proto3";

package payload.proto;

// Payloads for Noise handshake messages.

message NoiseExtensions {
repeated bytes webtransport_certhashes = 1;
repeated string stream_muxers = 2;
}

message NoiseHandshakePayload {
bytes identity_key = 1;
bytes identity_sig = 2;
bytes data = 3;
optional NoiseExtensions extensions = 4;
}
67 changes: 52 additions & 15 deletions transports/noise/src/generated/payload/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,25 @@
#![cfg_attr(rustfmt, rustfmt_skip)]


use std::borrow::Cow;
use quick_protobuf::{MessageInfo, MessageRead, MessageWrite, BytesReader, Writer, WriterBackend, Result};
use quick_protobuf::sizeofs::*;
use super::super::*;

#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Debug, Default, PartialEq, Clone)]
pub struct NoiseHandshakePayload {
pub identity_key: Vec<u8>,
pub identity_sig: Vec<u8>,
pub data: Vec<u8>,
pub struct NoiseExtensions<'a> {
pub webtransport_certhashes: Vec<Cow<'a, [u8]>>,
pub stream_muxers: Vec<Cow<'a, str>>,
}

impl<'a> MessageRead<'a> for NoiseHandshakePayload {
impl<'a> MessageRead<'a> for NoiseExtensions<'a> {
fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result<Self> {
let mut msg = Self::default();
while !r.is_eof() {
match r.next_tag(bytes) {
Ok(10) => msg.identity_key = r.read_bytes(bytes)?.to_owned(),
Ok(18) => msg.identity_sig = r.read_bytes(bytes)?.to_owned(),
Ok(26) => msg.data = r.read_bytes(bytes)?.to_owned(),
Ok(10) => msg.webtransport_certhashes.push(r.read_bytes(bytes).map(Cow::Borrowed)?),
Ok(18) => msg.stream_muxers.push(r.read_string(bytes).map(Cow::Borrowed)?),
Ok(t) => { r.read_unknown(bytes, t)?; }
Err(e) => return Err(e),
}
Expand All @@ -37,18 +36,56 @@ impl<'a> MessageRead<'a> for NoiseHandshakePayload {
}
}

impl MessageWrite for NoiseHandshakePayload {
impl<'a> MessageWrite for NoiseExtensions<'a> {
fn get_size(&self) -> usize {
0
+ if self.identity_key.is_empty() { 0 } else { 1 + sizeof_len((&self.identity_key).len()) }
+ if self.identity_sig.is_empty() { 0 } else { 1 + sizeof_len((&self.identity_sig).len()) }
+ if self.data.is_empty() { 0 } else { 1 + sizeof_len((&self.data).len()) }
+ self.webtransport_certhashes.iter().map(|s| 1 + sizeof_len((s).len())).sum::<usize>()
+ self.stream_muxers.iter().map(|s| 1 + sizeof_len((s).len())).sum::<usize>()
}

fn write_message<W: WriterBackend>(&self, w: &mut Writer<W>) -> Result<()> {
if !self.identity_key.is_empty() { w.write_with_tag(10, |w| w.write_bytes(&**&self.identity_key))?; }
if !self.identity_sig.is_empty() { w.write_with_tag(18, |w| w.write_bytes(&**&self.identity_sig))?; }
if !self.data.is_empty() { w.write_with_tag(26, |w| w.write_bytes(&**&self.data))?; }
for s in &self.webtransport_certhashes { w.write_with_tag(10, |w| w.write_bytes(&**s))?; }
for s in &self.stream_muxers { w.write_with_tag(18, |w| w.write_string(&**s))?; }
Ok(())
}
}

#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Debug, Default, PartialEq, Clone)]
pub struct NoiseHandshakePayload<'a> {
pub identity_key: Cow<'a, [u8]>,
pub identity_sig: Cow<'a, [u8]>,
pub extensions: Option<payload::proto::NoiseExtensions<'a>>,
}

impl<'a> MessageRead<'a> for NoiseHandshakePayload<'a> {
fn from_reader(r: &mut BytesReader, bytes: &'a [u8]) -> Result<Self> {
let mut msg = Self::default();
while !r.is_eof() {
match r.next_tag(bytes) {
Ok(10) => msg.identity_key = r.read_bytes(bytes).map(Cow::Borrowed)?,
Ok(18) => msg.identity_sig = r.read_bytes(bytes).map(Cow::Borrowed)?,
Ok(34) => msg.extensions = Some(r.read_message::<payload::proto::NoiseExtensions>(bytes)?),
Ok(t) => { r.read_unknown(bytes, t)?; }
Err(e) => return Err(e),
}
}
Ok(msg)
}
}

impl<'a> MessageWrite for NoiseHandshakePayload<'a> {
fn get_size(&self) -> usize {
0
+ if self.identity_key == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.identity_key).len()) }
+ if self.identity_sig == Cow::Borrowed(b"") { 0 } else { 1 + sizeof_len((&self.identity_sig).len()) }
+ self.extensions.as_ref().map_or(0, |m| 1 + sizeof_len((m).get_size()))
}

fn write_message<W: WriterBackend>(&self, w: &mut Writer<W>) -> Result<()> {
if self.identity_key != Cow::Borrowed(b"") { w.write_with_tag(10, |w| w.write_bytes(&**&self.identity_key))?; }
if self.identity_sig != Cow::Borrowed(b"") { w.write_with_tag(18, |w| w.write_bytes(&**&self.identity_sig))?; }
if let Some(ref s) = self.extensions { w.write_with_tag(34, |w| w.write_message(s))?; }
Ok(())
}
}
Expand Down
4 changes: 4 additions & 0 deletions transports/noise/src/io/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ impl<T> NoiseFramed<T, snow::HandshakeState> {
}
}

pub(crate) fn is_initiator(&self) -> bool {
self.session.is_initiator()
}

/// Converts the `NoiseFramed` into a `NoiseOutput` encrypted data stream
/// once the handshake is complete, including the static DH [`PublicKey`]
/// of the remote, if received.
Expand Down
72 changes: 69 additions & 3 deletions transports/noise/src/io/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
mod proto {
#![allow(unreachable_pub)]
include!("../generated/mod.rs");
pub use self::payload::proto::NoiseExtensions;
pub use self::payload::proto::NoiseHandshakePayload;
}

Expand All @@ -31,8 +32,10 @@ use crate::protocol::{KeypairIdentity, STATIC_KEY_DOMAIN};
use crate::{DecodeError, Error};
use bytes::Bytes;
use futures::prelude::*;
use libp2p_core::multihash::Multihash;
use libp2p_identity as identity;
use quick_protobuf::{BytesReader, MessageRead, MessageWrite, Writer};
use std::borrow::Cow;
use std::io;

//////////////////////////////////////////////////////////////////////////////
Expand All @@ -49,6 +52,15 @@ pub(crate) struct State<T> {
dh_remote_pubkey_sig: Option<Vec<u8>>,
/// The known or received public identity key of the remote, if any.
id_remote_pubkey: Option<identity::PublicKey>,
/// The WebTransport certhashes of the responder, if any.
responder_webtransport_certhashes: Option<Vec<Multihash>>,
/// The received extensions of the remote, if any.
remote_extensions: Option<Extensions>,
}

/// Extensions
struct Extensions {
webtransport_certhashes: Vec<Multihash>,
}

impl<T> State<T> {
Expand All @@ -63,12 +75,15 @@ impl<T> State<T> {
session: snow::HandshakeState,
identity: KeypairIdentity,
expected_remote_key: Option<identity::PublicKey>,
responder_webtransport_certhashes: Option<Vec<Multihash>>,
) -> Self {
Self {
identity,
io: NoiseFramed::new(io, session),
dh_remote_pubkey_sig: None,
id_remote_pubkey: expected_remote_key,
responder_webtransport_certhashes,
remote_extensions: None,
}
}
}
Expand All @@ -77,6 +92,7 @@ impl<T> State<T> {
/// Finish a handshake, yielding the established remote identity and the
/// [`Output`] for communicating on the encrypted channel.
pub(crate) fn finish(self) -> Result<(identity::PublicKey, Output<T>), Error> {
let is_initiator = self.io.is_initiator();
let (pubkey, io) = self.io.into_transport()?;

let id_pk = self
Expand All @@ -91,10 +107,42 @@ impl<T> State<T> {
return Err(Error::BadSignature);
}

// Check WebTransport certhashes that responder reported back to us
if is_initiator {
// We check only if we care (i.e. Config::with_webtransport_certhashes was used)
if let Some(valid_certhashes) = self.responder_webtransport_certhashes {
let ext = self
.remote_extensions
.ok_or(Error::BadWebTransportCerthashes)?;

// The known WebTransport certhashes must be a strict subset
// of the reported ones
let is_valid = valid_certhashes
.iter()
.all(|hash| ext.webtransport_certhashes.contains(hash));

if !is_valid {
return Err(Error::BadWebTransportCerthashes);
}
}
}

Ok((id_pk, io))
}
}

impl From<proto::NoiseExtensions<'_>> for Extensions {
fn from(value: proto::NoiseExtensions<'_>) -> Self {
Extensions {
webtransport_certhashes: value
.webtransport_certhashes
.into_iter()
.filter_map(|bytes| Multihash::read(&bytes[..]).ok())
.collect(),
}
}
}

//////////////////////////////////////////////////////////////////////////////
// Handshake Message Futures

Expand Down Expand Up @@ -146,7 +194,11 @@ where
state.id_remote_pubkey = Some(identity::PublicKey::try_decode_protobuf(&pb.identity_key)?);

if !pb.identity_sig.is_empty() {
state.dh_remote_pubkey_sig = Some(pb.identity_sig);
state.dh_remote_pubkey_sig = Some(pb.identity_sig.into_owned());
}

if let Some(extensions) = pb.extensions {
state.remote_extensions = Some(extensions.into());
}

Ok(())
Expand All @@ -158,11 +210,25 @@ where
T: AsyncWrite + Unpin,
{
let mut pb = proto::NoiseHandshakePayload {
identity_key: state.identity.public.encode_protobuf(),
identity_key: state.identity.public.encode_protobuf().into(),
..Default::default()
};

pb.identity_sig = state.identity.signature.clone();
pb.identity_sig = Cow::Borrowed(&state.identity.signature);

// If this is the responder then send WebTransport certhashes to initiator, if any
if !state.io.is_initiator() {
if let Some(ref certhashes) = state.responder_webtransport_certhashes {
let ext = pb
.extensions
.get_or_insert_with(proto::NoiseExtensions::default);

ext.webtransport_certhashes = certhashes
.iter()
.map(|hash| Cow::Owned(hash.to_bytes()))
.collect();
}
}

let mut msg = Vec::with_capacity(pb.get_size());

Expand Down
31 changes: 29 additions & 2 deletions transports/noise/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use crate::handshake::State;
use crate::io::handshake;
use crate::protocol::{noise_params_into_builder, AuthenticKeypair, Keypair, PARAMS_XX};
use futures::prelude::*;
use libp2p_core::multihash::Multihash;
use libp2p_core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use libp2p_identity as identity;
use libp2p_identity::PeerId;
Expand All @@ -76,6 +77,7 @@ use std::pin::Pin;
pub struct Config {
dh_keys: AuthenticKeypair,
params: NoiseParams,
webtransport_certhashes: Option<Vec<Multihash>>,

/// Prologue to use in the noise handshake.
///
Expand All @@ -94,14 +96,25 @@ impl Config {
Ok(Self {
dh_keys: noise_keys,
params: PARAMS_XX.clone(),
webtransport_certhashes: None,
prologue: vec![],
})
}

/// Set the noise prologue.
pub fn with_prologue(mut self, prologue: Vec<u8>) -> Self {
self.prologue = prologue;
self
}

/// Set WebTransport certhashes extension
///
/// In case of initiator, these certhashes will be used to validate the ones reported by
/// responder.
///
/// In case of responder, these certhashes will be reported to initiator.
pub fn with_webtransport_certhashes(mut self, certhashes: Vec<Multihash>) -> Self {
self.webtransport_certhashes = Some(certhashes);
self
}

Expand All @@ -114,7 +127,13 @@ impl Config {
)
.build_responder()?;

let state = State::new(socket, session, self.dh_keys.identity, None);
let state = State::new(
socket,
session,
self.dh_keys.identity,
None,
self.webtransport_certhashes,
);

Ok(state)
}
Expand All @@ -128,7 +147,13 @@ impl Config {
)
.build_initiator()?;

let state = State::new(socket, session, self.dh_keys.identity, None);
let state = State::new(
socket,
session,
self.dh_keys.identity,
None,
self.webtransport_certhashes,
);

Ok(state)
}
Expand Down Expand Up @@ -213,6 +238,8 @@ pub enum Error {
InvalidPayload(#[from] DecodeError),
#[error(transparent)]
SigningError(#[from] libp2p_identity::SigningError),
#[error("Bad WebTransport certhashes")]
BadWebTransportCerthashes,
}

#[derive(Debug, thiserror::Error)]
Expand Down

0 comments on commit 7965a1e

Please sign in to comment.