Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: feat: client rpc middleware #1521

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ publish = true
[dependencies]
async-trait = { workspace = true }
base64 = { workspace = true }
futures-util = { workspace = true }
hyper = { workspace = true, features = ["client", "http1", "http2"] }
hyper-rustls = { workspace = true, features = ["http1", "http2", "tls12", "logging", "ring"], optional = true }
hyper-util = { workspace = true, features = ["client", "client-legacy", "tokio", "http1", "http2"] }
Expand Down
162 changes: 70 additions & 92 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,21 @@ use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClient, HttpTransportClientBuilder};
use crate::types::{NotificationSer, RequestSer, Response};
use crate::rpc_service::RpcService;
use crate::transport::{self, Error as TransportError, HttpBackend, HttpTransportClientBuilder};
use crate::types::Response;
use crate::{HttpRequest, HttpResponse};
use async_trait::async_trait;
use hyper::body::Bytes;
use hyper::http::HeaderMap;
use jsonrpsee_core::client::{
generate_batch_id_range, BatchResponse, ClientT, Error, IdKind, RequestIdManager, Subscription, SubscriptionClientT,
};
use jsonrpsee_core::middleware::{RpcServiceBuilder, RpcServiceT};
use jsonrpsee_core::params::BatchRequestBuilder;
use jsonrpsee_core::traits::ToRpcParams;
use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::{ErrorObject, InvalidRequestId, ResponseSuccess, TwoPointZero};
use jsonrpsee_types::{InvalidRequestId, Notification, Request, ResponseSuccess, TwoPointZero};
use serde::de::DeserializeOwned;
use tokio::sync::Semaphore;
use tower::layer::util::Identity;
Expand Down Expand Up @@ -75,7 +77,7 @@ use crate::{CertificateStore, CustomCertStore};
/// }
/// ```
#[derive(Clone, Debug)]
pub struct HttpClientBuilder<L = Identity> {
pub struct HttpClientBuilder<HttpMiddleware = Identity, RpcMiddleware = Identity> {
max_request_size: u32,
max_response_size: u32,
request_timeout: Duration,
Expand All @@ -84,12 +86,13 @@ pub struct HttpClientBuilder<L = Identity> {
id_kind: IdKind,
max_log_length: u32,
headers: HeaderMap,
service_builder: tower::ServiceBuilder<L>,
service_builder: tower::ServiceBuilder<HttpMiddleware>,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
tcp_no_delay: bool,
max_concurrent_requests: Option<usize>,
}

impl<L> HttpClientBuilder<L> {
impl<HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware> {
/// Set the maximum size of a request body in bytes. Default is 10 MiB.
pub fn max_request_size(mut self, size: u32) -> Self {
self.max_request_size = size;
Expand Down Expand Up @@ -215,8 +218,29 @@ impl<L> HttpClientBuilder<L> {
self
}

/// Set the RPC middleware.
pub fn set_rpc_middleware<T>(self, rpc_builder: RpcServiceBuilder<T>) -> HttpClientBuilder<HttpMiddleware, T> {
HttpClientBuilder {
#[cfg(feature = "tls")]
certificate_store: self.certificate_store,
id_kind: self.id_kind,
headers: self.headers,
max_log_length: self.max_log_length,
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
service_builder: self.service_builder,
rpc_middleware: rpc_builder,
request_timeout: self.request_timeout,
tcp_no_delay: self.tcp_no_delay,
max_concurrent_requests: self.max_concurrent_requests,
}
}

/// Set custom tower middleware.
pub fn set_http_middleware<T>(self, service_builder: tower::ServiceBuilder<T>) -> HttpClientBuilder<T> {
pub fn set_http_middleware<T>(
self,
service_builder: tower::ServiceBuilder<T>,
) -> HttpClientBuilder<T, RpcMiddleware> {
HttpClientBuilder {
#[cfg(feature = "tls")]
certificate_store: self.certificate_store,
Expand All @@ -226,23 +250,26 @@ impl<L> HttpClientBuilder<L> {
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
service_builder,
rpc_middleware: self.rpc_middleware,
request_timeout: self.request_timeout,
tcp_no_delay: self.tcp_no_delay,
max_concurrent_requests: self.max_concurrent_requests,
}
}
}

impl<B, S, L> HttpClientBuilder<L>
impl<B, S, S2, HttpMiddleware, RpcMiddleware> HttpClientBuilder<HttpMiddleware, RpcMiddleware>
where
L: Layer<transport::HttpBackend, Service = S>,
RpcMiddleware: Layer<RpcService<S>, Service = S2>,
for<'a> <RpcMiddleware as Layer<RpcService<S>>>::Service: RpcServiceT<'a>,
HttpMiddleware: Layer<transport::HttpBackend, Service = S>,
S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Clone,
B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S>, Error> {
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient<S2>, Error> {
let Self {
max_request_size,
max_response_size,
Expand All @@ -254,10 +281,11 @@ where
max_log_length,
service_builder,
tcp_no_delay,
rpc_middleware,
..
} = self;

let transport = HttpTransportClientBuilder {
let http = HttpTransportClientBuilder {
max_request_size,
max_response_size,
headers,
Expand All @@ -266,6 +294,7 @@ where
service_builder,
#[cfg(feature = "tls")]
certificate_store,
request_timeout,
}
.build(target)
.map_err(|e| Error::Transport(e.into()))?;
Expand All @@ -275,9 +304,8 @@ where
.map(|max_concurrent_requests| Arc::new(Semaphore::new(max_concurrent_requests)));

Ok(HttpClient {
transport,
transport: rpc_middleware.service(RpcService::new(http, max_response_size)),
id_manager: Arc::new(RequestIdManager::new(id_kind)),
request_timeout,
request_guard,
})
}
Expand All @@ -295,6 +323,7 @@ impl Default for HttpClientBuilder<Identity> {
max_log_length: 4096,
headers: HeaderMap::new(),
service_builder: tower::ServiceBuilder::new(),
rpc_middleware: RpcServiceBuilder::default(),
tcp_no_delay: true,
max_concurrent_requests: None,
}
Expand All @@ -310,11 +339,9 @@ impl HttpClientBuilder<Identity> {

/// JSON-RPC HTTP Client that provides functionality to perform method calls and notifications.
#[derive(Debug, Clone)]
pub struct HttpClient<S = HttpBackend> {
pub struct HttpClient<S> {
/// HTTP transport client.
transport: HttpTransportClient<S>,
/// Request timeout. Defaults to 60sec.
request_timeout: Duration,
transport: S,
/// Request ID manager.
id_manager: Arc<RequestIdManager>,
/// Concurrent requests limit guard.
Expand All @@ -329,13 +356,9 @@ impl HttpClient<HttpBackend> {
}

#[async_trait]
impl<B, S> ClientT for HttpClient<S>
impl<S> ClientT for HttpClient<S>
where
S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
<S as Service<HttpRequest>>::Future: Send,
B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
B::Error: Into<BoxError>,
B::Data: Send,
for<'a> S: RpcServiceT<'a> + Send + Sync,
{
#[instrument(name = "notification", skip(self, params), level = "trace")]
async fn notification<Params>(&self, method: &str, params: Params) -> Result<(), Error>
Expand All @@ -347,16 +370,9 @@ where
None => None,
};
let params = params.to_rpc_params()?;
let notif =
serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?;

let fut = self.transport.send(notif);

match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(ok)) => Ok(ok),
Err(_) => Err(Error::RequestTimeout),
Ok(Err(e)) => Err(Error::Transport(e.into())),
}
let n = Notification { jsonrpc: TwoPointZero, method: method.into(), params };
self.transport.notification(n).await;
Ok(())
}

#[instrument(name = "method_call", skip(self, params), level = "trace")]
Expand All @@ -372,23 +388,12 @@ where
let id = self.id_manager.next_request_id();
let params = params.to_rpc_params()?;

let request = RequestSer::borrowed(&id, &method, params.as_deref());
let raw = serde_json::to_string(&request).map_err(Error::ParseError)?;

let fut = self.transport.send_and_read_body(raw);
let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => {
return Err(Error::RequestTimeout);
}
Ok(Err(e)) => {
return Err(Error::Transport(e.into()));
}
};
let request = Request::new(method.into(), params.as_deref(), id.clone());
let rp = self.transport.call(request).await;

// NOTE: it's decoded first to `JsonRawValue` and then to `R` below to get
// a better error message if `R` couldn't be decoded.
let response = ResponseSuccess::try_from(serde_json::from_slice::<Response<&JsonRawValue>>(&body)?)?;
let response = ResponseSuccess::try_from(serde_json::from_str::<Response<&JsonRawValue>>(&rp.as_result())?)?;

let result = serde_json::from_str(response.result.get()).map_err(Error::ParseError)?;

Expand All @@ -415,71 +420,44 @@ where
let mut batch_request = Vec::with_capacity(batch.len());
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
let id = self.id_manager.as_id_kind().into_id(id);
batch_request.push(RequestSer {
batch_request.push(Request {
jsonrpc: TwoPointZero,
id,
method: method.into(),
params: params.map(StdCow::Owned),
id,
extensions: Default::default(),
});
}

let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);

let body = match tokio::time::timeout(self.request_timeout, fut).await {
Ok(Ok(body)) => body,
Err(_e) => return Err(Error::RequestTimeout),
Ok(Err(e)) => return Err(Error::Transport(e.into())),
};

let json_rps: Vec<Response<&JsonRawValue>> = serde_json::from_slice(&body).map_err(Error::ParseError)?;

let mut responses = Vec::with_capacity(json_rps.len());
let mut successful_calls = 0;
let mut failed_calls = 0;
let batch = self.transport.batch(batch_request).await;
let responses: Vec<Response<&JsonRawValue>> = serde_json::from_str(&batch.as_result()).unwrap();

for _ in 0..json_rps.len() {
responses.push(Err(ErrorObject::borrowed(0, "", None)));
}
let mut x = Vec::new();
let mut success = 0;
let mut failed = 0;

for rp in json_rps {
let id = rp.id.try_parse_inner_as_number()?;

let res = match ResponseSuccess::try_from(rp) {
for rp in responses.into_iter() {
match ResponseSuccess::try_from(rp) {
Ok(r) => {
let result = serde_json::from_str(r.result.get())?;
successful_calls += 1;
Ok(result)
let v = serde_json::from_str(r.result.get()).map_err(Error::ParseError)?;
x.push(Ok(v));
success += 1;
}
Err(err) => {
failed_calls += 1;
Err(err)
x.push(Err(err));
failed += 1;
}
};

let maybe_elem = id
.checked_sub(id_range.start)
.and_then(|p| p.try_into().ok())
.and_then(|p: usize| responses.get_mut(p));

if let Some(elem) = maybe_elem {
*elem = res;
} else {
return Err(InvalidRequestId::NotPendingRequest(id.to_string()).into());
}
}

Ok(BatchResponse::new(successful_calls, responses, failed_calls))
Ok(BatchResponse::new(success, x, failed))
}
}

#[async_trait]
impl<B, S> SubscriptionClientT for HttpClient<S>
impl<S> SubscriptionClientT for HttpClient<S>
where
S: Service<HttpRequest, Response = HttpResponse<B>, Error = TransportError> + Send + Sync + Clone,
<S as Service<HttpRequest>>::Future: Send,
B: http_body::Body<Data = Bytes> + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
for<'a> S: RpcServiceT<'a> + Send + Sync,
{
/// Send a subscription request to the server. Not implemented for HTTP; will always return
/// [`Error::HttpNotImplemented`].
Expand Down
1 change: 1 addition & 0 deletions client/http-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#![cfg_attr(docsrs, feature(doc_cfg))]

mod client;
mod rpc_service;

/// HTTP transport.
pub mod transport;
Expand Down
Loading
Loading