Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): provide generic access to connect info #647

Merged
merged 3 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion examples/src/tls/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use futures::Stream;
use pb::{EchoRequest, EchoResponse};
use std::pin::Pin;
use tonic::{
transport::{Identity, Server, ServerTlsConfig},
transport::{
server::{TcpConnectInfo, TlsConnectInfo},
Identity, Server, ServerTlsConfig,
},
Request, Response, Status, Streaming,
};

Expand All @@ -19,6 +22,16 @@ pub struct EchoServer;
#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
let conn_info = request
.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.unwrap();
println!(
"Got a request from {:?} with info {:?}",
request.remote_addr(),
conn_info
);

let message = request.into_inner().message;
Ok(Response::new(EchoResponse { message }))
}
Expand Down
24 changes: 22 additions & 2 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);
#[cfg(unix)]
{
let conn_info = request.extensions().get::<unix::UdsConnectInfo>().unwrap();
println!("Got a request {:?} with info {:?}", request, conn_info);
}

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
Expand Down Expand Up @@ -64,6 +68,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
mod unix {
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

Expand All @@ -73,7 +78,22 @@ mod unix {
#[derive(Debug)]
pub struct UnixStream(pub tokio::net::UnixStream);

impl Connected for UnixStream {}
impl Connected for UnixStream {
type ConnectInfo = UdsConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
UdsConnectInfo {
peer_addr: self.0.peer_addr().ok().map(Arc::new),
peer_cred: self.0.peer_cred().ok(),
}
}
}

#[derive(Clone, Debug)]
pub struct UdsConnectInfo {
pub peer_addr: Option<Arc<tokio::net::unix::SocketAddr>>,
pub peer_cred: Option<tokio::net::unix::UCred>,
}

impl AsyncRead for UnixStream {
fn poll_read(
Expand Down
50 changes: 50 additions & 0 deletions tests/integration_tests/tests/connect_info.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use futures_util::FutureExt;
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::time::Duration;
use tokio::sync::oneshot;
use tonic::{
transport::{server::TcpConnectInfo, Endpoint, Server},
Request, Response, Status,
};

#[tokio::test]
async fn getting_connect_info() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
assert!(req.remote_addr().is_some());
assert!(req.extensions().get::<TcpConnectInfo>().is_some());

Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::new(Svc);

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1400")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}
52 changes: 39 additions & 13 deletions tonic/src/request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::metadata::{MetadataMap, MetadataValue};
#[cfg(all(feature = "transport", feature = "tls"))]
use crate::transport::server::TlsConnectInfo;
#[cfg(feature = "transport")]
use crate::transport::Certificate;
use crate::transport::{server::TcpConnectInfo, Certificate};
use crate::Extensions;
use futures_core::Stream;
#[cfg(feature = "transport")]
Expand All @@ -15,13 +17,6 @@ pub struct Request<T> {
extensions: Extensions,
}

#[derive(Clone)]
pub(crate) struct ConnectionInfo {
pub(crate) remote_addr: Option<SocketAddr>,
#[cfg(feature = "transport")]
pub(crate) peer_certs: Option<Arc<Vec<Certificate>>>,
}

/// Trait implemented by RPC request types.
///
/// Types implementing this trait can be used as arguments to client RPC
Expand Down Expand Up @@ -203,7 +198,32 @@ impl<T> Request<T> {
/// does not implement `Connected`. This currently,
/// only works on the server side.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.get::<ConnectionInfo>()?.remote_addr
#[cfg(feature = "transport")]
{
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
.or_else(|| {
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.get_ref().remote_addr())
})
}

#[cfg(not(feature = "tls"))]
{
self.extensions()
.get::<TcpConnectInfo>()
.and_then(|i| i.remote_addr())
}
}

#[cfg(not(feature = "transport"))]
{
None
}
}

/// Get the peer certificates of the connected client.
Expand All @@ -215,11 +235,17 @@ impl<T> Request<T> {
#[cfg(feature = "transport")]
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.get::<ConnectionInfo>()?.peer_certs.clone()
}
#[cfg(feature = "tls")]
{
self.extensions()
.get::<TlsConnectInfo<TcpConnectInfo>>()
.and_then(|i| i.peer_certs())
}

pub(crate) fn get<I: Send + Sync + 'static>(&self) -> Option<&I> {
self.extensions.get::<I>()
#[cfg(not(feature = "tls"))]
{
None
}
}

