Skip to content

Commit

Permalink
[client]: improve batch request API (#910)
Browse files Browse the repository at this point in the history
* better handling of batch requests

* add test for untagged enum

* remove annoying trait bounds

* cleanup

* more clear comments

* simplify batch request code

* bring back old API + a new one

* refactor batches to work with String IDs again

* refactor again: single batch response API

* fix tests + cleanup

* fix doc links

* address grumbles

* BatchRequestBuilder: add iterator API for the batch

* revert bench

* fix benches build

* address grumbles: ok and into_ok

* fix some nits

* fix nits
  • Loading branch information
niklasad1 authored Nov 7, 2022
1 parent 576d837 commit 824c369
Show file tree
Hide file tree
Showing 20 changed files with 746 additions and 211 deletions.
8 changes: 4 additions & 4 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) {
b.iter(|| {
let params = serde_json::value::RawValue::from_string("[1, 2]".to_string()).unwrap();

let request = RequestSer::new(&Id::Number(0), "say_hello", Some(params));
let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", Some(&params));
v2_serialize(request);
})
});
Expand All @@ -124,7 +124,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) {
builder.insert(1u64).unwrap();
builder.insert(2u32).unwrap();
let params = builder.to_rpc_params().expect("Valid params");
let request = RequestSer::new(&Id::Number(0), "say_hello", params);
let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", params.as_deref());
v2_serialize(request);
})
});
Expand All @@ -134,7 +134,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) {
b.iter(|| {
let params = serde_json::value::RawValue::from_string(r#"{"key": 1}"#.to_string()).unwrap();

let request = RequestSer::new(&Id::Number(0), "say_hello", Some(params));
let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", Some(&params));
v2_serialize(request);
})
});
Expand All @@ -146,7 +146,7 @@ pub fn jsonrpsee_types_v2(crit: &mut Criterion) {
let mut builder = ObjectParams::new();
builder.insert("key", 1u32).unwrap();
let params = builder.to_rpc_params().expect("Valid params");
let request = RequestSer::new(&Id::Number(0), "say_hello", params);
let request = RequestSer::borrowed(&Id::Number(0), &"say_hello", params.as_deref());
v2_serialize(request);
})
});
Expand Down
1 change: 1 addition & 0 deletions client/http-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tokio = { version = "1.16", features = ["time"] }
tracing = "0.1.34"

[dev-dependencies]
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1.16", features = ["net", "rt-multi-thread", "macros"] }

Expand Down
98 changes: 63 additions & 35 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,24 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use std::borrow::Cow as StdCow;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;

