Apache Spark-1.0.0浅析,六:资源调度——Task执行

前面说到向executorActor(task.executorID)发送LaunchTask(task)消息,在CoarseGrainedExecutorBackend中定义receive接收launchTask消息,执行executor.launchTask

override def receive = {
    case RegisteredExecutor(sparkProperties) =>
      logInfo("Successfully registered with driver")
      // Make this host instead of hostPort ?
      executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
        false)

    case RegisterExecutorFailed(message) =>
      logError("Slave registration failed: " + message)
      System.exit(1)

    case LaunchTask(taskDesc) =>
      logInfo("Got assigned task " + taskDesc.taskId)
      if (executor == null) {
        logError("Received LaunchTask command but executor was null")
        System.exit(1)
      } else {
        executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
      }

    case KillTask(taskId, _, interruptThread) =>
      if (executor == null) {
        logError("Received KillTask command but executor was null")
        System.exit(1)
      } else {
        executor.killTask(taskId, interruptThread)
      }

    case x: DisassociatedEvent =>
      logError(s"Driver $x disassociated! Shutting down.")
      System.exit(1)

    case StopExecutor =>
      logInfo("Driver commanded a shutdown")
      context.stop(self)
      context.system.shutdown()
  }

launchTask首先实例化TaskRunner,因为其继承自Runnable,所以在线程池threadPool中建立线程时,会在该独立运行的线程中自动执行run()

def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
    val tr = new TaskRunner(context, taskId, serializedTask)
    runningTasks.put(taskId, tr)
    threadPool.execute(tr)
  }

这个方法中,首先反序列化Task,task.run,运行过程中记录一些度量,如果有Accumulator,则更新其值

override def run() {
      val startTime = System.currentTimeMillis()
      SparkEnv.set(env)
      Thread.currentThread.setContextClassLoader(replClassLoader)
      val ser = SparkEnv.get.closureSerializer.newInstance()
      logInfo("Running task ID " + taskId)
      execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
      var attemptedTask: Option[Task[Any]] = None
      var taskStart: Long = 0
      def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
      val startGCTime = gcTime

      try {
        SparkEnv.set(env)
        Accumulators.clear()
        val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
        updateDependencies(taskFiles, taskJars)
        task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)

        // If this task has been killed before we deserialized it, let's quit now. Otherwise,
        // continue executing the task.
        if (killed) {
          // Throw an exception rather than returning, because returning within a try{} block
          // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
          // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
          // for the task.
          throw new TaskKilledException
        }

        attemptedTask = Some(task)
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.updateEpoch(task.epoch)

        // Run the actual task and measure its runtime.
        taskStart = System.currentTimeMillis()
        val value = task.run(taskId.toInt)
        val taskFinish = System.currentTimeMillis()

        // If the task has been killed, let's fail it.
        if (task.killed) {
          throw new TaskKilledException
        }

        val resultSer = SparkEnv.get.serializer.newInstance()
        val beforeSerialization = System.currentTimeMillis()
        val valueBytes = resultSer.serialize(value)
        val afterSerialization = System.currentTimeMillis()

        for (m <- task.metrics) {
          m.hostname = Utils.localHostName()
          m.executorDeserializeTime = taskStart - startTime
          m.executorRunTime = taskFinish - taskStart
          m.jvmGCTime = gcTime - startGCTime
          m.resultSerializationTime = afterSerialization - beforeSerialization
        }

        val accumUpdates = Accumulators.values

        val directResult = new DirectTaskResult(valueBytes, accumUpdates,
          task.metrics.getOrElse(null))
        val serializedDirectResult = ser.serialize(directResult)
        logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
        val serializedResult = {
          if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
            logInfo("Storing result for " + taskId + " in local BlockManager")
            val blockId = TaskResultBlockId(taskId)
            env.blockManager.putBytes(
              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
            ser.serialize(new IndirectTaskResult[Any](blockId))
          } else {
            logInfo("Sending result for " + taskId + " directly to driver")
            serializedDirectResult
          }
        }

        execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
        logInfo("Finished task ID " + taskId)
      } catch {
        case ffe: FetchFailedException => {
          val reason = ffe.toTaskEndReason
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
        }

        case _: TaskKilledException | _: InterruptedException if task.killed => {
          logInfo("Executor killed task " + taskId)
          execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
        }

        case t: Throwable => {
          // Attempt to exit cleanly by informing the driver of our failure.
          // If anything goes wrong (or this was a fatal exception), we will delegate to
          // the default uncaught exception handler, which will terminate the Executor.
          logError("Exception in task ID " + taskId, t)

          val serviceTime = System.currentTimeMillis() - taskStart
          val metrics = attemptedTask.flatMap(t => t.metrics)
          for (m <- metrics) {
            m.executorRunTime = serviceTime
            m.jvmGCTime = gcTime - startGCTime
          }
          val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
          execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))

          // Don't forcibly exit unless the exception was inherently fatal, to avoid
          // stopping other tasks unnecessarily.
          if (Utils.isFatalError(t)) {
            ExecutorUncaughtExceptionHandler.uncaughtException(t)
          }
        }
      } finally {
        // TODO: Unregister shuffle memory only for ResultTask
        val shuffleMemoryMap = env.shuffleMemoryMap
        shuffleMemoryMap.synchronized {
          shuffleMemoryMap.remove(Thread.currentThread().getId)
        }
        runningTasks.remove(taskId)
      }
    }

