Skip to content

Commit

Permalink
fix(server): make ws::on_connect work again
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasad1 committed Jun 25, 2024
1 parent 60c4fed commit 5831042
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 12 deletions.
25 changes: 13 additions & 12 deletions server/src/transport/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ async fn graceful_shutdown<S>(
/// async fn handle_websocket_conn<L>(
/// req: HttpRequest,
/// server_cfg: ServerConfig,
/// methods: impl Into<Methods> + 'static,
/// methods: impl Into<Methods> + Send + 'static,
/// conn: ConnectionState,
/// rpc_middleware: RpcServiceBuilder<L>,
/// mut disconnect: tokio::sync::mpsc::Receiver<()>
Expand Down Expand Up @@ -435,17 +435,6 @@ where

match server.receive_request(&req) {
Ok(response) => {
let extensions = req.extensions().clone();

let upgraded = match hyper::upgrade::on(req).await {
Ok(u) => u,
Err(e) => {
tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
return Err(HttpResponse::new(HttpBody::from(format!("WS upgrade handshake failed {e}"))));
}
};

let io = TokioIo::new(upgraded);
let (tx, rx) = mpsc::channel::<String>(server_cfg.message_buffer_capacity as usize);
let sink = MethodSink::new(tx);

Expand Down Expand Up @@ -473,6 +462,18 @@ where
// Note: This can't possibly be fulfilled until the HTTP response
// is returned below, so that's why it's a separate async block
let fut = async move {
let extensions = req.extensions().clone();

let upgraded = match hyper::upgrade::on(req).await {
Ok(upgraded) => upgraded,
Err(e) => {
tracing::debug!(target: LOG_TARGET, "WS upgrade handshake failed: {}", e);
return;
}
};

let io = TokioIo::new(upgraded);

let stream = BufReader::new(BufWriter::new(io.compat()));
let mut ws_builder = server.into_builder(stream);
ws_builder.set_max_message_size(server_cfg.max_response_body_size as usize);
Expand Down
90 changes: 90 additions & 0 deletions tests/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

mod helpers;

use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
Expand Down Expand Up @@ -1428,3 +1429,92 @@ async fn run_shutdown_test(transport: &str) {
_ => unreachable!("Only `http` and `ws` supported"),
}
}

#[tokio::test]
async fn server_ws_low_api_works() {
let local_addr = run_server().await.unwrap();

let client = WsClientBuilder::default().build(&format!("ws://{}", local_addr)).await.unwrap();
assert!(matches!(client.request::<String, _>("say_hello", rpc_params![]).await, Ok(r) if r == "hello"));

async fn run_server() -> anyhow::Result<SocketAddr> {
use futures_util::future::FutureExt;
use jsonrpsee::core::BoxError;
use jsonrpsee::server::{
http, middleware::rpc::RpcServiceBuilder, serve_with_graceful_shutdown, stop_channel, ws, ConnectionGuard,
ConnectionState, Methods, ServerConfig, StopHandle,
};

let listener = tokio::net::TcpListener::bind(std::net::SocketAddr::from(([127, 0, 0, 1], 0))).await?;
let local_addr = listener.local_addr()?;
let (stop_handle, server_handle) = stop_channel();

let mut methods = RpcModule::new(());

methods.register_async_method("say_hello", |_, _, _| async { "hello" }).unwrap();

#[derive(Clone)]
struct PerConnection {
methods: Methods,
stop_handle: StopHandle,
conn_guard: ConnectionGuard,
}

let per_conn = PerConnection {
methods: methods.into(),
stop_handle: stop_handle.clone(),
conn_guard: ConnectionGuard::new(100),
};

tokio::spawn(async move {
loop {
let (sock, _) = tokio::select! {
res = listener.accept() => {
match res {
Ok(sock) => sock,
Err(e) => {
tracing::error!("Failed to accept v4 connection: {:?}", e);
continue;
}
}
}
_ = per_conn.stop_handle.clone().shutdown() => break,
};
let per_conn = per_conn.clone();

let stop_handle2 = per_conn.stop_handle.clone();
let per_conn = per_conn.clone();
let svc = tower::service_fn(move |req| {
let PerConnection { methods, stop_handle, conn_guard } = per_conn.clone();
let conn_permit =
conn_guard.try_acquire().expect("Connection limit is 100 must be work for two connections");

if ws::is_upgrade_request(&req) {
let rpc_service = RpcServiceBuilder::new();

let conn = ConnectionState::new(stop_handle, 0, conn_permit);

async move {
match ws::connect(req, ServerConfig::default(), methods, conn, rpc_service).await {
Ok((rp, conn_fut)) => {
tokio::spawn(conn_fut);
Ok(rp)
}
Err(rp) => Ok(rp),
}
}
.boxed()
} else {
async { Ok::<_, BoxError>(http::response::denied()) }.boxed()
}
});

tokio::spawn(serve_with_graceful_shutdown(sock, svc, stop_handle2.shutdown()));
}
});

tokio::spawn(server_handle.stopped());

Ok(local_addr)
}
}

0 comments on commit 5831042

Please sign in to comment.