Skip to content

Commit

Permalink
feat(transport): provide generic access to connect info
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed May 17, 2021
1 parent f613386 commit f1c5257
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 87 deletions.
15 changes: 14 additions & 1 deletion examples/src/tls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use futures::Stream;
use pb::{EchoRequest, EchoResponse};
use std::pin::Pin;
use tonic::{
transport::{Identity, Server, ServerTlsConfig},
transport::{
server::{TcpConnectInfo, TlsConnectInfo},
Identity, Server, ServerTlsConfig,
},
Request, Response, Status, Streaming,
};

Expand All @@ -19,6 +22,16 @@ pub struct EchoServer;
#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
let conn_info = request
.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.unwrap();
println!(
"Got a request from {:?} with info {:?}",
request.remote_addr(),
conn_info
);

let message = request.into_inner().message;
Ok(Response::new(EchoResponse { message }))
}
Expand Down
24 changes: 22 additions & 2 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);
}

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
Expand Down Expand Up @@ -64,6 +68,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

Expand All @@ -73,7 +78,22 @@ mod unix {
#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {}
impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
Expand Down
50 changes: 50 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{
transport::{server::TcpConnectInfo, Endpoint, Server},
Request, Response, Status,
};

#[tokio::test]
async fn getting_connect_info() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
assert!(req.remote_addr().is_some());
assert!(req.extensions().get::<TcpConnectInfo>().is_some());

Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1400")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}
52 changes: 39 additions & 13 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(all(feature = "transport", feature = "tls"))]
use crate::transport::server::TlsConnectInfo;
#[cfg(feature = "transport")]
use crate::transport::Certificate;
use crate::transport::{server::TcpConnectInfo, Certificate};
use crate::Extensions;
use futures_core::Stream;
#[cfg(feature = "transport")]
Expand All @@ -15,13 +17,6 @@ pub struct Request<T> {
extensions: Extensions,
}

#[derive(Clone)]
pub(crate) struct ConnectionInfo {
pub(crate) remote_addr: Option<SocketAddr>,
#[cfg(feature = "transport")]
pub(crate) peer_certs: Option<Arc<Vec<Certificate>>>,
}

/// Trait implemented by RPC request types.
///
/// Types implementing this trait can be used as arguments to client RPC
Expand Down Expand Up @@ -203,7 +198,32 @@ impl<T> Request<T> {
/// does not implement `Connected`. This currently,
/// only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.get::<ConnectionInfo>()?.remote_addr
#[cfg(feature = "transport")]
{
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
.or_else(|| {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.get_ref().remote_addr())
})
}

#[cfg(not(feature = "tls"))]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
}
}

#[cfg(not(feature = "transport"))]
{
None
}
}

/// Get the peer certificates of the connected client.
Expand All @@ -215,11 +235,17 @@ impl<T> Request<T> {
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.get::<ConnectionInfo>()?.peer_certs.clone()
}
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.peer_certs())
}

pub(crate) fn get<I: Send + Sync + 'static>(&self) -> Option<&I> {
self.extensions.get::<I>()
#[cfg(not(feature = "tls"))]
{
None
}
}

/// Set the max duration the request is allowed to take.
Expand Down
148 changes: 124 additions & 24 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,158 @@
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use crate::transport::Certificate;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};

/// Trait that connected IO resources implement.
/// Trait that connected IO resources implement and use to produce info about the connection.
///
/// The goal for this trait is to allow users to implement
/// custom IO types that can still provide the same connection
/// metadata.
///
/// # Example
///
/// The `ConnectInfo` returned will be accessible through [request extensions][ext]:
///
/// ```
/// use tonic::{Request, transport::server::Connected};
///
/// // A `Stream` that yields connections
/// struct MyConnector {}
///
/// // Return metadata about the connection as `MyConnectInfo`
/// impl Connected for MyConnector {
/// type ConnectInfo = MyConnectInfo;
///
/// fn connect_info(&self) -> Self::ConnectInfo {
/// MyConnectInfo {}
/// }
/// }
///
/// #[derive(Clone)]
/// struct MyConnectInfo {
/// // Metadata about your connection
/// }
///
/// // The connect info can be accessed through request extensions:
/// # fn foo(request: Request<()>) {
/// let connect_info: &MyConnectInfo = request
/// .extensions()
/// .get::<MyConnectInfo>()
/// .expect("bug in tonic");
/// # }
/// ```
///
/// [ext]: crate::Request::extensions
pub trait Connected {
/// Return the remote address this IO resource is connected too.
fn remote_addr(&self) -> Option<SocketAddr> {
None
}
/// The connection info type the IO resources generates.
// all these bounds are necessary to set this as a request extension
type ConnectInfo: Clone + Send + Sync + 'static;

/// Return the set of connected peer TLS certificates.
fn peer_certs(&self) -> Option<Vec<Certificate>> {
None
/// Create type holding information about the connection.
fn connect_info(&self) -> Self::ConnectInfo;
}

/// Connection info for standard TCP streams.
///
/// This type will be accessible through [request extensions][ext] if you're using the default
/// non-TLS connector.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[derive(Debug, Clone)]
pub struct TcpConnectInfo {
remote_addr: Option<SocketAddr>,
}

impl TcpConnectInfo {
/// Return the remote address the IO resource is connected too.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
}

impl Connected for AddrStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr())
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: Some(self.remote_addr()),
}
}
}

impl Connected for TcpStream {
fn remote_addr(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: self.peer_addr().ok(),
}
}
}

#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
let (inner, _) = self.get_ref();

inner.remote_addr()
}
impl<T> Connected for TlsStream<T>
where
T: Connected,
{
type ConnectInfo = TlsConnectInfo<T::ConnectInfo>;

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let (_, session) = self.get_ref();
fn connect_info(&self) -> Self::ConnectInfo {
let (inner, session) = self.get_ref();
let inner = inner.connect_info();

if let Some(certs) = session.get_peer_certificates() {
let certs = if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
Some(Arc::new(certs))
} else {
None
}
};

TlsConnectInfo { inner, certs }
}
}

/// Connection info for TLS streams.
///
/// This type will be accessible through [request extensions][ext] if you're using a TLS connector.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[derive(Debug, Clone)]
pub struct TlsConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<Certificate>>>,
}

#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
impl<T> TlsConnectInfo<T> {
/// Get a reference to the underlying connection info.
pub fn get_ref(&self) -> &T {
&self.inner
}

/// Get a mutable reference to the underlying connection info.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}

/// Return the set of connected peer TLS certificates.
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.certs.clone()
}
}
Loading

0 comments on commit f1c5257

Please sign in to comment.