def boost()

in mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala [290:537]


  def boost(
      input: RDD[Instance],
      validationInput: RDD[Instance],
      boostingStrategy: OldBoostingStrategy,
      validate: Boolean,
      seed: Long,
      featureSubsetStrategy: String,
      instr: Option[Instrumentation] = None):
        (Array[DecisionTreeRegressionModel], Array[Double]) = {
    val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
    lastEarlyStoppedModelSize = 0
    val timer = new TimeTracker()
    timer.start("total")
    timer.start("init")

    val sc = input.sparkContext

    boostingStrategy.assertValid()

    // Initialize gradient boosting parameters
    val numIterations = boostingStrategy.numIterations
    val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
    val baseLearnerWeights = new Array[Double](numIterations)
    val loss = boostingStrategy.loss
    val learningRate = boostingStrategy.learningRate

    // Prepare strategy for individual trees, which use regression with variance impurity.
    val treeStrategy = boostingStrategy.treeStrategy.copy
    val validationTol = boostingStrategy.validationTol
    treeStrategy.algo = OldAlgo.Regression
    treeStrategy.impurity = OldVariance
    require(!treeStrategy.bootstrap, "GradientBoostedTrees does not need bootstrap sampling")
    treeStrategy.assertValid()

    // Prepare periodic checkpointers
    // Note: this is checkpointing the unweighted training error
    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
      treeStrategy.getCheckpointInterval(), sc, StorageLevel.MEMORY_AND_DISK)

    timer.stop("init")

    logDebug("##########")
    logDebug("Building tree 0")
    logDebug("##########")

    // Initialize tree
    timer.start("building tree 0")
    val retaggedInput = input.retag(classOf[Instance])
    timer.start("buildMetadata")
    val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, treeStrategy,
      numTrees = 1, featureSubsetStrategy)
    timer.stop("buildMetadata")

    timer.start("findSplits")
    val splits = RandomForest.findSplits(retaggedInput, metadata, seed)
    timer.stop("findSplits")
    val bcSplits = sc.broadcast(splits)

    // Bin feature values (TreePoint representation).
    // Cache input RDD for speedup during multiple passes.
    val treePoints = TreePoint.convertToTreeRDD(
      retaggedInput, splits, metadata)
      .persist(StorageLevel.MEMORY_AND_DISK)
      .setName("binned tree points")

    val firstCounts = BaggedPoint
      .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
        treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed)
      .map { bagged =>
        require(bagged.subsampleCounts.length == 1)
        require(bagged.sampleWeight == bagged.datum.weight)
        bagged.subsampleCounts.head
      }.persist(StorageLevel.MEMORY_AND_DISK)
      .setName("firstCounts at iter=0")

    val firstBagged = treePoints.zip(firstCounts)
      .map { case (treePoint, count) =>
        // according to current design, treePoint.weight == baggedPoint.sampleWeight
        new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
    }

    val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
      metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
      featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
      parentUID = None,
      earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)
      .head.asInstanceOf[DecisionTreeRegressionModel]

    firstCounts.unpersist()

    val firstTreeWeight = 1.0
    baseLearners(0) = firstTreeModel
    baseLearnerWeights(0) = firstTreeWeight

    var predError = computeInitialPredictionAndError(
      treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
    predErrorCheckpointer.update(predError)
    logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")

    // Note: A model of type regression is used since we require raw prediction
    timer.stop("building tree 0")

    var validationTreePoints: RDD[TreePoint] = null
    var validatePredError: RDD[(Double, Double)] = null
    var validatePredErrorCheckpointer: PeriodicRDDCheckpointer[(Double, Double)] = null
    var bestValidateError = 0.0
    if (validate) {
      timer.start("init validation")
      validationTreePoints = TreePoint.convertToTreeRDD(
        validationInput.retag(classOf[Instance]), splits, metadata)
        .persist(StorageLevel.MEMORY_AND_DISK)
      validatePredError = computeInitialPredictionAndError(
        validationTreePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
      validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
        treeStrategy.getCheckpointInterval(), sc, StorageLevel.MEMORY_AND_DISK)
      validatePredErrorCheckpointer.update(validatePredError)
      bestValidateError = computeWeightedError(validationTreePoints, validatePredError)
      timer.stop("init validation")
    }

    var accTreeSize = firstTreeModel.estimatedSize

    var validM = 1

    var m = 1
    var earlyStop = false
    modelSizeHistory.clear()
    modelSizeHistory.append(accTreeSize)
    if (
        earlyStopModelSizeThresholdInBytes > 0
        && accTreeSize > earlyStopModelSizeThresholdInBytes
    ) {
      lastEarlyStoppedModelSize = accTreeSize
      earlyStop = true
    }
    while (m < numIterations && !earlyStop) {
      timer.start(s"building tree $m")
      logDebug("###################################################")
      logDebug("Gradient boosting tree iteration " + m)
      logDebug("###################################################")

      // (label: Double, count: Int)
      val labelWithCounts = BaggedPoint
        .convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
          treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed + m)
        .zip(predError)
        .map { case (bagged, (pred, _)) =>
          require(bagged.subsampleCounts.length == 1)
          require(bagged.sampleWeight == bagged.datum.weight)
          // Update labels with pseudo-residuals
          val newLabel = -loss.gradient(pred, bagged.datum.label)
          (newLabel, bagged.subsampleCounts.head)
        }.persist(StorageLevel.MEMORY_AND_DISK)
        .setName(s"labelWithCounts at iter=$m")

      val bagged = treePoints.zip(labelWithCounts)
        .map { case (treePoint, (newLabel, count)) =>
          val newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures, treePoint.weight)
          // according to current design, treePoint.weight == baggedPoint.sampleWeight
          new BaggedPoint[TreePoint](newTreePoint, Array(count), treePoint.weight)
        }

      val model = RandomForest.runBagged(baggedInput = bagged,
        metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
        numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
        seed = seed + m, instr = None, parentUID = None,
        earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes - accTreeSize)
        .head.asInstanceOf[DecisionTreeRegressionModel]

      labelWithCounts.unpersist()

      timer.stop(s"building tree $m")
      // Update partial model
      baseLearners(m) = model
      accTreeSize += model.estimatedSize
      // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
      //       Technically, the weight should be optimized for the particular loss.
      //       However, the behavior should be reasonable, though not optimal.
      baseLearnerWeights(m) = learningRate

      predError = updatePredictionError(
        treePoints, predError, baseLearnerWeights(m),
        baseLearners(m), loss, bcSplits)
      predErrorCheckpointer.update(predError)
      logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")

      if (validate) {
        // Stop training early if
        // 1. Reduction in error is less than the validationTol or
        // 2. If the error increases, that is if the model is overfit.
        // We want the model returned corresponding to the best validation error.

        validatePredError = updatePredictionError(
          validationTreePoints, validatePredError, baseLearnerWeights(m),
          baseLearners(m), loss, bcSplits)
        validatePredErrorCheckpointer.update(validatePredError)
        val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
        if (bestValidateError - currentValidateError < validationTol * Math.max(
          currentValidateError, 0.01)) {
          earlyStop = true
        } else if (currentValidateError < bestValidateError) {
          bestValidateError = currentValidateError
          validM = m + 1
        }
      }
      if (!earlyStop) {
        modelSizeHistory.append(accTreeSize)
        if (
            earlyStopModelSizeThresholdInBytes > 0
            && accTreeSize > earlyStopModelSizeThresholdInBytes
        ) {
          earlyStop = true
          validM = m + 1
          lastEarlyStoppedModelSize = accTreeSize
        }
      }
      m += 1
    }

    timer.stop("total")

    logInfo("Internal timing for DecisionTree:")
    logInfo(log"${MDC(TIMER, timer)}")

    bcSplits.destroy()
    treePoints.unpersist()
    predErrorCheckpointer.unpersistDataSet()
    predErrorCheckpointer.deleteAllCheckpoints()
    if (validate) {
      validationTreePoints.unpersist()
      validatePredErrorCheckpointer.unpersistDataSet()
      validatePredErrorCheckpointer.deleteAllCheckpoints()
    }

    if (earlyStop) {
      // Early stop occurs in one of the 2 cases:
      //  - validation error increases
      //  - the accumulated size of trees exceeds the value of `earlyStopModelSizeThresholdInBytes`
      if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
        logWarning(
          "The boosting tree training stops early because the model size exceeds threshold."
        )
      }
      (baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
    } else {
      (baseLearners, baseLearnerWeights)
    }
  }