Skip to content

Commit

Permalink
s2a: fix flake in FakeS2AServerTest (#11673)
Browse files Browse the repository at this point in the history
While here:
 * add an awaitTermination to after calling shutdown on server
 * don't use port picker

Fixes #11648
  • Loading branch information
rmehta19 authored Nov 8, 2024
1 parent 5081e60 commit 546efd7
Showing 1 changed file with 23 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.ByteString;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.benchmarks.Utils;
import io.grpc.s2a.internal.handshaker.ValidatePeerCertificateChainReq.VerificationMode;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
Expand All @@ -49,51 +49,52 @@ public final class FakeS2AServerTest {

private static final ImmutableList<ByteString> FAKE_CERT_DER_CHAIN =
ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII)));
private int port;
private String serverAddress;
private SessionResp response = null;
private Server fakeS2AServer;

@Before
public void setUp() throws Exception {
port = Utils.pickUnusedPort();
fakeS2AServer = ServerBuilder.forPort(port).addService(new FakeS2AServer()).build();
fakeS2AServer = ServerBuilder.forPort(0).addService(new FakeS2AServer()).build();
fakeS2AServer.start();
serverAddress = String.format("localhost:%d", port);
serverAddress = String.format("localhost:%d", fakeS2AServer.getPort());
}

@After
public void tearDown() {
public void tearDown() throws Exception {
fakeS2AServer.shutdown();
fakeS2AServer.awaitTermination(10, SECONDS);
}

@Test
public void callS2AServerOnce_getTlsConfiguration_returnsValidResult()
throws InterruptedException, IOException {
throws InterruptedException, IOException, java.util.concurrent.ExecutionException {
ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create())
.executor(executor)
.build();

SettableFuture<SessionResp> respFuture = SettableFuture.create();
try {
S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel);
StreamObserver<SessionReq> requestObserver =
asyncStub.setUpSession(
new StreamObserver<SessionResp>() {
SessionResp recvResp;
@Override
public void onNext(SessionResp resp) {
response = resp;
recvResp = resp;
}

@Override
public void onError(Throwable t) {
throw new RuntimeException(t);
respFuture.setException(t);
}

@Override
public void onCompleted() {}
public void onCompleted() {
respFuture.set(recvResp);
}
});
try {
requestObserver.onNext(
Expand Down Expand Up @@ -138,36 +139,39 @@ public void onCompleted() {}
.addCiphersuites(
Ciphersuite.CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)))
.build();
assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected);
assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected);
}

@Test
public void callS2AServerOnce_validatePeerCertifiate_returnsValidResult()
throws InterruptedException {
throws InterruptedException, java.util.concurrent.ExecutionException {
ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel =
Grpc.newChannelBuilder(serverAddress, InsecureChannelCredentials.create())
.executor(executor)
.build();

SettableFuture<SessionResp> respFuture = SettableFuture.create();
try {
S2AServiceGrpc.S2AServiceStub asyncStub = S2AServiceGrpc.newStub(channel);
StreamObserver<SessionReq> requestObserver =
asyncStub.setUpSession(
new StreamObserver<SessionResp>() {
private SessionResp recvResp;
@Override
public void onNext(SessionResp resp) {
response = resp;
recvResp = resp;
}

@Override
public void onError(Throwable t) {
throw new RuntimeException(t);
respFuture.setException(t);
}

@Override
public void onCompleted() {}
public void onCompleted() {
respFuture.set(recvResp);
}
});
try {
requestObserver.onNext(
Expand Down Expand Up @@ -200,7 +204,7 @@ public void onCompleted() {}
ValidatePeerCertificateChainResp.newBuilder()
.setValidationResult(ValidatePeerCertificateChainResp.ValidationResult.SUCCESS))
.build();
assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(expected);
assertThat(respFuture.get()).ignoringRepeatedFieldOrder().isEqualTo(expected);
}

@Test
Expand Down

0 comments on commit 546efd7

Please sign in to comment.