Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use bidirectional streaming RPCs for scan operations in DistributedStorageService #235

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package com.scalar.db.storage.rpc;

import com.scalar.db.api.Result;
import com.scalar.db.api.Scan;
import com.scalar.db.api.TableMetadata;
import com.scalar.db.exception.storage.ExecutionException;
import com.scalar.db.rpc.DistributedStorageGrpc;
import com.scalar.db.rpc.ScanRequest;
import com.scalar.db.rpc.ScanResponse;
import com.scalar.db.util.ProtoUtil;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;

@NotThreadSafe
public class GrpcScanOnBidirectionalStream implements StreamObserver<ScanResponse> {

private final StreamObserver<ScanRequest> requestObserver;
private final TableMetadata metadata;
private final BlockingQueue<ResponseOrError> queue = new LinkedBlockingQueue<>();
private final AtomicBoolean hasMoreResults = new AtomicBoolean(true);

public GrpcScanOnBidirectionalStream(
DistributedStorageGrpc.DistributedStorageStub stub, TableMetadata metadata) {
this.metadata = metadata;
requestObserver = stub.scan(this);
}

@Override
public void onNext(ScanResponse response) {
try {
queue.put(new ResponseOrError(response));
} catch (InterruptedException ignored) {
// InterruptedException should not be thrown
}
}

@Override
public void onError(Throwable t) {
try {
queue.put(new ResponseOrError(t));
} catch (InterruptedException ignored) {
// InterruptedException should not be thrown
}
}

@Override
public void onCompleted() {}

private ResponseOrError sendRequest(ScanRequest request) {
requestObserver.onNext(request);
try {
return queue.take();
} catch (InterruptedException ignored) {
// InterruptedException should not be thrown
return null;
}
}

private void throwIfScannerHasNoMoreResults() {
if (!hasMoreResults.get()) {
throw new IllegalStateException("the scan operation has no more results");
}
}

private void throwIfError(ResponseOrError responseOrError) throws ExecutionException {
if (responseOrError.isError()) {
hasMoreResults.set(false);
Throwable error = responseOrError.getError();
if (error instanceof StatusRuntimeException) {
StatusRuntimeException e = (StatusRuntimeException) error;
if (e.getStatus().getCode() == Code.INVALID_ARGUMENT) {
throw new IllegalArgumentException(e.getMessage());
}
throw new ExecutionException(e.getMessage());
}
if (error instanceof Error) {
throw (Error) error;
}
throw new ExecutionException(error.getMessage());
}
}

private List<Result> getResults(ScanResponse response) {
if (!response.getHasMoreResults()) {
hasMoreResults.set(false);
}
return response.getResultList().stream()
.map(r -> ProtoUtil.toResult(r, metadata))
.collect(Collectors.toList());
}

public List<Result> openScanner(Scan scan) throws ExecutionException {
throwIfScannerHasNoMoreResults();
ResponseOrError responseOrError =
sendRequest(ScanRequest.newBuilder().setScan(ProtoUtil.toScan(scan)).build());
throwIfError(responseOrError);
return getResults(responseOrError.getResponse());
}

public List<Result> next() throws ExecutionException {
throwIfScannerHasNoMoreResults();
ResponseOrError responseOrError = sendRequest(ScanRequest.getDefaultInstance());
throwIfError(responseOrError);
return getResults(responseOrError.getResponse());
}

public List<Result> next(int fetchCount) throws ExecutionException {
throwIfScannerHasNoMoreResults();
ResponseOrError responseOrError =
sendRequest(ScanRequest.newBuilder().setFetchCount(fetchCount).build());
throwIfError(responseOrError);
return getResults(responseOrError.getResponse());
}

public void closeScanner() {
if (!hasMoreResults.get()) {
return;
}
requestObserver.onCompleted();
hasMoreResults.set(false);
}

public boolean hasMoreResults() {
return hasMoreResults.get();
}

private static class ResponseOrError {
private final ScanResponse response;
private final Throwable error;

public ResponseOrError(ScanResponse response) {
this.response = response;
this.error = null;
}

public ResponseOrError(Throwable error) {
this.response = null;
this.error = error;
}

private boolean isError() {
return error != null;
}

public ScanResponse getResponse() {
return response;
}

public Throwable getError() {
return error;
}
}
}
31 changes: 16 additions & 15 deletions core/src/main/java/com/scalar/db/storage/rpc/GrpcStorage.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.scalar.db.rpc.GetResponse;
import com.scalar.db.rpc.MutateRequest;
import com.scalar.db.util.ProtoUtil;
import com.scalar.db.util.ThrowableSupplier;
import com.scalar.db.util.Utility;
import io.grpc.ManagedChannel;
import io.grpc.Status.Code;
Expand All @@ -28,7 +29,6 @@
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import javax.annotation.concurrent.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -38,7 +38,8 @@ public class GrpcStorage implements DistributedStorage {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcStorage.class);

