Skip to content

Commit

Permalink
feat(transport): Allow custom IO and UDS example (#184)
Browse files Browse the repository at this point in the history
Closes #136
  • Loading branch information
LucioFranco authored Dec 13, 2019
1 parent 7077d8d commit b90c340
Show file tree
Hide file tree
Showing 12 changed files with 306 additions and 115 deletions.
10 changes: 9 additions & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,20 @@ path = "src/tracing/client.rs"
name = "tracing-server"
path = "src/tracing/server.rs"

[[bin]]
name = "uds-client"
path = "src/uds/client.rs"

[[bin]]
name = "uds-server"
path = "src/uds/server.rs"

[dependencies]
tonic = { path = "../tonic", features = ["tls"] }
bytes = "0.4"
prost = "0.5"

tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros"] }
tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] }
futures = { version = "0.3", default-features = false, features = ["alloc"]}
async-stream = "0.2"
http = "0.2"
Expand Down
39 changes: 39 additions & 0 deletions examples/src/uds/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#[cfg(unix)]

pub mod hello_world {
tonic::include_proto!("helloworld");
}

use hello_world::{greeter_client::GreeterClient, HelloRequest};
use http::Uri;
use std::convert::TryFrom;
use tokio::net::UnixStream;
use tonic::transport::Endpoint;
use tower::service_fn;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// We will ignore this uri because uds do not use it
// if your connector does use the uri it will be provided
// as the request to the `MakeConnection`.
let channel = Endpoint::try_from("lttp://[::]:50051")?
.connect_with_connector(service_fn(|_: Uri| {
let path = "/tmp/tonic/helloworld";

// Connect to a Uds socket
UnixStream::connect(path)
}))
.await?;

let mut client = GreeterClient::new(channel);

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
});

let response = client.say_hello(request).await?;

println!("RESPONSE={:?}", response);

Ok(())
}
48 changes: 48 additions & 0 deletions examples/src/uds/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::path::Path;
use tokio::net::UnixListener;
use tonic::{transport::Server, Request, Response, Status};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

use hello_world::{
greeter_server::{Greeter, GreeterServer},
HelloReply, HelloRequest,
};

#[derive(Default)]
pub struct MyGreeter {}

#[tonic::async_trait]
impl Greeter for MyGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
println!("Got a request: {:?}", request);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name).into(),
};
Ok(Response::new(reply))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let path = "/tmp/tonic/helloworld";

tokio::fs::create_dir_all(Path::new(path).parent().unwrap()).await?;

let mut uds = UnixListener::bind(path)?;

let greeter = MyGreeter::default();

Server::builder()
.add_service(GreeterServer::new(greeter))
.serve_with_incoming(uds.incoming())
.await?;

Ok(())
}
32 changes: 31 additions & 1 deletion tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::super::service;
use super::Channel;
#[cfg(feature = "tls")]
use super::ClientTlsConfig;
Expand All @@ -12,6 +13,7 @@ use std::{
sync::Arc,
time::Duration,
};
use tower_make::MakeConnection;

/// Channel builder.
///
Expand Down Expand Up @@ -182,7 +184,35 @@ impl Endpoint {

/// Create a channel from this config.
pub async fn connect(&self) -> Result<Channel, Error> {
Channel::connect(self.clone()).await
let mut http = hyper::client::connect::HttpConnector::new();
http.enforce_http(false);
http.set_nodelay(self.tcp_nodelay);
http.set_keepalive(self.tcp_keepalive);

#[cfg(feature = "tls")]
let connector = service::connector(http, self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = service::connector(http);

Channel::connect(connector, self.clone()).await
}

/// Connect with a custom connector.
pub async fn connect_with_connector<C>(&self, connector: C) -> Result<Channel, Error>
where
C: MakeConnection<Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
crate::Error: From<C::Error> + Send + 'static,
{
#[cfg(feature = "tls")]
let connector = service::connector(connector, self.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = service::connector(connector);

Channel::connect(connector, self.clone()).await
}
}

Expand Down
12 changes: 10 additions & 2 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ use http::{
uri::{InvalidUri, Uri},
Request, Response,
};
use hyper::client::connect::Connection as HyperConnection;
use std::{
fmt,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
buffer::{self, Buffer},
discover::Discover,
Expand Down Expand Up @@ -121,11 +123,17 @@ impl Channel {
Self::balance(discover, buffer_size, interceptor_headers)
}

pub(crate) async fn connect(endpoint: Endpoint) -> Result<Self, super::Error> {
pub(crate) async fn connect<C>(connector: C, endpoint: Endpoint) -> Result<Self, super::Error>
where
C: Service<Uri> + Send + 'static,
C::Error: Into<crate::Error> + Send,
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static,
{
let buffer_size = endpoint.buffer_size.clone().unwrap_or(DEFAULT_BUFFER_SIZE);
let interceptor_headers = endpoint.interceptor_headers.clone();

let svc = Connection::new(endpoint)
let svc = Connection::new(connector, endpoint)
.await
.map_err(|e| super::Error::from_source(super::ErrorKind::Client, e))?;

Expand Down
75 changes: 75 additions & 0 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::Server;
use crate::transport::service::BoxedIo;
use futures_core::Stream;
use futures_util::stream::TryStreamExt;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use std::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use tracing::error;

#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<BoxedIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
futures_util::pin_mut!(incoming);

while let Some(stream) = incoming.try_next().await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &server.tls {
let io = match tls.accept(stream).await {
Ok(io) => io,
Err(error) => {
error!(message = "Unable to accept incoming connection.", %error);
continue
},
};
yield BoxedIo::new(io);
continue;
}
}

yield BoxedIo::new(stream);
}
}
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}

impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
Loading

0 comments on commit b90c340

Please sign in to comment.