Skip to content

Commit

Permalink
Fix left shift overflow when preparing endpoints
Browse files Browse the repository at this point in the history
There is an overflowing left shift due to the interplay between the
driver and the state module. The endpoint addresses produced by the
driver could result in a left-shift overflow, and subsequent panic, back
in the state module (`check_allocated`). The fault is the usage of the
`capacity()` call, which returns the number of endpoints (up to 16), not
the endpoint indices (0-7, inclusive). The error happens when
`capacity() >= 8`:

1. When enabling endpoints, `all_ep_addrs` produces an invalid
   `EndpointAddress` with index 8 (or more).
2. `check_allocated` caller produces a raw index (using `index`),
    calculated to be at least `2 * 8 = 16`.
3. `check_allocated` left shifts `1 << 16` and panics.

The commit refactors the modules, moving the iterator construction into
the endpoint allocator. The allocator ensures that the accessed
endpoints are valid and won't cause a panic on access. The commit
updates the driver accordingly. The approach is more amenable to unit
testing, and the commit includes extra asserts to show endpoint
allocation and iteration does not panic.

The same overflow could have happened during endpoint allocation,
depending on the address supplied by the caller. This commit refactors
that method, too.

Note that this commit slightly changes the `enable_endpoints` behavior.
Specifically, we won't enable the control endpoint. This is intentional;
EP0 control IN / OUT are always enabled in hardware, so we don't need
the software call.

The Teensy 4 examples should have demonstrated this panic. And I was
able to demonstrate the issue with a debug build just before this
commit. I likely only tested release builds when preparing the release,
and release builds hide this defect. To prevent this in the future, I'm
enabling overflow checks in release builds.

I tested this by building debug builds of the serial example on the
Teensy 4, and ensuring that it did not panic. The test_class continues
to pass the usb-device test suite.
  • Loading branch information
mciantyre committed Mar 8, 2023
1 parent d5a40a1 commit d2a49c3
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 47 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Changelog
[Unreleased]
------------

Fix an overflowing left shift that could occur when enabling and allocating
endpoints.

