Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up streamed-proto query output by distributing work to multiple threads #24305

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,26 @@
// limitations under the License.
package com.google.devtools.build.lib.query2.query.output;

import com.google.common.collect.Iterables;
import com.google.devtools.build.lib.packages.LabelPrinter;
import com.google.devtools.build.lib.packages.Target;
import com.google.devtools.build.lib.query2.engine.OutputFormatterCallback;
import com.google.devtools.build.lib.query2.proto.proto2api.Build;
import com.google.protobuf.CodedOutputStream;

import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* An output formatter that outputs a protocol buffer representation of a query result and outputs
* the proto bytes to the output print stream. By taking the bytes and calling {@code mergeFrom()}
* on a {@code Build.QueryResult} object the full result can be reconstructed.
*/
public class StreamedProtoOutputFormatter extends ProtoOutputFormatter {

@Override
public String getName() {
return "streamed_proto";
Expand All @@ -34,13 +42,107 @@ public String getName() {
public OutputFormatterCallback<Target> createPostFactoStreamCallback(
final OutputStream out, final QueryOptions options, LabelPrinter labelPrinter) {
return new OutputFormatterCallback<Target>() {
private static final int MAX_CHUNKS_IN_QUEUE = Runtime.getRuntime().availableProcessors() * 2;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used ×2 to be safe, but I believe this actually just needs to be Runtime.getRuntime().availableProcessors(). Basically we just need to know that, each time the consumer pulls a chunk of byte arrays, some CPU is working on producing one to fill that spot.

private static final int TARGETS_PER_CHUNK = 500;

private final LabelPrinter ourLabelPrinter = labelPrinter;

@Override
public void processOutput(Iterable<Target> partialResult)
throws IOException, InterruptedException {
for (Target target : partialResult) {
toTargetProtoBuffer(target, labelPrinter).writeDelimitedTo(out);
ForkJoinTask<?> writeAllTargetsFuture;
try (ForkJoinPool executor =
new ForkJoinPool(
Runtime.getRuntime().availableProcessors(),
ForkJoinPool.defaultForkJoinWorkerThreadFactory,
null,
// we use asyncMode to ensure the queue is processed FIFO, which maximizes
// throughput
true)) {
var targetQueue = new LinkedBlockingQueue<Future<List<byte[]>>>(MAX_CHUNKS_IN_QUEUE);
var stillAddingTargetsToQueue = new AtomicBoolean(true);
writeAllTargetsFuture =
executor.submit(
() -> {
try {
while (stillAddingTargetsToQueue.get() || !targetQueue.isEmpty()) {
Future<List<byte[]>> targets = targetQueue.take();
for (byte[] target : targets.get()) {
out.write(target);
}
}
} catch (InterruptedException e) {
throw new WrappedInterruptedException(e);
} catch (IOException e) {
throw new WrappedIOException(e);
} catch (ExecutionException e) {
// TODO: figure out what might be in here and propagate
throw new RuntimeException(e);
}
});
try {
for (List<Target> targets : Iterables.partition(partialResult, TARGETS_PER_CHUNK)) {
targetQueue.put(executor.submit(() -> writeTargetsDelimitedToByteArrays(targets)));
}
} finally {
stillAddingTargetsToQueue.set(false);
}
}
try {
writeAllTargetsFuture.get();
} catch (ExecutionException e) {
// TODO: propagate
throw new RuntimeException(e);
}
}

private List<byte[]> writeTargetsDelimitedToByteArrays(List<Target> targets) {
return targets.stream().map(target -> writeDelimited(toProto(target))).toList();
}

private Build.Target toProto(Target target) {
try {
return toTargetProtoBuffer(target, ourLabelPrinter);
} catch (InterruptedException e) {
throw new WrappedInterruptedException(e);
}
}
};
}

private static byte[] writeDelimited(Build.Target targetProtoBuffer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a significant performance benefit to converting them to byte[] instead of just leaving them as Build.Target protos for the consumer to write?

If most of the benefit we gain comes from parallelizing toTargetProtoBuffer(), then perhaps we could reduce the complexity here and just deal with writing protos delimited to the output stream?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting precedent: https://cs.opensource.google/bazel/bazel/+/master:src/main/java/com/google/devtools/build/lib/runtime/ExecutionGraphModule.java;l=638. The byte representation probably takes up less memory while in the queue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to leave it as is since this probably quite memory intensive

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(do leave a comment here with regards to that)

try {
var serializedSize = targetProtoBuffer.getSerializedSize();
var headerSize = CodedOutputStream.computeUInt32SizeNoTag(serializedSize);
var output = new byte[headerSize + serializedSize];
var codedOut = CodedOutputStream.newInstance(output, headerSize, output.length - headerSize);
targetProtoBuffer.writeTo(codedOut);
codedOut.flush();
return output;
} catch (IOException e) {
throw new WrappedIOException(e);
}
}

private static class WrappedIOException extends RuntimeException {
private WrappedIOException(IOException cause) {
super(cause);
}

@Override
public IOException getCause() {
return (IOException) super.getCause();
}
}

private static class WrappedInterruptedException extends RuntimeException {
private WrappedInterruptedException(InterruptedException cause) {
super(cause);
}

@Override
public InterruptedException getCause() {
return (InterruptedException) super.getCause();
}
}
}