From 5831042c7d9564c1f0299db49ef97a80a7535869 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 25 Jun 2024 14:21:08 +0200 Subject: [PATCH] fix(server): make `ws::on_connect` work again --- server/src/transport/ws.rs | 25 ++++----- tests/tests/integration_tests.rs | 90 ++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 12 deletions(-) diff --git a/server/src/transport/ws.rs b/server/src/transport/ws.rs index 9fc24cdccf..bf6c573211 100644 --- a/server/src/transport/ws.rs +++ b/server/src/transport/ws.rs @@ -393,7 +393,7 @@ async fn graceful_shutdown( /// async fn handle_websocket_conn( /// req: HttpRequest, /// server_cfg: ServerConfig, -/// methods: impl Into + 'static, +/// methods: impl Into + Send + 'static, /// conn: ConnectionState, /// rpc_middleware: RpcServiceBuilder, /// mut disconnect: tokio::sync::mpsc::Receiver<()> @@ -435,17 +435,6 @@ where match server.receive_request(&req) { Ok(response) => { - let extensions = req.extensions().clone(); - - let upgraded = match hyper::upgrade::on(req).await { - Ok(u) => u, - Err(e) => { - tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e); - return Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed {e}")))); - } - }; - - let io = TokioIo::new(upgraded); let (tx, rx) = mpsc::channel::(server_cfg.message_buffer_capacity as usize); let sink = MethodSink::new(tx); @@ -473,6 +462,18 @@ where // Note: This can't possibly be fulfilled until the HTTP response // is returned below, so that's why it's a separate async block let fut = async move { + let extensions = req.extensions().clone(); + + let upgraded = match hyper::upgrade::on(req).await { + Ok(upgraded) => upgraded, + Err(e) => { + tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e); + return; + } + }; + + let io = TokioIo::new(upgraded); + let stream = BufReader::new(BufWriter::new(io.compat())); let mut ws_builder = server.into_builder(stream); ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 5136cefa41..4efc894231 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -29,6 +29,7 @@ mod helpers; +use std::net::SocketAddr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; @@ -1428,3 +1429,92 @@ async fn run_shutdown_test(transport: &str) { _ => unreachable!("Only `http` and `ws` supported"), } } + +#[tokio::test] +async fn server_ws_low_api_works() { + let local_addr = run_server().await.unwrap(); + + let client = WsClientBuilder::default().build(&format!("ws://{}", local_addr)).await.unwrap(); + assert!(matches!(client.request::("say_hello", rpc_params![]).await, Ok(r) if r == "hello")); + + async fn run_server() -> anyhow::Result { + use futures_util::future::FutureExt; + use jsonrpsee::core::BoxError; + use jsonrpsee::server::{ + http, middleware::rpc::RpcServiceBuilder, serve_with_graceful_shutdown, stop_channel, ws, ConnectionGuard, + ConnectionState, Methods, ServerConfig, StopHandle, + }; + + let listener = tokio::net::TcpListener::bind(std::net::SocketAddr::from(([127, 0, 0, 1], 0))).await?; + let local_addr = listener.local_addr()?; + let (stop_handle, server_handle) = stop_channel(); + + let mut methods = RpcModule::new(()); + + methods.register_async_method("say_hello", |_, _, _| async { "hello" }).unwrap(); + + #[derive(Clone)] + struct PerConnection { + methods: Methods, + stop_handle: StopHandle, + conn_guard: ConnectionGuard, + } + + let per_conn = PerConnection { + methods: methods.into(), + stop_handle: stop_handle.clone(), + conn_guard: ConnectionGuard::new(100), + }; + + tokio::spawn(async move { + loop { + let (sock, _) = tokio::select! { + res = listener.accept() => { + match res { + Ok(sock) => sock, + Err(e) => { + tracing::error!("Failed to accept v4 connection: {:?}", e); + continue; + } + } + } + _ = per_conn.stop_handle.clone().shutdown() => break, + }; + let per_conn = per_conn.clone(); + + let stop_handle2 = per_conn.stop_handle.clone(); + let per_conn = per_conn.clone(); + let svc = tower::service_fn(move |req| { + let PerConnection { methods, stop_handle, conn_guard } = per_conn.clone(); + let conn_permit = + conn_guard.try_acquire().expect("Connection limit is 100 must be work for two connections"); + + if ws::is_upgrade_request(&req) { + let rpc_service = RpcServiceBuilder::new(); + + let conn = ConnectionState::new(stop_handle, 0, conn_permit); + + async move { + match ws::connect(req, ServerConfig::default(), methods, conn, rpc_service).await { + Ok((rp, conn_fut)) => { + tokio::spawn(conn_fut); + Ok(rp) + } + Err(rp) => Ok(rp), + } + } + .boxed() + } else { + async { Ok::<_, BoxError>(http::response::denied()) }.boxed() + } + }); + + tokio::spawn(serve_with_graceful_shutdown(sock, svc, stop_handle2.shutdown())); + } + }); + + tokio::spawn(server_handle.stopped()); + + Ok(local_addr) + } +}