Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50292] Add MapStatus RowCount optimize skewed job #48825

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.util.Arrays;
import java.util.Optional;
import java.util.zip.Checksum;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -112,6 +113,8 @@ final class BypassMergeSortShuffleWriter<K, V>
*/
private boolean stopping = false;

private boolean enableRowCountOptimize = false;
private boolean enableMetricsRowCountCheck = true;
BypassMergeSortShuffleWriter(
BlockManager blockManager,
BypassMergeSortShuffleHandle<K, V> handle,
Expand All @@ -131,6 +134,8 @@ final class BypassMergeSortShuffleWriter<K, V>
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
this.enableRowCountOptimize = (boolean)conf.get(package$.MODULE$.SHUFFLE_MAP_STATUS_ROW_COUNT_OPTIMIZE_SKEWED_JOB());
this.enableMetricsRowCountCheck = (boolean)conf.get(package$.MODULE$.SHUFFLE_MAP_STATUS_ROW_COUNT_METRICS_CHCEK());
this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
}

Expand Down Expand Up @@ -168,10 +173,15 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
// included in the shuffle write time.
writeMetrics.incWriteTime(System.nanoTime() - openStartTime);

long[] partitionRecords = new long[numPartitions];
while (records.hasNext()) {
final Product2<K, V> record = records.next();
final K key = record._1();
partitionWriters[partitioner.getPartition(key)].write(key, record._2());
int partitionIndex = partitioner.getPartition(key);
partitionWriters[partitionIndex].write(key, record._2());
if (enableRowCountOptimize) {
partitionRecords[partitionIndex] += 1;
}
}

for (int i = 0; i < numPartitions; i++) {
Expand All @@ -181,8 +191,32 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
}