[0.2.0] 2022-11-30
------------------

Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ members = [

[workspace.package]
edition = "2021"

[profile.release]
overflow-checks = true
38 changes: 8 additions & 30 deletions src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,6 @@ fn ctrl_ep0_in() -> EndpointAddress {
EndpointAddress::from_parts(0, UsbDirection::In)
}

/// Produce an iterator over `count` endpoint addresses.
fn all_ep_addrs(count: usize) -> impl Iterator<Item = EndpointAddress> {
(0..count).flat_map(|index| {
let ep_out = EndpointAddress::from_parts(index, UsbDirection::Out);
let ep_in = EndpointAddress::from_parts(index, UsbDirection::In);
[ep_out, ep_in]
})
}

/// Produce an iterator over all endpoint addresses with a non-zero index.
///
/// This skips control endpoints.
fn non_zero_ep_addrs(count: usize) -> impl Iterator<Item = EndpointAddress> {
all_ep_addrs(count).skip(2)
}

/// USB low / full / high speed setting.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Speed {
Expand Down Expand Up @@ -391,31 +375,25 @@ impl Driver {
///
/// This should only be called when the device is configured
fn enable_endpoints(&mut self) {
for addr in all_ep_addrs(self.ep_allocator.capacity()) {
if let Some(ep) = self.ep_allocator.endpoint_mut(addr) {
ep.enable(&self.usb);
}
for ep in self.ep_allocator.nonzero_endpoints_iter_mut() {
ep.enable(&self.usb);
}
}

/// Prime all non-zero, enabled OUT endpoints
fn prime_endpoints(&mut self) {
for addr in non_zero_ep_addrs(self.ep_allocator.capacity()) {
if let Some(ep) = self.ep_allocator.endpoint_mut(addr) {
if ep.is_enabled(&self.usb) && ep.address().direction() == UsbDirection::Out {
let max_packet_len = ep.max_packet_len();
ep.schedule_transfer(&self.usb, max_packet_len);
}
for ep in self.ep_allocator.nonzero_endpoints_iter_mut() {
if ep.is_enabled(&self.usb) && ep.address().direction() == UsbDirection::Out {
let max_packet_len = ep.max_packet_len();
ep.schedule_transfer(&self.usb, max_packet_len);
}
}
}

/// Initialize (or reinitialize) all non-zero endpoints
fn initialize_endpoints(&mut self) {
for addr in non_zero_ep_addrs(self.ep_allocator.capacity()) {
if let Some(ep) = self.ep_allocator.endpoint_mut(addr) {
ep.initialize(&self.usb);
}
for ep in self.ep_allocator.nonzero_endpoints_iter_mut() {
ep.initialize(&self.usb);
}
}

Expand Down
106 changes: 89 additions & 17 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ impl EndpointAllocator<'_> {

/// Returns `Some` if the endpoint is allocated.
fn check_allocated(&self, index: usize) -> Option<()> {
let mask = (index < self.qh_list.len()).then_some(1u16 << index)?;
(index < self.qh_list.len()).then_some(())?;
let mask = 1u16 << index;
(mask & self.alloc_mask.load(Ordering::SeqCst) as u16 != 0).then_some(())
}

Expand All @@ -172,11 +173,6 @@ impl EndpointAllocator<'_> {
self.qh_list.as_ptr().cast()
}

/// Returns the total number of endpoints that could be allocated.
pub fn capacity(&self) -> usize {
self.ep_list.len()
}

/// Acquire the endpoint.
///
/// Returns `None` if the endpoint isn't allocated.
Expand All @@ -193,22 +189,53 @@ impl EndpointAllocator<'_> {
Some(unsafe { ep.assume_init_ref() })
}

/// Aquire the mutable endpoint.
/// Implementation detail to permit endpoint iteration.
///
/// Returns `None` if the endpoint isn't allocated.
pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
/// # Safety
///
/// This can only be called from a method that takes a mutable receiver.
/// Otherwise, you could reach the same mutable endpoint more than once.
unsafe fn endpoint_mut_inner(&self, addr: EndpointAddress) -> Option<&mut Endpoint> {
let index = index(addr);
self.check_allocated(index)?;

// Safety: there's no other immutable or mutable access at this call site.
// Perceived lifetime is tied to the EndpointAllocator, which has a
// mutable receiver.
// Safety: the caller ensures that we actually have a mutable reference.
// Once we have a mutable reference, this is equivalent to calling the
// safe UnsafeCell::get_mut method.
let ep = unsafe { &mut *self.ep_list[index].get() };

// Safety: endpoint is allocated. Checked above.
Some(unsafe { ep.assume_init_mut() })
}

/// Aquire the mutable endpoint.
///
/// Returns `None` if the endpoint isn't allocated.
pub fn endpoint_mut(&mut self, addr: EndpointAddress) -> Option<&mut Endpoint> {
// Safety: call from method with mutable receiver.
unsafe { self.endpoint_mut_inner(addr) }
}

/// Return an iterator of all allocated endpoints.
pub fn endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
(0..8)
.flat_map(|index| {
let ep_out = EndpointAddress::from_parts(index, UsbDirection::Out);
let ep_in = EndpointAddress::from_parts(index, UsbDirection::In);
[ep_out, ep_in]
})
// Safety: call from method with mutable receiver.
.flat_map(|ep| unsafe { self.endpoint_mut_inner(ep) })
}

/// Returns an iterator for all non-zero, allocated endpoints.
///
/// "Non-zero" excludes the first two control endpoints.
pub fn nonzero_endpoints_iter_mut(&mut self) -> impl Iterator<Item = &mut Endpoint> {
self.endpoints_iter_mut()
.filter(|ep| ep.address().index() != 0)
}

/// Allocate the endpoint for the specified address.
///
/// Returns `None` if any are true:
Expand All @@ -222,7 +249,8 @@ impl EndpointAllocator<'_> {
kind: EndpointType,
) -> Option<&mut Endpoint> {
let index = index(addr);
let mask = (index < self.qh_list.len()).then_some(1u16 << index)?;
(index < self.qh_list.len()).then_some(())?;
let mask = 1u16 << index;

// If we pass this call, we're the only caller able to observe mutable
// QHs, TDs, and EPs at index.
Expand Down Expand Up @@ -274,16 +302,23 @@ mod tests {
assert!(ep_alloc.endpoint_mut(addr).is_none());

let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk)
.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
)
.unwrap();
assert_eq!(ep.address(), addr);

assert!(ep_alloc.endpoint(addr).is_some());
assert!(ep_alloc.endpoint_mut(addr).is_some());

// Double-allocate existing endpoint.
let ep =
ep_alloc.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk);
let ep = ep_alloc.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
);
assert!(ep.is_none());

assert!(ep_alloc.endpoint(addr).is_some());
Expand All @@ -296,8 +331,45 @@ mod tests {
assert!(ep_alloc.endpoint_mut(addr).is_none());

let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(2).unwrap(), EndpointType::Bulk)
.allocate_endpoint(
addr,
buffer_alloc.allocate(2).unwrap(),
EndpointType::Control,
)
.unwrap();
assert_eq!(ep.address(), addr);

// Allocate a non-zero endpoint

let addr = EndpointAddress::from(3);
assert!(ep_alloc.endpoint(addr).is_none());
assert!(ep_alloc.endpoint_mut(addr).is_none());

let ep = ep_alloc
.allocate_endpoint(addr, buffer_alloc.allocate(4).unwrap(), EndpointType::Bulk)
.unwrap();
assert_eq!(ep.address(), addr);

assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
assert_eq!(ep_alloc.nonzero_endpoints_iter_mut().count(), 1);

for (actual, expected) in ep_alloc.endpoints_iter_mut().zip([0usize, 0, 3]) {
assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
}

for (actual, expected) in ep_alloc.nonzero_endpoints_iter_mut().zip([3]) {
assert_eq!(actual.address().index(), expected, "{:?}", actual.address());
}

// Try to allocate an invalid endpoint.
let addr = EndpointAddress::from(42);
let ep = ep_alloc.allocate_endpoint(
addr,
buffer_alloc.allocate(4).unwrap(),
EndpointType::Interrupt,
);
assert!(ep.is_none());

assert_eq!(ep_alloc.endpoints_iter_mut().count(), 3);
}
}

0 comments on commit d2a49c3

Please sign in to comment.