看一下updateDependencies,其作用是解决外部依赖,下载依赖的文件和JARs,在Spark初始化时,SparkEnv创建临时目录,如果是local模式,创建临时文件夹,如果是distributed模式,则为executor当前工作目录

/**
   * Download any missing dependencies if we receive a new set of files and JARs from the
   * SparkContext. Also adds any new JARs we fetched to the class loader.
   */
  private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
    synchronized {
      // Fetch missing dependencies
      for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
        currentFiles(name) = timestamp
      }
      for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
        logInfo("Fetching " + name + " with timestamp " + timestamp)
        Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager)
        currentJars(name) = timestamp
        // Add it to our class loader
        val localName = name.split("/").last
        val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
        if (!urlClassLoader.getURLs.contains(url)) {
          logInfo("Adding " + url + " to class loader")
          urlClassLoader.addURL(url)
        }
      }
    }
  }

最后看task.run,final定义不允许其再被override,实例化TaskContext,以context对象为参数,进一步执行runTask

final def run(attemptId: Long): T = {
    context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
    taskThread = Thread.currentThread()
    if (_killed) {
      kill(interruptThread = false)
    }
    runTask(context)
  }

《Task创建和分发》中提到,创建的任务分为Shuffle Map Task和Result Task,它们都继承了抽象类Task,下面具体看Shuffle Map Task和Result Task中runTask的代码实现

ResultTask.runTask比较简单,通过TaskContext获取一些度量,func是resultTask实例化时的参数job.func,最后回调executeOnCompleteCallbacks,Task执行完成

override def runTask(context: TaskContext): U = {
    metrics = Some(context.taskMetrics)
    try {
      func(context, rdd.iterator(split, context))
    } finally {
      context.executeOnCompleteCallbacks()
    }
  }

RDD.iterator的定义如下,final前缀使它不能有其他实现,iterrator将从cache中读取RDD或者计算它,首先判断storageLevel,如果没有限制任何存储级别,则调用getOrCompute方法实现读取或计算RDD分片,接着,如果有存储级别设定,则调用computeOrReadCheckpoint计算或者读取Checkpoint信息

/**
   * Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
   * This should ''not'' be called by users directly, but is available for implementors of custom
   * subclasses of RDD.
   */
  final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
    if (storageLevel != StorageLevel.NONE) {
      SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
    } else {
      computeOrReadCheckpoint(split, context)
    }
  }

可以发现computeValues也是通过调用rdd.computeOrReadCheckpoint的