/// Set the max duration the request is allowed to take.
Expand Down
148 changes: 124 additions & 24 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,158 @@
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;

#[cfg(feature = "tls")]
use crate::transport::Certificate;
#[cfg(feature = "tls")]
use std::sync::Arc;
#[cfg(feature = "tls")]
use tokio_rustls::{rustls::Session, server::TlsStream};

/// Trait that connected IO resources implement.
/// Trait that connected IO resources implement and use to produce info about the connection.
///
/// The goal for this trait is to allow users to implement
/// custom IO types that can still provide the same connection
/// metadata.
///
/// # Example
///
/// The `ConnectInfo` returned will be accessible through [request extensions][ext]:
///
/// ```
/// use tonic::{Request, transport::server::Connected};
///
/// // A `Stream` that yields connections
/// struct MyConnector {}
///
/// // Return metadata about the connection as `MyConnectInfo`
/// impl Connected for MyConnector {
/// type ConnectInfo = MyConnectInfo;
///
/// fn connect_info(&self) -> Self::ConnectInfo {
/// MyConnectInfo {}
/// }
/// }
///
/// #[derive(Clone)]
/// struct MyConnectInfo {
/// // Metadata about your connection
/// }
///
/// // The connect info can be accessed through request extensions:
/// # fn foo(request: Request<()>) {
/// let connect_info: &MyConnectInfo = request
/// .extensions()
/// .get::<MyConnectInfo>()
/// .expect("bug in tonic");
/// # }
/// ```
///
/// [ext]: crate::Request::extensions
pub trait Connected {
/// Return the remote address this IO resource is connected too.
fn remote_addr(&self) -> Option<SocketAddr> {
None
}
/// The connection info type the IO resources generates.
// all these bounds are necessary to set this as a request extension
type ConnectInfo: Clone + Send + Sync + 'static;

/// Return the set of connected peer TLS certificates.
fn peer_certs(&self) -> Option<Vec<Certificate>> {
None
/// Create type holding information about the connection.
fn connect_info(&self) -> Self::ConnectInfo;
}

/// Connection info for standard TCP streams.
///
/// This type will be accessible through [request extensions][ext] if you're using the default
/// non-TLS connector.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[derive(Debug, Clone)]
pub struct TcpConnectInfo {
remote_addr: Option<SocketAddr>,
}

impl TcpConnectInfo {
/// Return the remote address the IO resource is connected too.
pub fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
}
}

impl Connected for AddrStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr())
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: Some(self.remote_addr()),
}
}
}

impl Connected for TcpStream {
fn remote_addr(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
type ConnectInfo = TcpConnectInfo;

fn connect_info(&self) -> Self::ConnectInfo {
TcpConnectInfo {
remote_addr: self.peer_addr().ok(),
}
}
}

#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
let (inner, _) = self.get_ref();

inner.remote_addr()
}
impl<T> Connected for TlsStream<T>
where
T: Connected,
{
type ConnectInfo = TlsConnectInfo<T::ConnectInfo>;

fn peer_certs(&self) -> Option<Vec<Certificate>> {
let (_, session) = self.get_ref();
fn connect_info(&self) -> Self::ConnectInfo {
let (inner, session) = self.get_ref();
let inner = inner.connect_info();

if let Some(certs) = session.get_peer_certificates() {
let certs = if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
Some(Arc::new(certs))
} else {
None
}
};

TlsConnectInfo { inner, certs }
}
}

/// Connection info for TLS streams.
///
/// This type will be accessible through [request extensions][ext] if you're using a TLS connector.
///
/// See [`Connected`] for more details.
///
/// [ext]: crate::Request::extensions
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
#[derive(Debug, Clone)]
pub struct TlsConnectInfo<T> {
inner: T,
certs: Option<Arc<Vec<Certificate>>>,
}

#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
impl<T> TlsConnectInfo<T> {
/// Get a reference to the underlying connection info.
pub fn get_ref(&self) -> &T {
&self.inner
}

/// Get a mutable reference to the underlying connection info.
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}

/// Return the set of connected peer TLS certificates.
pub fn peer_certs(&self) -> Option<Arc<Vec<Certificate>>> {
self.certs.clone()
}
}
Loading