-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up streamed-proto query output by distributing work to multiple threads #24305
base: master
Are you sure you want to change the base?
Changes from all commits
6bd312b
41fbacf
0c9a1d2
913d4a3
9843a5e
9a0efa0
1852be0
89e8b3b
5fc8b13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,18 +13,26 @@ | |
// limitations under the License. | ||
package com.google.devtools.build.lib.query2.query.output; | ||
|
||
import com.google.common.collect.Iterables; | ||
import com.google.devtools.build.lib.packages.LabelPrinter; | ||
import com.google.devtools.build.lib.packages.Target; | ||
import com.google.devtools.build.lib.query2.engine.OutputFormatterCallback; | ||
import com.google.devtools.build.lib.query2.proto.proto2api.Build; | ||
import com.google.protobuf.CodedOutputStream; | ||
|
||
import java.io.IOException; | ||
import java.io.OutputStream; | ||
import java.util.List; | ||
import java.util.concurrent.*; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
|
||
/** | ||
* An output formatter that outputs a protocol buffer representation of a query result and outputs | ||
* the proto bytes to the output print stream. By taking the bytes and calling {@code mergeFrom()} | ||
* on a {@code Build.QueryResult} object the full result can be reconstructed. | ||
*/ | ||
public class StreamedProtoOutputFormatter extends ProtoOutputFormatter { | ||
|
||
@Override | ||
public String getName() { | ||
return "streamed_proto"; | ||
|
@@ -34,13 +42,107 @@ public String getName() { | |
public OutputFormatterCallback<Target> createPostFactoStreamCallback( | ||
final OutputStream out, final QueryOptions options, LabelPrinter labelPrinter) { | ||
return new OutputFormatterCallback<Target>() { | ||
private static final int MAX_CHUNKS_IN_QUEUE = Runtime.getRuntime().availableProcessors() * 2; | ||
private static final int TARGETS_PER_CHUNK = 500; | ||
|
||
private final LabelPrinter ourLabelPrinter = labelPrinter; | ||
|
||
@Override | ||
public void processOutput(Iterable<Target> partialResult) | ||
throws IOException, InterruptedException { | ||
for (Target target : partialResult) { | ||
toTargetProtoBuffer(target, labelPrinter).writeDelimitedTo(out); | ||
ForkJoinTask<?> writeAllTargetsFuture; | ||
try (ForkJoinPool executor = | ||
new ForkJoinPool( | ||
Runtime.getRuntime().availableProcessors(), | ||
ForkJoinPool.defaultForkJoinWorkerThreadFactory, | ||
null, | ||
// we use asyncMode to ensure the queue is processed FIFO, which maximizes | ||
// throughput | ||
true)) { | ||
var targetQueue = new LinkedBlockingQueue<Future<List<byte[]>>>(MAX_CHUNKS_IN_QUEUE); | ||
var stillAddingTargetsToQueue = new AtomicBoolean(true); | ||
writeAllTargetsFuture = | ||
executor.submit( | ||
() -> { | ||
try { | ||
while (stillAddingTargetsToQueue.get() || !targetQueue.isEmpty()) { | ||
Future<List<byte[]>> targets = targetQueue.take(); | ||
for (byte[] target : targets.get()) { | ||
out.write(target); | ||
} | ||
} | ||
} catch (InterruptedException e) { | ||
throw new WrappedInterruptedException(e); | ||
} catch (IOException e) { | ||
throw new WrappedIOException(e); | ||
} catch (ExecutionException e) { | ||
// TODO: figure out what might be in here and propagate | ||
throw new RuntimeException(e); | ||
} | ||
}); | ||
try { | ||
for (List<Target> targets : Iterables.partition(partialResult, TARGETS_PER_CHUNK)) { | ||
targetQueue.put(executor.submit(() -> writeTargetsDelimitedToByteArrays(targets))); | ||
} | ||
} finally { | ||
stillAddingTargetsToQueue.set(false); | ||
} | ||
} | ||
try { | ||
writeAllTargetsFuture.get(); | ||
} catch (ExecutionException e) { | ||
// TODO: propagate | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private List<byte[]> writeTargetsDelimitedToByteArrays(List<Target> targets) { | ||
return targets.stream().map(target -> writeDelimited(toProto(target))).toList(); | ||
} | ||
|
||
private Build.Target toProto(Target target) { | ||
try { | ||
return toTargetProtoBuffer(target, ourLabelPrinter); | ||
} catch (InterruptedException e) { | ||
throw new WrappedInterruptedException(e); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
private static byte[] writeDelimited(Build.Target targetProtoBuffer) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a significant performance benefit to converting them to If most of the benefit we gain comes from parallelizing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just noting precedent: https://cs.opensource.google/bazel/bazel/+/master:src/main/java/com/google/devtools/build/lib/runtime/ExecutionGraphModule.java;l=638. The byte representation probably takes up less memory while in the queue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. happy to leave it as is since this probably quite memory intensive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (do leave a comment here with regards to that) |
||
try { | ||
var serializedSize = targetProtoBuffer.getSerializedSize(); | ||
var headerSize = CodedOutputStream.computeUInt32SizeNoTag(serializedSize); | ||
var output = new byte[headerSize + serializedSize]; | ||
var codedOut = CodedOutputStream.newInstance(output, headerSize, output.length - headerSize); | ||
targetProtoBuffer.writeTo(codedOut); | ||
codedOut.flush(); | ||
return output; | ||
} catch (IOException e) { | ||
throw new WrappedIOException(e); | ||
} | ||
} | ||
|
||
private static class WrappedIOException extends RuntimeException { | ||
private WrappedIOException(IOException cause) { | ||
super(cause); | ||
} | ||
|
||
@Override | ||
public IOException getCause() { | ||
return (IOException) super.getCause(); | ||
} | ||
} | ||
|
||
private static class WrappedInterruptedException extends RuntimeException { | ||
private WrappedInterruptedException(InterruptedException cause) { | ||
super(cause); | ||
} | ||
|
||
@Override | ||
public InterruptedException getCause() { | ||
return (InterruptedException) super.getCause(); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used ×2 to be safe, but I believe this actually just needs to be
Runtime.getRuntime().availableProcessors()
. Basically we just need to know that, each time the consumer pulls a chunk of byte arrays, some CPU is working on producing one to fill that spot.