private final ManagedChannel channel;
private final DistributedStorageGrpc.DistributedStorageBlockingStub stub;
private final DistributedStorageGrpc.DistributedStorageStub stub;
private final DistributedStorageGrpc.DistributedStorageBlockingStub blockingStub;
private final GrpcTableMetadataManager metadataManager;

private Optional<String> namespace;
Expand All @@ -50,7 +51,8 @@ public GrpcStorage(DatabaseConfig config) {
NettyChannelBuilder.forAddress(config.getContactPoints().get(0), config.getContactPort())
.usePlaintext()
.build();
stub = DistributedStorageGrpc.newBlockingStub(channel);
stub = DistributedStorageGrpc.newStub(channel);
blockingStub = DistributedStorageGrpc.newBlockingStub(channel);
metadataManager =
new GrpcTableMetadataManager(DistributedStorageAdminGrpc.newBlockingStub(channel));
namespace = Optional.empty();
Expand All @@ -59,10 +61,12 @@ public GrpcStorage(DatabaseConfig config) {

@VisibleForTesting
GrpcStorage(
DistributedStorageGrpc.DistributedStorageBlockingStub stub,
DistributedStorageGrpc.DistributedStorageStub stub,
DistributedStorageGrpc.DistributedStorageBlockingStub blockingStub,
GrpcTableMetadataManager metadataManager) {
channel = null;
this.stub = stub;
this.blockingStub = blockingStub;
this.metadataManager = metadataManager;
namespace = Optional.empty();
tableName = Optional.empty();
Expand Down Expand Up @@ -101,7 +105,7 @@ public Optional<Result> get(Get get) throws ExecutionException {
Utility.setTargetToIfNot(get, namespace, tableName);

GetResponse response =
stub.get(GetRequest.newBuilder().setGet(ProtoUtil.toGet(get)).build());
blockingStub.get(GetRequest.newBuilder().setGet(ProtoUtil.toGet(get)).build());
if (response.hasResult()) {
TableMetadata tableMetadata = metadataManager.getTableMetadata(get);
return Optional.of(ProtoUtil.toResult(response.getResult(), tableMetadata));
Expand All @@ -112,13 +116,9 @@ public Optional<Result> get(Get get) throws ExecutionException {

@Override
public Scanner scan(Scan scan) throws ExecutionException {
return execute(
() -> {
Utility.setTargetToIfNot(scan, namespace, tableName);

TableMetadata tableMetadata = metadataManager.getTableMetadata(scan);
return new ScannerImpl(scan, stub, tableMetadata);
});
Utility.setTargetToIfNot(scan, namespace, tableName);
TableMetadata tableMetadata = metadataManager.getTableMetadata(scan);
return new ScannerImpl(scan, stub, tableMetadata);
}

@Override
Expand Down Expand Up @@ -146,7 +146,7 @@ private void mutate(Mutation mutation) throws ExecutionException {
() -> {
Utility.setTargetToIfNot(mutation, namespace, tableName);

stub.mutate(
blockingStub.mutate(
MutateRequest.newBuilder().addMutation(ProtoUtil.toMutation(mutation)).build());
return null;
});
Expand All @@ -160,12 +160,13 @@ public void mutate(List<? extends Mutation> mutations) throws ExecutionException

MutateRequest.Builder builder = MutateRequest.newBuilder();
mutations.forEach(m -> builder.addMutation(ProtoUtil.toMutation(m)));
stub.mutate(builder.build());
blockingStub.mutate(builder.build());
return null;
});
}

static <T> T execute(Supplier<T> supplier) throws ExecutionException {
static <T> T execute(ThrowableSupplier<T, ExecutionException> supplier)
throws ExecutionException {
try {
return supplier.get();
} catch (StatusRuntimeException e) {
Expand Down
47 changes: 9 additions & 38 deletions core/src/main/java/com/scalar/db/storage/rpc/ScannerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,28 @@
import com.scalar.db.api.Scanner;
import com.scalar.db.api.TableMetadata;
import com.scalar.db.exception.storage.ExecutionException;
import com.scalar.db.rpc.CloseScannerRequest;
import com.scalar.db.rpc.DistributedStorageGrpc;
import com.scalar.db.rpc.OpenScannerRequest;
import com.scalar.db.rpc.OpenScannerResponse;
import com.scalar.db.rpc.ScanNextRequest;
import com.scalar.db.rpc.ScanNextResponse;
import com.scalar.db.storage.common.ScannerIterator;
import com.scalar.db.util.ProtoUtil;
import io.grpc.StatusRuntimeException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;

@NotThreadSafe
public class ScannerImpl implements Scanner {
private final DistributedStorageGrpc.DistributedStorageBlockingStub stub;
private final TableMetadata metadata;
private final String scannerId;

private final GrpcScanOnBidirectionalStream stream;

private List<Result> results;
private boolean hasMoreResults;

public ScannerImpl(
Scan scan,
DistributedStorageGrpc.DistributedStorageBlockingStub stub,
TableMetadata metadata) {
this.stub = stub;
this.metadata = metadata;

OpenScannerResponse response =
stub.openScanner(OpenScannerRequest.newBuilder().setScan(ProtoUtil.toScan(scan)).build());
scannerId = response.getScannerId();
results =
response.getResultList().stream()
.map(r -> ProtoUtil.toResult(r, metadata))
.collect(Collectors.toList());
hasMoreResults = response.getHasMoreResults();
Scan scan, DistributedStorageGrpc.DistributedStorageStub stub, TableMetadata metadata)
throws ExecutionException {
stream = new GrpcScanOnBidirectionalStream(stub, metadata);
results = stream.openScanner(scan);
}

@Override
Expand All @@ -56,14 +37,8 @@ public Optional<Result> one() throws ExecutionException {
return Optional.empty();
}
Result result = results.remove(0);
if (results.isEmpty() && hasMoreResults) {
ScanNextResponse response =
stub.scanNext(ScanNextRequest.newBuilder().setScannerId(scannerId).build());
results =
response.getResultList().stream()
.map(r -> ProtoUtil.toResult(r, metadata))
.collect(Collectors.toList());
hasMoreResults = response.getHasMoreResults();
if (results.isEmpty() && stream.hasMoreResults()) {
results = stream.next();
}
return Optional.of(result);
});
Expand All @@ -86,11 +61,7 @@ public List<Result> all() throws ExecutionException {
@Override
public void close() throws IOException {
try {
// if hasMoreResult is false, the scanner should already be closed. So we don't need to close
// it here
if (hasMoreResults) {
stub.closeScanner(CloseScannerRequest.newBuilder().setScannerId(scannerId).build());
}
stream.closeScanner();
} catch (StatusRuntimeException e) {
throw new IOException("failed to close the scanner", e);
}
Expand Down
Loading