Spark 任务分发与执行流程
上一篇 blog 写了 spark 上提交任务的流程,但这些任务如何分发到集群中去执行呢?本文尝试分析一下任务的分发与执行的流程 由于流程比较复杂,我们在这里尝试先画个图来表示
初始化 TaskScheduler 和 SchedulerBackend
提交 Task 流程
后面要分析的流程为:
- task 执行后结果返回给 driver 端的流程是什么?
- 启动 executor 的流程是什么?
spark 上的 Job 执行是通过 DAGScheduler 的 submitMissingTasks 来提交的,由处理成 task 后,由 TaskScheduler 的 submitTasks 方法提交这些 tasks. 因此,在 submitTasks 之前,要先初始化 TaskScheduler。TaskSchedulerImpl 继承了 TaskScheduler, 其初始化是在 SparkContext 类中,由方法 createTaskScheduler 完成的, 这个方法会解析 spark.master 配置项(如 local/sprak:// /local-cluster/yarn 等),根据这个参数生成 TaskScheduler 的实例和 SchedulerBackend 的实例。
本文比较关心 task 是如何发送到集群的,所以关注集群的方式, 这里以 standalone 模式为例, 此时 spark.master.url 是 spark://ip1:port1,ip2:port2,ip3:port3, 通过 TaskSchedulerImpl 实例化 TaskScheduler,并通过 StandaloneSchedulerBackend 实例化 SchedulerBackend 如果 master 是 yarn,则通过 YarnClusterManager 来创建 TaskScheduler 和 SchedulerBackend 实例.
至此,我们找到了 TaskScheduler 和 SchedulerBackend 的实例化类.
/** Create a task scheduler based on a given master URL. */
/** Return a 2-tuple of the scheduler backend and the task scheduler. */
private def createTaskScheduler(
sc: SparkContext,
master: String,
deployMode: String): (SchedulerBackend, TaskScheduler) = {
import SparkMasterRegex._
/** When running locally, don't try to re-execute tasks on failure. */
val MAX_LOCAL_TASK_FAILURES = 1
master match {
/** local */
case "local" =>
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, 1)
scheduler.initialize(backend)
(backend, scheduler)
/** local[*] */
case LOCAL_N_REGEX(threads) =>
def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
/** local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. */
val threadCount = if (threads == "*") localCpuCount else threads.toInt
if (threadCount <= 0) {
throw new SparkException(s"Asked to run locally with $threadCount threads")
}
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
scheduler.initialize(backend)
(backend, scheduler)
/** local[*, M] */
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
def localCpuCount: Int = Runtime.getRuntime.availableProcessors()
/** local[*, M] means the number of cores on the computer with M failures */
/** local[N, M] means exactly N threads with M failures */
val threadCount = if (threads == "*") localCpuCount else threads.toInt
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalSchedulerBackend(sc.getConf, scheduler, threadCount)
scheduler.initialize(backend)
(backend, scheduler)
/** spark://... standalone 集群时的 masterUrl */
case SPARK_REGEX(sparkUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
/** 是 CoarseGrainedSchedulerBackend 的子类,同时实现了 StandaloneAppClientListener 接口 */
val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
(backend, scheduler)
/** local-cluster */
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
/** Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. */
val memoryPerSlaveInt = memoryPerSlave.toInt
if (sc.executorMemory > memoryPerSlaveInt) {
throw new SparkException(
"Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
memoryPerSlaveInt, sc.executorMemory))
}
val scheduler = new TaskSchedulerImpl(sc)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf)
val masterUrls = localCluster.start()
val backend = new StandaloneSchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: StandaloneSchedulerBackend) => {
localCluster.stop()
}
(backend, scheduler)
/** other, 如 yarn, yarn-client, yarn-cluster 等 */
case masterUrl =>
val cm = getClusterManager(masterUrl) match {
case Some(clusterMgr) => clusterMgr
case None => throw new SparkException("Could not parse Master URL: '" + master + "'")
}
try {
val scheduler = cm.createTaskScheduler(sc, masterUrl)
val backend = cm.createSchedulerBackend(sc, masterUrl, scheduler)
cm.initialize(scheduler, backend)
(backend, scheduler)
} catch {
case se: SparkException => throw se
case NonFatal(e) =>
throw new SparkException("External scheduler cannot be instantiated", e)
}
}
}
在实例化之后,SparkContext 会启动 TaskScheduler, _taskScheduler.start()
, start 动作里,会启动 SchedulerBackend,
这里是调用类 StandaloneSchedulerBackend 的 start 函数,并在函数里调用父类 CoarseGrainedSchedulerBackend 的 start 方法,
在 start 过程中,StandaloneSchedulerBackend 实例化了 StandaloneAppClient 并启动,
而 CoarseGrainedSchedulerBackend 主要实例化了 driverEndpoint 成员.
对 StandaloneAppClient 的启动在以后的博客中分析, 这里涉及在计算节点创建 executor. 总之,本地的 driver 信息的初始化在这里完成。
/** StandaloneSchedulerBackend 类的 start */
override def start() {
super.start()
/** SPARK-21159. The scheduler backend should only try to connect to the launcher when in client */
/** mode. In cluster mode, the code that submits the application to the Master needs to connect */
/** to the launcher instead. */
if (sc.deployMode == "client") {
launcherBackend.connect()
}
/** The endpoint for executors to talk to us */
val driverUrl = RpcEndpointAddress(
sc.conf.get("spark.driver.host"),
sc.conf.get("spark.driver.port").toInt,
CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString
val args = Seq(
"--driver-url", driverUrl,
"--executor-id", "",
"--hostname", "",
"--cores", "",
"--app-id", "",
"--worker-url", "")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
.map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
.map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
/** When testing, expose the parent class path to the child. This is processed by */
/** compute-classpath.{cmd,sh} and makes all needed jars available to child processes */
/** when the assembly is built with the "*-provided" profiles enabled. */
val testingClassPath =
if (sys.props.contains("spark.testing")) {
sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
} else {
Nil
}
/** Start executors with a few necessary configs for registering with the scheduler */
val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
val javaOpts = sparkJavaOpts ++ extraJavaOpts
/** CoarseGrainedExecutorBackend 这个类用于创建 executor,而在它的 onStart 方法中,会向 */
/** scheduler 的 receiveAndReply 方法发送事件 RegisterExecutor ,进而调用方法 makeOffers */
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt)
/** If we're using dynamic allocation, set our initial executor limit to 0 for now. */
/** ExecutorAllocationManager will send the real initial limit to the Master later. */
val initialExecutorLimit =
if (Utils.isDynamicAllocationEnabled(conf)) {
Some(0)
} else {
None
}
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit)
client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf)
client.start()
launcherBackend.setState(SparkAppHandle.State.SUBMITTED)
waitForRegistration()
launcherBackend.setState(SparkAppHandle.State.RUNNING)
}
/** 类 CoarseGrainedSchedulerBackend 的 start */
override def start() {
val properties = new ArrayBuffer[(String, String)]
for ((key, value) <- scheduler.sc.conf.getAll) {
if (key.startsWith("spark.")) {
properties += ((key, value))
}
}
/** TODO (prashant) send conf instead of properties */
driverEndpoint = createDriverEndpointRef(properties)
}
protected def createDriverEndpointRef(
properties: ArrayBuffer[(String, String)]): RpcEndpointRef = {
rpcEnv.setupEndpoint(ENDPOINT_NAME, createDriverEndpoint(properties))
}
protected def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
new DriverEndpoint(rpcEnv, properties)
}
/** Make fake resource offers on all executors */
/** 初步确认分配资源的数量,并没有真正分配资源? */
/** 根据本地的 executorDataMap 保存的 executor 的信息,尝试去 launchTasks */
private def makeOffers() {
// Filter out executors under killing
val activeExecutors = executorDataMap.filterKeys(executorIsAlive)
val workOffers = activeExecutors.map { case (id, executorData) =>
new WorkerOffer(id, executorData.executorHost, executorData.freeCores)
}.toIndexedSeq
/** 这里调用了 scheduler 的 resourceOffers 方法,后面会提到 */
launchTasks(scheduler.resourceOffers(workOffers))
}
准备工作
初始化 TaskSchedulerImpl 和 SchedulerBackend
在初始化类 TaskSchedulerImpl 的时候,同时生成了 SchedulerBackend,然后调用 TaskSchedulerImpl 的方法 initialize, 来准备 如下的 TaskSchedulerImpl 的成员变量, 但这些变量不是全部成员变量
/** Listener object to pass upcalls into */
var dagScheduler: DAGScheduler = null
/** SchedulerBackend 接口是 driver 端用来定位 executor 和向 executor 发执行 Task 请求的接口 */
var backend: SchedulerBackend = null
/** TODO 好像用来记录任务执行信息 */
val mapOutputTracker = SparkEnv.get.mapOutputTracker
/** 任务执行顺序,是 FIFO (默认 FIFO) 还是 FAIR */
var schedulableBuilder: SchedulableBuilder = null
/** 记录 task 的缓存池 */
var rootPool: Pool = null
def initialize(backend: SchedulerBackend) {
this.backend = backend
/** temporarily set rootPool name to empty */
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
schedulingMode match {
case SchedulingMode.FIFO =>
/** FIFO 比较简单,只是简单地把实例化,后面的 buildPools 里也没有其它动作 */
new FIFOSchedulableBuilder(rootPool)
case SchedulingMode.FAIR =>
/** Fair 比较复杂,后面 buildPools 需要读取配置文件 fairscheduler.xml 并解析, */
/** 把里面的每个 Pool 配置都添加到 rootPool 中,还会添加一个 defaultPool */
new FairSchedulableBuilder(rootPool, conf)
case _ =>
throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode")
}
}
schedulableBuilder.buildPools()
}
消息传递机制
driver 端与计算节点的消息传递机制,在本文最后介绍
TaskSchedulerImpl 的方法 submitTasks
根据博文中的 submitMissingTasks 方法,DAGScheduler 准备好 tasks 后,将任务提交给 TaskSchedulerImpl 的 submitTasks 方法。这个方法等待 hasLaunchedTask 为 true 后(设置 hasLaunchedTask 为true 的方法为 resourceOffers), 执行 backend.reviveOffers 方法
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
/** 如果如果任务不是本地运行,且 hasReceivedTask 为 true, 则循环等待 hasLaunchedTask 为 true */
/** hasLaunchedTask 在方法 resourceOffers 中设置 */
if (!isLocal && !hasReceivedTask) {
/** Timer 类的 scheduleAtFixedRate 方法会循环执行,以等待各节点 hasLaunchedTask,然后才结束等待 */
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
"and have sufficient resources")
} else {
/** 若各节点 hasLaunchedTask,则 cancel 等待*/
this.cancel()
}
}
}, STARVATION_TIMEOUT_MS, STARVATION_TIMEOUT_MS)
}
hasReceivedTask = true
}
backend.reviveOffers()
}
TaskSchedulerImpl 的方法 resourceOffers
在准备工作前面的代码里,提到 CoarseGrainedExecutorBackend 在 onStart 里会调用到方法 makeOffers,里面会调用到本方法 resourceOffers, 用于根据生成的 WorkerOffer 信息生成 task 可用的资源,并在这里会设置 hasLaunchedTask 为 true
/** Called by cluster manager to offer resources on slaves. We respond by asking our active task */
/** sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so */
/** that tasks are balanced across the cluster. */
/** 由 cluster manager 调用来为 slave 提供资源。我们按照请求我们的 active task sets 的请求的优先顺序来响应。 */
/** 我们以循环的方式填充每个节点,从而使任务在集群中保持平衡 */
def resourceOffers(offers: IndexedSeq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
/** Mark each slave as alive and remember its hostname */
/** Also track if new executor is added */
var newExecAvail = false
/** 遍历 executor 所在的 worker 信息,添加 executor 信息到本地的 hashmap 中 */
for (o <- offers) {
if (!hostToExecutors.contains(o.host)) {
hostToExecutors(o.host) = new HashSet[String]()
}
if (!executorIdToRunningTaskIds.contains(o.executorId)) {
hostToExecutors(o.host) += o.executorId
executorAdded(o.executorId, o.host)
executorIdToHost(o.executorId) = o.host
executorIdToRunningTaskIds(o.executorId) = HashSet[Long]()
newExecAvail = true
}
for (rack <- getRackForHost(o.host)) {
hostsByRack.getOrElseUpdate(rack, new HashSet[String]()) += o.host
}
}
/** Randomly shuffle offers to avoid always placing tasks on the same set of workers. */
/** 把 worker 信息 shuffle,避免每次都把 task 放在相同的几个 worker 节点上 */
val shuffledOffers = Random.shuffle(offers)
/** Build a list of tasks to assign to each worker. */
/** 注意 shuffledOffers 只是将参数 offers shuffle,并没有其它变化,然后 map 成为只包括计算节点的核数信息 */
val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = shuffledOffers.map(o => o.cores).toArray
val sortedTaskSets = rootPool.getSortedTaskSetQueue
for (taskSet <- sortedTaskSets) {
logDebug("parentName: %s, name: %s, runningTasks: %s".format(
taskSet.parent.name, taskSet.name, taskSet.runningTasks))
if (newExecAvail) {
/** 计算 taskSet 的 locality level 和本地化索引,貌似可以用来计算本地 task 的数量等 */
taskSet.executorAdded()
}
}
/** Take each TaskSet in our scheduling order, and then offer it each node in increasing order */
/** of locality levels so that it gets a chance to launch local tasks on all of them.*/
/** 上面计算完成了 taskset 的 locality level,这里按 level 从小到大的顺序来 launch task */
/** NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY*/
for (taskSet <- sortedTaskSets) {
var launchedAnyTask = false
var launchedTaskAtCurrentMaxLocality = false
for (currentMaxLocality <- taskSet.myLocalityLevels) {
do {
/** 调用 resourceOfferSingleTaskSet 来 launch taskset, 成功为 true, 否则为 false */
launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(
taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks)
launchedAnyTask |= launchedTaskAtCurrentMaxLocality
} while (launchedTaskAtCurrentMaxLocality)
}
/** 遍历完 launchedAnyTask 仍为 false,则将节点加入黑名单 */
if (!launchedAnyTask) {
taskSet.abortIfCompletelyBlacklisted(hostToExecutors)
}
}
/** 这里将 hasLaunchedTask 设为 true,以便后面可以不再循环等待 */
if (tasks.size > 0) {
hasLaunchedTask = true
}
return tasks
}
/** 判断节点是否满足 launch taskset 的条件 */
private def resourceOfferSingleTaskSet(
taskSet: TaskSetManager,
maxLocality: TaskLocality,
shuffledOffers: Seq[WorkerOffer],
availableCpus: Array[Int],
tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = {
var launchedTask = false
for (i <- 0 until shuffledOffers.size) {
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
if (availableCpus(i) >= CPUS_PER_TASK) {
try {
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
} catch {
case e: TaskNotSerializableException =>
logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
/** Do not offer resources for this task, but don't throw an error to allow other */
/** task sets to be submitted. */
return launchedTask
}
}
}
return launchedTask
}
CoarseGrainedSchedulerBackend 的方法 launchTasks
从 TaskSchedulerImpl 的方法 resourceOffers 返回 tasks 后,会由 CoarseGrainedSchedulerBackend 的 launchTasks 调用 这个返回值,更新本地存储的各个 executor 的资源信息(主要是核数),然后发送给 executor 所在节点去 LaunchTask. 注意,这个 launchTasks 方法是 driverEndpoint 的方法,是由本地的 driverEndpoint 成员函数去调用的。
/** Launch tasks returned by a set of resource offers */
private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= maxRpcMessageSize) {
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
"spark.rpc.message.maxSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
}
else {
/** 本地的 executorDataMap 存储了各个 executor 及其可用资源的信息, 这里根据 tasks 需要消耗的资源,更新本地信息 */
val executorData = executorDataMap(task.executorId)
executorData.freeCores -= scheduler.CPUS_PER_TASK
logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
s"${executorData.executorHost}.")
/** 发送消息到远程的 executor 节点去 LaunchTask, 到了真正需要 Launch 的阶段 */
executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
}
CoarseGrainedExecutorBackend 的方法 LaunchTask
override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
try {
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
} catch {
case NonFatal(e) =>
exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
}
case RegisterExecutorFailed(message) =>
exitExecutor(1, "Slave registration failed: " + message)
case LaunchTask(data) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
/** 这里调用 executor 的 launchTask 方法来 launchTask, 这里已经是本地的 */
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread)
}
case StopExecutor =>
stopping.set(true)
logInfo("Driver commanded a shutdown")
/** Cannot shutdown here because an ack may need to be sent back to the caller. So send */
/** a message to self to actually do the shutdown. */
self.send(Shutdown)
case Shutdown =>
stopping.set(true)
new Thread("CoarseGrainedExecutorBackend-stop-executor") {
override def run(): Unit = {
/** executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.*/
/** However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to*/
/** stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).*/
/** Therefore, we put this line in a new thread.*/
executor.stop()
}
}.start()
}
Executor 类的方法 launchTask
上面一节,我们提到 CoarseGrainedExecutorBackend 的 receive 方法里,会触发 LaunchTask 模式,调用 executor 的 launchTask 来 启动 task
def launchTask(
context: ExecutorBackend,
taskId: Long,
attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer): Unit = {
/** 显然,这里是实例化了一个 TaskRunner 类的对象,由 executor 的 threadPool 执行,从而启动 task */
/** 所以核心任务,是在 TaskRunner 类中定义 */
val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
TaskRunner 类
这里把整个 TaskRunner 类都贴过来,显得有些多,不过本类最重要的部分 run 方法也占了代码的绝大部分
class TaskRunner(
execBackend: ExecutorBackend,
val taskId: Long,
val attemptNumber: Int,
taskName: String,
serializedTask: ByteBuffer)
extends Runnable {
val threadName = s"Executor task launch worker for task $taskId"
/** Whether this task has been killed. */
@volatile private var killed = false
@volatile private var threadId: Long = -1
def getThreadId: Long = threadId
/** Whether this task has been finished. */
@GuardedBy("TaskRunner.this")
private var finished = false
def isFinished: Boolean = synchronized { finished }
/** How much the JVM process has spent in GC when the task starts to run. */
@volatile var startGCTime: Long = _
/** The task to run. This will be set in run() by deserializing the task binary coming */
/** from the driver. Once it is set, it will never be changed. */
@volatile var task: Task[Any] = _
def kill(interruptThread: Boolean): Unit = {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
killed = true
if (task != null) {
synchronized {
if (!finished) {
task.kill(interruptThread)
}
}
}
}
/** Set the finished flag to true and clear the current thread's interrupt status */
private def setTaskFinishedAndClearInterruptStatus(): Unit = synchronized {
this.finished = true
/** SPARK-14234 - Reset the interrupted status of the thread to avoid the */
/** ClosedByInterruptException during execBackend.statusUpdate which causes */
/** Executor to crash */
Thread.interrupted()
/** Notify any waiting TaskReapers. Generally there will only be one reaper per task but there */
/** is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) */
/** is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: */
notifyAll()
}
override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
/** 开始启动 task, Executor 所在节点更新 task 状态,这个过程中会通知 driver 更新 task 状态,后面分析这个过程 */
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
/** 统计当前 task 的 gc 时间 */
startGCTime = computeTotalGcTime()
try {
/** 从 serializedTask 逆序列化 task 的信息,获取 task 需要的文件,jar包,属性,以及 task 本身的二进制数据 */
val (taskFiles, taskJars, taskProps, taskBytes) =
Task.deserializeWithDependencies(serializedTask)
/** Must be set before updateDependencies() is called, in case fetching dependencies */
/** requires access to properties contained within (e.g. for access control). */
Executor.taskDeserializationProps.set(taskProps)
/** 为防止某些文件已经被更新,这里根据文件的名称和时间戳判断本节点上的文件是否是最新文件,如果不是则重要获取 */
/** jar 包也是同样的,只是获取最新的 jar 包后,需要重要 load 到内存 */
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task.localProperties = taskProps
task.setTaskMemoryManager(taskMemoryManager)
/** If this task has been killed before we deserialized it, let's quit now. Otherwise, */
/** continue executing the task. */
/** 如果 task 在被逆序列化之前,已经被 kill 掉了,则现在就停止 */
if (killed) {
/** Throw an exception rather than returning, because returning within a try{} block */
/** causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl */
/** exception will be caught by the catch block, leading to an incorrect ExceptionFailure */
/** for the task. */
throw new TaskKilledException
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
/** Run the actual task and measure its runtime. */
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
/** task 执行结果 */
val value = try {
/** 这里执行 task, 后面分析 task 的 run 方法 */
val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
/** 释放 task 占有的锁和内存 */
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
/** 根据释放的内存大小,确定是否有内存泄露,若有则提示 */
if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
logWarning(errMsg)
}
}
/** 确定 task 是否成功释放占有的锁 */
if (releasedLocks.nonEmpty && !threwException) {
val errMsg =
s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
releasedLocks.mkString("[", ", ", "]")
if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
throw new SparkException(errMsg)
} else {
logWarning(errMsg)
}
}
}
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
/** If the task has been killed, let's fail it. */
if (task.killed) {
throw new TaskKilledException
}
/** 得到 task run 的结果 value 后,将其序列化 */
val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
/** Deserialization happens in two parts: first, we deserialize a Task object, which */
/** includes the Partition. Second, Task.run() deserializes the RDD and function to be run. */
/** 统计 task 的 metrics,主要是运行时间,序列化时间, gc 时间等 */
task.metrics.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
task.metrics.setExecutorDeserializeCpuTime(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
/** We need to subtract Task.run()'s deserialization time to avoid double-counting */
task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization)
/** Note: accumulator updates must be collected after TaskMetrics is updated */
val accumUpdates = task.collectAccumulatorUpdates()
/** TODO: do not serialize value twice */
/** DirectTaskResult 是执行完成后,需要发回到 driver 的结果 */
val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
/** directSend = sending directly back to the driver */
val serializedResult: ByteBuffer = {
if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
/** 如果 DirectTaskResult 太大,大于了允许的最大的结果大小,则警告,同时尝试使用 IndirectTaskResult 向 driver 发送结果 */
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize > maxDirectResultSize) {
/** 如果 DirectTaskResult 比较大,大于了允许的最大的直接结果大小,则尝试通过 Block 返回,这里需要使用 BlockManager 管理 block*/
/** 同时返回的时候,也是 IndirectTaskResult, 但这里的内容是 blockId */
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId,
new ChunkedByteBuffer(serializedDirectResult.duplicate()),
StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
serializedDirectResult
}
}
/** executor 执行 task 结束,ExecutorBackend 把最终的结果更新发送给 SchedulerBackend, 后面分析 SchedulerBackend 接收到 */
/** 任务结束后的动作 */
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case ffe: FetchFailedException =>
val reason = ffe.toTaskFailedReason
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
case _: InterruptedException if task.killed =>
logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
case t: Throwable =>
/** Attempt to exit cleanly by informing the driver of our failure. */
/** If anything goes wrong (or this was a fatal exception), we will delegate to*/
/** the default uncaught exception handler, which will terminate the Executor. */
logError(s"Exception in $taskName (TID $taskId)", t)
/** Collect latest accumulator values to report back to the driver */
val accums: Seq[AccumulatorV2[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
Seq.empty
}
val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
val serializedTaskEndReason = {
try {
ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
} catch {
case _: NotSerializableException =>
/** t is not serializable so just send the stacktrace */
ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
}
}
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
/** Don't forcibly exit unless the exception was inherently fatal, to avoid */
/** stopping other tasks unnecessarily. */
if (Utils.isFatalError(t)) {
SparkUncaughtExceptionHandler.uncaughtException(t)
}
} finally {
runningTasks.remove(taskId)
}
}
}
Task 的 run 方法
上一节提到,TaskRunner 类的 run 方法中,真正执行 task 的过程,是执行类 Task 的对象 task 的 run 方法。
/** Called by [[org.apache.spark.executor.Executor]] to run this task. */
/** @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext. */
/** @param attemptNumber how many times this task has been attempted (0 for the first attempt) */
/** @return the result of the task along with updates of Accumulators. */
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
stageId,
partitionId,
taskAttemptId,
attemptNumber,
taskMemoryManager,
localProperties,
metricsSystem,
metrics)
/** 实例化 TaskContext,后面 runTask 使用 */
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId),
Option(taskAttemptId), Option(attemptNumber)).setCurrentContext()
try {
/** 运行 task, 由于 task 类型分 ShuffleMapTask 和 ResultTask 两种,这两个类分别实现了 runTask 接口 */
runTask(context)
} catch {
case e: Throwable =>
/** Catch all errors; run task failure callbacks, and rethrow the exception. */
try {
context.markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
throw e
} finally {
/** Call the task completion callbacks. */
/** 触发 task completion callback, 什么时候绑定的 onComplete 等事件呢,看代码是在生成 RDD 的时候绑定的 */
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
/** Release memory used by this thread for unrolling blocks */
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
/** Notify any tasks waiting for execution memory to be freed to wake up and try to */
/** acquire memory again. This makes impossible the scenario where a task sleeps forever */
/** because there are no other tasks left to notify it. Since this is safe to do but may */
/** not be strictly necessary, we should revisit whether we can remove this in the future. */
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
TaskContext.unset()
}
}
}
/** ShuffleMapTask 对 runTask 的实现, 可以看到这里并没有通过线程等执行代码的过程, 只是把 RDD 和广播变量逆序列化 */
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
var writer: ShuffleWriter[Any, Any] = null
try {
val manager = SparkEnv.get.shuffleManager
/** 这里不再分析 shuffleManager 根据 rdd 和 dep 生成 ShuffleWriter */
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
/** ReslutTask 对 runTask 接口的实现 */
override def runTask(context: TaskContext): U = {
/** Deserialize the RDD and the func using the broadcast variables. */
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
/** 这里逆序列化 rdd 、 func 和 广播变量 taskBinary */
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
/** 执行代码, 至此,我们找到了真正执行 task 的地方, 剩下的工作,只是在类 Executor 的 launchTask 方法中 */
/** 通过线程池运行这个 TaskRunner */
func(context, rdd.iterator(partition, context))
}
后面我们再补充一下前面留下来没有分析的流程
ExecutorBackend 与 SchedulerBackend 的消息传递
从前面的分析,我们大概猜测出来,ExecutorBackend 是计算节点上,负责起 task;而 SchedulerBackend 则负责处理 driver 端 对 stage 以及 task 相关的计算。executor 启动过程以后的博客中再梳理,这里只介绍消息传递。
在 executor 启动时,ExecutorBackend 在 onStart 函数中向 driver 发送 RegisterExecutor 事件;这个事件被 SchedulerBackend 接收后,先给 ExecutorBackend 发送 RegisteredExecutor 事件(会创建 executor),然后向消息总线 listenerBus 发送 SparkListenerExecutorAdded 事件,最后调用 makeOffers 方法(这个方法里,如果当前 activeExecutors 不为空,会在 TaskSchedulerImpl 的方法 resourceOffers 里将 hasLaunchedTask 设为 true)。在 makeOffers 方法中,就会去 launchTasks
/** CoarseGrainedSchedulerBackend 的接收消息并回复方法 */
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
/** 接收到来自 ExecutorBackend 的 onStart 方法中的 RegisterExecutor 事件 */
case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls) =>
if (executorDataMap.contains(executorId)) {
executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
context.reply(true)
} else {
/** If the executor's rpc env is not listening for incoming connections, `hostPort` */
/** will be null, and the client connection should be used to contact the executor. */
val executorAddress = if (executorRef.address != null) {
executorRef.address
} else {
context.senderAddress
}
logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId")
addressToExecutorId(executorAddress) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
val data = new ExecutorData(executorRef, executorRef.address, hostname,
cores, cores, logUrls)
/** This must be synchronized because variables mutated */
/** in this block are read when requesting executors */
CoarseGrainedSchedulerBackend.this.synchronized {
executorDataMap.put(executorId, data)
if (currentExecutorIdCounter < executorId.toInt) {
currentExecutorIdCounter = executorId.toInt
}
if (numPendingExecutors > 0) {
numPendingExecutors -= 1
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
/** 向 ExecutorBackend 发送 RegisteredExecutor 事件,由 ExecutorBackend 创建 executor */
executorRef.send(RegisteredExecutor)
/** Note: some tests expect the reply to come after we put the executor in the map */
context.reply(true)
listenerBus.post(
SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
/** 在这里 LaunchTask */
makeOffers()
}
case StopDriver =>
context.reply(true)
stop()
case StopExecutors =>
logInfo("Asking each executor to shut down")
for ((_, executorData) <- executorDataMap) {
executorData.executorEndpoint.send(StopExecutor)
}
context.reply(true)
case RemoveExecutor(executorId, reason) =>
/** We will remove the executor's state and cannot restore it. However, the connection */
/** between the driver and the executor may be still alive so that the executor won't exit */
/** automatically, so try to tell the executor to stop itself. See SPARK-13519. */
executorDataMap.get(executorId).foreach(_.executorEndpoint.send(StopExecutor))
removeExecutor(executorId, reason)
context.reply(true)
case RetrieveSparkAppConfig =>
val reply = SparkAppConfig(sparkProperties,
SparkEnv.get.securityManager.getIOEncryptionKey())
context.reply(reply)
}
/** CoarseGrainedSchedulerBackend 的接收消息方法 */
override def receive: PartialFunction[Any, Unit] = {
/** 接收来自 ExecutorBackend 的更新任务状态的方法 statusUpdate 发来的 StatusUpdate 事件 */
case StatusUpdate(executorId, taskId, state, data) =>
scheduler.statusUpdate(taskId, state, data.value)
if (TaskState.isFinished(state)) {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.freeCores += scheduler.CPUS_PER_TASK
makeOffers(executorId)
case None =>
/** Ignoring the update since we don't know about the executor. */
logWarning(s"Ignored task status update ($taskId state $state) " +
s"from unknown executor with ID $executorId")
}
}
/** 若已经有创建过的 task,则标志位 hasLaunchedTask 为true,则 TaskSchedulerImpl 的 submitTasks 中 */
/** 调用 SchedulerBackend 的 reviveOffers 方法,向 SchedulerBackend 发送 ReviveOffers 事件, */
/** 触发这里的事件调用 makeOffers */
case ReviveOffers =>
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread))
case None =>
/** Ignoring the task kill since the executor is not registered. */
logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
}
}
/** CoarseGrainedExecutorBackend 的接收消息方法 */
override def receive: PartialFunction[Any, Unit] = {
/** 接收 SchedulerBackend 发来的事件,创建 executor */
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
try {
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
} catch {
case NonFatal(e) =>
exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
}
case RegisterExecutorFailed(message) =>
exitExecutor(1, "Slave registration failed: " + message)
/** 接收 SchedulerBackend 发来的事件,launchTask */
case LaunchTask(data) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread)
}
case StopExecutor =>
stopping.set(true)
logInfo("Driver commanded a shutdown")
/** Cannot shutdown here because an ack may need to be sent back to the caller. So send */
/** a message to self to actually do the shutdown. */
self.send(Shutdown)
case Shutdown =>
stopping.set(true)
new Thread("CoarseGrainedExecutorBackend-stop-executor") {
override def run(): Unit = {
/** executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally. */
/** However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to */
/** stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180). */
/** Therefore, we put this line in a new thread.
executor.stop()
}
}.start()
}