void build_blending_indices()

in src/nanotron/data/nemo_dataset/helpers.cpp [35:97]


void build_blending_indices(py::array_t<int16_t>& dataset_index,
			    py::array_t<int64_t>& dataset_sample_index,
          py::array_t<int64_t>& dataset_num_samples,
			    const py::array_t<double>& weights,
			    const int32_t num_datasets,
			    const int64_t size, const bool verbose) {
  /* Given multiple datasets and a weighting array, build samples
   such that it follows those weights.*/

  if (verbose) {
    std::cout << "> building indices for blendable datasets ..." << std::endl;
  }

  // Get the pointer access without the checks.
  auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
  auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
  auto dataset_num_samples_ptr = dataset_num_samples.mutable_unchecked<1>();
  auto weights_ptr = weights.unchecked<1>();

  // Initialize buffer for number of samples used for each dataset.
  // int64_t current_samples[num_datasets];
  // for(int64_t i = 0; i < num_datasets; ++i) {
  //   current_samples[i] = 0;
  // }

  // For each sample:
  for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {

    // Determine where the max error in sampling is happening.
    auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
    int64_t max_error_index = 0;
    double max_error = weights_ptr[0] * sample_idx_double -
      static_cast<double>(dataset_num_samples_ptr[0]);
    for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
      double error = weights_ptr[dataset_idx] * sample_idx_double -
	static_cast<double>(dataset_num_samples_ptr[dataset_idx]);
      if (error > max_error) {
	max_error = error;
	max_error_index = dataset_idx;
      }
    }

    // Populate the indices.
    dataset_index_ptr[sample_idx] = static_cast<int16_t>(max_error_index);
    dataset_sample_index_ptr[sample_idx] = dataset_num_samples_ptr[max_error_index];

    // Update the total samples.
    dataset_num_samples_ptr[max_error_index] += 1;

  }

  // print info
  if (verbose) {
    std::cout << " > sample ratios:" << size << std::endl;
    for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
      auto ratio = static_cast<double>(dataset_num_samples_ptr[dataset_idx]) /
	static_cast<double>(size);
      std::cout << "   dataset " << dataset_idx << ", input: " <<
	weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
    }
  }

}