in src/nanotron/data/nemo_dataset/helpers.cpp [35:97]
void build_blending_indices(py::array_t<int16_t>& dataset_index,
py::array_t<int64_t>& dataset_sample_index,
py::array_t<int64_t>& dataset_num_samples,
const py::array_t<double>& weights,
const int32_t num_datasets,
const int64_t size, const bool verbose) {
/* Given multiple datasets and a weighting array, build samples
such that it follows those weights.*/
if (verbose) {
std::cout << "> building indices for blendable datasets ..." << std::endl;
}
// Get the pointer access without the checks.
auto dataset_index_ptr = dataset_index.mutable_unchecked<1>();
auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>();
auto dataset_num_samples_ptr = dataset_num_samples.mutable_unchecked<1>();
auto weights_ptr = weights.unchecked<1>();
// Initialize buffer for number of samples used for each dataset.
// int64_t current_samples[num_datasets];
// for(int64_t i = 0; i < num_datasets; ++i) {
// current_samples[i] = 0;
// }
// For each sample:
for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) {
// Determine where the max error in sampling is happening.
auto sample_idx_double = std::max(static_cast<double>(sample_idx), 1.0);
int64_t max_error_index = 0;
double max_error = weights_ptr[0] * sample_idx_double -
static_cast<double>(dataset_num_samples_ptr[0]);
for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) {
double error = weights_ptr[dataset_idx] * sample_idx_double -
static_cast<double>(dataset_num_samples_ptr[dataset_idx]);
if (error > max_error) {
max_error = error;
max_error_index = dataset_idx;
}
}
// Populate the indices.
dataset_index_ptr[sample_idx] = static_cast<int16_t>(max_error_index);
dataset_sample_index_ptr[sample_idx] = dataset_num_samples_ptr[max_error_index];
// Update the total samples.
dataset_num_samples_ptr[max_error_index] += 1;
}
// print info
if (verbose) {
std::cout << " > sample ratios:" << size << std::endl;
for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) {
auto ratio = static_cast<double>(dataset_num_samples_ptr[dataset_idx]) /
static_cast<double>(size);
std::cout << " dataset " << dataset_idx << ", input: " <<
weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl;
}
}
}