include/stats/histogram_binner.h (72 lines of code) (raw):

// // Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All rights reserved. // #ifndef NCCL_OFI_STATS_HISTOGRAM_BINNER #define NCCL_OFI_STATS_HISTOGRAM_BINNER #include <cassert> #include <vector> // // A linear binner creates `num_bins_arg` bins, each of size `bin_size_arg` // where the first bin's range is `min_val_arg` to `min_val_arg + bin_size_arg - // 1`. // template <typename T> class histogram_linear_binner { public: histogram_linear_binner(const T& min_val_arg, const T& bin_size_arg, std::size_t num_bins_arg) : min_val(min_val_arg),bin_size(bin_size_arg), num_bins(num_bins_arg) { } std::size_t get_bin(const T& input_val) { assert(input_val >= min_val); std::size_t bin = (input_val - min_val) / bin_size; if (bin >= num_bins) { bin = num_bins - 1; } return bin; } std::size_t get_num_bins(void) const { return num_bins; } const std::vector<T> & get_bin_ranges(void) { if (range_labels.size() == 0) { for (std::size_t i = 0 ; i < num_bins ; ++i) { T val = min_val + (i * bin_size); range_labels.insert(range_labels.end(), val); } } return range_labels; } protected: const T min_val; const T bin_size; const std::size_t num_bins; std::vector<T> range_labels; }; // // Flexible binner where the user provides the list of starting points of the // bin. Slowest binner, since a linear search is currently used to find the // right bin. This could be made log(n), but even that will be considerably // slower than the direct computation used in the linear binner. // template <typename T> class histogram_custom_binner { public: histogram_custom_binner(const std::vector<T> &ranges_arg) : ranges(ranges_arg) { } std::size_t get_bin(const T& input_val) { std::size_t ret = 0; assert(input_val >= ranges[0]); for (auto i = ranges.begin() + 1 ; i != ranges.end() ; ++i) { if (*i > input_val) { break; } ret++; } return ret; } std::size_t get_num_bins(void) const { return ranges.size(); } const std::vector<std::size_t> & get_bin_ranges(void) { return ranges; } protected: std::vector<T> ranges; }; #endif