use crate::transport::HttpTransportClient;
use crate::types::{ErrorResponse, Id, NotificationSer, RequestSer, Response};
use crate::types::{ErrorResponse, NotificationSer, RequestSer, Response};
use async_trait::async_trait;
use hyper::http::HeaderMap;
use jsonrpsee_core::client::{CertificateStore, ClientT, IdKind, RequestIdManager, Subscription, SubscriptionClientT};
use jsonrpsee_core::client::{
generate_batch_id_range, BatchResponse, CertificateStore, ClientT, IdKind, RequestIdManager, Subscription,
SubscriptionClientT,
};
use jsonrpsee_core::params::BatchRequestBuilder;
use jsonrpsee_core::traits::ToRpcParams;
use jsonrpsee_core::{Error, JsonRawValue, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::error::CallError;
use rustc_hash::FxHashMap;
use jsonrpsee_types::{ErrorObject, TwoPointZero};
use serde::de::DeserializeOwned;
use tracing::instrument;

Expand Down Expand Up @@ -173,7 +178,8 @@ impl ClientT for HttpClient {
Params: ToRpcParams + Send,
{
let params = params.to_rpc_params()?;
let notif = serde_json::to_string(&NotificationSer::new(method, params)).map_err(Error::ParseError)?;
let notif =
serde_json::to_string(&NotificationSer::borrowed(&method, params.as_deref())).map_err(Error::ParseError)?;

let fut = self.transport.send(notif);

Expand All @@ -196,7 +202,7 @@ impl ClientT for HttpClient {
let id = guard.inner();
let params = params.to_rpc_params()?;

let request = RequestSer::new(&id, method, params);
let request = RequestSer::borrowed(&id, &method, params.as_deref());
let raw = serde_json::to_string(&request).map_err(Error::ParseError)?;

let fut = self.transport.send_and_read_body(raw);
Expand Down Expand Up @@ -230,23 +236,23 @@ impl ClientT for HttpClient {
}

#[instrument(name = "batch", skip(self, batch), level = "trace")]
async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result<Vec<R>, Error>
async fn batch_request<'a, R>(&self, batch: BatchRequestBuilder<'a>) -> Result<BatchResponse<'a, R>, Error>
where
R: DeserializeOwned + Default + Clone,
R: DeserializeOwned + fmt::Debug + 'a,
{
let batch = batch.build();
let guard = self.id_manager.next_request_ids(batch.len())?;
let ids: Vec<Id> = guard.inner();
let batch = batch.build()?;
let guard = self.id_manager.next_request_id()?;
let id_range = generate_batch_id_range(&guard, batch.len() as u64)?;

let mut batch_request = Vec::with_capacity(batch.len());
// NOTE(niklasad1): `ID` is not necessarily monotonically increasing.
let mut ordered_requests = Vec::with_capacity(batch.len());
let mut request_set = FxHashMap::with_capacity_and_hasher(batch.len(), Default::default());

for (pos, (method, params)) in batch.into_iter().enumerate() {
batch_request.push(RequestSer::new(&ids[pos], method, params));
ordered_requests.push(&ids[pos]);
request_set.insert(&ids[pos], pos);
for ((method, params), id) in batch.into_iter().zip(id_range.clone()) {
let id = self.id_manager.as_id_kind().into_id(id);
batch_request.push(RequestSer {
jsonrpc: TwoPointZero,
id,
method: method.into(),
params: params.map(StdCow::Owned),
});
}

let fut = self.transport.send_and_read_body(serde_json::to_string(&batch_request).map_err(Error::ParseError)?);
Expand All @@ -257,26 +263,48 @@ impl ClientT for HttpClient {
Ok(Err(e)) => return Err(Error::Transport(e.into())),
};

// NOTE: it's decoded first to `JsonRawValue` and then to `R` below to get
// a better error message if `R` couldn't be decoded.
let rps: Vec<Response<&JsonRawValue>> =
serde_json::from_slice(&body).map_err(|_| match serde_json::from_slice::<ErrorResponse>(&body) {
Ok(e) => Error::Call(CallError::Custom(e.error_object().clone().into_owned())),
Err(e) => Error::ParseError(e),
})?;

// NOTE: `R::default` is placeholder and will be replaced in loop below.
let mut responses = vec![R::default(); ordered_requests.len()];
for rp in rps {
let pos = match request_set.get(&rp.id) {
Some(pos) => *pos,
None => return Err(Error::InvalidRequestId),
let json_rps: Vec<&JsonRawValue> = serde_json::from_slice(&body).map_err(Error::ParseError)?;

let mut responses = Vec::with_capacity(json_rps.len());
let mut successful_calls = 0;
let mut failed_calls = 0;

for _ in 0..json_rps.len() {
responses.push(Err(ErrorObject::borrowed(0, &"", None)));
}

for rp in json_rps {
let (id, res) = match serde_json::from_str::<Response<R>>(rp.get()).map_err(Error::ParseError) {
Ok(r) => {
let id = r.id.try_parse_inner_as_number().ok_or(Error::InvalidRequestId)?;
successful_calls += 1;
(id, Ok(r.result))
}
Err(err) => match serde_json::from_str::<ErrorResponse>(rp.get()).map_err(Error::ParseError) {
Ok(err) => {
let id = err.id().try_parse_inner_as_number().ok_or(Error::InvalidRequestId)?;
failed_calls += 1;
(id, Err(err.error_object().clone().into_owned()))
}
Err(_) => {
return Err(err);
}
},
};
let result = serde_json::from_str(rp.result.get()).map_err(Error::ParseError)?;
responses[pos] = result;

let maybe_elem = id
.checked_sub(id_range.start)
.and_then(|p| p.try_into().ok())
.and_then(|p: usize| responses.get_mut(p));

if let Some(elem) = maybe_elem {
*elem = res;
} else {
return Err(Error::InvalidRequestId);
}
}

Ok(responses)
Ok(BatchResponse::new(successful_calls, responses, failed_calls))
}
}

Expand Down
118 changes: 109 additions & 9 deletions client/http-client/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,21 @@
use crate::types::error::{ErrorCode, ErrorObject};

use crate::HttpClientBuilder;
use jsonrpsee_core::client::{ClientT, IdKind};
use jsonrpsee_core::client::{BatchResponse, ClientT, IdKind};
use jsonrpsee_core::params::BatchRequestBuilder;
use jsonrpsee_core::rpc_params;
use jsonrpsee_core::Error;
use jsonrpsee_core::{rpc_params, DeserializeOwned};
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::mocks::Id;
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::{CallError, ErrorObjectOwned};

fn init_logger() {
let _ = tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.try_init();
}

#[tokio::test]
async fn method_call_works() {
let result = run_request_with_response(ok_response("hello".into(), Id::Num(0)))
Expand Down Expand Up @@ -141,9 +147,95 @@ async fn batch_request_works() {
batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap();
batch_request.insert("get_swag", rpc_params![]).unwrap();
let server_response = r#"[{"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","result":"goodbye","id":1}, {"jsonrpc":"2.0","result":"here's your swag","id":2}]"#.to_string();
let response =
run_batch_request_with_response(batch_request, server_response).with_default_timeout().await.unwrap().unwrap();
assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string()]);
let batch_response = run_batch_request_with_response::<String>(batch_request, server_response)
.with_default_timeout()
.await
.unwrap()
.unwrap();
assert_eq!(batch_response.num_successful_calls(), 3);
let results: Vec<String> = batch_response.into_ok().unwrap().collect();
assert_eq!(results, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string()]);
}

#[tokio::test]
async fn batch_request_with_failed_call_works() {
let mut batch_request = BatchRequestBuilder::new();
batch_request.insert("say_hello", rpc_params![]).unwrap();
batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap();
batch_request.insert("get_swag", rpc_params![]).unwrap();
let server_response = r#"[{"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":1}, {"jsonrpc":"2.0","result":"here's your swag","id":2}]"#.to_string();
let res = run_batch_request_with_response::<String>(batch_request, server_response)
.with_default_timeout()
.await
.unwrap()
.unwrap();
assert_eq!(res.num_successful_calls(), 2);
assert_eq!(res.num_failed_calls(), 1);
assert_eq!(res.len(), 3);

let successful_calls: Vec<_> = res.iter().filter_map(|r| r.as_ref().ok()).collect();
let failed_calls: Vec<_> = res
.iter()
.filter_map(|r| match r {
Err(e) => Some(e),
_ => None,
})
.collect();

assert_eq!(successful_calls, vec!["hello", "here's your swag"]);
assert_eq!(failed_calls, vec![&ErrorObject::from(ErrorCode::MethodNotFound)]);
}

#[tokio::test]
async fn batch_request_with_failed_call_gives_proper_error() {
let mut batch_request = BatchRequestBuilder::new();
batch_request.insert("say_hello", rpc_params![]).unwrap();
batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap();
batch_request.insert("get_swag", rpc_params![]).unwrap();
let server_response = r#"[{"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":1}, {"jsonrpc":"2.0","error":{"code":-32602,"message":"foo"},"id":2}]"#.to_string();
let res = run_batch_request_with_response::<String>(batch_request, server_response)
.with_default_timeout()
.await
.unwrap()
.unwrap();
let err: Vec<_> = res.into_ok().unwrap_err().collect();
assert_eq!(err, vec![ErrorObject::from(ErrorCode::MethodNotFound), ErrorObject::borrowed(-32602, &"foo", None)]);
}

#[tokio::test]
async fn batch_request_with_untagged_enum_works() {
init_logger();

#[derive(serde::Deserialize, Clone, Debug, PartialEq)]
#[serde(untagged)]
enum Custom {
Text(String),
Number(u8),
}

impl Default for Custom {
fn default() -> Self {
Self::Number(0)
}
}

let mut batch_request = BatchRequestBuilder::new();
batch_request.insert("text", rpc_params![]).unwrap();
batch_request.insert("binary", rpc_params![0_u64, 1, 2]).unwrap();
let server_response =
r#"[{"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","result":13,"id":1}]"#.to_string();
let res = run_batch_request_with_response::<Custom>(batch_request, server_response)
.with_default_timeout()
.await
.unwrap()
.unwrap();

assert_eq!(res.num_successful_calls(), 2);
assert_eq!(res.num_failed_calls(), 0);
assert_eq!(res.len(), 2);
let response: Vec<_> = res.into_ok().unwrap().collect();

assert_eq!(response, vec![Custom::Text("hello".to_string()), Custom::Number(13)]);
}

#[tokio::test]
Expand All @@ -153,15 +245,23 @@ async fn batch_request_out_of_order_response() {
batch_request.insert("say_goodbye", rpc_params![0_u64, 1, 2]).unwrap();
batch_request.insert("get_swag", rpc_params![]).unwrap();
let server_response = r#"[{"jsonrpc":"2.0","result":"here's your swag","id":2}, {"jsonrpc":"2.0","result":"hello","id":0}, {"jsonrpc":"2.0","result":"goodbye","id":1}]"#.to_string();
let response =
run_batch_request_with_response(batch_request, server_response).with_default_timeout().await.unwrap().unwrap();
let res = run_batch_request_with_response::<String>(batch_request, server_response)
.with_default_timeout()
.await
.unwrap()
.unwrap();
assert_eq!(res.num_successful_calls(), 3);
assert_eq!(res.num_failed_calls(), 0);
assert_eq!(res.len(), 3);
let response: Vec<_> = res.into_ok().unwrap().collect();

assert_eq!(response, vec!["hello".to_string(), "goodbye".to_string(), "here's your swag".to_string()]);
}

async fn run_batch_request_with_response(
async fn run_batch_request_with_response<T: Send + DeserializeOwned + std::fmt::Debug + 'static>(
batch: BatchRequestBuilder<'_>,
response: String,
) -> Result<Vec<String>, Error> {
) -> Result<BatchResponse<T>, Error> {
let server_addr = http_server_with_hardcoded_response(response).with_default_timeout().await.unwrap();
let uri = format!("http://{}", server_addr);
let client = HttpClientBuilder::default().build(&uri).unwrap();
Expand Down
1 change: 1 addition & 0 deletions client/ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
jsonrpsee-test-utils = { path = "../../test-utils" }
tokio = { version = "1.16", features = ["macros"] }
serde_json = "1"
serde = "1"

[features]
tls = ["jsonrpsee-client-transport/tls"]
Expand Down
Loading

0 comments on commit 824c369

Please sign in to comment.