in mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala [487:765]
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
metadata: DecisionTreeMetadata,
topNodesForGroup: Map[Int, LearningNode],
nodesForGroup: Map[Int, Array[LearningNode]],
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]],
bcSplits: Broadcast[Array[Array[Split]]],
nodeStack: mutable.ListBuffer[(Int, LearningNode)],
timer: TimeTracker = new TimeTracker,
nodeIds: RDD[Array[Int]] = null,
outputBestSplits: Boolean = false): Array[Map[Int, Split]] = {
/*
* The high-level descriptions of the best split optimizations are noted here.
*
* *Group-wise training*
* We perform bin calculations for groups of nodes to reduce the number of
* passes over the data. Each iteration requires more computation and storage,
* but saves several iterations over the data.
*
* *Bin-wise computation*
* We use a bin-wise best split computation strategy instead of a straightforward best split
* computation strategy. Instead of analyzing each sample for contribution to the left/right
* child node impurity of every split, we first categorize each feature of a sample into a
* bin. We exploit this structure to calculate aggregates for bins and then use these aggregates
* to calculate information gain for each split.
*
* *Aggregation over partitions*
* Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know
* the number of splits in advance. Thus, we store the aggregates (at the appropriate
* indices) in a single array for all bins and rely upon the RDD aggregate method to
* drastically reduce the communication overhead.
*/
val useNodeIdCache = nodeIds != null
// numNodes: Number of nodes in this group
val numNodes = nodesForGroup.values.map(_.length).sum
logDebug(s"numNodes = $numNodes")
logDebug(s"numFeatures = ${metadata.numFeatures}")
logDebug(s"numClasses = ${metadata.numClasses}")
logDebug(s"isMulticlass = ${metadata.isMulticlass}")
logDebug(s"isMulticlassWithCategoricalFeatures = " +
s"${metadata.isMulticlassWithCategoricalFeatures}")
logDebug(s"using nodeIdCache = $useNodeIdCache")
/*
* Performs a sequential aggregation over a partition for a particular tree and node.
*
* For each feature, the aggregate sufficient statistics are updated for the relevant
* bins.
*
* @param treeIndex Index of the tree that we want to perform aggregation for.
* @param nodeInfo The node info for the tree node.
* @param agg Array storing aggregate calculation, with a set of sufficient statistics
* for each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
*/
def nodeBinSeqOp(
treeIndex: Int,
nodeInfo: NodeIndexInfo,
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint],
splits: Array[Array[Split]]): Unit = {
if (nodeInfo != null) {
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val featuresForNode = nodeInfo.featureSubset
val numSamples = baggedPoint.subsampleCounts(treeIndex)
val sampleWeight = baggedPoint.sampleWeight
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, numSamples, sampleWeight,
featuresForNode)
} else {
mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
metadata.unorderedFeatures, numSamples, sampleWeight, featuresForNode)
}
agg(aggNodeIndex).updateParent(baggedPoint.datum.label, numSamples, sampleWeight)
}
}
/*
* Performs a sequential aggregation over a partition.
*
* Each data point contributes to one node. For each feature,
* the aggregate sufficient statistics are updated for the relevant bins.
*
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (node, feature, bin).
* @param baggedPoint Data point being aggregated.
* @return agg
*/
def binSeqOp(
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint],
splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex =
topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
agg, baggedPoint, splits)
}
agg
}
/**
* Do the same thing as binSeqOp, but with nodeIdCache.
*/
def binSeqOpWithNodeIdCache(
agg: Array[DTStatsAggregator],
dataPoint: (BaggedPoint[TreePoint], Array[Int]),
splits: Array[Array[Split]]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val baggedPoint = dataPoint._1
val nodeIdCache = dataPoint._2
val nodeIndex = nodeIdCache(treeIndex)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null),
agg, baggedPoint, splits)
}
agg
}
/**
* Get node index in group --> features indices map,
* which is a short cut to find feature indices for a node given node index in group.
*/
def getNodeToFeatures(
treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = {
if (!metadata.subsamplingFeatures) {
None
} else {
val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]()
treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo =>
nodeIdToNodeInfo.values.foreach { nodeIndexInfo =>
assert(nodeIndexInfo.featureSubset.isDefined)
mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get
}
}
Some(mutableNodeToFeatures.toMap)
}
}
// array of nodes to train indexed by node index in group
val nodes = new Array[LearningNode](numNodes)
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
}
}
// Calculate best splits for all nodes in the group
timer.start("chooseSplits")
// In each partition, iterate all instances and compute aggregate stats for each node,
// yield a (nodeIndex, nodeAggregateStats) pair for each node.
// After a `reduceByKey` operation,
// stats of a node will be shuffled to a particular partition and be combined together,
// then best splits for nodes are found there.
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
val partitionAggregates = if (useNodeIdCache) {
input.zip(nodeIds).mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
nodeToFeatures(nodeIndex)
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _, bcSplits.value))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
}
} else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
new DTStatsAggregator(metadata, featuresForNode)
}
// iterator all instances in current partition and update aggregate stats
points.foreach(binSeqOp(nodeStatsAggregators, _, bcSplits.value))
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.iterator.zipWithIndex.map(_.swap)
}
}
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map {
case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
}
// find best split for each node
val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, bcSplits.value, featuresForNode, nodes(nodeIndex))
(nodeIndex, (split, stats))
}.collectAsMap()
nodeToFeaturesBc.destroy()
timer.stop("chooseSplits")
val bestSplits = if (outputBestSplits) {
Array.ofDim[mutable.Map[Int, Split]](metadata.numTrees)
} else {
null
}
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
val (split: Split, stats: ImpurityStats) =
nodeToBestSplits(aggNodeIndex)
logDebug(s"best split = $split")
// Extract info for this node. Create children if not leaf.
val isLeaf =
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
node.isLeaf = isLeaf
node.stats = stats
logDebug(s"Node = $node")
if (!isLeaf) {
node.split = Some(split)
val childIsLeaf = (LearningNode.indexToLevel(nodeIndex) + 1) == metadata.maxDepth
val leftChildIsLeaf = childIsLeaf || (math.abs(stats.leftImpurity) < Utils.EPSILON)
val rightChildIsLeaf = childIsLeaf || (math.abs(stats.rightImpurity) < Utils.EPSILON)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
if (outputBestSplits) {
val bestSplitsInTree = bestSplits(treeIndex)
if (bestSplitsInTree == null) {
bestSplits(treeIndex) = mutable.Map[Int, Split](nodeIndex -> split)
} else {
bestSplitsInTree.update(nodeIndex, split)
}
}
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeStack.prepend((treeIndex, node.leftChild.get))
}
if (!rightChildIsLeaf) {
nodeStack.prepend((treeIndex, node.rightChild.get))
}
logDebug(s"leftChildIndex = ${node.leftChild.get.id}" +
s", impurity = ${stats.leftImpurity}")
logDebug(s"rightChildIndex = ${node.rightChild.get.id}" +
s", impurity = ${stats.rightImpurity}")
}
}
}
if (outputBestSplits) {
bestSplits.map { m => if (m == null) null else m.toMap }
} else {
null
}
}