Spark shuffle 原理
关于 shuffle 是什么,本文就不再介绍了,想了解的自己百度。
简而言之:在分布式环境下,shuffle 是造成性能差的主要原因。所以,了解 shuffle 的原理,以及了解针对 shuffle 的优化,是学习分布式计算迈不过去的槛儿。
常见的分布式计算框架,主要是 hadoop 的 Map/Reduce 和 spark。在 spark 刚推出的时期,spark 运算速度号称最多能达到 M/R 的 100 倍,即使是现在,M/R 也主要 用在日志文件处理等时效性不强的应用场景中,而 spark 却可以在时效性很强的场景中甚至流式计算中表现出很好的适应性。造成这一差异的关键原因,网上有很多介绍, 但个人对 M/R 的原理并没有深入了解过,这里无法说明,但主要的原因,应该是 spark 的 DAG 计算模式可以减少 shuffle 次数。但必须知道, 无论是 M/R 还是 spark 都一直在持续对 shuffle 这一性能杀手进行了优化。本文主要是分析 spark 的 shuffle 原理,后续再分析 M/R 的 shuffle 原理和优化,并 比较 spark 与 M/R 之间的性能差异。
spark shuffle 的原理与优化其实是一个很复杂的逻辑,涉及了许多类以及相互转换,在阅读源码的过程中,个人并没有对这一过程完全理解,所以发现文中的错误,麻 烦发邮件到 mzl9039@163.com, 谢谢指正。
Spark shuffle 的历史
从 spark 最早推出开始,就带有了 shuffle 机制,但在 spark 的 shuffle 机制经过了多次优化,目前spark 2.1.X 默认采用 Sort-Based shuffle, 而非早期的 HashShuffle.
从配置项 spark.shuffle.manager 来看,spark 2.1.3-SNAPSHOT 版本中,支持 sort shuffle 和 tungsten-sort shuffle,但创建 ShuffleManager 时,两种方式 创建的 ShuffleManager 都是 SortShuffleManager,所以这两种配置并没有差别,即 spark 只支持一种 shuffle 方法:SortShuffleManager.
/** Let the user specify short names for shuffle managers */
val shortShuffleMgrNames = Map(
"sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
"tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)
但是,这只代表 spark shuffle 全部是基于 sort 的,并不代表其它 spark shuffle 只有单一的这一种,例如:关于 spark shuffle writer 有 UnsafeShuffleWriter、 BypassMergeSortShuffleWriter、SortShuffleWriter 三种,针对不同的情况有不同的优化实现。
Spark shuffle 涉及的类
Spark shuffle 逻辑比较复杂,涉及的类也比较多,这里做简单的介绍:
- ShuffleManager:shuffle system 的插件式接口。shuffleManager 在 SparkEnv 中被创建,在 driver 端和 executor 端均存在,由配置项 spark.shuffle.manager 设置; driver 端用它来注册 shuffle, 而 executor 端通过它拉取或写入 shuffle 数据。
- SortShuffleManager: 一个 sort-based shuffle,所有的记录都根据它们的目标 partition 的 id 排过序,然后写入到一个单独的 map output file. Reducers 通过 拉取这个文件的连续的 region 来读取自己部分的 map output. 为防止 map output data 数据超过内存限制,output 的排序过的一部分会写入磁盘,而这些磁盘上的文件 会 merge 到一个最终的 output file 中去。
- ShuffleHandle: 一个不透明的 handle,shuffleManager 用它来传递 shuffle 信息给 task。
- BaseShuffleHandle: 一个基本的 ShuffleHandle 的实现,仅用来获取 registerShuffle 的信息。
- SerializedShuffleHandle: BaseShuffleHandle 的子类,用于唯一标识我们已经选择使用的 serialized shuffle.
- BypassMergeSortShuffleHandle: BaseShuffleHandle 的子类,用于唯一标识我们已经选择使用的 bypass merge sort shuffle path.
- ShuffleWriter: 在 map task 内获取,用于将 records 写入 shuffle system.
- SortShuffleWriter: 对应 ShuffleHandle 为 BaseShuffleHandle 的 writer.
- UnsafeShuffleWriter: 对应 ShuffleHandle 为 SerializedShuffleHandle 的 writer.
- BypassMergeSortShuffleWriter: 对应 ShuffleHandle 为 BypassMergeSortShuffleHandle 的 writer. 这个类是 spark 早期的 HashShuffle 衍生而来,而早期的 HashShuffleWriter 很像。
- ShuffleReader: 在 reduce task 内获取,用于读取 mapper 合并过的 records。
- BlockStoreShuffleReader: 通过向其它节点的 block store 发送请求,从 shuffle 中按 range 范围 [startPartition, endPartition) 拉取和读取 partitions.
- ShuffleBlockResolver: 这个接口的实现知道如何获取某个逻辑上的 shuffle block 的 block data. 实现可能需要使用 files 或 file segment 来封装 shuffle data. 这个主要在 BlockStore 中使用,当获取 shuffle data 时来抽象不同 shuffle 的实现。
- IndexShuffleBlockResolver: 用于创建和维护 shuffle block 的 logic block 和 physical file location 之间的映射关系。同一个 map task 的 shuffle block 数据 被存储到一个单独的 consolidated data file(合并的数据文件) 中。而这个 data block 在这个 data file 中的偏移量则被存储在一个另外的 index file 中。使用 shuffle data 的 shuffleBlockId 和 reduce ID 作用名字,”.data” 是文件后缀,”.index” 是索引文件后缀.
- ExternalSorter: 排序以及潜在地合并一些 (K,V) 类型的 key-value pair 来生成 (K,C) 类型的 key-combiner 类型的 pair. 使用 Partitioner, 首先将 key 分组 放到 partitions 中,其次可能利用常见的 Comparator 将每个 partition 中的 keys 进行排序。能输出一个单独的 partitioned file, 这个文件记录了每个 partition 的 不同 byte range, 适用于后面 fetch shuffle.
- PartitionedAppendOnlyMap: 接口 WritablePartitionedPairCollection 的一个实现,同时是类 AppendOnlyMap 的子类。它的 key 是 (partitionId, key) 组成的元组, 它的 value 项是 combiner.
- AppendOnlyMap: 一种开放的 hash table 的优化,仅针对只有追加的使用场景,即 key 不会被删除,但每个 key 的 value 可能会被更新。它不是用链表实现的,而是 用数组 Array 实现的,数据结构为:key1,value1,key2,value2…这样,详见引用3.
- DiskBlockObjectWriter: 将 JVM 中的对象直接写入文件的类。这个类允许把数据追加到已经存在的 block。为了效率,它持有底层的文件通道。这个通道会保持 open 状态, 直到 DiskBlockObjectWriter 的 close() 方法被调用。为了防止出现错误(如正在写的过程中出错了), 调用者需要调用方法 revertPartialWritesAndClose 而不是 close 方法, 来自动 revert 那些未提交的 write 操作。
- ShuffleBlockFetcherIterator: 一个获取多个 blocks 的迭代器。对于本地的 blocks, 它会从本地的 blockManager 拉取;对于远程的 blocks, 它使用 BlockTransferService 来拉取。它会创建一个 (BlockId, InputStream) 这样的 tuple 的迭代器,以保证调用者要接收到数据时像流水线一样。另外它能限制从远程拉取的速度,来保证拉取时不会 超过 maxBytesInFlight, 从而不会使用太多内存.
- ShuffleClient: 读取 shuffle file 的接口,即可以是 Executor 端也可以是外部 service 端
- NettyBlockTransferService: 使用 netty 来一次拉取多个 blocks 的一个 BlockTransferService, 是 ShuffleClient 的子类。
Spark shuffle 的流程
Spark shuffle 是在任务运行中发生的。从博客-Spark 任务分发与执行流程 中,我们知道,当 task 在 Executor 上执行时,任务 task 是通过逆序列化得到的,同样的还包括 taskFiles/taskJars/taskProps,然后调用得到的 task 的 run 方法执行方法, 而这个 run 方法则会在调用 runTask 方法来执行。
我们知道,Task 有两类:ShuffleMapTask 和 ResultTask. 它们有一个区别:ShuffleMapTask 的 runTask 方法返回值为 MapStatus,而 ResultTask 的 runTask 返回值为 当前 Task 执行结果的类型。我们关注 shuffle,因此我们关注 ShuffleMapTask 的 runTask 方法,因为这里是 shuffle 写入开始的地方。
/** ShuffleMapTask 的 runTask 方法 */
override def runTask(context: TaskContext): MapStatus = {
/** Deserialize the RDD using the broadcast variable. */
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
var writer: ShuffleWriter[Any, Any] = null
try {
/** 获取到 ShuffleManager, 当前版本中只有 SortShuffleManager */
val manager = SparkEnv.get.shuffleManager
/** 根据 dep.shuffleHandle 获取 ShuffleWriter, 注意 shuffleHandle 是 dep 的属性 */
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
/** 这里开始写 shuffle. shuffle 写从这里开始,但需要首先分析 rdd.iterator */
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
Rdd 的 iterator 方法
由于不同的 RDD 类型不同,调用顺序不同等原因,对分析有很大影响,所以这里我们举个官方的例子:LogQuery.scala. 在这个类中,rdd 分别是:
- ShuffledRDD –> reduceByKey
- MapPartitionsRDD –> map
- ParallelCollectionRDD –> parallelize
对于方法调用 rdd.iterator(partition, context)
, iterator 方法如下:
/** Internal method to this RDD; will read from cache if applicable, or otherwise compute it. */
/** This should ''not'' be called by users directly, but is available for implementors of custom */
/** subclasses of RDD. */
/** 由于这几个 RDD 的 storageLevel 为 NONE, 所以这里调用的是 computeOrReadCheckpoint 方法 */
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}
/** Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing. */
/** 这几个 rdd 都没有做过 checkpoint,所以 isCheckpointedAndMaterialized 为 false, 调用方法 compute. */
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
}
从 compute 的调用顺序来看,LogQuery 的调用,是先调用 MapPartitionsRDD 的 compute,然后是 ParallelCollectionRDD 的 compute, 最后是 ShuffledRDD 的 compute, 因为 MapPartitionsRDD 和 ParallelCollectionRDD 在同一个 stage(ShuffleMapStage), 而 ShuffledRDD 则在第二个 stage(ResultStage),先提交的是第一个 stage.
在 MapPartitionsRDD 的 compute 中,调用了 firstParent[T].iterator,
由于 firstParent 即为 ParallelCollectionRDD, 即在这里调用了 ParallelCollectionRDD 的 iterator 方法,并在这个方法里调用了其 compute 方法。
而对于 ParallelCollectionRDD 的 iterator 方法,时面返回了一个 InterruptibleIterator,这个类其实只是一个委托,正在起作用的还是参数
s.asInstanceOf[ParallelCollectionPartition[T]].iterator
,这里的 s 是 Partition.
/** compute 方法在 RDD.scala 中是一个钩子函数(类似于抽象类中的抽象函数),是由子类实现的函数, 因此不同的 */
/** RDD 类型实现不同, 例如: */
/** ShuffledRDD.compute */
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
/** 注意这里的 compute 已经是 reducer task 阶段,去读取 shuffle 信息 */
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
/** MapPartitionsRDD.compute */
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
/** ParallelCollectionRDD.compute */
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
/** s.asInstanceOf[ParallelCollectionPartition[T]].iterator 在这里,最终转为了初始数据的迭代器 values.iterator */
/** 对 LogQuery.scala 而言,这里的 iterator 就是最初的 exampleApacheLogs 经过 RDD 切分之后某个 partition 的 iterator */
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
Spark shuffle 的写入
上一节我们以 LogQuery 为例,分析了其在 ShuffleMapTask 阶段,对 rdd 的 iterator 的调用顺序。这一节我们尝试分析 ShuffleWriter 在 write 方法中 对 shuffle 进行写入的逻辑。
在上面分析 spark shuffle 流程时,我们知道 spark 是调用 writer 的 write 方法对 shuffle 进行写入的。这里我们以 SortShuffleWriter 为例进行分析。
SortShuffleWriter 的 write 方法
SortShuffleWriter 的 write 方法,主要是初始化了一个 ExternalSorter, 并将 rdd 的 partition 信息 insert sorter 内部的 PartitionedAppendOnlyMap 里。 如果这个 Map/Buffer 太大,则会将部分信息写入磁盘,即 spill 操作。
注意:有些类型的 RDD 在 compute 方法中,会直接返回一个 ExternalAppendOnlyMap, 这个类型和 ExternalSorter 很像,用在 Aggregator 和 CoGroupedRDD 中
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
/** In this case we pass neither an aggregator nor an ordering to the sorter, because we don't */
/** care whether the keys get sorted in each partition; that will be done on the reduce side */
/** if the operation being run is sortByKey. */
new ExternalSorter[K, V, V](
context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
/** 将 records 代表的 rdd 的 partition 信息保存到 sorter 内部的 map/buffer */
/** 若 sorter 内保存的 partition 信息过大,则会 spill 到磁盘 */
/** 这里如果写磁盘了,文件名则是 temp_shuffle_uuid */
sorter.insertAll(records)
/** Don't bother including the time to open the merged output file in the shuffle write time, */
/** because it just opens a single file, so is typically too fast to measure accurately */
/** (see SPARK-3570). */
/** 这里会根据 shuffleId 和 mapId 获取原来的数据文件或新创建数据文件, 从前面 insertAll 的逻辑来看,第一次运行到这里时,是新建数据文件的 */
/** 这里的文件名是 shuffleId_mapId_0_reduceId.data */
/** IndexShuffleBlockResolver 中写明了:disk store 计划存储与 (map, reduce) 相关的 pair, 但在基于排序的、用于多个 reduce 的 shuffle output 会被拼凑成单个文件? */
/** TODO:上面是翻译错了?这么做的原因是什么呢? */
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
/** TODO: 这个 tempFileWith 方法要求的参数是一个 path,那 output 到底是一个路径还是一个文件呢? */
val tmp = Utils.tempFileWith(output)
try {
/** 根据 shuffleId mapId 拿到 blockId, 这一步是关键,前面获取到 output 内部也是一样的方式 */
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
/** 这里是合并写入的地方, 详细内容见后面的分析 */
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
}
ExternalSorter 的 insertAll 方法
insertAll 方法会根据是否需要 combine, 把 records 加入 map/buffer 中。这里我们以 map 为例说明如何插入到 map 以及如何 spill 到磁盘。
这里的 changeValue 会根据情况判断:如果 key 第一次出现,则插入值;否则更新值。而且会自动扩展容量,每次扩展 1倍。
def insertAll(records: Iterator[Product2[K, V]]): Unit = {
/** TODO: stop combining if we find that the reduction factor isn't high */
val shouldCombine = aggregator.isDefined
if (shouldCombine) {
/** Combine values in-memory first using our AppendOnlyMap */
val mergeValue = aggregator.get.mergeValue
val createCombiner = aggregator.get.createCombiner
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (records.hasNext) {
addElementsRead()
kv = records.next()
map.changeValue((getPartition(kv._1), kv._1), update)
maybeSpillCollection(usingMap = true)
}
} else {
/** Stick values into our buffer */
while (records.hasNext) {
addElementsRead()
val kv = records.next()
buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
maybeSpillCollection(usingMap = false)
}
}
}
Spillable 类的 maySpill 方法
在 insertAll 方法中,每次添加元素都会调用 maybeSpillCollection 方法,并根据方法的参数 usingMap 决定当前 map/buffer 的大小, ,并根据当前 map/buffer 的大小,调用 maybeSpill 方法决定是否需要 spill 到磁盘
/** Spill the current in-memory collection to disk if needed. */
/** @param usingMap whether we're using a map or buffer as our current in-memory collection */
private def maybeSpillCollection(usingMap: Boolean): Unit = {
var estimatedSize = 0L
/** 参数 usingMap 决定使用 map 还是使用 buffer */
if (usingMap) {
estimatedSize = map.estimateSize()
/** 根据 estimatedSize 决定是否需要 spill, 使用 buffer 时也一样 */
/** 若需要 spill,则在 spill 之后新创建一个 map/buffer */
if (maybeSpill(map, estimatedSize)) {
map = new PartitionedAppendOnlyMap[K, C]
}
} else {
estimatedSize = buffer.estimateSize()
if (maybeSpill(buffer, estimatedSize)) {
buffer = new PartitionedPairBuffer[K, C]
}
}
if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
}
}
在上面的方法中,我们看到每次都会调用 maybeSpill 方法,在这个方法中决定是否需要 spill, 但注意 Spillable 的 spill 方法是个 抽象函数,其具体实现在 ExternalSorter 中:
/** Spills the current in-memory collection to disk if needed. Attempts to acquire more */
/** memory before spilling. */
/** @param collection collection to spill to disk */
/** @param currentMemory estimated size of the collection in bytes */
/** @return true if `collection` was spilled to disk; false otherwise */
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
var shouldSpill = false
/** 每次函数进来都要检查是否需要 spill,条件是 collection 中元素个数是 32 的倍数,且当前内存大于内存阈值 */
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
/** Claim up to double our current memory from the shuffle memory pool */
/** 计算需要申请多少内存,申请完内存后应该是现在内存的 2 倍,所以要申请的内存是当前内存的 2 倍减去内存阈值 */
val amountToRequest = 2 * currentMemory - myMemoryThreshold
/** 计算能授权申请到的内存值, TODO:关于 Memory 相关的分析后续再分析 */
val granted = acquireMemory(amountToRequest)
/** 更新当前的内存阈值,若申请的内存都能拿到,则更新后内存阈值为当前内存的 2 倍; */
/** 若申请的内存不能完全拿到, 则内存阈值已经是最大可用内存值,但这个值有可能还小于当前内存 */
myMemoryThreshold += granted
/** If we were granted too little memory to grow further (either tryToAcquire returned 0, */
/** or we already had more memory than myMemoryThreshold), spill the current collection */
/** 若增加内存阈值后,内存阈值仍小于当前内存(注意不是当前内存的 2 倍) ,则需要 spill */
shouldSpill = currentMemory >= myMemoryThreshold
}
/** shouldSpill 为 true 的条件是:shouldSpill 为 true 或 当前 collection 中元素数量大于强制进行 spill 的集合元素数量阈值 */
/** 这里其实是从内存和数量两个条件来决定是否需要 spill. */
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
/** Actually spill */
if (shouldSpill) {
_spillCount += 1
logSpillage(currentMemory)
/** 这里进行 spill 操作, 即将内存中的数据落盘, 可知这里是落盘的关键 */
spill(collection)
/** 落盘后集合中元素数量为0,更新已经落盘的数据的大小,并释放内存 */
_elementsRead = 0
_memoryBytesSpilled += currentMemory
releaseMemory()
}
shouldSpill
}
ExternalSorter 的 spill 方法
注意:在 ExternalAppendOnlyMap 类中也有一样的 spill 方法,因为那个类和 ExternalSorter 很像。
spill 方法如下
/** Spill our in-memory collection to a sorted file that we can merge later. */
/** We add this file into `spilledFiles` to find it later. */
/** @param collection whichever collection we're using (map or buffer) */
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
/** 这里获取比较器 comparator, 并返回排序的可写 Partition 的迭代器 */
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
/** 根据迭代器,将内存中的数据写到磁盘 */
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
spills += spillFile
}
下面先分析 spill 方法中排序的过程
ExternalSorter 的 comparator 方法
在 spill 方法的第一行, destructiveSortedWritablePartitionedIterator 方法的参数是比较器 comparator, 这是一个方法,它默认采用 ExternalSorter 的参数 ordering, 但若 ordering 为 None,则比较类型 K 的哈希值
/** A comparator for keys K that orders them within a partition to allow aggregation or sorting. */
/** Can be a partial ordering by hash code if a total ordering is not provided through by the */
/** user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some */
/** non-equal keys also have this, so we need to do a later pass to find truly equal keys). */
/** Note that we ignore this if no aggregator and no ordering are given. */
private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
override def compare(a: K, b: K): Int = {
val h1 = if (a == null) 0 else a.hashCode()
val h2 = if (b == null) 0 else b.hashCode()
if (h1 < h2) -1 else if (h1 == h2) 0 else 1
}
})
private def comparator: Option[Comparator[K]] = {
if (ordering.isDefined || aggregator.isDefined) {
Some(keyComparator)
} else {
None
}
}
WritablePartitionedPairCollection 的 destructiveSortedWritablePartitionedIterator 方法
这个类是 map/buffer 的类型的父类,其部分抽象方法在子类中实现,这里以 map 为例说明.
/** Iterate through the data in order of partition ID and then the given comparator. This may */
/** destroy the underlying collection. */
def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)]
/** Iterate through the data and write out the elements instead of returning them. Records are */
/** returned in order of their partition ID and then the given comparator. */
/** This may destroy the underlying collection. */
def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
: WritablePartitionedIterator = {
/** partitionedDestructiveSortedIterator 是个抽象方法,见上。我们以 PartitionedAppendOnlyMap 的实现来说明 */
val it = partitionedDestructiveSortedIterator(keyComparator)
/** 我们看到,获取迭代器后,后面就是创建一个新的迭代器 WritablePartitionedIterator, 并实现了 writeNext 方法 */
new WritablePartitionedIterator {
/** 这里需要注意,cur 的类型是 Tuple2: ((Int, K), V) */
private[this] var cur = if (it.hasNext) it.next() else null
def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
def hasNext(): Boolean = cur != null
def nextPartition(): Int = cur._1._1
}
}
下面是 PartitionedAppendOnlyMap 的方法 partitionedDestructiveSortedIterator, 这个方法也只是调用了父类的方法 destructiveSortedIterator.
def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
: Iterator[((Int, K), V)] = {
/** 第一步也还是获取比较器 */
val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
/** 调用父类的方法 destructiveSortedIterator */
destructiveSortedIterator(comparator)
}
AppendOnlyMap 的 destructiveSortedIterator 方法
这个方法的关键是两个步骤:
- 整理数组中的数据,前数据整理到数组前面,使数据之间不存在 null, 即数据前面没有影响排序的索引
- 通过 Sorter 实现排序, Sorter 内部通过 TimSort 对象完成对数据中数据的排序
/** Return an iterator of the map in sorted order. This provides a way to sort the map without */
/** using additional memory, at the expense of destroying the validity of the map. */
def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
destroyed = true
/** Pack KV pairs into the front of the underlying array */
/** 原来 map 中的数据,是通过 hash 得到的,可能不均匀地分散在 array 中(因为这个 map 的底层数据结构是 array) */
/** 下面的调整,是把 array 中所有数据整理到 array 的前面,即数据中间不再有 null */
var keyIndex, newIndex = 0
while (keyIndex < capacity) {
/** 若当前 keyIndx 对应的数据不为 null, 则把 keyIndx 放到 newIndex 的位置; */
/** 若为空则跳过,这样保证 keyIndex 一定大于 newIndex, 不会存在前面的数据被覆盖的情况 */
/** 由于用 array 实现的 map,所以数据结构里,偶数位存 key, 奇数位存 value */
if (data(2 * keyIndex) != null) {
data(2 * newIndex) = data(2 * keyIndex)
data(2 * newIndex + 1) = data(2 * keyIndex + 1)
newIndex += 1
}
keyIndex += 1
}
assert(curSize == newIndex + (if (haveNullValue) 1 else 0))
/** KVArraySortDataFormat 确定了需要排序的数据的数据格式,这是针对这里的 AppendOnlyMap 这种用 array 实现 map 的方式专门写的 */
/** Sorter 内部实例化了 TimSort 排序算法的实例,TODO:TimSort 算法后续再分析, 我们知道这里根据 key 对数据进行了排序 */
/** 这里执行完成后,即完成了 data 数据按 keyComparator 的排序结果 */
new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, newIndex, keyComparator)
/** 返回 map 排序结果的迭代器, 至此,排序的分析完成了, 后面是写磁盘的分析 */
new Iterator[(K, V)] {
var i = 0
var nullValueReady = haveNullValue
def hasNext: Boolean = (i < newIndex || nullValueReady)
def next(): (K, V) = {
if (nullValueReady) {
nullValueReady = false
(null.asInstanceOf[K], nullValue)
} else {
val item = (data(2 * i).asInstanceOf[K], data(2 * i + 1).asInstanceOf[V])
i += 1
item
}
}
}
}
下面是对 spill 方法中写磁盘的分析
ExternalSorter 的 spillMemoryIteratorToDisk 方法
这里将内存中的 map/buffer 的信息写到磁盘,是 shuffle 中真正的磁盘 IO 写操作,每次 flush 产生一个 FileSegment, 而返回值 SpilledFile 则记录了 blockId 和物理机上的文件的路径等信息,把逻辑信息和物理信息之间的映射
/** Spill contents of in-memory iterator to a temporary file on disk. */
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
: SpilledFile = {
/** Because these files may be read during shuffle, their compression must be controlled by */
/** spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use */
/** createTempShuffleBlock here; see SPARK-3426 for more context. */
/** 由于在 shuffle 过程中,这些文件可能正在被读取,所以他们的压缩格式必须由 spark.shuffle.compress 控制, */
/** 所以这些文件必须通过方法 createTempShuffleBlock 创建, SPARK-3426 */
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
/** These variables are reset after each flush */
/** 真正需要 reset 的只有 objectsWritten,是 this, 而不是 these */
var objectsWritten: Long = 0
val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
/** 获取到 writer 对象,由于初始化时已经指定了 file 等信息,所以直接像流一样写,然后循环 flush 即可 */
/** 注意这里的 blockId 和 file 就指定了逻辑位置与物理位置的关系 */
val writer: DiskBlockObjectWriter =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
/** List of batch sizes (bytes) in the order they are written to disk */
val batchSizes = new ArrayBuffer[Long]
/** How many elements we have in each partition */
val elementsPerPartition = new Array[Long](numPartitions)
/** Flush the disk writer's contents to disk, and update relevant variables. */
/** The writer is committed at the end of this process. */
def flush(): Unit = {
/** 每次 commitAndGet 都会返回一个 FileSegment, 后面会分析到 */
val segment = writer.commitAndGet()
batchSizes += segment.length
_diskBytesSpilled += segment.length
objectsWritten = 0
}
var success = false
try {
while (inMemoryIterator.hasNext) {
val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
/** 这里把 writer 传进去,由于 inMemoryIterator 是排序后的 map/buffer 的迭代器,每个 Key-Value 对都对应一个 partition */
/** 所以这里可以理解为每次写一个 partition */
inMemoryIterator.writeNext(writer)
/** 记录每个 partition 写了几次, 其实是当前 partition 被分成了多少份 */
elementsPerPartition(partitionId) += 1
objectsWritten += 1
/** serializerBatchSize 为固定值 10000, 说明每写 10000 个 partition 的信息,flush 一交,即生成一个 FileSegment */
if (objectsWritten == serializerBatchSize) {
flush()
}
}
/** whilte 循环后,可能有部分已经写了,但没达到 serializerBatchSize, 这部分也需要 flush */
if (objectsWritten > 0) {
flush()
} else {
/** 但如果 objectsWritten 为 0 的话,需要 revertPartialWritesAndClose, 下面分析这个方法和 writer 的 write 方法 */
writer.revertPartialWritesAndClose()
}
success = true
} finally {
if (success) {
writer.close()
} else {
/** This code path only happens if an exception was thrown above before we set success; */
/** close our stuff and let the exception be thrown further */
/** 如果在 success 置为 true 之前抛出了异常,则需要关闭连接并将文件删除 */
writer.revertPartialWritesAndClose()
if (file.exists()) {
if (!file.delete()) {
logWarning(s"Error deleting ${file}")
}
}
}
}
/** 这个 SpilledFile 记录了逻辑上的 blockId 和物理上的 file 之间的对应关系,还有其它信息,如每个 partition 被分为了几部分, 每个 FileSegment 的大小等 */
SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}
DiskBlockObjectWriter 的相关方法
这个类的成员属性使用了缩写,为了知道这些缩写的意思,这里把成员属性也放到代码里, 这里几乎把整个类的代码都放过来了
/** The file channel, used for repositioning / truncating the file. */
/** 一些重要的成员属性 */
private var channel: FileChannel = null
private var mcs: ManualCloseOutputStream = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialized = false
private var streamOpen = false
private var hasBeenClosed = false
/** Cursors used to represent positions in the file. */
/** 描述文件中当前位置的 cursor */
/** xxxxxxxxxx|----------|-----| */
/** ^ ^ ^ */
/** | | channel.position() */
/** | reportedPosition */
/** committedPosition */
/** */
/** reportedPosition: Position at the time of the last update to the write metrics. */
/** 上次更新后更新到 write metrics 的位置 */
/** committedPosition: Offset after last committed write. */
/** 已经提交 write 的 offset */
/** -----: Current writes to the underlying file. */
/** xxxxx: Committed contents of the file. */
private var committedPosition = file.length()
private var reportedPosition = committedPosition
/** Keep track of number of records written and also use this to periodically */
/** output bytes written since the latter is expensive to do for each record. */
private var numRecordsWritten = 0
private def initialize(): Unit = {
fos = new FileOutputStream(file, true)
channel = fos.getChannel()
ts = new TimeTrackingOutputStream(writeMetrics, fos)
class ManualCloseBufferedOutputStream
extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
mcs = new ManualCloseBufferedOutputStream
}
/** 只能 open 一次,且不能 reopen */
def open(): DiskBlockObjectWriter = {
if (hasBeenClosed) {
throw new IllegalStateException("Writer already closed. Cannot be reopened.")
}
if (!initialized) {
initialize()
initialized = true
}
bs = serializerManager.wrapStream(blockId, mcs)
objOut = serializerInstance.serializeStream(bs)
streamOpen = true
this
}
/** Flush the partial writes and commit them as a single atomic block. */
/** A commit may write additional bytes to frame the atomic block. */
/** 指部分写入的内容提交到一个单独的原子 block */
/** @return file segment with previous offset and length committed on this call. */
/** 每次调用都会返回一个 FileSegment */
def commitAndGet(): FileSegment = {
if (streamOpen) {
/** NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the */
/** serializer stream and the lower level stream. */
objOut.flush()
bs.flush()
objOut.close()
streamOpen = false
/** syncWrites 黑夜为 false */
if (syncWrites) {
/** Force outstanding writes to disk and track how long it takes */
/** 强制将写入的内容落到磁盘,并记录落盘的时间长短 */
val start = System.nanoTime()
fos.getFD.sync()
writeMetrics.incWriteTime(System.nanoTime() - start)
}
val pos = channel.position()
val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition)
committedPosition = pos
/** In certain compression codecs, more bytes are written after streams are closed */
writeMetrics.incBytesWritten(committedPosition - reportedPosition)
reportedPosition = committedPosition
fileSegment
} else {
new FileSegment(file, committedPosition, 0)
}
}
/** Reverts writes that haven't been committed yet. Callers should invoke this function */
/** when there are runtime exceptions. This method will not throw, though it may be */
/** unsuccessful in truncating written data. */
/** revert 那些还没提交的写入内容。 */
/** @return the file that this DiskBlockObjectWriter wrote to. */
def revertPartialWritesAndClose(): File = {
/** Discard current writes. We do this by flushing the outstanding writes and then */
/** truncating the file to its initial position. */
Utils.tryWithSafeFinally {
if (initialized) {
writeMetrics.decBytesWritten(reportedPosition - committedPosition)
writeMetrics.decRecordsWritten(numRecordsWritten)
streamOpen = false
closeResources()
}
} {
var truncateStream: FileOutputStream = null
try {
truncateStream = new FileOutputStream(file, true)
truncateStream.getChannel.truncate(committedPosition)
} catch {
case e: Exception =>
logError("Uncaught exception while reverting partial writes to file " + file, e)
} finally {
if (truncateStream != null) {
truncateStream.close()
truncateStream = null
}
}
}
file
}
/** Writes a key-value pair. */
def write(key: Any, value: Any) {
if (!streamOpen) {
open()
}
objOut.writeKey(key)
objOut.writeValue(value)
recordWritten()
}
override def write(b: Int): Unit = throw new UnsupportedOperationException()
override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
if (!streamOpen) {
open()
}
bs.write(kvBytes, offs, len)
}
/** Notify the writer that a record worth of bytes has been written with OutputStream#write. */
/** 通知 writer: 一条记录已经被写入 OutputStream */
def recordWritten(): Unit = {
/** 每次都会记录一次 numRecordsWritten */
numRecordsWritten += 1
writeMetrics.incRecordsWritten(1)
if (numRecordsWritten % 16384 == 0) {
updateBytesWritten()
}
}
/** Report the number of bytes written in this writer's shuffle write metrics. */
/** Note that this is only valid before the underlying streams are closed. */
private def updateBytesWritten() {
val pos = channel.position()
writeMetrics.incBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
ExternalSorter 的 writePartitionedFile 方法
这里方法是把 ExternalSorter 里已经记录的所有数据写入到一个文件,注意的是,如果这些数据没有排序过,会先排序; 另外如果有部分数据因为内存不够而已经写到磁盘, 会把这些文件先合并
/** Write all the data added into this ExternalSorter into a file in the disk store. This is */
/** called by the SortShuffleWriter. */
/** 将所有加入 ExternalSorter 的数据写入一个文件。 */
/** @param blockId block ID to write to. The index file will be blockId.name + ".index". */
/** @return array of lengths, in bytes, of each partition of the file (used by map output tracker) */
/** 返回值是一个数组,数组的每一项都是这个方法中输出的文件中某个 partition 的长度 (in bytes), 后面被 map output tracker 使用 */
def writePartitionedFile(
blockId: BlockId,
outputFile: File): Array[Long] = {
/** Track location of each range in the output file */
/** 追踪文件中每个 partition 的 length */
val lengths = new Array[Long](numPartitions)
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
context.taskMetrics().shuffleWriteMetrics)
/** 如果原来的 map/buffer 已经有一总分写磁盘了,则 spills 不为空,否则为空 */
if (spills.isEmpty) {
/** Case where we only have in-memory data */
/** 如果内存够大,所有数据都在内存中,没 spill 到磁盘过 */
val collection = if (aggregator.isDefined) map else buffer
/** 获取加入 ExternalSorter 的数据的集合的迭代器, 这个方法我们前面已经分析过 */
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
/** while 里的判断,可能是为了避免多线程时,前后句之间有线程执行了 it.next */
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(writer)
}
/** 这里可以看出,每个 partition 提交一次, 注意这里没有 flush */
val segment = writer.commitAndGet()
lengths(partitionId) = segment.length
}
} else {
/** We must perform merge-sort; get an iterator by partition and write everything directly. */
/** 如果已经有部分数据写磁盘了 */
for ((id, elements) <- this.partitionedIterator) {
if (elements.hasNext) {
for (elem <- elements) {
writer.write(elem._1, elem._2)
}
val segment = writer.commitAndGet()
lengths(id) = segment.length
}
}
}
writer.close()
context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
lengths
}
ExternalSorter 的 partitionedIterator 方法
这个方法最终就是返回当前对象中的所有数据的迭代器,这个迭代器以 partition 为单位
/** Return an iterator over all the data written to this object, grouped by partition and */
/** aggregated by the requested aggregator. For each partition we then have an iterator over its */
/** contents, and these are expected to be accessed in order (you can't "skip ahead" to one */
/** partition without reading the previous one). Guaranteed to return a key-value pair for each */
/** partition, in order of partition ID. */
/** 返回所有写入这个对象的数据的迭代器,按 partition 分组,并根据请求的 aggregator 进行聚合。对每个 partition, 我们都有一个迭代器来遍历它, */
/** 这些 partition 将会被按顺序访问,不能跳跃。保证返回一个 key-value pair,安排好 partition ID 排序 */
/** For now, we just merge all the spilled files in once pass, but this can be modified to */
/** support hierarchical merging. */
/** Exposed for testing. */
def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
/** 这里只在决定当前对象使用的是 map 还是 buffer */
val usingMap = aggregator.isDefined
val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
/** 若 spills 为 empty, 说明之前没有数据被 spill 到磁盘; 否则有数据被 spill 到磁盘 */
if (spills.isEmpty) {
/** Special case: if we have only in-memory data, we don't need to merge streams, and perhaps */
/** we don't even need to sort by anything other than partition ID */
/** 特殊情况:如果我们只有内存中的数据,我们不需要把不同的 streams merge 到一起,甚至我们除了 partition ID 外,不需要按其它任何东西排序 */
/** 如果 ordering 定义过,除了需要按照 partition ID 排序外,还需要按照 key 排序;否则只需要按照 partition ID 排序 */
if (!ordering.isDefined) {
/** The user hasn't requested sorted keys, so only sort by partition ID, not key */
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {
/** We do need to sort by both partition ID and key */
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
/** Merge spilled and in-memory data */
/** 当已经有数据 spill 到磁盘,则把 spill 到磁盘上的数据,和内存中的数据 merge 到一起 */
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}
ExternalSorter 的 groupByPartition 方法
方法 groupByPartition 写的很明确,参数是一个迭代器,里面每个项都是 ((partition, key), combiner) 这样的类型,而且这些数据 已经按照 partition ID 完成排序. 这里要把这些数据里的每个 partition 组合为 (partition, (key, combiner)) 这样的类型
/** Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*, */
/** group together the pairs for each partition into a sub-iterator. */
/** @param data an iterator of elements, assumed to already be sorted by partition ID */
private def groupByPartition(data: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] =
{
val buffered = data.buffered
(0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
}
/** An iterator that reads only the elements for a given partition ID from an underlying buffered */
/** stream, assuming this partition is the next one to be read. Used to make it easier to return */
/** partitioned iterators from our in-memory collection. */
private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)])
extends Iterator[Product2[K, C]]
{
override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val elem = data.next()
(elem._1._2, elem._2)
}
}
ExternalSorter 的 merge 方法
在 partitionedIterator 中,若 spills.isEmpty 为 false, 即已经有部分数据已经写入到磁盘,则需要调用 merge 方法, 将还在内存中的数据和磁盘上的数据进行 merge 操作,最终返回一个按 partition 分组的,对所有写入当前对象的数据 都可以访问的迭代器(注意这些数据可以同时在磁盘和内存中)
/** Merge a sequence of sorted files, giving an iterator over partitions and then over elements */
/** inside each partition. This can be used to either write out a new file or return data to */
/** the user. */
/** */
/** Returns an iterator over all the data written to this object, grouped by partition. For each */
/** partition we then have an iterator over its contents, and these are expected to be accessed */
/** in order (you can't "skip ahead" to one partition without reading the previous one). */
/** Guaranteed to return a key-value pair for each partition, in order of partition ID. */
private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
: Iterator[(Int, Iterator[Product2[K, C]])] = {
/** 磁盘上的文件中的数据的读取访问,通过 SpillReader 完成, 所以这里每个文件生成一个 reader */
val readers = spills.map(new SpillReader(_))
val inMemBuffered = inMemory.buffered
/** 以 partition 来分组,所以按 partition 的个数来 iterator */
(0 until numPartitions).iterator.map { p =>
/** 注意这里 p 是 partition 的 id, 即 partitionId */
/** 内存中的数据的遍历,通过 IteratorForPartition 完成,这里生成内存中数据访问的迭代器 */
val inMemIterator = new IteratorForPartition(p, inMemBuffered)
/** 文件中的数据的访问,通过 reader 的方法 readNextPartition 生成迭代器,和内存中的数据的迭代器一起,形成最终的迭代器 */
val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
/** 如果定义了聚合,则需要 mergeWithAggregation */
/** TODO:这里不会同时定义,对么? 或者说,只需要其中一个就可以了,比如聚合定义了,排序就没有意义了? */
if (aggregator.isDefined) {
/** Perform partial aggregation across partitions */
(p, mergeWithAggregation(
iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
} else if (ordering.isDefined) {
/** No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey); */
/** sort the elements without trying to merge them */
/** 如果定义了排序,则需要 mergeSort */
(p, mergeSort(iterators, ordering.get))
} else {
(p, iterators.iterator.flatten)
}
}
}
Spillable 的 readNextPartition 方法
这里返回当前 Spillable 对应的 SpilledFile 中所有的 partition 的 (key, combiner) 的迭代器
/** 由这个方法看出,这里尝试读取下一个 partition, 而方法的关键,在于 readNextItem, 这里返回的是下一个 partition 的 (k, c) pair */
var nextPartitionToRead = 0
def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
val myPartition = nextPartitionToRead
nextPartitionToRead += 1
override def hasNext: Boolean = {
if (nextItem == null) {
nextItem = readNextItem()
if (nextItem == null) {
return false
}
}
assert(lastPartitionId >= myPartition)
/** Check that we're still in the right partition; note that readNextItem will have returned */
/** null at EOF above so we would've returned false there */
lastPartitionId == myPartition
}
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val item = nextItem
nextItem = null
item
}
}
/** Return the next (K, C) pair from the deserialization stream and update partitionId, */
/** indexInPartition, indexInBatch and such to match its location. */
/** 由于数据 spill 到磁盘上的时候,每个 SpilledFile 文件记录了这个 SpilledFile 文件的大小,及其在文件中的 offset(以 byte 为单位), */
/** 因为写磁盘的时候,每 flush 一次,会将先前的写入提交一次,从而生成一个 FileSegment,这个 FileSegment 记录了这次提交的数据量的大小(以 byte 为单位) */
/** 对应到物理机上,其实多个 SpilledFile 是同一个文件;所以可以根据 offset, 很容易地定义到需要获取的文件流的起始位置与结束位置, */
/** 这是 nextBatchStream 这个方法的底层原理 */
/** SpilledFile 的属性 elementsPerPartition 是同一个 SpilledFile 中,相同的 partition 被访问了几次,注意这里相同的 partition 可能进同一个 SpilledFile */
/** If the current batch is drained, construct a stream for the next batch and read from it. */
/** If no more pairs are left, return null. */
/** 这里从磁盘文件中读取一个 stream, 这个 stream 对应一个 batch, 一个 batch 在文件中对应一个 FileSegment, 由于一个 FileSegment 有多个 partition */
/** 这里在 indexInBatch 等于 serializerBatchSize 时,才读取下一个 batch, 否则一直在当前的 batch stream 中读取下一个 partition */
private def readNextItem(): (K, C) = {
if (finished || deserializeStream == null) {
return null
}
/** 从 stream 中读取一个 partition 的 key 和 combiner */
val k = deserializeStream.readKey().asInstanceOf[K]
val c = deserializeStream.readValue().asInstanceOf[C]
lastPartitionId = partitionId
/** Start reading the next batch if we're done with this one */
indexInBatch += 1
if (indexInBatch == serializerBatchSize) {
indexInBatch = 0
deserializeStream = nextBatchStream()
}
/** Update the partition location of the element we're reading */
indexInPartition += 1
skipToNextPartition()
/** If we've finished reading the last partition, remember that we're done */
if (partitionId == numPartitions) {
finished = true
if (deserializeStream != null) {
deserializeStream.close()
}
}
(k, c)
}
/** Construct a stream that only reads from the next batch */
def nextBatchStream(): DeserializationStream = {
/** Note that batchOffsets.length = numBatches + 1 since we did a scan above; check whether */
/** we're still in a valid batch. */
/** 由于上面调用 scanLeft(0)(_ + _), 所以 batchOffsets 要比 numBatches 大 1, 所以这里检查当前是否是个有效的 batch */
if (batchId < batchOffsets.length - 1) {
if (deserializeStream != null) {
deserializeStream.close()
fileStream.close()
deserializeStream = null
fileStream = null
}
/** 由于 batchOffsets 是由不同的 batch 的 size 这个数组逐渐累加的(scanLeft(0L)(_ + _)),类似于斐波那契数列一样, */
/** 所以根据 batchId 即可前面多个 batch size 相加后的和,即当前 batchId 的起始 offset */
val start = batchOffsets(batchId)
fileStream = new FileInputStream(spill.file)
/** 由于拿到了当前 batchId 的 start,因此能一次性定义到位置 */
fileStream.getChannel.position(start)
batchId += 1
val end = batchOffsets(batchId)
assert(end >= start, "start = " + start + ", end = " + end +
", batchOffsets = " + batchOffsets.mkString("[", ", ", "]"))
val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start))
val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream)
serInstance.deserializeStream(wrappedStream)
} else {
/** No more batches left */
cleanup()
null
}
}
/** Update partitionId if we have reached the end of our current partition, possibly skipping */
/** empty partitions on the way. */
/** 这个方法用在跳过空的 partition(如在 Spillable 初始化时调用过)时,以及用在到当前 partition 的尾部时 */
/** 更新并记录当前对象的 partitionId 和 indexInPartition 信息, 以便后续使用 */
private def skipToNextPartition() {
while (partitionId < numPartitions &&
indexInPartition == spill.elementsPerPartition(partitionId)) {
partitionId += 1
indexInPartition = 0L
}
}
ExternalSorter 的 mergeWithAggregation 方法
这个方法主要是对结果进行聚合,即根据参数 mergeCombiners 对相同 key 的 partition 执行 combiner 操作。 由于结果可能已经按 key 排序过,所以要区分是否已经 totalOrder.
由方法可知,这里返回的结果,都是通过 mergeSort 进行排序后的结果, 所以 mergeSort 方法决定了 next 的顺序
/** Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each */
/** iterator is sorted by key with a given comparator. If the comparator is not a total ordering */
/** (e.g. when we sort objects by hash code and different keys may compare as equal although */
/** they're not), we still merge them by doing equality tests for all keys that compare as equal. */
/** 根据要聚合的值,为每个 key (对应一个 partition),将一个序列的 iterator 聚合到一起. 假定每个 */
/** iterator 按给定的 comparator 对 key 进行排序 */
private def mergeWithAggregation(
iterators: Seq[Iterator[Product2[K, C]]],
mergeCombiners: (C, C) => C,
comparator: Comparator[K],
totalOrder: Boolean)
: Iterator[Product2[K, C]] =
{
/** totalOrder: orderging 是否定义过 */
if (!totalOrder) {
/** We only have a partial ordering, e.g. comparing the keys by hash code, which means that */
/** multiple distinct keys might be treated as equal by the ordering. To deal with this, we */
/** need to read all keys considered equal by the ordering at once and compare them. */
new Iterator[Iterator[Product2[K, C]]] {
/** 初始化时要根据 comparator 对 iterator 中的 key 进行排序 */
val sorted = mergeSort(iterators, comparator).buffered
/** Buffers reused across elements to decrease memory allocation */
/** 这里使用 ArrayBuffer 是为了减少内存分配 */
/** 其中 keys 用来存储 iterator 中的 key, combiners 用来存储其中的 combiner */
val keys = new ArrayBuffer[K]
val combiners = new ArrayBuffer[C]
override def hasNext: Boolean = sorted.hasNext
override def next(): Iterator[Product2[K, C]] = {
if (!hasNext) {
throw new NoSuchElementException
}
/** 为后面的 merge 做准备 */
keys.clear()
combiners.clear()
/** 获取第一个 pair */
val firstPair = sorted.next()
keys += firstPair._1
combiners += firstPair._2
val key = firstPair._1
/** 这里遍历 iterator,获取所有的 (K, C) */
while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
/** 如果有下一个 pair,则获取下一个 pair */
val pair = sorted.next()
var i = 0
var foundKey = false
/** 对下一个 pair, 从 i=0 开始遍历 keys, 如果能找到和 pair 相同的 key, 则 merge, 否则继续遍历 */
while (i < keys.size && !foundKey) {
/** 如果 pair 的 key 与 keys(i) 相同,则进行 merge, 则设置 foundKey 为true, 即不再循环; 否则继续遍历 */
if (keys(i) == pair._1) {
/** 这里的 mergeCombiners 是参数,也是一个方法,这里是调用方法完成对相同 key 的 merge */
combiners(i) = mergeCombiners(combiners(i), pair._2)
foundKey = true
}
i += 1
}
/** 如果遍历 keys 都没有找到相同的 key, 则添加到 keys 和 combiners 中去 */
if (!foundKey) {
keys += pair._1
combiners += pair._2
}
}
/** Note that we return an iterator of elements since we could've had many keys marked */
/** equal by the partial order; we flatten this below to get a flat iterator of (K, C). */
keys.iterator.zip(combiners.iterator)
}
}.flatMap(i => i)
} else {
/** We have a total ordering, so the objects with the same key are sequential. */
/** 如果 totalOrder 为 True, 即已经排过序,相同的 key 已经是一个序列的了,则直接根据 comparator 对 partition 进行排序即可 */
new Iterator[Product2[K, C]] {
val sorted = mergeSort(iterators, comparator).buffered
override def hasNext: Boolean = sorted.hasNext
/** 这个方法的逻辑与 totalOrder 为 false 时很类似,但要简单一些,在此跳过 */
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val elem = sorted.next()
val k = elem._1
var c = elem._2
while (sorted.hasNext && sorted.head._1 == k) {
val pair = sorted.next()
c = mergeCombiners(c, pair._2)
}
(k, c)
}
}
}
}
ExternalSorter 的 mergeSort 方法
这个方法主要是根据 comparator 按 key 对 iterator 进行归并排序
/** Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys. */
private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
: Iterator[Product2[K, C]] =
{
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
type Iter = BufferedIterator[Product2[K, C]]
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
/** Use the reverse of comparator.compare because PriorityQueue dequeues the max */
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
})
heap.enqueue(bufferedIters: _*) /** Will contain only the iterators with hasNext = true */
new Iterator[Product2[K, C]] {
override def hasNext: Boolean = !heap.isEmpty
override def next(): Product2[K, C] = {
if (!hasNext) {
throw new NoSuchElementException
}
val firstBuf = heap.dequeue()
val firstPair = firstBuf.next()
if (firstBuf.hasNext) {
heap.enqueue(firstBuf)
}
firstPair
}
}
}
Spark shuffle 的读取
我们在 RDD 的 iterator 方法中,已经介绍了,对于 ShuffledRDD 的 iterator 方法,是在 ResultTask 的 runTask 中触发的,该方法 这里不再介绍,但 ShuffledRDD 的 iterator 方法,会获取 ShuffleReader 的一个实例,并调用其 read 方法来读取已经 combine 过的 key-value 数据。compute 方法在 RDD 的 iterator 方法中已经有介绍,这里继续分析 ShuffleReader 的 read 方法。
BlockStoreShuffleReader 的 read 方法
当前的 spark 版本中,ShuffleReader 只有一个版本的实现: BlockStoreShuffleReader. 同时,由于数据的 shuffle 和 combine 在
shuffle 写入时已经完成,所以 shuffleReader 看起来并没有多少优化的空间,只需要将 combine 过后的数据拉到 reduce 执行的节点
进行最后的结果计算,即 ResultTask 的 runTask 最后一行:func(context, rdd.iterator(partition, context))
, 在这里,rdd.iterator
会调用 rdd 的 compute 方法,由于当前的 rdd 是 ShuffledRDD(LogQuery 的 reduceByKey), 所以在其 compute 方法中会实例化这个
BlockStoreShuffleReader 来获得 shuffleReader
/** Read the combined key-values for this reduce task */
/** 为当前的 reduce task 读取已经在 map task 中 combine 过的 key-value 值 */
override def read(): Iterator[Product2[K, C]] = {
/** 由于 map task 中的结果存储在 block 中,这里返回拉取 block 的迭代器, 以读取 map 端的结果 */
/** 关于这个类,后面详细分析 */
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
/** 我们知道 shuffle 写入完成后,返回的是 MapStatus, 而 mapOutputTracker 就是用来追踪 mapStatus 的 */
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
/** Note: we use getSizeAsMb when no suffix is provided for backwards compatibility */
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
/** Wrap the streams for compression and encryption based on configuration */
/** 如果需要压缩或数据加密需求,则在这里将输入流添加一层 wrap, 思想类似于装饰者模式 */
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance()
/** Create a key/value iterator for each stream */
/** 将输入流 wrappedStreams 逆序列化,map 成 key/value 形式的迭代器 */
val recordIter = wrappedStreams.flatMap { wrappedStream =>
/** Note: the asKeyValueIterator below wraps a key/value iterator inside of a */
/** NextIterator. The NextIterator makes sure that close() is called on the */
/** underlying InputStream when all records have been read. */
/** 由于前面的 wrappedStream 是流,所以读取完成需要关闭,这里 asKeyValueIterator 会在读取完成后关闭流 */
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}
/** Update the context task metrics for each record read. */
/** 这里定义了 metricIter, 其实是要在 recordIter 读取完成后自动调用 mergeShuffleReadMetrics 方法, 理解成是对 recordIter 的一种包装 */
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics())
/** An interruptible iterator must be used here in order to support task cancellation */
/** 又对 metricIter 加了一层包装,支持了 interrupt */
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
/** 如果定义了聚合函数,则根据需要进行聚合;否则直接 asInstanceOf 即可 */
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
/** 根据 map 端已经 combine, 将 interruptibleIter 转化为不同的类型,然后进行 reduce 端的 combine */
if (dep.mapSideCombine) {
/** We are reading values that are already combined */
/** 读取已经 combine 过的结果, 然后再对 combiners 进行 combine, 注意这里是对 combiner 进行 combine, 不是对 values */
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
/** We don't know the value type, but also don't care -- the dependency *should* */
/** have made sure its compatible w/ this aggregator, which will convert the value */
/** type to the combined type C */
/** 如果 map 端没有 combine 过,则需要对 values 进行 combine */
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}
/** Sort the output if there is a sort ordering defined. */
/** 如果需要对结果进行排序,则使用 ExternalSorter 进行排序,由于前面经过了 combine, map 端之前可能的排序已被打乱了 */
dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
/** Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, */
/** the ExternalSorter won't spill to disk. */
/** 这里使用 ExternalSorter 对数据进行排序,前面对 ExternalSorter 有过较为详细地分析,这里的排序可能会 spill 到磁盘 */
val sorter =
new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
/** spark 会对这个外部排序的过程记录使用的内在/磁盘大小,所以只要能获取到 metrics, 就知道这个过程占用多大的空间 */
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
}
ShuffleBlockFetcherIterator 的 next 方法
ShuffleBlockFetcherIterator 类本身是一个迭代器,用来一次拉取一些 blocks. 它不只能一些拉取多个 blocks, 还会限制拉取 blocks 的最大值,从而保证拉取的 block 不会占用大量内存,即起到加速的效果,又有限制作用。这个类重要的方法主要有初始化方法 initialize 和读取下一批 block 的 next 方法
private[this] def initialize(): Unit = {
/** Add a task completion callback (called in both success case and failure case) to cleanup. */
/** 添加任务的监听事件,确保释放所有的 buffer, 不论是否成功地获取到了结果,都会释放 */
context.addTaskCompletionListener(_ => cleanup())
/** Split local and remote blocks. */
/** 这里会区别要拉取的 block 是本地还是远程,本地的通过本地的 blockManager 去拉; */
/** 远程的 block, 根据 block 的 size,确保每个请求要拉取的 blocks 的 size 总和超过 targetRequestSize */
/** 这里根据 size 大小相加,当 size 大小超过 targetRequestSize 时,封装成一个请求 */
/** 这就保证了一个请求要拉取的数据量不会太大,也不会太小,超到一个限制最大最小的作用 */
val remoteRequests = splitLocalRemoteBlocks()
/** Add the remote requests into our queue in a random order */
fetchRequests ++= Utils.randomize(remoteRequests)
assert ((0 == reqsInFlight) == (0 == bytesInFlight),
"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
/** Send out initial requests for blocks, up to our maxBytesInFlight */
/** 这个方法是要保证一次发送的请求,不超过 maxBytesInFlight(是 targetRequestSize 的 5 倍) */
/** 即保证每次发送的拉取数据的请求,拉回来的数据占用的内存不会太大 */
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size
logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
/** Get Local Blocks */
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
/** Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers */
/** underlying each InputStream will be freed by the cleanup() method registered with the */
/** TaskCompletionListener. However, callers should close() these InputStreams */
/** as soon as they are no longer needed, in order to release memory as early as possible. */
/** */
/** Throws a FetchFailedException if the next block could not be fetched. */
/** 这个方法比较简单,每次返回一个结果,并再调用 fetchUpToMaxBytes 以发送足够的请求 */
override def next(): (BlockId, InputStream) = {
if (!hasNext) {
throw new NoSuchElementException
}
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
shuffleMetrics.incRemoteBlocksFetched(1)
}
bytesInFlight -= size
if (isNetworkReqDone) {
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
}
case _ =>
}
/** Send fetch requests up to maxBytesInFlight */ fetchUpToMaxBytes()
result match {
case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case SuccessFetchResult(blockId, address, _, buf, _) =>
try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch {
case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}
}
}
shuffle 过程的读取内容比较简单,主要是 reduce 端的 combine 和 block 的拉取过程的逻辑。所以也写的比较简单。
总结
至此,shuffle 的写入和读取的过程基本分析完了。由于用时比较长,且难度比较大,所以存在不少错误之处,后续理解更深入之后再慢慢改正.