From 4467b512cbd34f5be9c995de9fd225248217b8a6 Mon Sep 17 00:00:00 2001 From: Grzegorz Piwowarek Date: Wed, 21 Aug 2024 09:52:04 +0200 Subject: [PATCH] Implement BatchingSpliterator#trySplit --- .../collectors/BatchingSpliterator.java | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/pivovarit/collectors/BatchingSpliterator.java b/src/main/java/com/pivovarit/collectors/BatchingSpliterator.java index 8910795c..d0f9dc66 100644 --- a/src/main/java/com/pivovarit/collectors/BatchingSpliterator.java +++ b/src/main/java/com/pivovarit/collectors/BatchingSpliterator.java @@ -6,6 +6,7 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; +import java.util.stream.StreamSupport; import static java.util.stream.Stream.empty; import static java.util.stream.Stream.of; @@ -16,8 +17,8 @@ */ final class BatchingSpliterator implements Spliterator> { - private final List source; - private final int maxChunks; + private List source; + private int maxChunks; private int chunks; private int chunkSize; @@ -80,6 +81,17 @@ public boolean tryAdvance(Consumer> action) { @Override public Spliterator> trySplit() { + if (actualBatchCount(source, chunks) > 1 || consumed == 0 ) { + var first = source.subList(0, source.size() / 2); + var second = source.subList(source.size() / 2, source.size()); + var originalChunks = chunks; + + source = first; + chunks = originalChunks % 2 == 0 ? originalChunks / 2 : originalChunks / 2 + 1; + maxChunks = Math.min(source.size(), chunks); + chunkSize = (int) Math.ceil(((double) source.size()) / chunks); + return new BatchingSpliterator<>(second, originalChunks / 2); + } return null; } @@ -92,4 +104,25 @@ public long estimateSize() { public int characteristics() { return ORDERED | SIZED; } + + private static int actualBatchCount(List list, int numberOfBatches) { + int batchSize = list.size() / numberOfBatches; + int remainder = list.size() % numberOfBatches; + + int batches = 0; + int currentIndex = 0; + + for (int i = 0; i < numberOfBatches; i++) { + int currentBatchSize = batchSize + (remainder > 0 ? 1 : 0); + remainder--; + + int nextIndex = Math.min(currentIndex + currentBatchSize, list.size()); + if (currentIndex < nextIndex) { + batches++; + } + currentIndex = nextIndex; + } + + return batches; + } }