From 4d2667d1cb1b938756d20dafa3cccae1db23a831 Mon Sep 17 00:00:00 2001 From: David Pedersen Date: Wed, 19 May 2021 09:19:58 +0200 Subject: [PATCH] feat(tonic): make it easier to add tower middleware to servers (#651) --- examples/src/tower/client.rs | 18 +- examples/src/tower/server.rs | 68 ++-- tests/integration_tests/Cargo.toml | 3 + .../tests/complex_tower_middleware.rs | 113 +++++++ tonic-build/src/client.rs | 20 +- tonic-build/src/server.rs | 43 +-- tonic/Cargo.toml | 2 + tonic/src/client/grpc.rs | 23 +- tonic/src/codegen.rs | 1 + tonic/src/extensions.rs | 4 +- tonic/src/interceptor.rs | 86 ----- tonic/src/lib.rs | 4 +- tonic/src/request.rs | 4 +- tonic/src/server/grpc.rs | 43 +-- tonic/src/service/interceptor.rs | 183 +++++++++++ tonic/src/service/mod.rs | 6 + tonic/src/transport/server/incoming.rs | 8 +- tonic/src/transport/server/mod.rs | 302 ++++++++++++++---- tonic/src/transport/server/recover_error.rs | 81 ++++- tonic/src/transport/service/grpc_timeout.rs | 7 +- tonic/src/transport/service/router.rs | 3 +- tonic/src/util.rs | 13 + 22 files changed, 734 insertions(+), 301 deletions(-) create mode 100644 tests/integration_tests/tests/complex_tower_middleware.rs delete mode 100644 tonic/src/interceptor.rs create mode 100644 tonic/src/service/interceptor.rs create mode 100644 tonic/src/service/mod.rs create mode 100644 tonic/src/util.rs diff --git a/examples/src/tower/client.rs b/examples/src/tower/client.rs index aa3135dba..eab48ba5c 100644 --- a/examples/src/tower/client.rs +++ b/examples/src/tower/client.rs @@ -1,8 +1,9 @@ use hello_world::greeter_client::GreeterClient; use hello_world::HelloRequest; use service::AuthSvc; +use tower::ServiceBuilder; -use tonic::transport::Channel; +use tonic::{transport::Channel, Request, Status}; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -11,9 +12,14 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { let channel = Channel::from_static("http://[::1]:50051").connect().await?; - let auth = AuthSvc::new(channel); - let mut client = GreeterClient::new(auth); + let channel = ServiceBuilder::new() + // Interceptors can be also be applied as middleware + .layer(tonic::service::interceptor_fn(intercept)) + .layer_fn(AuthSvc::new) + .service(channel); + + let mut client = GreeterClient::new(channel); let request = tonic::Request::new(HelloRequest { name: "Tonic".into(), @@ -26,6 +32,12 @@ async fn main() -> Result<(), Box> { Ok(()) } +// An interceptor function. +fn intercept(req: Request<()>) -> Result, Status> { + println!("received {:?}", req); + Ok(req) +} + mod service { use http::{Request, Response}; use std::future::Future; diff --git a/examples/src/tower/server.rs b/examples/src/tower/server.rs index 574b76b7f..2f3e36fcb 100644 --- a/examples/src/tower/server.rs +++ b/examples/src/tower/server.rs @@ -1,11 +1,10 @@ -use hyper::{Body, Request as HyperRequest, Response as HyperResponse}; -use std::task::{Context, Poll}; -use tonic::{ - body::BoxBody, - transport::{NamedService, Server}, - Request, Response, Status, +use hyper::Body; +use std::{ + task::{Context, Poll}, + time::Duration, }; -use tower::Service; +use tonic::{body::BoxBody, transport::Server, Request, Response, Status}; +use tower::{Layer, Service}; use hello_world::greeter_server::{Greeter, GreeterServer}; use hello_world::{HelloReply, HelloRequest}; @@ -39,27 +38,52 @@ async fn main() -> Result<(), Box> { println!("GreeterServer listening on {}", addr); - let svc = InterceptedService { - inner: GreeterServer::new(greeter), - }; - - Server::builder().add_service(svc).serve(addr).await?; + let svc = GreeterServer::new(greeter); + + // The stack of middleware that our service will be wrapped in + let layer = tower::ServiceBuilder::new() + // Apply middleware from tower + .timeout(Duration::from_secs(30)) + // Apply our own middleware + .layer(MyMiddlewareLayer::default()) + // Interceptors can be also be applied as middleware + .layer(tonic::service::interceptor_fn(intercept)) + .into_inner(); + + Server::builder() + // Wrap all services in the middleware stack + .layer(layer) + .add_service(svc) + .serve(addr) + .await?; Ok(()) } +// An interceptor function. +fn intercept(req: Request<()>) -> Result, Status> { + Ok(req) +} + +#[derive(Debug, Clone, Default)] +struct MyMiddlewareLayer; + +impl Layer for MyMiddlewareLayer { + type Service = MyMiddleware; + + fn layer(&self, service: S) -> Self::Service { + MyMiddleware { inner: service } + } +} + #[derive(Debug, Clone)] -struct InterceptedService { +struct MyMiddleware { inner: S, } -impl Service> for InterceptedService +impl Service> for MyMiddleware where - S: Service, Response = HyperResponse> - + NamedService - + Clone - + Send - + 'static, + S: Service, Response = hyper::Response> + Clone + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; @@ -70,7 +94,7 @@ where self.inner.poll_ready(cx) } - fn call(&mut self, req: HyperRequest) -> Self::Future { + fn call(&mut self, req: hyper::Request) -> Self::Future { // This is necessary because tonic internally uses `tower::buffer::Buffer`. // See https://github.com/tower-rs/tower/issues/547#issuecomment-767629149 // for details on why this is necessary @@ -85,7 +109,3 @@ where }) } } - -impl NamedService for InterceptedService { - const NAME: &'static str = S::NAME; -} diff --git a/tests/integration_tests/Cargo.toml b/tests/integration_tests/Cargo.toml index 1b4acef2a..fc33b0d37 100644 --- a/tests/integration_tests/Cargo.toml +++ b/tests/integration_tests/Cargo.toml @@ -20,6 +20,9 @@ tokio-stream = { version = "0.1.5", features = ["net"] } tower-service = "0.3" hyper = "0.14" futures = "0.3" +tower = { version = "0.4", features = [] } +http-body = "0.4" +http = "0.2" [build-dependencies] tonic-build = { path = "../../tonic-build" } diff --git a/tests/integration_tests/tests/complex_tower_middleware.rs b/tests/integration_tests/tests/complex_tower_middleware.rs new file mode 100644 index 000000000..5d7690be3 --- /dev/null +++ b/tests/integration_tests/tests/complex_tower_middleware.rs @@ -0,0 +1,113 @@ +#![allow(unused_variables, dead_code)] + +use http_body::Body; +use integration_tests::pb::{test_server, Input, Output}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tonic::{transport::Server, Request, Response, Status}; +use tower::{layer::Layer, BoxError, Service}; + +// all we care about is that this compiles +async fn complex_tower_layers_work() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + unimplemented!() + } + } + + let svc = test_server::TestServer::new(Svc); + + Server::builder() + .layer(MyServiceLayer::new()) + .add_service(svc) + .serve("127.0.0.1:1322".parse().unwrap()) + .await + .unwrap(); +} + +#[derive(Debug, Clone)] +struct MyServiceLayer {} + +impl MyServiceLayer { + fn new() -> Self { + unimplemented!() + } +} + +impl Layer for MyServiceLayer { + type Service = MyService; + + fn layer(&self, inner: S) -> Self::Service { + unimplemented!() + } +} + +#[derive(Debug, Clone)] +struct MyService { + inner: S, +} + +impl Service for MyService +where + S: Service>, +{ + type Response = http::Response>; + type Error = BoxError; + type Future = MyFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + unimplemented!() + } + + fn call(&mut self, req: R) -> Self::Future { + unimplemented!() + } +} + +struct MyFuture { + inner: F, + body: B, +} + +impl Future for MyFuture +where + F: Future, E>>, +{ + type Output = Result>, BoxError>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unimplemented!() + } +} + +struct MyBody { + inner: B, +} + +impl Body for MyBody +where + B: Body, +{ + type Data = B::Data; + type Error = BoxError; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + unimplemented!() + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + unimplemented!() + } +} diff --git a/tonic-build/src/client.rs b/tonic-build/src/client.rs index e65248908..67b4792c5 100644 --- a/tonic-build/src/client.rs +++ b/tonic-build/src/client.rs @@ -36,18 +36,24 @@ pub fn generate( #connect impl #service_ident - where T: tonic::client::GrpcService, - T::ResponseBody: Body + Send + Sync + 'static, - T::Error: Into, - ::Error: Into + Send, { + where + T: tonic::client::GrpcService, + T::ResponseBody: Body + Send + Sync + 'static, + T::Error: Into, + ::Error: Into + Send, + { pub fn new(inner: T) -> Self { let inner = tonic::client::Grpc::new(inner); Self { inner } } - pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { - let inner = tonic::client::Grpc::with_interceptor(inner, interceptor); - Self { inner } + pub fn with_interceptor(inner: T, interceptor: F) -> #service_ident> + where + F: FnMut(tonic::Request<()>) -> Result, tonic::Status>, + T: Service, Response = http::Response>, + >>::Error: Into + Send + Sync, + { + #service_ident::new(InterceptedService::new(inner, interceptor)) } #methods diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 170c79995..517917540 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -50,19 +50,20 @@ pub fn generate( inner: _Inner, } - struct _Inner(Arc, Option); + struct _Inner(Arc); impl #server_service { pub fn new(inner: T) -> Self { let inner = Arc::new(inner); - let inner = _Inner(inner, None); + let inner = _Inner(inner); Self { inner } } - pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { - let inner = Arc::new(inner); - let inner = _Inner(inner, Some(interceptor.into())); - Self { inner } + pub fn with_interceptor(inner: T, interceptor: F) -> InterceptedService + where + F: FnMut(tonic::Request<()>) -> Result, tonic::Status>, + { + InterceptedService::new(Self::new(inner), interceptor) } } @@ -107,7 +108,7 @@ pub fn generate( impl Clone for _Inner { fn clone(&self) -> Self { - Self(self.0.clone(), self.1.clone()) + Self(self.0.clone()) } } @@ -336,16 +337,11 @@ fn generate_unary( let inner = self.inner.clone(); let fut = async move { - let interceptor = inner.1.clone(); let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = if let Some(interceptor) = interceptor { - tonic::server::Grpc::with_interceptor(codec, interceptor) - } else { - tonic::server::Grpc::new(codec) - }; + let mut grpc = tonic::server::Grpc::new(codec); let res = grpc.unary(method, req).await; Ok(res) @@ -391,16 +387,11 @@ fn generate_server_streaming( let inner = self.inner.clone(); let fut = async move { - let interceptor = inner.1; let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = if let Some(interceptor) = interceptor { - tonic::server::Grpc::with_interceptor(codec, interceptor) - } else { - tonic::server::Grpc::new(codec) - }; + let mut grpc = tonic::server::Grpc::new(codec); let res = grpc.server_streaming(method, req).await; Ok(res) @@ -443,16 +434,11 @@ fn generate_client_streaming( let inner = self.inner.clone(); let fut = async move { - let interceptor = inner.1; let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = if let Some(interceptor) = interceptor { - tonic::server::Grpc::with_interceptor(codec, interceptor) - } else { - tonic::server::Grpc::new(codec) - }; + let mut grpc = tonic::server::Grpc::new(codec); let res = grpc.client_streaming(method, req).await; Ok(res) @@ -498,16 +484,11 @@ fn generate_streaming( let inner = self.inner.clone(); let fut = async move { - let interceptor = inner.1; let inner = inner.0; let method = #service_ident(inner); let codec = #codec_name::default(); - let mut grpc = if let Some(interceptor) = interceptor { - tonic::server::Grpc::with_interceptor(codec, interceptor) - } else { - tonic::server::Grpc::new(codec) - }; + let mut grpc = tonic::server::Grpc::new(codec); let res = grpc.streaming(method, req).await; Ok(res) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index eb023a959..f078d5dbd 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -52,6 +52,7 @@ base64 = "0.13" percent-encoding = "2.1" tower-service = "0.3" +tower-layer = "0.3" tokio-util = { version = "0.6", features = ["codec"] } async-stream = "0.3" http-body = "0.4.2" @@ -83,6 +84,7 @@ rand = "0.8" bencher = "0.1.5" quickcheck = "1.0" quickcheck_macros = "1.0" +tower = { version = "0.4.7", features = ["full"] } [package.metadata.docs.rs] all-features = true diff --git a/tonic/src/client/grpc.rs b/tonic/src/client/grpc.rs index a249fb757..287949914 100644 --- a/tonic/src/client/grpc.rs +++ b/tonic/src/client/grpc.rs @@ -2,7 +2,6 @@ use crate::{ body::BoxBody, client::GrpcService, codec::{encode_client, Codec, Streaming}, - interceptor::Interceptor, Code, Request, Response, Status, }; use futures_core::Stream; @@ -29,25 +28,12 @@ use std::fmt; /// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests pub struct Grpc { inner: T, - interceptor: Option, } impl Grpc { /// Creates a new gRPC client with the provided [`GrpcService`]. pub fn new(inner: T) -> Self { - Self { - inner, - interceptor: None, - } - } - - /// Creates a new gRPC client with the provided [`GrpcService`] and will apply - /// the provided interceptor on each request. - pub fn with_interceptor(inner: T, interceptor: impl Into) -> Self { - Self { - inner, - interceptor: Some(interceptor.into()), - } + Self { inner } } /// Check if the inner [`GrpcService`] is able to accept a new request. @@ -153,12 +139,6 @@ impl Grpc { M1: Send + Sync + 'static, M2: Send + Sync + 'static, { - let request = if let Some(interceptor) = &self.interceptor { - interceptor.call(request)? - } else { - request - }; - let mut parts = Parts::default(); parts.path_and_query = Some(path); @@ -217,7 +197,6 @@ impl Clone for Grpc { fn clone(&self) -> Self { Self { inner: self.inner.clone(), - interceptor: self.interceptor.clone(), } } } diff --git a/tonic/src/codegen.rs b/tonic/src/codegen.rs index 5cb4221b4..dd83f2c4a 100644 --- a/tonic/src/codegen.rs +++ b/tonic/src/codegen.rs @@ -10,6 +10,7 @@ pub use std::sync::Arc; pub use std::task::{Context, Poll}; pub use tower_service::Service; pub type StdError = Box; +pub use crate::service::interceptor::InterceptedService; pub use http_body::Body; pub type BoxFuture = self::Pin> + Send + 'static>>; diff --git a/tonic/src/extensions.rs b/tonic/src/extensions.rs index a42a4c276..c7a40ffb6 100644 --- a/tonic/src/extensions.rs +++ b/tonic/src/extensions.rs @@ -2,10 +2,10 @@ use std::fmt; /// A type map of protocol extensions. /// -/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from +/// `Extensions` can be used by [`interceptor_fn`] and [`Request`] to store extra data derived from /// the underlying protocol. /// -/// [`Interceptor`]: crate::Interceptor +/// [`interceptor_fn`]: crate::service::interceptor_fn /// [`Request`]: crate::Request pub struct Extensions { inner: http::Extensions, diff --git a/tonic/src/interceptor.rs b/tonic/src/interceptor.rs deleted file mode 100644 index f99d27cc2..000000000 --- a/tonic/src/interceptor.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::{Request, Status}; -use std::panic::{RefUnwindSafe, UnwindSafe}; -use std::{fmt, sync::Arc}; - -type InterceptorFn = Arc< - dyn Fn(Request<()>) -> Result, Status> - + Send - + Sync - + UnwindSafe - + RefUnwindSafe - + 'static, ->; - -/// Represents a gRPC interceptor. -/// -/// gRPC interceptors are similar to middleware but have much less -/// flexibility. This interceptor allows you to do two main things, -/// one is to add/remove/check items in the `MetadataMap` of each -/// request. Two, cancel a request with any `Status`. -/// -/// An interceptor can be used on both the server and client side through -/// the `tonic-build` crate's generated structs. -/// -/// These interceptors do not allow you to modify the `Message` of the request -/// but allow you to check for metadata. If you would like to apply middleware like -/// features to the body of the request, going through the `tower` abstraction is recommended. -#[derive(Clone)] -pub struct Interceptor { - f: InterceptorFn, -} - -impl Interceptor { - /// Create a new `Interceptor` from the provided function. - pub fn new( - f: impl Fn(Request<()>) -> Result, Status> - + Send - + Sync - + UnwindSafe - + RefUnwindSafe - + 'static, - ) -> Self { - Interceptor { f: Arc::new(f) } - } - - pub(crate) fn call(&self, req: Request) -> Result, Status> { - let (metadata, ext, message) = req.into_parts(); - - let temp_req = Request::from_parts(metadata, ext, ()); - - let (metadata, ext, _) = (self.f)(temp_req)?.into_parts(); - - Ok(Request::from_parts(metadata, ext, message)) - } -} - -impl From for Interceptor -where - F: Fn(Request<()>) -> Result, Status> - + Send - + Sync - + UnwindSafe - + RefUnwindSafe - + 'static, -{ - fn from(f: F) -> Self { - Interceptor::new(f) - } -} - -impl fmt::Debug for Interceptor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Interceptor").finish() - } -} - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - - #[test] - fn interceptor_fn_is_unwind_safe() { - fn is_unwind_safe() {} - is_unwind_safe::(); - } -} diff --git a/tonic/src/lib.rs b/tonic/src/lib.rs index d5eb61ec3..815c0b1a9 100644 --- a/tonic/src/lib.rs +++ b/tonic/src/lib.rs @@ -83,17 +83,18 @@ pub mod client; pub mod codec; pub mod metadata; pub mod server; +pub mod service; #[cfg(feature = "transport")] #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub mod transport; mod extensions; -mod interceptor; mod macros; mod request; mod response; mod status; +mod util; /// A re-export of [`async-trait`](https://docs.rs/async-trait) for use with codegen. #[cfg(feature = "codegen")] @@ -103,7 +104,6 @@ pub use async_trait::async_trait; #[doc(inline)] pub use codec::Streaming; pub use extensions::Extensions; -pub use interceptor::Interceptor; pub use request::{IntoRequest, IntoStreamingRequest, Request}; pub use response::Response; pub use status::{Code, Status}; diff --git a/tonic/src/request.rs b/tonic/src/request.rs index e0a033ef6..c7780ebbd 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -267,13 +267,13 @@ impl Request { /// Extensions can be set in interceptors: /// /// ```no_run - /// use tonic::{Request, Interceptor}; + /// use tonic::{Request, service::interceptor_fn}; /// /// struct MyExtension { /// some_piece_of_data: String, /// } /// - /// Interceptor::new(|mut request: Request<()>| { + /// interceptor_fn(|mut request: Request<()>| { /// request.extensions_mut().insert(MyExtension { /// some_piece_of_data: "foo".to_string(), /// }); diff --git a/tonic/src/server/grpc.rs b/tonic/src/server/grpc.rs index 7852d7fe7..e640ac8ed 100644 --- a/tonic/src/server/grpc.rs +++ b/tonic/src/server/grpc.rs @@ -1,7 +1,6 @@ use crate::{ body::BoxBody, codec::{encode_server, Codec, Streaming}, - interceptor::Interceptor, server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService}, Code, Request, Status, }; @@ -10,16 +9,6 @@ use futures_util::{future, stream, TryStreamExt}; use http_body::Body; use std::fmt; -// A try! type macro for intercepting requests -macro_rules! t { - ($expr : expr) => { - match $expr { - Ok(request) => request, - Err(res) => return res, - } - }; -} - /// A gRPC Server handler. /// /// This will wrap some inner [`Codec`] and provide utilities to handle @@ -31,7 +20,6 @@ macro_rules! t { /// implements some [`Body`]. pub struct Grpc { codec: T, - interceptor: Option, } impl Grpc @@ -41,19 +29,7 @@ where { /// Creates a new gRPC server with the provided [`Codec`]. pub fn new(codec: T) -> Self { - Self { - codec, - interceptor: None, - } - } - - /// Creates a new gRPC server with the provided [`Codec`] and will apply the provided - /// interceptor on each inbound request. - pub fn with_interceptor(codec: T, interceptor: impl Into) -> Self { - Self { - codec, - interceptor: Some(interceptor.into()), - } + Self { codec } } /// Handle a single unary gRPC request. @@ -77,8 +53,6 @@ where } }; - let request = t!(self.intercept_request(request)); - let response = service .call(request) .await @@ -106,8 +80,6 @@ where } }; - let request = t!(self.intercept_request(request)); - let response = service.call(request).await; self.map_response(response) @@ -125,7 +97,6 @@ where B::Error: Into + Send + 'static, { let request = self.map_request_streaming(req); - let request = t!(self.intercept_request(request)); let response = service .call(request) .await @@ -146,7 +117,6 @@ where B::Error: Into + Send, { let request = self.map_request_streaming(req); - let request = t!(self.intercept_request(request)); let response = service.call(request).await; self.map_response(response) } @@ -213,17 +183,6 @@ where Err(status) => status.to_http(), } } - - fn intercept_request(&self, req: Request) -> Result, http::Response> { - if let Some(interceptor) = &self.interceptor { - match interceptor.call(req) { - Ok(req) => Ok(req), - Err(status) => Err(status.to_http()), - } - } else { - Ok(req) - } - } } impl fmt::Debug for Grpc { diff --git a/tonic/src/service/interceptor.rs b/tonic/src/service/interceptor.rs new file mode 100644 index 000000000..7131bfcbb --- /dev/null +++ b/tonic/src/service/interceptor.rs @@ -0,0 +1,183 @@ +//! gRPC interceptors which are a kind of middleware. + +use crate::Status; +use pin_project::pin_project; +use std::{ + fmt, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tower_layer::Layer; +use tower_service::Service; + +/// Create a new interceptor from a function. +/// +/// gRPC interceptors are similar to middleware but have less flexibility. This interceptor allows +/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each +/// request. Two, cancel a request with any `Status`. +/// +/// An interceptor can be used on both the server and client side through the `tonic-build` crate's +/// generated structs. +/// +/// These interceptors do not allow you to modify the `Message` of the request but allow you to +/// check for metadata. If you would like to apply middleware like features to the body of the +/// request, going through the [tower] abstraction is recommended. +/// +/// Interceptors is not recommend should not be used to add logging to your service. For that a +/// [tower] middleware is more appropriate since it can also act on the response. +/// +/// See the [interceptor example][example] for more details. +/// +/// [tower]: https://crates.io/crates/tower +/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor +// TODO: when tower-http is shipped update the docs to mention its `Trace` middleware which has +// support for gRPC and is an easy to add logging +pub fn interceptor_fn(f: F) -> InterceptorFn +where + F: FnMut(crate::Request<()>) -> Result, Status>, +{ + InterceptorFn { f } +} + +/// An interceptor created from a function. +/// +/// See [`interceptor_fn`] for more details. +#[derive(Debug, Clone, Copy)] +pub struct InterceptorFn { + f: F, +} + +impl Layer for InterceptorFn +where + F: FnMut(crate::Request<()>) -> Result, Status> + Clone, +{ + type Service = InterceptedService; + + fn layer(&self, service: S) -> Self::Service { + InterceptedService::new(service, self.f.clone()) + } +} + +/// A service wrapped in an interceptor middleware. +/// +/// See [`interceptor_fn`] for more details. +#[derive(Clone, Copy)] +pub struct InterceptedService { + inner: S, + f: F, +} + +impl InterceptedService { + /// Create a new `InterceptedService` thats wraps `S` and intercepts each request with the + /// function `F`. + pub fn new(service: S, f: F) -> Self + where + F: FnMut(crate::Request<()>) -> Result, Status>, + { + Self { inner: service, f } + } +} + +impl fmt::Debug for InterceptedService +where + S: fmt::Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("InterceptedService") + .field("inner", &self.inner) + .field("f", &format_args!("{}", std::any::type_name::())) + .finish() + } +} + +impl Service> for InterceptedService +where + F: FnMut(crate::Request<()>) -> Result, Status>, + S: Service, Response = http::Response>, + S::Error: Into, +{ + type Response = http::Response; + type Error = crate::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: http::Request) -> Self::Future { + let uri = req.uri().clone(); + let req = crate::Request::from_http(req); + let (metadata, extensions, msg) = req.into_parts(); + + match (self.f)(crate::Request::from_parts(metadata, extensions, ())) { + Ok(req) => { + let (metadata, extensions, _) = req.into_parts(); + let req = crate::Request::from_parts(metadata, extensions, msg); + let req = req.into_http(uri); + ResponseFuture::future(self.inner.call(req)) + } + Err(status) => ResponseFuture::error(status), + } + } +} + +// required to use `InterceptedService` with `Router` +#[cfg(feature = "transport")] +impl crate::transport::NamedService for InterceptedService +where + S: crate::transport::NamedService, +{ + const NAME: &'static str = S::NAME; +} + +/// Response future for [`InterceptedService`]. +#[pin_project] +#[derive(Debug)] +pub struct ResponseFuture { + #[pin] + kind: Kind, +} + +impl ResponseFuture { + fn future(future: F) -> Self { + Self { + kind: Kind::Future(future), + } + } + + fn error(status: Status) -> Self { + Self { + kind: Kind::Error(Some(status)), + } + } +} + +#[pin_project(project = KindProj)] +#[derive(Debug)] +enum Kind { + Future(#[pin] F), + Error(Option), +} + +impl Future for ResponseFuture +where + F: Future, E>>, + E: Into, +{ + type Output = Result, crate::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.project().kind.project() { + KindProj::Future(future) => { + let response = futures_core::ready!(future.poll(cx).map_err(Into::into)?); + Poll::Ready(Ok(response)) + } + KindProj::Error(status) => { + let error = status.take().unwrap().into(); + Poll::Ready(Err(error)) + } + } + } +} diff --git a/tonic/src/service/mod.rs b/tonic/src/service/mod.rs new file mode 100644 index 000000000..2a6ab273a --- /dev/null +++ b/tonic/src/service/mod.rs @@ -0,0 +1,6 @@ +//! Utilities for using Tower services with Tonic. + +pub mod interceptor; + +#[doc(inline)] +pub use self::interceptor::interceptor_fn; diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 4ead628f0..f9a21a0a9 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -15,9 +15,9 @@ use std::{ use tokio::io::{AsyncRead, AsyncWrite}; #[cfg(not(feature = "tls"))] -pub(crate) fn tcp_incoming( +pub(crate) fn tcp_incoming( incoming: impl Stream>, - _server: Server, + _server: Server, ) -> impl Stream> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, @@ -35,9 +35,9 @@ where } #[cfg(feature = "tls")] -pub(crate) fn tcp_incoming( +pub(crate) fn tcp_incoming( incoming: impl Stream>, - server: Server, + server: Server, ) -> impl Stream> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index e4beece0f..4ec198237 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -25,26 +25,33 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, Or, Routes, ServerIo}; use crate::{body::BoxBody, request::ConnectionInfo}; +use bytes::Bytes; use futures_core::Stream; use futures_util::{ - future::{self, Either as FutureEither, MapErr}, - TryFutureExt, + future::{self, MapErr}, + ready, TryFutureExt, }; use http::{Request, Response}; +use http_body::Body as _; use hyper::{server::accept, Body}; +use pin_project::pin_project; use std::{ fmt, future::Future, net::SocketAddr, + pin::Pin, sync::Arc, task::{Context, Poll}, time::Duration, }; use tokio::io::{AsyncRead, AsyncWrite}; -use tower::{limit::concurrency::ConcurrencyLimitLayer, util::Either, Service, ServiceBuilder}; -use tracing_futures::{Instrument, Instrumented}; +use tower::{ + layer::util::Identity, layer::Layer, limit::concurrency::ConcurrencyLimitLayer, util::Either, + Service, ServiceBuilder, +}; -type BoxService = tower::util::BoxService, Response, crate::Error>; +type BoxHttpBody = http_body::combinators::BoxBody; +type BoxService = tower::util::BoxService, Response, crate::Error>; type TraceInterceptor = Arc) -> tracing::Span + Send + Sync + 'static>; const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; @@ -58,7 +65,7 @@ const DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS: u64 = 20; /// reference implementation that should be a good starting point for anyone /// wanting to create a more complex and/or specific implementation. #[derive(Default, Clone)] -pub struct Server { +pub struct Server { trace_interceptor: Option, concurrency_limit: Option, timeout: Option, @@ -73,12 +80,13 @@ pub struct Server { http2_keepalive_timeout: Option, max_frame_size: Option, accept_http1: bool, + layer: L, } /// A stack based `Service` router. #[derive(Debug)] -pub struct Router { - server: Server, +pub struct Router { + server: Server, routes: Routes>, } @@ -88,35 +96,29 @@ pub struct Router { /// gRPC endpoints and can be consumed with the rest of the `tower` /// ecosystem. #[derive(Debug)] -pub struct RouterService { - router: Router, +pub struct RouterService { + inner: S, } -impl Service> for RouterService +impl Service> for RouterService where - A: Service, Response = Response> + Clone + Send + 'static, - A::Future: Send + 'static, - A::Error: Into + Send, - B: Service, Response = Response> + Clone + Send + 'static, - B::Future: Send + 'static, - B::Error: Into + Send, + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, { type Response = Response; type Error = crate::Error; #[allow(clippy::type_complexity)] - type Future = FutureEither< - MapErr crate::Error>, - MapErr crate::Error>, - >; + type Future = MapErr crate::Error>; + #[inline] fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - #[inline] fn call(&mut self, req: Request) -> Self::Future { - self.router.routes.call(req) + self.inner.call(req).map_err(Into::into) } } @@ -144,7 +146,7 @@ impl Server { } } -impl Server { +impl Server { /// Configure TLS for this server. #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] @@ -319,7 +321,7 @@ impl Server { /// /// This will clone the `Server` builder and create a router that will /// route around different services. - pub fn add_service(&mut self, svc: S) -> Router + pub fn add_service(&mut self, svc: S) -> Router where S: Service, Response = Response> + NamedService @@ -328,6 +330,7 @@ impl Server { + 'static, S::Future: Send + 'static, S::Error: Into + Send, + L: Clone, { Router::new(self.clone(), svc) } @@ -343,7 +346,7 @@ impl Server { pub fn add_optional_service( &mut self, svc: Option, - ) -> Router, Unimplemented> + ) -> Router, Unimplemented, L> where S: Service, Response = Response> + NamedService @@ -352,6 +355,7 @@ impl Server { + 'static, S::Future: Send + 'static, S::Error: Into + Send, + L: Clone, { let svc = match svc { Some(some) => Either::A(some), @@ -360,20 +364,104 @@ impl Server { Router::new(self.clone(), svc) } - pub(crate) async fn serve_with_shutdown( + /// Set the [Tower] [`Layer`] all services will be wrapped in. + /// + /// This enables using middleware from the [Tower ecosystem][eco]. + /// + /// # Example + /// + /// ``` + /// # use tonic::transport::Server; + /// # use tower_service::Service; + /// use tower::timeout::TimeoutLayer; + /// use std::time::Duration; + /// + /// # let mut builder = Server::builder(); + /// builder.layer(TimeoutLayer::new(Duration::from_secs(30))); + /// ``` + /// + /// Note that timeouts should be set using [`Server::timeout`]. `TimeoutLayer` is only used + /// here as an example. + /// + /// You can build more complex layers using [`ServiceBuilder`]. Those layers can include + /// [interceptors]: + /// + /// ``` + /// # use tonic::transport::Server; + /// # use tower_service::Service; + /// use tower::ServiceBuilder; + /// use std::time::Duration; + /// use tonic::{Request, Status, service::interceptor_fn}; + /// + /// fn auth_interceptor(request: Request<()>) -> Result, Status> { + /// if valid_credentials(&request) { + /// Ok(request) + /// } else { + /// Err(Status::unauthenticated("invalid credentials")) + /// } + /// } + /// + /// fn valid_credentials(request: &Request<()>) -> bool { + /// // ... + /// # true + /// } + /// + /// fn some_other_interceptor(request: Request<()>) -> Result, Status> { + /// Ok(request) + /// } + /// + /// let layer = ServiceBuilder::new() + /// .load_shed() + /// .timeout(Duration::from_secs(30)) + /// .layer(interceptor_fn(auth_interceptor)) + /// .layer(interceptor_fn(some_other_interceptor)) + /// .into_inner(); + /// + /// Server::builder().layer(layer); + /// ``` + /// + /// [Tower]: https://github.com/tower-rs/tower + /// [`Layer`]: tower::layer::Layer + /// [eco]: https://github.com/tower-rs + /// [`ServiceBuilder`]: tower::ServiceBuilder + /// [interceptors]: crate::service::interceptor_fn + pub fn layer(self, new_layer: NewLayer) -> Server { + Server { + layer: new_layer, + trace_interceptor: self.trace_interceptor, + concurrency_limit: self.concurrency_limit, + timeout: self.timeout, + #[cfg(feature = "tls")] + tls: self.tls, + init_stream_window_size: self.init_stream_window_size, + init_connection_window_size: self.init_connection_window_size, + max_concurrent_streams: self.max_concurrent_streams, + tcp_keepalive: self.tcp_keepalive, + tcp_nodelay: self.tcp_nodelay, + http2_keepalive_interval: self.http2_keepalive_interval, + http2_keepalive_timeout: self.http2_keepalive_timeout, + max_frame_size: self.max_frame_size, + accept_http1: self.accept_http1, + } + } + + pub(crate) async fn serve_with_shutdown( self, svc: S, incoming: I, signal: Option, ) -> Result<(), super::Error> where - S: Service, Response = Response> + Clone + Send + 'static, - S::Future: Send + 'static, - S::Error: Into + Send, + L: Layer, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>::Service as Service>>::Future: Send + 'static, + <>::Service as Service>>::Error: Into + Send, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, F: Future, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, { let trace_interceptor = self.trace_interceptor.clone(); let concurrency_limit = self.concurrency_limit; @@ -387,7 +475,9 @@ impl Server { let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self .http2_keepalive_timeout - .unwrap_or(Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0)); + .unwrap_or_else(|| Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0)); + + let svc = self.layer.layer(svc); let tcp = incoming::tcp_incoming(incoming, self); let incoming = accept::from_stream::<_, _, crate::Error>(tcp); @@ -422,8 +512,8 @@ impl Server { } } -impl Router { - pub(crate) fn new(server: Server, svc: S) -> Self +impl Router { + pub(crate) fn new(server: Server, svc: S) -> Self where S: Service, Response = Response> + NamedService @@ -447,7 +537,7 @@ impl Router { } } -impl Router +impl Router where A: Service, Response = Response> + Clone + Send + 'static, A::Future: Send + 'static, @@ -457,7 +547,7 @@ where B::Error: Into + Send, { /// Add a new service to this router. - pub fn add_service(self, svc: S) -> Router>> + pub fn add_service(self, svc: S) -> Router>, L> where S: Service, Response = Response> + NamedService @@ -486,10 +576,11 @@ where /// # Note /// Even when the argument given is `None` this will capture *all* requests to this service name. /// As a result, one cannot use this to toggle between two identically named implementations. + #[allow(clippy::type_complexity)] pub fn add_optional_service( self, svc: Option, - ) -> Router, Or>> + ) -> Router, Or>, L> where S: Service, Response = Response> + NamedService @@ -518,27 +609,53 @@ where } /// Consume this [`Server`] creating a future that will execute the server - /// on [`tokio`]'s default executor. + /// on [tokio]'s default executor. /// /// [`Server`]: struct.Server.html - pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> { + /// [tokio]: https://docs.rs/tokio + pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> + where + L: Layer>>, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>>>::Service as Service>>::Future: + Send + 'static, + <>>>::Service as Service>>::Error: + Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, + { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) .map_err(super::Error::from_source)?; self.server - .serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None) + .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( + self.routes, + incoming, + None, + ) .await } /// Consume this [`Server`] creating a future that will execute the server - /// on [`tokio`]'s default executor. And shutdown when the provided signal + /// on [tokio]'s default executor. And shutdown when the provided signal /// is received. /// /// [`Server`]: struct.Server.html - pub async fn serve_with_shutdown>( + /// [tokio]: https://docs.rs/tokio + pub async fn serve_with_shutdown, ResBody>( self, addr: SocketAddr, signal: F, - ) -> Result<(), super::Error> { + ) -> Result<(), super::Error> + where + L: Layer>>, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>>>::Service as Service>>::Future: + Send + 'static, + <>>>::Service as Service>>::Error: + Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, + { let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive) .map_err(super::Error::from_source)?; self.server @@ -550,14 +667,29 @@ where /// the provided incoming stream of `AsyncRead + AsyncWrite`. /// /// [`Server`]: struct.Server.html - pub async fn serve_with_incoming(self, incoming: I) -> Result<(), super::Error> + pub async fn serve_with_incoming( + self, + incoming: I, + ) -> Result<(), super::Error> where I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, + L: Layer>>, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>>>::Service as Service>>::Future: + Send + 'static, + <>>>::Service as Service>>::Error: + Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, { self.server - .serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None) + .serve_with_shutdown::<_, _, future::Ready<()>, _, _, ResBody>( + self.routes, + incoming, + None, + ) .await } @@ -567,7 +699,7 @@ where /// gracefully shutdown the server. /// /// [`Server`]: struct.Server.html - pub async fn serve_with_incoming_shutdown( + pub async fn serve_with_incoming_shutdown( self, incoming: I, signal: F, @@ -577,6 +709,14 @@ where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, F: Future, + L: Layer>>, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>>>::Service as Service>>::Future: + Send + 'static, + <>>>::Service as Service>>::Error: + Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, { self.server .serve_with_shutdown(self.routes, incoming, Some(signal)) @@ -584,12 +724,23 @@ where } /// Create a tower service out of a router. - pub fn into_service(self) -> RouterService { - RouterService { router: self } + pub fn into_service(self) -> RouterService + where + L: Layer>>, + L::Service: Service, Response = Response> + Clone + Send + 'static, + <>>>::Service as Service>>::Future: + Send + 'static, + <>>>::Service as Service>>::Error: + Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, + { + let inner = self.server.layer.layer(self.routes); + RouterService { inner } } } -impl fmt::Debug for Server { +impl fmt::Debug for Server { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Builder").finish() } @@ -601,16 +752,16 @@ struct Svc { conn_info: ConnectionInfo, } -impl Service> for Svc +impl Service> for Svc where - S: Service, Response = Response>, + S: Service, Response = Response>, S::Error: Into, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, { - type Response = Response; + type Response = Response; type Error = crate::Error; - - #[allow(clippy::type_complexity)] - type Future = MapErr, fn(S::Error) -> crate::Error>; + type Future = SvcFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx).map_err(Into::into) @@ -633,7 +784,36 @@ where req.extensions_mut().insert(self.conn_info.clone()); - self.inner.call(req).instrument(span).map_err(|e| e.into()) + SvcFuture { + inner: self.inner.call(req), + span, + } + } +} + +#[pin_project] +struct SvcFuture { + #[pin] + inner: F, + span: tracing::Span, +} + +impl Future for SvcFuture +where + F: Future, E>>, + E: Into, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, +{ + type Output = Result, crate::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let _guard = this.span.enter(); + + let response: Response = ready!(this.inner.poll(cx)).map_err(Into::into)?; + let response = response.map(|body| body.map_err(Into::into).boxed()); + Poll::Ready(Ok(response)) } } @@ -650,11 +830,13 @@ struct MakeSvc { trace_interceptor: Option, } -impl Service<&ServerIo> for MakeSvc +impl Service<&ServerIo> for MakeSvc where - S: Service, Response = Response> + Clone + Send + 'static, + S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, + ResBody: http_body::Body + Send + Sync + 'static, + ResBody::Error: Into, { type Response = BoxService; type Error = crate::Error; @@ -681,11 +863,13 @@ where .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc); - let svc = BoxService::new(Svc { + let svc = Svc { inner: svc, trace_interceptor, conn_info, - }); + }; + + let svc = BoxService::new(svc); future::ready(Ok(svc)) } diff --git a/tonic/src/transport/server/recover_error.rs b/tonic/src/transport/server/recover_error.rs index 2004560c2..72b6bf82b 100644 --- a/tonic/src/transport/server/recover_error.rs +++ b/tonic/src/transport/server/recover_error.rs @@ -1,4 +1,7 @@ -use crate::{body::BoxBody, Status}; +use crate::{ + util::{OptionPin, OptionPinProj}, + Status, +}; use futures_util::ready; use http::Response; use pin_project::pin_project; @@ -22,12 +25,12 @@ impl RecoverError { } } -impl Service for RecoverError +impl Service for RecoverError where - S: Service>, + S: Service>, S::Error: Into, { - type Response = Response; + type Response = Response>; type Error = crate::Error; type Future = ResponseFuture; @@ -48,22 +51,25 @@ pub(crate) struct ResponseFuture { inner: F, } -impl Future for ResponseFuture +impl Future for ResponseFuture where - F: Future, E>>, + F: Future, E>>, E: Into, { - type Output = Result, crate::Error>; + type Output = Result>, crate::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let result: Result, crate::Error> = + let result: Result, crate::Error> = ready!(self.project().inner.poll(cx)).map_err(Into::into); match result { - Ok(res) => Poll::Ready(Ok(res)), + Ok(response) => { + let response = response.map(MaybeEmptyBody::full); + Poll::Ready(Ok(response)) + } Err(err) => { if let Some(status) = Status::try_from_error(&*err) { - let mut res = Response::new(crate::body::empty_body()); + let mut res = Response::new(MaybeEmptyBody::empty()); status.add_header(res.headers_mut()).unwrap(); Poll::Ready(Ok(res)) } else { @@ -73,3 +79,58 @@ where } } } + +#[pin_project] +pub(crate) struct MaybeEmptyBody { + #[pin] + inner: OptionPin, +} + +impl MaybeEmptyBody { + fn full(inner: B) -> Self { + Self { + inner: OptionPin::Some(inner), + } + } + + fn empty() -> Self { + Self { + inner: OptionPin::None, + } + } +} + +impl http_body::Body for MaybeEmptyBody +where + B: http_body::Body + Send, +{ + type Data = B::Data; + type Error = B::Error; + + fn poll_data( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + match self.project().inner.project() { + OptionPinProj::Some(b) => b.poll_data(cx), + OptionPinProj::None => Poll::Ready(None), + } + } + + fn poll_trailers( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>> { + match self.project().inner.project() { + OptionPinProj::Some(b) => b.poll_trailers(cx), + OptionPinProj::None => Poll::Ready(Ok(None)), + } + } + + fn is_end_stream(&self) -> bool { + match &self.inner { + OptionPin::Some(b) => b.is_end_stream(), + OptionPin::None => true, + } + } +} diff --git a/tonic/src/transport/service/grpc_timeout.rs b/tonic/src/transport/service/grpc_timeout.rs index 580addbac..17cea077e 100644 --- a/tonic/src/transport/service/grpc_timeout.rs +++ b/tonic/src/transport/service/grpc_timeout.rs @@ -1,4 +1,5 @@ use crate::metadata::GRPC_TIMEOUT_HEADER; +use crate::util::{OptionPin, OptionPinProj}; use http::{HeaderMap, HeaderValue, Request}; use pin_project::pin_project; use std::{ @@ -97,12 +98,6 @@ where } } -#[pin_project(project = OptionPinProj)] -enum OptionPin { - Some(#[pin] T), - None, -} - const SECONDS_IN_HOUR: u64 = 60 * 60; const SECONDS_IN_MINUTE: u64 = 60; diff --git a/tonic/src/transport/service/router.rs b/tonic/src/transport/service/router.rs index 76f0d0c2b..2eff8cafe 100644 --- a/tonic/src/transport/service/router.rs +++ b/tonic/src/transport/service/router.rs @@ -9,8 +9,9 @@ use std::{ }; use tower_service::Service; +#[doc(hidden)] #[derive(Debug)] -pub(crate) struct Routes { +pub struct Routes { routes: Or, } diff --git a/tonic/src/util.rs b/tonic/src/util.rs new file mode 100644 index 000000000..2dd303df7 --- /dev/null +++ b/tonic/src/util.rs @@ -0,0 +1,13 @@ +//! Various utilities used throughout tonic. + +// some combinations of features might cause things here not to be used +#![allow(dead_code)] + +use pin_project::pin_project; + +/// A pin-project compatible `Option` +#[pin_project(project = OptionPinProj)] +pub(crate) enum OptionPin { + Some(#[pin] T), + None, +}