Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proto: Connection side enum #2084

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 108 additions & 31 deletions quinn-proto/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ use timer::{Timer, TimerTable};
/// events or timeouts with different instants must not be interleaved.
pub struct Connection {
endpoint_config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
config: Arc<TransportConfig>,
rng: StdRng,
crypto: Box<dyn crypto::Session>,
Expand All @@ -145,7 +144,7 @@ pub struct Connection {
allow_mtud: bool,
prev_path: Option<(ConnectionId, PathData)>,
state: State,
side: Side,
side: ConnectionSide,
/// Whether or not 0-RTT was enabled during the handshake. Does not imply acceptance.
zero_rtt_enabled: bool,
/// Set if 0-RTT is supported, then cleared when no longer needed.
Expand Down Expand Up @@ -191,9 +190,6 @@ pub struct Connection {
authentication_failures: u64,
/// Why the connection was lost, if it has been
error: Option<ConnectionError>,
/// Sent in every outgoing Initial packet. Always empty for servers and after Initial keys are
/// discarded.
retry_token: Bytes,
/// Identifies Data-space packet numbers to skip. Not used in earlier spaces.
packet_number_filter: PacketNumberFilter,

Expand Down Expand Up @@ -242,12 +238,10 @@ pub struct Connection {
impl Connection {
pub(crate) fn new(
endpoint_config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
config: Arc<TransportConfig>,
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
pref_addr_cid: Option<ConnectionId>,
remote: SocketAddr,
local_ip: Option<IpAddr>,
crypto: Box<dyn crypto::Session>,
Expand All @@ -256,13 +250,12 @@ impl Connection {
version: u32,
allow_mtud: bool,
rng_seed: [u8; 32],
path_validated: bool,
side_args: SideArgs,
) -> Self {
let side = if server_config.is_some() {
Side::Server
} else {
Side::Client
};
let pref_addr_cid = side_args.pref_addr_cid();
let path_validated = side_args.path_validated();
let connection_side = ConnectionSide::from(side_args);
let side = connection_side.side();
let initial_space = PacketSpace {
crypto: Some(crypto.initial_keys(&init_cid, side)),
..PacketSpace::new(now)
Expand All @@ -275,7 +268,6 @@ impl Connection {
let mut rng = StdRng::from_seed(rng_seed);
let mut this = Self {
endpoint_config,
server_config,
crypto,
handshake_cid: loc_cid,
rem_handshake_cid: rem_cid,
Expand All @@ -289,8 +281,8 @@ impl Connection {
allow_mtud,
local_ip,
prev_path: None,
side,
state,
side: connection_side,
zero_rtt_enabled: false,
zero_rtt_crypto: None,
key_phase: false,
Expand Down Expand Up @@ -323,7 +315,6 @@ impl Connection {
timers: TimerTable::default(),
authentication_failures: 0,
error: None,
retry_token: Bytes::new(),
#[cfg(test)]
packet_number_filter: match config.deterministic_packet_numbers {
false => PacketNumberFilter::new(&mut rng),
Expand Down Expand Up @@ -420,7 +411,7 @@ impl Connection {
/// Provide control over streams
#[must_use]
pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() != self.side);
assert!(id.dir() == Dir::Bi || id.initiator() != self.side.side());
RecvStream {
id,
state: &mut self.streams,
Expand All @@ -431,7 +422,7 @@ impl Connection {
/// Provide control over streams
#[must_use]
pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() == self.side);
assert!(id.dir() == Dir::Bi || id.initiator() == self.side.side());
SendStream {
id,
state: &mut self.streams,
Expand Down Expand Up @@ -1075,9 +1066,7 @@ impl Connection {
// If this packet could initiate a migration and we're a client or a server that
// forbids migration, drop the datagram. This could be relaxed to heuristically
// permit NAT-rebinding-like migration.
if remote != self.path.remote
&& self.server_config.as_ref().map_or(true, |x| !x.migration)
{
if remote != self.path.remote && !self.side.remote_may_migrate() {
trace!("discarding packet from unrecognized peer {}", remote);
return;
}
Expand Down Expand Up @@ -1297,7 +1286,7 @@ impl Connection {

/// Look up whether we're the client or server of this Connection
pub fn side(&self) -> Side {
self.side
self.side.side()
}

/// The latest socket address for this connection's peer
Expand Down Expand Up @@ -2101,7 +2090,9 @@ impl Connection {
trace!("discarding {:?} keys", space_id);
if space_id == SpaceId::Initial {
// No longer needed
self.retry_token = Bytes::new();
if let ConnectionSide::Client { token, .. } = &mut self.side {
*token = Bytes::new();
}
}
let space = &mut self.spaces[space_id];
space.crypto = None;
Expand Down Expand Up @@ -2398,7 +2389,7 @@ impl Connection {

self.discard_space(now, SpaceId::Initial); // Make sure we clean up after any retransmitted Initials
self.spaces[SpaceId::Initial] = PacketSpace {
crypto: Some(self.crypto.initial_keys(&rem_cid, self.side)),
crypto: Some(self.crypto.initial_keys(&rem_cid, self.side.side())),
next_packet_number: self.spaces[SpaceId::Initial].next_packet_number,
crypto_offset: client_hello.len() as u64,
..PacketSpace::new(now)
Expand All @@ -2420,7 +2411,10 @@ impl Connection {
self.streams.retransmit_all_for_0rtt();

let token_len = packet.payload.len() - 16;
self.retry_token = packet.payload.freeze().split_to(token_len);
let ConnectionSide::Client { ref mut token, .. } = self.side else {
unreachable!("we already short-circuited if we're server");
};
*token = packet.payload.freeze().split_to(token_len);
self.state = State::Handshake(state::Handshake {
expected_token: Bytes::new(),
rem_cid_set: false,
Expand Down Expand Up @@ -2745,7 +2739,7 @@ impl Connection {
debug!(offset, "peer claims to be blocked at connection level");
}
Frame::StreamDataBlocked { id, offset } => {
if id.initiator() == self.side && id.dir() == Dir::Uni {
if id.initiator() == self.side.side() && id.dir() == Dir::Uni {
debug!("got STREAM_DATA_BLOCKED on send-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
"STREAM_DATA_BLOCKED on send-only stream",
Expand All @@ -2768,7 +2762,7 @@ impl Connection {
);
}
Frame::StopSending(frame::StopSending { id, error_code }) => {
if id.initiator() != self.side {
if id.initiator() != self.side.side() {
if id.dir() == Dir::Uni {
debug!("got STOP_SENDING on recv-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
Expand Down Expand Up @@ -2938,11 +2932,11 @@ impl Connection {
&& !is_probing_packet
&& number == self.spaces[SpaceId::Data].rx_packet
{
let ConnectionSide::Server { ref server_config } = self.side else {
panic!("packets from unknown remote should be dropped by clients");
};
debug_assert!(
self.server_config
.as_ref()
.expect("packets from unknown remote should be dropped by clients")
.migration,
server_config.migration,
"migration-initiating packets should have been dropped immediately"
);
self.migrate(now, remote);
Expand Down Expand Up @@ -3618,6 +3612,89 @@ impl fmt::Debug for Connection {
}
}

/// Fields of `Connection` specific to it being client-side or server-side
enum ConnectionSide {
Client {
/// Sent in every outgoing Initial packet. Always empty after Initial keys are discarded
token: Bytes,
},
Server {
server_config: Arc<ServerConfig>,
},
}

impl ConnectionSide {
fn side(&self) -> Side {
match *self {
Self::Client { .. } => Side::Client,
Self::Server { .. } => Side::Server,
}
}

fn is_client(&self) -> bool {
self.side().is_client()
}

fn is_server(&self) -> bool {
self.side().is_server()
}

fn remote_may_migrate(&self) -> bool {
match self {
Self::Server { server_config } => server_config.migration,
Self::Client { .. } => false,
}
}
Comment on lines +3627 to +3647
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ordering nit: let's have remote_may_migrate() first, then order is_client() and is_server() before side() (since they depend on it).

}

impl From<SideArgs> for ConnectionSide {
fn from(side: SideArgs) -> Self {
match side {
SideArgs::Client => Self::Client {
token: Bytes::new(),
},
SideArgs::Server {
server_config,
pref_addr_cid: _,
path_validated: _,
} => Self::Server { server_config },
}
}
}

/// Parameters to `Connection::new` specific to it being client-side or server-side
pub(crate) enum SideArgs {
Client,
Server {
server_config: Arc<ServerConfig>,
pref_addr_cid: Option<ConnectionId>,
path_validated: bool,
},
}

impl SideArgs {
pub(crate) fn side(&self) -> Side {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ordering nit: let's order this last.

match *self {
Self::Client { .. } => Side::Client,
Self::Server { .. } => Side::Server,
}
}

pub(crate) fn pref_addr_cid(&self) -> Option<ConnectionId> {
match *self {
Self::Client { .. } => None,
Self::Server { pref_addr_cid, .. } => pref_addr_cid,
}
}

pub(crate) fn path_validated(&self) -> bool {
match *self {
Self::Client { .. } => true,
Self::Server { path_validated, .. } => path_validated,
}
}
}

/// Reasons why a connection might be lost
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectionError {
Expand Down
6 changes: 5 additions & 1 deletion quinn-proto/src/connection/packet_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use tracing::{trace, trace_span};

use super::{spaces::SentPacket, Connection, SentFrames};
use crate::{
connection::ConnectionSide,
frame::{self, Close},
packet::{Header, InitialHeader, LongType, PacketNumber, PartialEncode, SpaceId, FIXED_BIT},
ConnectionId, Instant, TransportError, TransportErrorCode,
Expand Down Expand Up @@ -113,7 +114,10 @@ impl PacketBuilder {
SpaceId::Initial => Header::Initial(InitialHeader {
src_cid: conn.handshake_cid,
dst_cid,
token: conn.retry_token.clone(),
token: match &conn.side {
ConnectionSide::Client { token, .. } => token.clone(),
ConnectionSide::Server { .. } => Bytes::new(),
},
number,
version,
}),
Expand Down
28 changes: 11 additions & 17 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
cid_generator::ConnectionIdGenerator,
coding::BufMutExt,
config::{ClientConfig, EndpointConfig, ServerConfig},
connection::{Connection, ConnectionError},
connection::{Connection, ConnectionError, SideArgs},
crypto::{self, Keys, UnsupportedVersion},
frame,
packet::{
Expand Down Expand Up @@ -423,16 +423,14 @@ impl Endpoint {
remote_id,
loc_cid,
remote_id,
None,
FourTuple {
remote,
local_ip: None,
},
now,
tls,
None,
config.transport,
true,
SideArgs::Client,
);
Ok((ch, conn))
}
Expand Down Expand Up @@ -660,13 +658,15 @@ impl Endpoint {
dst_cid,
loc_cid,
src_cid,
pref_addr_cid,
incoming.addresses,
incoming.received_at,
tls,
Some(server_config),
transport_config,
remote_address_validated,
SideArgs::Server {
server_config,
pref_addr_cid,
path_validated: remote_address_validated,
},
);
self.index.insert_initial(dst_cid, ch);

Expand Down Expand Up @@ -829,28 +829,22 @@ impl Endpoint {
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
pref_addr_cid: Option<ConnectionId>,
addresses: FourTuple,
now: Instant,
tls: Box<dyn crypto::Session>,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
path_validated: bool,
side_args: SideArgs,
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let side = match server_config.is_some() {
true => Side::Server,
false => Side::Client,
};
let side = side_args.side();
let pref_addr_cid = side_args.pref_addr_cid();
let conn = Connection::new(
self.config.clone(),
server_config,
transport_config,
init_cid,
loc_cid,
rem_cid,
pref_addr_cid,
addresses.remote,
addresses.local_ip,
tls,
Expand All @@ -859,7 +853,7 @@ impl Endpoint {
version,
self.allow_mtud,
rng_seed,
path_validated,
side_args,
);

let mut cids_issued = 0;
Expand Down
Loading