void VotingParallelTreeLearner::FindBestSplits()

in src/treelearner/voting_parallel_tree_learner.cpp [243:391]


void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
  // use local data to find local best splits
  std::vector<int8_t> is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
    if (!this->col_sampler_.is_feature_used_bytree()[feature_index]) continue;
    if (this->parent_leaf_histogram_array_ != nullptr
      && !this->parent_leaf_histogram_array_[feature_index].is_splittable()) {
      this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
  if (this->parent_leaf_histogram_array_ == nullptr) {
    use_subtract = false;
  }
  TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract);

  const int smaller_leaf_index = this->smaller_leaf_splits_->leaf_index();
  const data_size_t local_data_on_smaller_leaf = this->data_partition_->leaf_count(smaller_leaf_index);
  if (local_data_on_smaller_leaf <= 0) {
    // clear histogram buffer before synchronizing
    // otherwise histogram contents from the previous iteration will be sent
    OMP_INIT_EX();
    #pragma omp parallel for schedule(static)
    for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
      OMP_LOOP_EX_BEGIN();
      if (!is_feature_used[feature_index]) { continue; }
      const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index);
      const int num_bin = feature_bin_mapper->num_bin();
      const int offset = static_cast<int>(feature_bin_mapper->GetMostFreqBin() == 0);
      hist_t* hist_ptr = this->smaller_leaf_histogram_array_[feature_index].RawData();
      std::memset(reinterpret_cast<void*>(hist_ptr), 0, (num_bin - offset) * kHistEntrySize);
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  }

  if (this->larger_leaf_splits_ != nullptr) {
    const int larger_leaf_index = this->larger_leaf_splits_->leaf_index();
    if (larger_leaf_index >= 0) {
      const data_size_t local_data_on_larger_leaf = this->data_partition_->leaf_count(larger_leaf_index);
      if (local_data_on_larger_leaf <= 0) {
        OMP_INIT_EX();
        #pragma omp parallel for schedule(static)
        for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
          OMP_LOOP_EX_BEGIN();
          if (!is_feature_used[feature_index]) { continue; }
          const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index);
          const int num_bin = feature_bin_mapper->num_bin();
          const int offset = static_cast<int>(feature_bin_mapper->GetMostFreqBin() == 0);
          hist_t* hist_ptr = this->larger_leaf_histogram_array_[feature_index].RawData();
          std::memset(reinterpret_cast<void*>(hist_ptr), 0, (num_bin - offset) * kHistEntrySize);
          OMP_LOOP_EX_END();
        }
        OMP_THROW_EX();
      }
    }
  }

  std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
  std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
  double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());
  double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_.get());
  OMP_INIT_EX();
  // find splits
#pragma omp parallel for schedule(static)
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
    OMP_LOOP_EX_BEGIN();
    if (!is_feature_used[feature_index]) { continue; }
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
    this->train_data_->FixHistogram(feature_index,
      this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
      this->smaller_leaf_histogram_array_[feature_index].RawData());

    this->ComputeBestSplitForFeature(
        this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->smaller_leaf_splits_->num_data_in_leaf(),
        this->smaller_leaf_splits_.get(),
        &smaller_bestsplit_per_features[feature_index],
        smaller_leaf_parent_output);
    // only has root leaf
    if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }

    if (use_subtract) {
      this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
    } else {
      this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
        this->larger_leaf_histogram_array_[feature_index].RawData());
    }
    this->ComputeBestSplitForFeature(
        this->larger_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->larger_leaf_splits_->num_data_in_leaf(),
        this->larger_leaf_splits_.get(),
        &larger_bestsplit_per_features[feature_index],
        larger_leaf_parent_output);
    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();

  std::vector<SplitInfo> smaller_top_k_splits, larger_top_k_splits;
  // local voting
  ArrayArgs<SplitInfo>::MaxK(smaller_bestsplit_per_features, top_k_, &smaller_top_k_splits);
  ArrayArgs<SplitInfo>::MaxK(larger_bestsplit_per_features, top_k_, &larger_top_k_splits);

  std::vector<LightSplitInfo> smaller_top_k_light_splits(top_k_);
  std::vector<LightSplitInfo> larger_top_k_light_splits(top_k_);
  for (int i = 0; i < top_k_; ++i) {
    smaller_top_k_light_splits[i].CopyFrom(smaller_top_k_splits[i]);
    larger_top_k_light_splits[i].CopyFrom(larger_top_k_splits[i]);
  }

  // gather
  int offset = 0;
  for (int i = 0; i < top_k_; ++i) {
    std::memcpy(input_buffer_.data() + offset, &smaller_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
    std::memcpy(input_buffer_.data() + offset, &larger_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
  }
  Network::Allgather(input_buffer_.data(), offset, output_buffer_.data());
  // get all top-k from all machines
  std::vector<LightSplitInfo> smaller_top_k_splits_global;
  std::vector<LightSplitInfo> larger_top_k_splits_global;
  offset = 0;
  for (int i = 0; i < num_machines_; ++i) {
    for (int j = 0; j < top_k_; ++j) {
      smaller_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
      larger_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
    }
  }
  // global voting
  std::vector<int> smaller_top_features, larger_top_features;
  GlobalVoting(this->smaller_leaf_splits_->leaf_index(), smaller_top_k_splits_global, &smaller_top_features);
  GlobalVoting(this->larger_leaf_splits_->leaf_index(), larger_top_k_splits_global, &larger_top_features);
  // copy local histgrams to buffer
  CopyLocalHistogram(smaller_top_features, larger_top_features);

  // Reduce scatter for histogram
  Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
                         output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);

  this->FindBestSplitsFromHistograms(is_feature_used, false, tree);
}