def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext,
      storageLevel: StorageLevel): Iterator[T] = {
    val key = RDDBlockId(rdd.id, split.index)
    logDebug("Looking for partition " + key)
    blockManager.get(key) match {
      case Some(values) =>
        // Partition is already materialized, so just return its values
        new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])

      case None =>
        // Mark the split as loading (unless someone else marks it first)
        loading.synchronized {
          if (loading.contains(key)) {
            logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
            while (loading.contains(key)) {
              try {
                loading.wait()
              } catch {
                case e: Exception =>
                  logWarning(s"Got an exception while waiting for another thread to load $key", e)
              }
            }
            logInfo("Finished waiting for %s".format(key))
            /* See whether someone else has successfully loaded it. The main way this would fail
             * is for the RDD-level cache eviction policy if someone else has loaded the same RDD
             * partition but we didn't want to make space for it. However, that case is unlikely
             * because it's unlikely that two threads would work on the same RDD partition. One
             * downside of the current code is that threads wait serially if this does happen. */
            blockManager.get(key) match {
              case Some(values) =>
                return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
              case None =>
                logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
                loading.add(key)
            }
          } else {
            loading.add(key)
          }
        }
        try {
          // If we got here, we have to load the split
          logInfo("Partition %s not found, computing it".format(key))
          val computedValues = rdd.computeOrReadCheckpoint(split, context)

          // Persist the result, so long as the task is not running locally
          if (context.runningLocally) {
            return computedValues
          }

          // Keep track of blocks with updated statuses
          var updatedBlocks = Seq[(BlockId, BlockStatus)]()
          val returnValue: Iterator[T] = {
            if (storageLevel.useDisk && !storageLevel.useMemory) {
              /* In the case that this RDD is to be persisted using DISK_ONLY
               * the iterator will be passed directly to the blockManager (rather then
               * caching it to an ArrayBuffer first), then the resulting block data iterator
               * will be passed back to the user. If the iterator generates a lot of data,
               * this means that it doesn't all have to be held in memory at one time.
               * This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
               * blocks aren't dropped by the block store before enabling that. */
              updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
              blockManager.get(key) match {
                case Some(values) =>
                  values.asInstanceOf[Iterator[T]]
                case None =>
                  logInfo("Failure to store %s".format(key))
                  throw new Exception("Block manager failed to return persisted valued")
              }
            } else {
              // In this case the RDD is cached to an array buffer. This will save the results
              // if we're dealing with a 'one-time' iterator
              val elements = new ArrayBuffer[Any]
              elements ++= computedValues
              updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true)
              elements.iterator.asInstanceOf[Iterator[T]]
            }
          }

          // Update task metrics to include any blocks whose storage status is updated
          val metrics = context.taskMetrics
          metrics.updatedBlocks = Some(updatedBlocks)

          new InterruptibleIterator(context, returnValue)

        } finally {
          loading.synchronized {
            loading.remove(key)
            loading.notifyAll()
          }
        }
    }
  }

而computeOrReadCheckpoint中,判断如果Checkpointed,则返回firstParent.iterator(因checkpoint时,之前的parent信息全被擦除,所以firstparent即checkpoint当时的RDD),否则调用compute

/**
   * Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
   */
  private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
  {
    if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)
  }

所谓的StorageLevel在StorageLevel.scala中定义如下

class StorageLevel private(
private var useDisk_ : Boolean,
private var useMemory_ : Boolean,
private var useOffHeap_ : Boolean,
private var deserialized_ : Boolean,
private var replication_ : Int = 1)
extends Externalizable
  val NONE = new StorageLevel(false, false, false, false)
  val DISK_ONLY = new StorageLevel(true, false, false, false)
  val DISK_ONLY_2 = new StorageLevel(true, false, false, false, 2)
  val MEMORY_ONLY = new StorageLevel(false, true, false, true)
  val MEMORY_ONLY_2 = new StorageLevel(false, true, false, true, 2)
  val MEMORY_ONLY_SER = new StorageLevel(false, true, false, false)
  val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, false, 2)
  val MEMORY_AND_DISK = new StorageLevel(true, true, false, true)
  val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2)
  val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false)
  val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2)
  val OFF_HEAP = new StorageLevel(false, false, true, false)

