Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uniform API for custom headers between clients #814

Merged
merged 13 commits into from
Jul 13, 2022
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ thiserror = "1.0"
tokio = { version = "1.16", features = ["time"] }
tracing = "0.1.34"
tracing-futures = "0.2.5"
http = "0.2.0"
lexnv marked this conversation as resolved.
Show resolved Hide resolved

[dev-dependencies]
jsonrpsee-test-utils = { path = "../../test-utils" }
Expand Down
44 changes: 41 additions & 3 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@ use serde::de::DeserializeOwned;
use tracing_futures::Instrument;

/// Http Client Builder.
///
/// # Examples
///
/// ```no_run
///
/// use jsonrpsee_http_client::HttpClientBuilder;
///
/// #[tokio::main]
/// async fn main() {
/// // Build custom headers used for every submitted request.
/// let mut headers = http::HeaderMap::new();
/// headers.insert("Any-Header-You-Like", http::HeaderValue::from_static("42"));
///
/// // Build client
/// let client = HttpClientBuilder::default()
/// .set_headers(headers)
/// .build("wss://localhost:443")
/// .unwrap();
///
/// // use client....
/// }
///
/// ```
#[derive(Debug)]
pub struct HttpClientBuilder {
max_request_body_size: u32,
Expand All @@ -47,6 +70,7 @@ pub struct HttpClientBuilder {
certificate_store: CertificateStore,
id_kind: IdKind,
max_log_length: u32,
headers: http::HeaderMap,
}

impl HttpClientBuilder {
Expand Down Expand Up @@ -88,11 +112,24 @@ impl HttpClientBuilder {
self
}

/// Set a custom header passed to the server with every request (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport =
HttpTransportClient::new(target, self.max_request_body_size, self.certificate_store, self.max_log_length)
.map_err(|e| Error::Transport(e.into()))?;
let transport = HttpTransportClient::new(
target,
self.max_request_body_size,
self.certificate_store,
self.max_log_length,
self.headers,
)
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(self.max_concurrent_requests, self.id_kind)),
Expand All @@ -110,6 +147,7 @@ impl Default for HttpClientBuilder {
certificate_store: CertificateStore::Native,
id_kind: IdKind::Number,
max_log_length: 4096,
headers: http::HeaderMap::new(),
}
}
}
Expand Down
87 changes: 70 additions & 17 deletions client/http-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct HttpTransportClient {
///
/// Logs bigger than this limit will be truncated.
max_log_length: u32,
/// Custom headers to pass with every request.
headers: http::HeaderMap,
}

impl HttpTransportClient {
Expand All @@ -57,6 +59,7 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
max_log_length: u32,
headers: http::HeaderMap,
) -> Result<Self, Error> {
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
Expand Down Expand Up @@ -90,7 +93,20 @@ impl HttpTransportClient {
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size, max_log_length })

// Cache request headers: 2 default headers, followed by user custom headers.
// Maintain order for headers in case of duplicate keys:
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2
let mut cached_headers = http::HeaderMap::with_capacity(2 + headers.len());
cached_headers.insert(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON));
cached_headers.insert(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON));
for (key, value) in headers.into_iter() {
if let Some(key) = key {
cached_headers.insert(key, value);
}
}

Ok(Self { target, client, max_request_body_size, max_log_length, headers: cached_headers })
}

async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
Expand All @@ -100,11 +116,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}

let req = hyper::Request::post(&self.target)
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
.expect("URI and request headers are valid; qed");
let mut req = hyper::Request::post(&self.target);
req.headers_mut().map(|headers| *headers = self.headers.clone());
let req = req.body(From::from(body)).expect("URI and request headers are valid; qed");

let response = self.client.request(req).await.map_err(|e| Error::Http(Box::new(e)))?;
if response.status().is_success() {
Expand Down Expand Up @@ -198,37 +212,67 @@ mod tests {

#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80, http::HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"https://localhost:9933",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}

#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new(
"https://localhost:9933",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80, http::HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new(
"http://localhost:-99999",
80,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native, 80)
.unwrap();
let client = HttpTransportClient::new(
"http://localhost:9944/my-special-path",
1337,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}

Expand All @@ -239,22 +283,31 @@ mod tests {
u32::MAX,
CertificateStore::WebPki,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
}

#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"http://127.0.0.1:9944/my.htm#ignore",
999,
CertificateStore::Native,
80,
http::HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}

#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
let client = HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99).unwrap();
let client =
HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99, http::HeaderMap::new())
.unwrap();
assert_eq!(client.max_request_body_size, eighty_bytes_limit);

let body = "a".repeat(81);
Expand Down
26 changes: 15 additions & 11 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,32 @@ pub struct Receiver {

/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder<'a> {
pub struct WsTransportClientBuilder {
/// What certificate store to use
pub certificate_store: CertificateStore,
/// Timeout for the connection.
pub connection_timeout: Duration,
/// Custom headers to pass during the HTTP handshake. If `None`, no
/// custom header is passed.
pub headers: Vec<Header<'a>>,
/// Custom headers to pass during the HTTP handshake.
pub headers: http::HeaderMap,
/// Max payload size
pub max_request_body_size: u32,
/// Max number of redirections.
pub max_redirections: usize,
}

impl<'a> Default for WsTransportClientBuilder<'a> {
impl Default for WsTransportClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
connection_timeout: Duration::from_secs(10),
headers: Vec::new(),
headers: http::HeaderMap::new(),
max_redirections: 5,
}
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Set whether to use system certificates (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
Expand All @@ -107,8 +106,8 @@ impl<'a> WsTransportClientBuilder<'a> {
/// Set a custom header passed to the server during the handshake (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

Expand Down Expand Up @@ -240,7 +239,7 @@ impl TransportReceiverT for Receiver {
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Try to establish the connection.
pub async fn build(self, uri: Uri) -> Result<(Sender, Receiver), WsHandshakeError> {
let target: Target = uri.try_into()?;
Expand Down Expand Up @@ -289,7 +288,12 @@ impl<'a> WsTransportClientBuilder<'a> {
&target.path_and_query,
);

client.set_headers(&self.headers);
let headers: Vec<_> = self
.headers
.iter()
.map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() })
.collect();
client.set_headers(&headers);

// Perform the initial handshake.
match client.handshake().await {
Expand Down
1 change: 1 addition & 0 deletions client/ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client"
jsonrpsee-types = { path = "../../types", version = "0.14.0" }
jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", features = ["ws"] }
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-client"] }
http = "0.2.0"

[dev-dependencies]
env_logger = "0.9"
Expand Down
Loading