in mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala [290:537]
def boost(
input: RDD[Instance],
validationInput: RDD[Instance],
boostingStrategy: OldBoostingStrategy,
validate: Boolean,
seed: Long,
featureSubsetStrategy: String,
instr: Option[Instrumentation] = None):
(Array[DecisionTreeRegressionModel], Array[Double]) = {
val earlyStopModelSizeThresholdInBytes = TreeConfig.trainingEarlyStopModelSizeThresholdInBytes
lastEarlyStoppedModelSize = 0
val timer = new TimeTracker()
timer.start("total")
timer.start("init")
val sc = input.sparkContext
boostingStrategy.assertValid()
// Initialize gradient boosting parameters
val numIterations = boostingStrategy.numIterations
val baseLearners = new Array[DecisionTreeRegressionModel](numIterations)
val baseLearnerWeights = new Array[Double](numIterations)
val loss = boostingStrategy.loss
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
val validationTol = boostingStrategy.validationTol
treeStrategy.algo = OldAlgo.Regression
treeStrategy.impurity = OldVariance
require(!treeStrategy.bootstrap, "GradientBoostedTrees does not need bootstrap sampling")
treeStrategy.assertValid()
// Prepare periodic checkpointers
// Note: this is checkpointing the unweighted training error
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval(), sc, StorageLevel.MEMORY_AND_DISK)
timer.stop("init")
logDebug("##########")
logDebug("Building tree 0")
logDebug("##########")
// Initialize tree
timer.start("building tree 0")
val retaggedInput = input.retag(classOf[Instance])
timer.start("buildMetadata")
val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, treeStrategy,
numTrees = 1, featureSubsetStrategy)
timer.stop("buildMetadata")
timer.start("findSplits")
val splits = RandomForest.findSplits(retaggedInput, metadata, seed)
timer.stop("findSplits")
val bcSplits = sc.broadcast(splits)
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treePoints = TreePoint.convertToTreeRDD(
retaggedInput, splits, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
.setName("binned tree points")
val firstCounts = BaggedPoint
.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed)
.map { bagged =>
require(bagged.subsampleCounts.length == 1)
require(bagged.sampleWeight == bagged.datum.weight)
bagged.subsampleCounts.head
}.persist(StorageLevel.MEMORY_AND_DISK)
.setName("firstCounts at iter=0")
val firstBagged = treePoints.zip(firstCounts)
.map { case (treePoint, count) =>
// according to current design, treePoint.weight == baggedPoint.sampleWeight
new BaggedPoint[TreePoint](treePoint, Array(count), treePoint.weight)
}
val firstTreeModel = RandomForest.runBagged(baggedInput = firstBagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy, numTrees = 1,
featureSubsetStrategy = featureSubsetStrategy, seed = seed, instr = instr,
parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes)
.head.asInstanceOf[DecisionTreeRegressionModel]
firstCounts.unpersist()
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
var predError = computeInitialPredictionAndError(
treePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
predErrorCheckpointer.update(predError)
logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
var validationTreePoints: RDD[TreePoint] = null
var validatePredError: RDD[(Double, Double)] = null
var validatePredErrorCheckpointer: PeriodicRDDCheckpointer[(Double, Double)] = null
var bestValidateError = 0.0
if (validate) {
timer.start("init validation")
validationTreePoints = TreePoint.convertToTreeRDD(
validationInput.retag(classOf[Instance]), splits, metadata)
.persist(StorageLevel.MEMORY_AND_DISK)
validatePredError = computeInitialPredictionAndError(
validationTreePoints, firstTreeWeight, firstTreeModel, loss, bcSplits)
validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval(), sc, StorageLevel.MEMORY_AND_DISK)
validatePredErrorCheckpointer.update(validatePredError)
bestValidateError = computeWeightedError(validationTreePoints, validatePredError)
timer.stop("init validation")
}
var accTreeSize = firstTreeModel.estimatedSize
var validM = 1
var m = 1
var earlyStop = false
modelSizeHistory.clear()
modelSizeHistory.append(accTreeSize)
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
lastEarlyStoppedModelSize = accTreeSize
earlyStop = true
}
while (m < numIterations && !earlyStop) {
timer.start(s"building tree $m")
logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################")
// (label: Double, count: Int)
val labelWithCounts = BaggedPoint
.convertToBaggedRDD(treePoints, treeStrategy.subsamplingRate, numSubsamples = 1,
treeStrategy.bootstrap, (tp: TreePoint) => tp.weight, seed = seed + m)
.zip(predError)
.map { case (bagged, (pred, _)) =>
require(bagged.subsampleCounts.length == 1)
require(bagged.sampleWeight == bagged.datum.weight)
// Update labels with pseudo-residuals
val newLabel = -loss.gradient(pred, bagged.datum.label)
(newLabel, bagged.subsampleCounts.head)
}.persist(StorageLevel.MEMORY_AND_DISK)
.setName(s"labelWithCounts at iter=$m")
val bagged = treePoints.zip(labelWithCounts)
.map { case (treePoint, (newLabel, count)) =>
val newTreePoint = new TreePoint(newLabel, treePoint.binnedFeatures, treePoint.weight)
// according to current design, treePoint.weight == baggedPoint.sampleWeight
new BaggedPoint[TreePoint](newTreePoint, Array(count), treePoint.weight)
}
val model = RandomForest.runBagged(baggedInput = bagged,
metadata = metadata, bcSplits = bcSplits, strategy = treeStrategy,
numTrees = 1, featureSubsetStrategy = featureSubsetStrategy,
seed = seed + m, instr = None, parentUID = None,
earlyStopModelSizeThresholdInBytes = earlyStopModelSizeThresholdInBytes - accTreeSize)
.head.asInstanceOf[DecisionTreeRegressionModel]
labelWithCounts.unpersist()
timer.stop(s"building tree $m")
// Update partial model
baseLearners(m) = model
accTreeSize += model.estimatedSize
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate
predError = updatePredictionError(
treePoints, predError, baseLearnerWeights(m),
baseLearners(m), loss, bcSplits)
predErrorCheckpointer.update(predError)
logDebug(s"error of gbt = ${computeWeightedError(treePoints, predError)}")
if (validate) {
// Stop training early if
// 1. Reduction in error is less than the validationTol or
// 2. If the error increases, that is if the model is overfit.
// We want the model returned corresponding to the best validation error.
validatePredError = updatePredictionError(
validationTreePoints, validatePredError, baseLearnerWeights(m),
baseLearners(m), loss, bcSplits)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = computeWeightedError(validationTreePoints, validatePredError)
if (bestValidateError - currentValidateError < validationTol * Math.max(
currentValidateError, 0.01)) {
earlyStop = true
} else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError
validM = m + 1
}
}
if (!earlyStop) {
modelSizeHistory.append(accTreeSize)
if (
earlyStopModelSizeThresholdInBytes > 0
&& accTreeSize > earlyStopModelSizeThresholdInBytes
) {
earlyStop = true
validM = m + 1
lastEarlyStoppedModelSize = accTreeSize
}
}
m += 1
}
timer.stop("total")
logInfo("Internal timing for DecisionTree:")
logInfo(log"${MDC(TIMER, timer)}")
bcSplits.destroy()
treePoints.unpersist()
predErrorCheckpointer.unpersistDataSet()
predErrorCheckpointer.deleteAllCheckpoints()
if (validate) {
validationTreePoints.unpersist()
validatePredErrorCheckpointer.unpersistDataSet()
validatePredErrorCheckpointer.deleteAllCheckpoints()
}
if (earlyStop) {
// Early stop occurs in one of the 2 cases:
// - validation error increases
// - the accumulated size of trees exceeds the value of `earlyStopModelSizeThresholdInBytes`
if (accTreeSize > earlyStopModelSizeThresholdInBytes) {
logWarning(
"The boosting tree training stops early because the model size exceeds threshold."
)
}
(baseLearners.slice(0, validM), baseLearnerWeights.slice(0, validM))
} else {
(baseLearners, baseLearnerWeights)
}
}