最后RDD.compute对RDD的给定partition计算结果,这个方法将在各种RDD的子类中实现

/**
   * :: DeveloperApi ::
   * Implemented by subclasses to compute a given partition.
   */
  @DeveloperApi
  def compute(split: Partition, context: TaskContext): Iterator[T]

返回到ShuffleMapTask.runTask,与ResultTask.runTask的区别主要在于,需要获得partition的数量,定义ShuffleWriterGroup,获得所有shuffle blocks的block writer,对于每个RDD split和对应TaskContext,调用rdd.iterator,使用ShuffleWriterGroup.writer写入相关buckets

override def runTask(context: TaskContext): MapStatus = {
    val numOutputSplits = dep.partitioner.numPartitions
    metrics = Some(context.taskMetrics)

    val blockManager = SparkEnv.get.blockManager
    val shuffleBlockManager = blockManager.shuffleBlockManager
    var shuffle: ShuffleWriterGroup = null
    var success = false

    try {
      // Obtain all the block writers for shuffle blocks.
      val ser = Serializer.getSerializer(dep.serializer)
      shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)

      // Write the map output to its associated buckets.
      for (elem <- rdd.iterator(split, context)) {
        val pair = elem.asInstanceOf[Product2[Any, Any]]
        val bucketId = dep.partitioner.getPartition(pair._1)
        shuffle.writers(bucketId).write(pair)
      }

      // Commit the writes. Get the size of each bucket block (total block size).
      var totalBytes = 0L
      var totalTime = 0L
      val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
        writer.commit()
        writer.close()
        val size = writer.fileSegment().length
        totalBytes += size
        totalTime += writer.timeWriting()
        MapOutputTracker.compressSize(size)
      }

      // Update shuffle metrics.
      val shuffleMetrics = new ShuffleWriteMetrics
      shuffleMetrics.shuffleBytesWritten = totalBytes
      shuffleMetrics.shuffleWriteTime = totalTime
      metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)

      success = true
      new MapStatus(blockManager.blockManagerId, compressedSizes)
    } catch { case e: Exception =>
      // If there is an exception from running the task, revert the partial writes
      // and throw the exception upstream to Spark.
      if (shuffle != null && shuffle.writers != null) {
        for (writer <- shuffle.writers) {
          writer.revertPartialWrites()
          writer.close()
        }
      }
      throw e
    } finally {
      // Release the writers back to the shuffle block manager.
      if (shuffle != null && shuffle.writers != null) {
        try {
          shuffle.releaseWriters(success)
        } catch {
          case e: Exception => logError("Failed to release shuffle writers", e)
        }
      }
      // Execute the callbacks on task completion.
      context.executeOnCompleteCallbacks()
    }
  }

至此Task执行完毕。

附:调度和执行过程console输出如下,很明显可以看出,整个job从reduceByKey分成两个stage,当提交foreach时,发现有依赖parent stage,故首先提交前一个stage,并执行task,然后提交第二个stage,执行对应task。相关RDD的lineage可参考toDebugString的输出