partitionLengths = writePartitionedData(mapOutputWriter);
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
boolean checkMetricsRowCountResult = true;
if (enableRowCountOptimize && enableMetricsRowCountCheck) {
long partitionRecordsSum = Arrays.stream(partitionRecords).sum();
long metricsRecordsSum = writeMetrics.recordsWritten();
if (logger.isDebugEnabled()) {
long partitionRecordsMax = Arrays.stream(partitionRecords).max().getAsLong();
long partitionRecordsMin = Arrays.stream(partitionRecords).min().getAsLong();
logger.debug("PartitionRecords ShuffleId : {}, max : {}, min : {}, sum : {}. MetricsRecords sum : {}.",
shuffleId, partitionRecordsMax, partitionRecordsMin, partitionRecordsSum, metricsRecordsSum);
} else {
logger.info("PartitionRecords ShuffleId : {}, sum : {}. MetricsRecords sum : {}.",
shuffleId, partitionRecordsSum, metricsRecordsSum);
}
if (partitionRecordsSum != metricsRecordsSum) {
checkMetricsRowCountResult = false;
logger.info("ShuffleId : {}, MetricsRecords sum : {} not equal PartitionRecords sum : {}", shuffleId,
metricsRecordsSum, partitionRecordsSum);
}
}
if (enableRowCountOptimize && checkMetricsRowCountResult) {
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId, Option.apply(partitionRecords));
} else {
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
}
} catch (Exception e) {
try {
mapOutputWriter.abort(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream
private MyByteArrayOutputStream serBuffer;
private SerializationStream serOutputStream;

private boolean enableRowCountOptimize = false;
private boolean enableRowCountMetricsCheck = true;

/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
* and then call stop() with success = false if they get an exception, we want to make sure
Expand Down Expand Up @@ -138,6 +141,8 @@ public UnsafeShuffleWriter(
this.taskContext = taskContext;
this.sparkConf = sparkConf;
this.transferToEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_MERGE_PREFER_NIO());
this.enableRowCountOptimize = (boolean)sparkConf.get(package$.MODULE$.SHUFFLE_MAP_STATUS_ROW_COUNT_OPTIMIZE_SKEWED_JOB());
this.enableRowCountMetricsCheck = (boolean)sparkConf.get(package$.MODULE$.SHUFFLE_MAP_STATUS_ROW_COUNT_METRICS_CHCEK());
this.initialSortBufferSize =
(int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE());
this.mergeBufferSizeInBytes =
Expand Down Expand Up @@ -178,10 +183,22 @@ public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOEx
// generic throwables.
boolean success = false;
try {
long[] partitionRecords = new long[partitioner.numPartitions()];

while (records.hasNext()) {
insertRecordIntoSorter(records.next());
Product2<K, V> record = records.next();
insertRecordIntoSorter(record);
if (enableRowCountOptimize) {
int partitionId = partitioner.getPartition(record._1());
partitionRecords[partitionId] += 1;
}
}

if (enableRowCountOptimize) {
closeAndWriteOutput(partitionRecords);
} else {
closeAndWriteOutput();
}
closeAndWriteOutput();
success = true;
} finally {
if (sorter != null) {
Expand Down Expand Up @@ -217,6 +234,11 @@ private void open() throws SparkException {

@VisibleForTesting
void closeAndWriteOutput() throws IOException {
closeAndWriteOutput(null);
}

@VisibleForTesting
void closeAndWriteOutput(long[] partitionRecords) throws IOException {
assert(sorter != null);
updatePeakMemoryUsed();
serBuffer = null;
Expand All @@ -234,7 +256,31 @@ void closeAndWriteOutput() throws IOException {
}
}
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
blockManager.shuffleServerId(), partitionLengths, mapId);
if (enableRowCountOptimize && partitionRecords != null) {
boolean metricsCheckResult = true;
if (enableRowCountMetricsCheck) {
long partitionRecordsSum = Arrays.stream(partitionRecords).sum();
long metricsRecordsSum = taskContext.taskMetrics().shuffleWriteMetrics().recordsWritten();
if (logger.isDebugEnabled()) {
long partitionRecordsMax = Arrays.stream(partitionRecords).max().getAsLong();
long partitionRecordsMin = Arrays.stream(partitionRecords).min().getAsLong();
logger.debug("PartitionRecords ShuffleId : {}, max : {}, min : {}, sum : {}. MetricsRecords sum : {}.",
shuffleId, partitionRecordsMax, partitionRecordsMin, partitionRecordsSum, metricsRecordsSum);
} else {
logger.info("PartitionRecords ShuffleId : {}, sum : {}. MetricsRecords sum : {}.",
shuffleId, partitionRecordsSum, metricsRecordsSum);
}
if (partitionRecordsSum != metricsRecordsSum) {
metricsCheckResult = false;
logger.info("ShuffleId : {}, MetricsRecords sum : {} not equal PartitionRecords sum : {}", shuffleId,
metricsRecordsSum, partitionRecordsSum);
}
}
if (metricsCheckResult) {
mapStatus.updateRecordsArray(partitionRecords);
}
}
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ package org.apache.spark
* @param bytesByPartitionId approximate number of output bytes for each map output partition
* (may be inexact due to use of compressed map statuses)
*/
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long])
private[spark] class MapOutputStatistics(val shuffleId: Int, val bytesByPartitionId: Array[Long],
val recordsByPartitionId: Option[Array[Long]] = None)
78 changes: 72 additions & 6 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,34 @@

package org.apache.spark

import java.io.{ByteArrayInputStream, InputStream, IOException, ObjectInputStream, ObjectOutputStream}
import java.io.{ByteArrayInputStream, IOException, InputStream, ObjectInputStream, ObjectOutputStream}
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.locks.ReentrantReadWriteLock

import scala.collection
import scala.collection.mutable.{HashMap, ListBuffer, Map}
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream}
import org.roaringbitmap.RoaringBitmap

import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.{Logging, MDC, MessageWithContext}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus, MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId}
import org.apache.spark.util._
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}

import java.util.concurrent.atomic.AtomicLongArray

/**
* Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single
* ShuffleMapStage.
Expand Down Expand Up @@ -698,6 +697,9 @@ private[spark] class MapOutputTrackerMaster(
/** Whether to compute locality preferences for reduce tasks */
private val shuffleLocalityEnabled = conf.get(SHUFFLE_REDUCE_LOCALITY_ENABLE)

private lazy val enableOptimizeMapStatusRowCount = conf.get(
SHUFFLE_MAP_STATUS_ROW_COUNT_OPTIMIZE_SKEWED_JOB)

private val shuffleMigrationEnabled = conf.get(DECOMMISSION_ENABLED) &&
conf.get(STORAGE_DECOMMISSION_ENABLED) && conf.get(STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED)

