diff --git a/tonic/src/transport/server.rs b/tonic/src/transport/server.rs index 02701b44a..a16657440 100644 --- a/tonic/src/transport/server.rs +++ b/tonic/src/transport/server.rs @@ -261,7 +261,6 @@ impl Server { } }, ); - let svc = MakeSvc { inner: svc, interceptor, @@ -280,6 +279,71 @@ impl Server { Ok(()) } + + pub(crate) async fn serve_with_shutdown( + self, + addr: SocketAddr, + svc: S, + signal: F, + ) -> Result<(), super::Error> + where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into + Send, + F: Future, + { + let interceptor = self.interceptor.clone(); + let concurrency_limit = self.concurrency_limit; + let init_connection_window_size = self.init_connection_window_size; + let init_stream_window_size = self.init_stream_window_size; + let max_concurrent_streams = self.max_concurrent_streams; + // let timeout = self.timeout.clone(); + + let incoming = hyper::server::accept::from_stream::<_, _, crate::Error>( + async_stream::try_stream! { + let mut tcp = TcpIncoming::bind(addr)? + .set_nodelay(self.tcp_nodelay) + .set_keepalive(self.tcp_keepalive); + + while let Some(stream) = tcp.try_next().await? { + #[cfg(feature = "tls")] + { + if let Some(tls) = &self.tls { + let io = match tls.connect(stream.into_inner()).await { + Ok(io) => io, + Err(error) => { + error!(message = "Unable to accept incoming connection.", %error); + continue + }, + }; + yield BoxedIo::new(io); + continue; + } + } + + yield BoxedIo::new(stream); + } + }, + ); + + let svc = MakeSvc { + inner: svc, + interceptor, + concurrency_limit, + // timeout, + }; + hyper::Server::builder(incoming) + .http2_only(true) + .http2_initial_connection_window_size(init_connection_window_size) + .http2_initial_stream_window_size(init_stream_window_size) + .http2_max_concurrent_streams(max_concurrent_streams) + .serve(svc) + .with_graceful_shutdown(signal) + .await + .map_err(map_err)?; + + Ok(()) + } } impl Router { @@ -348,6 +412,19 @@ where pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> { self.server.serve(addr, self.routes).await } + + /// Consume this [`Server`] creating a future that will execute the server + /// on [`tokio`]'s default executor. And shutdown when the provided signal + /// is received. + /// + /// [`Server`]: struct.Server.html + pub async fn serve_with_shutdown>( + self, + addr: SocketAddr, + f: F, + ) -> Result<(), super::Error> { + self.server.serve_with_shutdown(addr, self.routes, f).await + } } fn map_err(e: impl Into) -> super::Error {