15/07/24 12:49:45 INFO spark.SparkContext: Starting job: foreach at LocalWordCount.scala:15
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Registering RDD 4 (reduceByKey at LocalWordCount.scala:13)
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Got job 0 (foreach at LocalWordCount.scala:15) with 1 output partitions (allowLocal=false)
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Final stage: Stage 0(foreach at LocalWordCount.scala:15)
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Parents of final stage: List(Stage 1)
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Missing parents: List(Stage 1)
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Submitting Stage 1 (MapPartitionsRDD[4] at reduceByKey at LocalWordCount.scala:13), which has no missing parents
15/07/24 12:49:45 INFO scheduler.DAGScheduler: Submitting 1 missing tasks from Stage 1 (MapPartitionsRDD[4] at reduceByKey at LocalWordCount.scala:13)
15/07/24 12:49:45 INFO scheduler.TaskSchedulerImpl: Adding task set 1.0 with 1 tasks
15/07/24 12:49:45 INFO scheduler.TaskSetManager: Starting task 1.0:0 as TID 0 on executor localhost: localhost (PROCESS_LOCAL)
15/07/24 12:49:45 INFO scheduler.TaskSetManager: Serialized task 1.0:0 as 2067 bytes in 3 ms
15/07/24 12:49:45 INFO executor.Executor: Running task ID 0
15/07/24 12:49:45 INFO storage.BlockManager: Found block broadcast_0 locally
15/07/24 12:49:45 INFO rdd.HadoopRDD: Input split: file:/D:/IdeaProjects/spark-1.0.0/README.md:0+4221
15/07/24 12:49:46 INFO executor.Executor: Serialized size of result for 0 is 784
15/07/24 12:49:46 INFO executor.Executor: Sending result for 0 directly to driver
15/07/24 12:49:46 INFO executor.Executor: Finished task ID 0
15/07/24 12:49:46 INFO scheduler.TaskSetManager: Finished TID 0 in 300 ms on localhost (progress: 1/1)
15/07/24 12:49:46 INFO scheduler.DAGScheduler: Completed ShuffleMapTask(1, 0)
15/07/24 12:49:46 INFO scheduler.TaskSchedulerImpl: Removed TaskSet 1.0, whose tasks have all completed, from pool 
15/07/24 12:49:46 INFO scheduler.DAGScheduler: Stage 1 (reduceByKey at LocalWordCount.scala:13) finished in 0.317 s
15/07/24 12:49:46 INFO scheduler.DAGScheduler: looking for newly runnable stages
15/07/24 12:49:46 INFO scheduler.DAGScheduler: running: Set()
15/07/24 12:49:46 INFO scheduler.DAGScheduler: waiting: Set(Stage 0)
15/07/24 12:49:46 INFO scheduler.DAGScheduler: failed: Set()
15/07/24 12:49:46 INFO scheduler.DAGScheduler: Missing parents for Stage 0: List()
15/07/24 12:49:46 INFO scheduler.DAGScheduler: Submitting Stage 0 (MapPartitionsRDD[6] at reduceByKey at LocalWordCount.scala:13), which is now runnable
15/07/24 12:49:46 INFO scheduler.DAGScheduler: Submitting 1 missing tasks from Stage 0 (MapPartitionsRDD[6] at reduceByKey at LocalWordCount.scala:13)
15/07/24 12:49:46 INFO scheduler.TaskSchedulerImpl: Adding task set 0.0 with 1 tasks
15/07/24 12:49:46 INFO scheduler.TaskSetManager: Starting task 0.0:0 as TID 1 on executor localhost: localhost (PROCESS_LOCAL)
15/07/24 12:49:46 INFO scheduler.TaskSetManager: Serialized task 0.0:0 as 1943 bytes in 0 ms
15/07/24 12:49:46 INFO executor.Executor: Running task ID 1
15/07/24 12:49:46 INFO storage.BlockManager: Found block broadcast_0 locally
15/07/24 12:49:46 INFO storage.BlockFetcherIterator$BasicBlockFetcherIterator: maxBytesInFlight: 50331648, targetRequestSize: 10066329
15/07/24 12:49:46 INFO storage.BlockFetcherIterator$BasicBlockFetcherIterator: Getting 1 non-empty blocks out of 1 blocks
15/07/24 12:49:46 INFO storage.BlockFetcherIterator$BasicBlockFetcherIterator: Started 0 remote fetches in 4 ms

RDD count的lineage如下

MapPartitionsRDD[6] at reduceByKey at LocalWordCount.scala:13 (1 partitions)
  ShuffledRDD[5] at reduceByKey at LocalWordCount.scala:13 (1 partitions)
    MapPartitionsRDD[4] at reduceByKey at LocalWordCount.scala:13 (1 partitions)
      MappedRDD[3] at map at LocalWordCount.scala:12 (1 partitions)
        FlatMappedRDD[2] at flatMap at LocalWordCount.scala:12 (1 partitions)
          MappedRDD[1] at textFile at LocalWordCount.scala:11 (1 partitions)
            HadoopRDD[0] at textFile at LocalWordCount.scala:11 (1 partitions)

END