Expand All @@ -718,6 +720,8 @@ private[spark] class MapOutputTrackerMaster(
// Exposed for testing
val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala

val mapStatusRowCount = new ConcurrentHashMap[Int, AtomicLongArray]()

private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)

// requests for MapOutputTrackerMasterMessages
Expand Down Expand Up @@ -815,6 +819,9 @@ private[spark] class MapOutputTrackerMaster(
}

def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = {
if (enableOptimizeMapStatusRowCount) {
mapStatusRowCount.put(shuffleId, new AtomicLongArray(numReduces))
}
if (pushBasedShuffleEnabled) {
if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
Expand All @@ -839,7 +846,66 @@ private[spark] class MapOutputTrackerMaster(
}

def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
if (enableOptimizeMapStatusRowCount && status != null && status.hasRowCount()) {
val array = mapStatusRowCount.get(shuffleId)
for (i <- 0 until array.length()) {
val rowCount = status.getRecordForBlock(i)
array.addAndGet(i, rowCount)
}
}
if (enableOptimizeCompressedMapStatus && status != null) {
val array = mapOutputStatisticsCache.get(shuffleId)
val uncompressedSizes = new Array[Long](array.length())
for (i <- 0 until array.length()) {
val blockSize = status.getSizeForBlock(i)
array.getAndAdd(i, blockSize)
uncompressedSizes(i) = blockSize
}
if (enableOptimizeCompressedConvertHighly && status.isInstanceOf[CompressedMapStatus]) {
// Convert to HighlyCompressedMapStatus for reduce Driver Memory
val compressedMapStatus = status.asInstanceOf[CompressedMapStatus]
val recordsArray = if (compressedMapStatus.getRecordsArray() != null) {
Some(compressedMapStatus.getRecordsArray())
} else {
None
}
val highlyStatus = HighlyCompressedMapStatus(status.location, uncompressedSizes,
status.mapId, recordsArray)
shuffleStatuses(shuffleId).addMapOutput(mapIndex, highlyStatus)
} else {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
} else {
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
}
}

def getCacheSizeArray(dep: ShuffleDependency[_, _, _]): Array[Long] = {
val numReducer = dep.partitioner.numPartitions
val statistics = mapOutputStatisticsCache.get(dep.shuffleId)
val newArray = new Array[Long](dep.partitioner.numPartitions)
for (i <- 0 until numReducer) {
newArray(i) = statistics.get(i)
}
newArray
}

def getCacheRecordArray(dep: ShuffleDependency[_, _, _]): Array[Long] = {
val statusRowCount = shuffleStatuses(dep.shuffleId).mapStatuses.
forall(status => status != null && status.hasRowCount())
logInfo(s"ShuffleId : ${dep.shuffleId}, mapStatusRowCount has rowCount : " +
s"${statusRowCount}")
if (statusRowCount) {
val atomicArray = mapStatusRowCount.get(dep.shuffleId)
val totalRecords = new Array[Long](atomicArray.length)
for (i <- 0 until atomicArray.length) {
totalRecords(i) = atomicArray.get(i)
}
logInfo(s"ShuffleId : ${dep.shuffleId} get MapStatus RowCount totalRecords")
totalRecords
} else {
Array.empty[Long]
}
}

/** Unregister map output information of the given shuffle, mapper and block manager */
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,22 @@ package object config {
.checkValue(v => v > 0, "The threshold should be positive.")
.createWithDefault(10000000)

private[spark] val SHUFFLE_MAP_STATUS_ROW_COUNT_OPTIMIZE_SKEWED_JOB =
ConfigBuilder("spark.shuffle.mapStatus.rowCount.optimize.skewed.job")
.internal()
.doc("MapStaus use rowCount metrics optimize skewed job")
.version("2.3.0")
.booleanConf
.createWithDefault(false)

private[spark] val SHUFFLE_MAP_STATUS_ROW_COUNT_METRICS_CHCEK =
ConfigBuilder("spark.shuffle.mapStatus.rowCount.metrics.check")
.internal()
.doc("MapStaus use metrics check rowCount")
.version("2.3.0")
.booleanConf
.createWithDefault(true)

private[spark] val MAX_RESULT_SIZE = ConfigBuilder("spark.driver.maxResultSize")
.doc("Size limit for results.")
.version("1.2.0")
Expand Down
Loading