Skip to content

Commit

Permalink
[grid] limit the number of websocket connections per session (#14410)
Browse files Browse the repository at this point in the history
Co-authored-by: Viet Nguyen Duc <[email protected]>
  • Loading branch information
joerg1985 and VietND96 authored Oct 29, 2024
1 parent b01041f commit e9e684d
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 5 deletions.
11 changes: 11 additions & 0 deletions java/src/org/openqa/selenium/grid/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@
* by {@code sessionId}. This returns a boolean.</td>
* </tr>
* <tr>
* <td>POST</td>
* <td>/se/grid/node/connection/{sessionId}</td>
* <td>Allows the node to be ask about whether or not new websocket connections are allowed for the {@link Session}
* identified by {@code sessionId}. This returns a boolean.</td>
* </tr>
* <tr>
* <td>*</td>
* <td>/session/{sessionId}/*</td>
* <td>The request is forwarded to the {@link Session} identified by {@code sessionId}. When the
Expand Down Expand Up @@ -172,6 +178,9 @@ protected Node(
get("/se/grid/node/owner/{sessionId}")
.to(params -> new IsSessionOwner(this, sessionIdFrom(params)))
.with(spanDecorator("node.is_session_owner").andThen(requiresSecret)),
post("/se/grid/node/connection/{sessionId}")
.to(params -> new TryAcquireConnection(this, sessionIdFrom(params)))
.with(spanDecorator("node.is_session_owner").andThen(requiresSecret)),
delete("/se/grid/node/session/{sessionId}")
.to(params -> new StopNodeSession(this, sessionIdFrom(params)))
.with(spanDecorator("node.stop_session").andThen(requiresSecret)),
Expand Down Expand Up @@ -244,6 +253,8 @@ public TemporaryFilesystem getDownloadsFilesystem(UUID uuid) throws IOException

public abstract boolean isSessionOwner(SessionId id);

public abstract boolean tryAcquireConnection(SessionId id);

public abstract boolean isSupporting(Capabilities capabilities);

public abstract NodeStatus getStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ public Optional<Consumer<Message>> apply(String uri, Consumer<Message> downstrea
return Optional.empty();
}

// ensure one session does not open to many connections, this might have a negative impact on
// the grid health
if (!node.tryAcquireConnection(id)) {
LOG.warning("Too many websocket connections initiated by " + id);
return Optional.empty();
}

Session session = node.getSession(id);
Capabilities caps = session.getCapabilities();
LOG.fine("Scanning for endpoint: " + caps);
Expand Down
45 changes: 45 additions & 0 deletions java/src/org/openqa/selenium/grid/node/TryAcquireConnection.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the Software Freedom Conservancy (SFC) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The SFC licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.openqa.selenium.grid.node;

import static org.openqa.selenium.remote.http.Contents.asJson;

import com.google.common.collect.ImmutableMap;
import java.io.UncheckedIOException;
import org.openqa.selenium.internal.Require;
import org.openqa.selenium.remote.SessionId;
import org.openqa.selenium.remote.http.HttpHandler;
import org.openqa.selenium.remote.http.HttpRequest;
import org.openqa.selenium.remote.http.HttpResponse;

class TryAcquireConnection implements HttpHandler {

private final Node node;
private final SessionId id;

TryAcquireConnection(Node node, SessionId id) {
this.node = Require.nonNull("Node", node);
this.id = Require.nonNull("Session id", id);
}

@Override
public HttpResponse execute(HttpRequest req) throws UncheckedIOException {
return new HttpResponse()
.setContent(asJson(ImmutableMap.of("value", node.tryAcquireConnection(id))));
}
}
9 changes: 9 additions & 0 deletions java/src/org/openqa/selenium/grid/node/config/NodeFlags.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.openqa.selenium.grid.node.config;

import static org.openqa.selenium.grid.config.StandardGridRoles.NODE_ROLE;
import static org.openqa.selenium.grid.node.config.NodeOptions.DEFAULT_CONNECTION_LIMIT;
import static org.openqa.selenium.grid.node.config.NodeOptions.DEFAULT_DETECT_DRIVERS;
import static org.openqa.selenium.grid.node.config.NodeOptions.DEFAULT_DRAIN_AFTER_SESSION_COUNT;
import static org.openqa.selenium.grid.node.config.NodeOptions.DEFAULT_ENABLE_BIDI;
Expand Down Expand Up @@ -77,6 +78,14 @@ public class NodeFlags implements HasRoles {
@ConfigValue(section = NODE_SECTION, name = "session-timeout", example = "60")
public int sessionTimeout = DEFAULT_SESSION_TIMEOUT;

@Parameter(
names = {"--connection-limit-per-session"},
description =
"Let X be the maximum number of websocket connections per session.This will ensure one"
+ " session is not able to exhaust the connection limit of the host")
@ConfigValue(section = NODE_SECTION, name = "connection-limit-per-session", example = "8")
public int connectionLimitPerSession = DEFAULT_CONNECTION_LIMIT;

@Parameter(
names = {"--detect-drivers"},
arity = 1,
Expand Down
10 changes: 10 additions & 0 deletions java/src/org/openqa/selenium/grid/node/config/NodeOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public class NodeOptions {
public static final int DEFAULT_HEARTBEAT_PERIOD = 60;
public static final int DEFAULT_SESSION_TIMEOUT = 300;
public static final int DEFAULT_DRAIN_AFTER_SESSION_COUNT = 0;
public static final int DEFAULT_CONNECTION_LIMIT = 10;
public static final boolean DEFAULT_ENABLE_CDP = true;
public static final boolean DEFAULT_ENABLE_BIDI = true;
static final String NODE_SECTION = "node";
Expand Down Expand Up @@ -262,6 +263,15 @@ public int getMaxSessions() {
return Math.min(maxSessions, DEFAULT_MAX_SESSIONS);
}

public int getConnectionLimitPerSession() {
int connectionLimit =
config
.getInt(NODE_SECTION, "connection-limit-per-session")
.orElse(DEFAULT_CONNECTION_LIMIT);
Require.positive("Session connection limit", connectionLimit);
return connectionLimit;
}

public Duration getSessionTimeout() {
// If the user sets 10s or less, we default to 10s.
int seconds =
Expand Down
15 changes: 13 additions & 2 deletions java/src/org/openqa/selenium/grid/node/k8s/OneShotNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Optional;
import java.util.ServiceLoader;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
import java.util.stream.StreamSupport;
import org.openqa.selenium.Capabilities;
Expand Down Expand Up @@ -98,6 +99,8 @@ public class OneShotNode extends Node {
private final Duration heartbeatPeriod;
private final URI gridUri;
private final UUID slotId = UUID.randomUUID();
private final int connectionLimitPerSession;
private final AtomicInteger connectionCounter = new AtomicInteger();
private RemoteWebDriver driver;
private SessionId sessionId;
private HttpClient client;
Expand All @@ -114,14 +117,16 @@ private OneShotNode(
URI uri,
URI gridUri,
Capabilities stereotype,
WebDriverInfo driverInfo) {
WebDriverInfo driverInfo,
int connectionLimitPerSession) {
super(tracer, id, uri, registrationSecret, Require.positive(sessionTimeout));

this.heartbeatPeriod = heartbeatPeriod;
this.events = Require.nonNull("Event bus", events);
this.gridUri = Require.nonNull("Public Grid URI", gridUri);
this.stereotype = ImmutableCapabilities.copyOf(Require.nonNull("Stereotype", stereotype));
this.driverInfo = Require.nonNull("Driver info", driverInfo);
this.connectionLimitPerSession = connectionLimitPerSession;

new JMXHelper().register(this);
}
Expand Down Expand Up @@ -177,7 +182,8 @@ public static Node create(Config config) {
.getPublicGridUri()
.orElseThrow(() -> new ConfigException("Unable to determine public grid address")),
stereotype,
driverInfo);
driverInfo,
nodeOptions.getConnectionLimitPerSession());
}

@Override
Expand Down Expand Up @@ -357,6 +363,11 @@ public boolean isSessionOwner(SessionId id) {
return driver != null && sessionId.equals(id);
}

@Override
public boolean tryAcquireConnection(SessionId id) {
return sessionId.equals(id) && connectionLimitPerSession > connectionCounter.getAndIncrement();
}

@Override
public boolean isSupporting(Capabilities capabilities) {
return driverInfo.isSupporting(capabilities);
Expand Down
33 changes: 31 additions & 2 deletions java/src/org/openqa/selenium/grid/node/local/LocalNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -127,6 +128,7 @@ public class LocalNode extends Node {
private final int configuredSessionCount;
private final boolean cdpEnabled;
private final boolean managedDownloadsEnabled;
private final int connectionLimitPerSession;

private final boolean bidiEnabled;
private final AtomicBoolean drainAfterSessions = new AtomicBoolean();
Expand All @@ -153,7 +155,8 @@ protected LocalNode(
Duration heartbeatPeriod,
List<SessionSlot> factories,
Secret registrationSecret,
boolean managedDownloadsEnabled) {
boolean managedDownloadsEnabled,
int connectionLimitPerSession) {
super(
tracer,
new NodeId(UUID.randomUUID()),
Expand All @@ -176,6 +179,7 @@ protected LocalNode(
this.cdpEnabled = cdpEnabled;
this.bidiEnabled = bidiEnabled;
this.managedDownloadsEnabled = managedDownloadsEnabled;
this.connectionLimitPerSession = connectionLimitPerSession;

this.healthCheck =
healthCheck == null
Expand Down Expand Up @@ -579,6 +583,24 @@ public boolean isSessionOwner(SessionId id) {
return currentSessions.getIfPresent(id) != null;
}

@Override
public boolean tryAcquireConnection(SessionId id) throws NoSuchSessionException {
SessionSlot slot = currentSessions.getIfPresent(id);

if (slot == null) {
return false;
}

if (connectionLimitPerSession == -1) {
// no limit
return true;
}

AtomicLong counter = slot.getConnectionCounter();

return connectionLimitPerSession > counter.getAndIncrement();
}

@Override
public Session getSession(SessionId id) throws NoSuchSessionException {
Require.nonNull("Session ID", id);
Expand Down Expand Up @@ -987,6 +1009,7 @@ public static class Builder {
private HealthCheck healthCheck;
private Duration heartbeatPeriod = Duration.ofSeconds(NodeOptions.DEFAULT_HEARTBEAT_PERIOD);
private boolean managedDownloadsEnabled = false;
private int connectionLimitPerSession = -1;

private Builder(Tracer tracer, EventBus bus, URI uri, URI gridUri, Secret registrationSecret) {
this.tracer = Require.nonNull("Tracer", tracer);
Expand Down Expand Up @@ -1041,6 +1064,11 @@ public Builder enableManagedDownloads(boolean enable) {
return this;
}

public Builder connectionLimitPerSession(int connectionLimitPerSession) {
this.connectionLimitPerSession = connectionLimitPerSession;
return this;
}

public LocalNode build() {
return new LocalNode(
tracer,
Expand All @@ -1057,7 +1085,8 @@ public LocalNode build() {
heartbeatPeriod,
factories.build(),
registrationSecret,
managedDownloadsEnabled);
managedDownloadsEnabled,
connectionLimitPerSession);
}

public Advanced advanced() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ public static Node create(Config config) {
.enableCdp(nodeOptions.isCdpEnabled())
.enableBiDi(nodeOptions.isBiDiEnabled())
.enableManagedDownloads(nodeOptions.isManagedDownloadsEnabled())
.heartbeatPeriod(nodeOptions.getHeartbeatPeriod());
.heartbeatPeriod(nodeOptions.getHeartbeatPeriod())
.connectionLimitPerSession(nodeOptions.getConnectionLimitPerSession());

List<DriverService.Builder<?, ?>> builders = new ArrayList<>();
ServiceLoader.load(DriverService.Builder.class).forEach(builders::add);
Expand Down
9 changes: 9 additions & 0 deletions java/src/org/openqa/selenium/grid/node/local/SessionSlot.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.ServiceLoader;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.logging.Level;
Expand Down Expand Up @@ -59,6 +60,7 @@ public class SessionSlot
private final AtomicBoolean reserved = new AtomicBoolean(false);
private final boolean supportingCdp;
private final boolean supportingBiDi;
private final AtomicLong connectionCounter;
private ActiveSession currentSession;

public SessionSlot(EventBus bus, Capabilities stereotype, SessionFactory factory) {
Expand All @@ -68,6 +70,7 @@ public SessionSlot(EventBus bus, Capabilities stereotype, SessionFactory factory
this.factory = Require.nonNull("Session factory", factory);
this.supportingCdp = isSlotSupportingCdp(this.stereotype);
this.supportingBiDi = isSlotSupportingBiDi(this.stereotype);
this.connectionCounter = new AtomicLong();
}

public UUID getId() {
Expand Down Expand Up @@ -112,6 +115,7 @@ public void stop() {
LOG.log(Level.WARNING, "Unable to cleanly close session", e);
}
currentSession = null;
connectionCounter.set(0);
release();
bus.fire(new SessionClosedEvent(id));
LOG.info(String.format("Stopping session %s", id));
Expand Down Expand Up @@ -148,6 +152,7 @@ public Either<WebDriverException, ActiveSession> apply(CreateSessionRequest sess
if (possibleSession.isRight()) {
ActiveSession session = possibleSession.right();
currentSession = session;
connectionCounter.set(0);
return Either.right(session);
} else {
return Either.left(possibleSession.left());
Expand Down Expand Up @@ -185,4 +190,8 @@ public boolean hasRelayFactory() {
public boolean isRelayServiceUp() {
return hasRelayFactory() && ((RelaySessionFactory) factory).isServiceUp();
}

public AtomicLong getConnectionCounter() {
return connectionCounter;
}
}
12 changes: 12 additions & 0 deletions java/src/org/openqa/selenium/grid/node/remote/RemoteNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ public boolean isSessionOwner(SessionId id) {
return Boolean.TRUE.equals(Values.get(res, Boolean.class));
}

@Override
public boolean tryAcquireConnection(SessionId id) {
Require.nonNull("Session ID", id);

HttpRequest req = new HttpRequest(POST, "/se/grid/node/connection/" + id);
HttpTracing.inject(tracer, tracer.getCurrentContext(), req);

HttpResponse res = client.with(addSecret).execute(req);

return Boolean.TRUE.equals(Values.get(res, Boolean.class));
}

@Override
public Session getSession(SessionId id) throws NoSuchSessionException {
Require.nonNull("Session ID", id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ public boolean isSessionOwner(SessionId id) {
return running != null && running.getId().equals(id);
}

@Override
public boolean tryAcquireConnection(SessionId id) {
return false;
}

@Override
public boolean isSupporting(Capabilities capabilities) {
return Objects.equals("cake", capabilities.getCapability("cheese"));
Expand Down

0 comments on commit e9e684d

Please sign in to comment.