From 71b511a0c9a3632a2afb69a885bce6453aede785 Mon Sep 17 00:00:00 2001 From: Oliver Gould Date: Mon, 25 Mar 2024 21:48:01 +0000 Subject: [PATCH] chore(http): Add a TimeoutBody middleware This commit implements a new TimeoutBody middleware that uses the new `Body::poll_progess` method to constrain the amount of time a stream waits for send capacity. This change does not yet wire up the middleware into the Linkerd server. --- Cargo.lock | 26 ++++ Cargo.toml | 1 + linkerd/http/body-timeout/Cargo.toml | 16 +++ linkerd/http/body-timeout/src/lib.rs | 174 +++++++++++++++++++++++++++ 4 files changed, 217 insertions(+) create mode 100644 linkerd/http/body-timeout/Cargo.toml create mode 100644 linkerd/http/body-timeout/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index cd8cd488ec..165ed4ed77 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -640,6 +640,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -661,6 +672,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -1444,6 +1456,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "linkerd-http-body-timeout" +version = "0.1.0" +dependencies = [ + "futures", + "http", + "http-body", + "linkerd-error", + "pin-project", + "thiserror", + "tokio", + "tower-service", +] + [[package]] name = "linkerd-http-box" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index e6622eefeb..93cf404c1f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ members = [ "linkerd/error-respond", "linkerd/exp-backoff", "linkerd/http/access-log", + "linkerd/http/body-timeout", "linkerd/http/box", "linkerd/http/classify", "linkerd/http/metrics", diff --git a/linkerd/http/body-timeout/Cargo.toml b/linkerd/http/body-timeout/Cargo.toml new file mode 100644 index 0000000000..f3c46017d3 --- /dev/null +++ b/linkerd/http/body-timeout/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "linkerd-http-body-timeout" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +http = "0.2" +http-body = "0.4" +futures = "0.3" +pin-project = "1" +thiserror = "1" +tokio = { version = "1", features = ["time"] } +tower-service = "0.3" + +linkerd-error = { path = "../../error" } diff --git a/linkerd/http/body-timeout/src/lib.rs b/linkerd/http/body-timeout/src/lib.rs new file mode 100644 index 0000000000..29b10e0608 --- /dev/null +++ b/linkerd/http/body-timeout/src/lib.rs @@ -0,0 +1,174 @@ +#![deny(rust_2018_idioms, clippy::disallowed_methods, clippy::disallowed_types)] +#![forbid(unsafe_code)] + +use futures::prelude::*; +use http::{HeaderMap, HeaderValue}; +use http_body::Body; +use linkerd_error::Error; +use pin_project::pin_project; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::time; + +pub struct TimeoutRequestProgress { + inner: S, + timeout: time::Duration, +} + +pub struct TimeoutResponseProgress { + inner: S, + timeout: time::Duration, +} + +/// A [`Body`] that imposes a timeout on the amount of time the stream may be +/// stuck waiting for capacity. +#[derive(Debug)] +#[pin_project] +pub struct ProgressTimeoutBody { + #[pin] + inner: B, + sleep: Pin>, + timeout: time::Duration, + is_pending: bool, +} + +#[derive(Debug, thiserror::Error)] +#[error("body progress timeout after {0:?}")] +pub struct BodyProgressTimeoutError(time::Duration); + +// === impl TimeoutRequestProgress === + +impl TimeoutRequestProgress { + pub fn new(timeout: time::Duration, inner: S) -> Self { + Self { inner, timeout } + } +} + +impl tower_service::Service> for TimeoutRequestProgress +where + S: tower_service::Service>>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: http::Request) -> Self::Future { + self.inner + .call(req.map(|b| ProgressTimeoutBody::new(self.timeout, b))) + } +} + +// === impl TimeoutResponseProgress === + +impl TimeoutResponseProgress { + pub fn new(timeout: time::Duration, inner: S) -> Self { + Self { inner, timeout } + } +} + +impl tower_service::Service for TimeoutResponseProgress +where + S: tower_service::Service>, + S::Future: Send + 'static, +{ + type Response = http::Response>; + type Error = S::Error; + type Future = + Pin> + Send>>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + #[inline] + fn call(&mut self, req: Req) -> Self::Future { + let timeout = self.timeout; + self.inner + .call(req) + .map_ok(move |res| res.map(|b| ProgressTimeoutBody::new(timeout, b))) + .boxed() + } +} + +// === impl ProgressTimeoutBody === + +impl ProgressTimeoutBody { + pub fn new(timeout: time::Duration, inner: B) -> Self { + // Avoid overflows by capping MAX to roughly 30 years. + const MAX: time::Duration = time::Duration::from_secs(86400 * 365 * 30); + Self { + inner, + timeout: timeout.min(MAX), + is_pending: false, + sleep: Box::pin(time::sleep(MAX)), + } + } +} + +impl Body for ProgressTimeoutBody +where + B: Body + Send + 'static, + B::Data: Send + 'static, + B::Error: Into, +{ + type Data = B::Data; + type Error = Error; + + #[inline] + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + + #[inline] + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + let this = self.project(); + *this.is_pending = false; + this.inner.poll_data(cx).map_err(Into::into) + } + + #[inline] + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>, Self::Error>> { + let this = self.project(); + *this.is_pending = false; + this.inner.poll_trailers(cx).map_err(Into::into) + } + + fn poll_progress(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + let _ = this.inner.poll_progress(cx).map_err(Into::into)?; + + if !*this.is_pending { + this.sleep + .as_mut() + .reset(time::Instant::now() + *this.timeout); + *this.is_pending = true; + } + + match this.sleep.as_mut().poll(cx) { + Poll::Ready(()) => Poll::Ready(Err(BodyProgressTimeoutError(*this.timeout).into())), + Poll::Pending => Poll::Pending, + } + } + + #[inline] + fn size_hint(&self) -> http_body::SizeHint { + self.inner.size_hint() + } +}