src/types/tdigest.h (77 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. * */ #pragma once #include <fmt/format.h> #include <vector> #include "common/status.h" struct Centroid { double mean; double weight = 1.0; // merge with another centroid void Merge(const Centroid& centroid) { weight += centroid.weight; mean += (centroid.mean - mean) * centroid.weight / weight; } std::string ToString() const { return fmt::format("centroid<mean: {}, weight: {}>", mean, weight); } explicit Centroid() = default; explicit Centroid(double mean, double weight) : mean(mean), weight(weight) {} }; struct CentroidsWithDelta { std::vector<Centroid> centroids; uint64_t delta; double min; double max; double total_weight; }; StatusOr<CentroidsWithDelta> TDigestMerge(const std::vector<CentroidsWithDelta>& centroids_list); StatusOr<CentroidsWithDelta> TDigestMerge(const std::vector<double>& buffer, const CentroidsWithDelta& centroid_list); /** TD should looks like below: class TDSample { public: struct Iterator { Iterator* Clone() const; bool Next(); bool Valid() const; StatusOr<Centroid> GetCentroid() const; }; Iterator* Begin(); Iterator* End(); double TotalWeight(); double Min() const; double Max() const; }; **/ // a numerically stable lerp is unbelievably complex // but we are *approximating* the quantile, so let's keep it simple // reference: // https://github.com/apache/arrow/blob/27bbd593625122a4a25d9471c8aaf5df54a6dcf9/cpp/src/arrow/util/tdigest.cc#L38 static inline double Lerp(double a, double b, double t) { return a + t * (b - a); } template <typename TD> inline StatusOr<double> TDigestQuantile(TD&& td, double q) { if (q < 0 || q > 1 || td.Size() == 0) { return Status{Status::InvalidArgument, "invalid quantile or empty tdigest"}; } const double index = q * td.TotalWeight(); if (index <= 1) { return td.Min(); } else if (index >= td.TotalWeight() - 1) { return td.Max(); } // find centroid contains the index double weight_sum = 0; auto iter = td.Begin(); for (; iter->Valid(); iter->Next()) { weight_sum += GET_OR_RET(iter->GetCentroid()).weight; if (index <= weight_sum) { break; } } // since index is in (1, total_weight - 1), iter should be valid if (!iter->Valid()) { return Status{Status::InvalidArgument, "invalid iterator during decoding tdigest centroid"}; } auto centroid = GET_OR_RET(iter->GetCentroid()); // deviation of index from the centroid center double diff = index + centroid.weight / 2 - weight_sum; // index happen to be in a unit weight centroid if (centroid.weight == 1 && std::abs(diff) < 0.5) { return centroid.mean; } // find adjacent centroids for interpolation auto ci_left = iter->Clone(); auto ci_right = iter->Clone(); if (diff > 0) { if (ci_right == td.End()) { // index larger than center of last bin auto c = GET_OR_RET(ci_left->GetCentroid()); CHECK(c.weight >= 2); return Lerp(c.mean, td.Max(), diff / (c.weight / 2)); } ci_right->Next(); } else { if (ci_left == td.Begin()) { // index smaller than center of first bin auto c = GET_OR_RET(ci_left->GetCentroid()); CHECK(c.weight >= 2); return Lerp(td.Min(), c.mean, index / (c.weight / 2)); } ci_left->Prev(); auto lc = GET_OR_RET(ci_left->GetCentroid()); auto rc = GET_OR_RET(ci_right->GetCentroid()); diff += lc.weight / 2 + rc.weight / 2; } auto lc = GET_OR_RET(ci_left->GetCentroid()); auto rc = GET_OR_RET(ci_right->GetCentroid()); // interpolate from adjacent centroids diff /= (lc.weight / 2 + rc.weight / 2); return Lerp(lc.mean, rc.mean, diff); }