Skip to content

Commit

Permalink
feat: avoid lock and add check when complete (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattisonchao authored Jan 15, 2025
1 parent 7c12b39 commit 8f317e8
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2022-2024 StreamNative Inc.
* Copyright © 2022-2025 StreamNative Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,11 +21,12 @@
import java.util.concurrent.ConcurrentHashMap;

public final class OxiaWriteStreamManager {
private final Map<Long, WriteStreamWrapper> writeStreams = new ConcurrentHashMap<>();
private final Map<Long, WriteStreamWrapper> writeStreams;
private final OxiaStubProvider provider;

public OxiaWriteStreamManager(OxiaStubProvider provider) {
this.provider = provider;
this.writeStreams = new ConcurrentHashMap<>();
}

private static final Metadata.Key<String> NAMESPACE_KEY =
Expand All @@ -34,18 +35,28 @@ public OxiaWriteStreamManager(OxiaStubProvider provider) {
Metadata.Key.of("shard-id", Metadata.ASCII_STRING_MARSHALLER);

public WriteStreamWrapper getWriteStream(long shardId) {
return writeStreams.compute(
shardId,
(key, stream) -> {
if (stream == null || !stream.isValid()) {
Metadata headers = new Metadata();
headers.put(NAMESPACE_KEY, provider.getNamespace());
headers.put(SHARD_ID_KEY, String.format("%d", shardId));
final var asyncStub = provider.getStubForShard(shardId).async();
return new WriteStreamWrapper(
asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(headers)));
}
return stream;
});
WriteStreamWrapper wrapper = null;
for (int i = 0; i < 2; i++) {
wrapper = writeStreams.get(shardId); // lock free first
if (wrapper == null) {
wrapper =
writeStreams.computeIfAbsent(
shardId,
(__) -> {
Metadata headers = new Metadata();
headers.put(NAMESPACE_KEY, provider.getNamespace());
headers.put(SHARD_ID_KEY, String.format("%d", shardId));
final var asyncStub = provider.getStubForShard(shardId).async();
return new WriteStreamWrapper(
asyncStub.withInterceptors(
MetadataUtils.newAttachHeadersInterceptor(headers)));
});
}
if (wrapper.isValid()) {
break;
}
writeStreams.remove(shardId, wrapper);
}
return wrapper;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2022-2024 StreamNative Inc.
* Copyright © 2022-2025 StreamNative Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.streamnative.oxia.client.grpc;

import io.grpc.stub.StreamObserver;
Expand All @@ -22,6 +21,7 @@
import io.streamnative.oxia.proto.WriteResponse;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Optional;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -30,15 +30,20 @@
public final class WriteStreamWrapper implements StreamObserver<WriteResponse> {

private final StreamObserver<WriteRequest> clientStream;
private final Deque<CompletableFuture<WriteResponse>> pendingWrites = new ArrayDeque<>();
private volatile Throwable failed = null;
private final Deque<CompletableFuture<WriteResponse>> pendingWrites;

private volatile boolean completed;
private volatile Throwable completedException;

public WriteStreamWrapper(OxiaClientGrpc.OxiaClientStub stub) {
this.clientStream = stub.writeStream(this);
this.pendingWrites = new ArrayDeque<>();
this.completed = false;
this.completedException = null;
}

public boolean isValid() {
return failed == null;
return !completed;
}

@Override
Expand All @@ -52,34 +57,45 @@ public void onNext(WriteResponse value) {
}

@Override
public void onError(Throwable t) {
public void onError(Throwable error) {
synchronized (WriteStreamWrapper.this) {
completedException = error;
completed = true;
if (!pendingWrites.isEmpty()) {
log.warn("Got Error", t);
log.warn(
"Receive error when writing data to server through the stream, prepare to fail pending requests. pendingWrites={}",
pendingWrites.size(),
completedException);
}
pendingWrites.forEach(f -> f.completeExceptionally(t));
pendingWrites.forEach(f -> f.completeExceptionally(completedException));
pendingWrites.clear();
failed = t;
}
}

@Override
public void onCompleted() {
synchronized (WriteStreamWrapper.this) {
// complete pending request if the server close stream without any response
pendingWrites.forEach(
f -> {
if (!f.isDone()) {
f.completeExceptionally(new CancellationException());
}
});
completed = true;
if (!pendingWrites.isEmpty()) {
log.warn(
"Receive stream close signal when writing data to server through the stream, prepare to cancel pending requests. pendingWrites={}",
pendingWrites.size(),
completedException);
}
pendingWrites.forEach(f -> f.completeExceptionally(new CancellationException()));
pendingWrites.clear();
}
}

public CompletableFuture<WriteResponse> send(WriteRequest request) {
if (completed) {
return CompletableFuture.failedFuture(
Optional.ofNullable(completedException).orElseGet(CancellationException::new));
}
synchronized (WriteStreamWrapper.this) {
if (failed != null) {
return CompletableFuture.failedFuture(failed);
if (completed) {
return CompletableFuture.failedFuture(
Optional.ofNullable(completedException).orElseGet(CancellationException::new));
}
final CompletableFuture<WriteResponse> future = new CompletableFuture<>();
try {
Expand Down

0 comments on commit 8f317e8

Please sign in to comment.