in tfx_bsl/cc/sketches/sketches_submodule.cc [110:224]
void DefineMisraGriesSketchClass(py::module sketch_module) {
py::class_<MisraGriesSketch>(sketch_module, "MisraGriesSketch")
.def(py::init([](const int& num_buckets,
absl::optional<std::string> invalid_utf8_placeholder,
absl::optional<int> large_string_thresehold,
absl::optional<std::string> large_string_placeholder) {
if (large_string_thresehold.has_value() !=
large_string_placeholder.has_value()) {
throw std::runtime_error(
"Must provide both or neither large_string_threshold and "
"large_string_placeholder.");
}
return absl::make_unique<MisraGriesSketch>(
num_buckets,
std::move(invalid_utf8_placeholder),
std::move(large_string_thresehold),
std::move(large_string_placeholder));
}),
py::arg("num_buckets"),
py::arg("invalid_utf8_placeholder") = absl::nullopt,
py::arg("large_string_threshold") = absl::nullopt,
py::arg("large_string_placeholder") = absl::nullopt)
.def(
"AddValues",
[](MisraGriesSketch& sketch,
const std::shared_ptr<arrow::Array>& items) {
absl::Status s = sketch.AddValues(*items);
if (!s.ok()) {
throw std::runtime_error(s.ToString());
}
},
py::doc("Adds an array of items."),
py::call_guard<py::gil_scoped_release>())
.def(
"AddValues",
[](MisraGriesSketch& sketch,
const std::shared_ptr<arrow::Array>& items,
const std::shared_ptr<arrow::Array>& weights) {
absl::Status s = sketch.AddValues(*items, *weights);
if (!s.ok()) {
throw std::runtime_error(s.ToString());
}
},
py::doc("Adds an array of items with their associated weights. "
"Raises an error if the weights are not a FloatArray."),
py::call_guard<py::gil_scoped_release>())
.def(
"Merge",
[](MisraGriesSketch& sketch, MisraGriesSketch& other) {
absl::Status s = sketch.Merge(other);
if (!s.ok()) {
throw std::runtime_error(s.ToString());
}
},
py::doc("Merges another MisraGriesSketch into this sketch. Raises "
"an error if the sketches do not have the same number of "
"buckets."),
py::call_guard<py::gil_scoped_release>())
.def(
"Estimate",
[](MisraGriesSketch& sketch) {
std::shared_ptr<arrow::Array> result;
absl::Status s = sketch.Estimate(&result);
if (!s.ok()) {
throw std::runtime_error(s.ToString());
}
return result;
},
py::doc(
"Creates a struct array <values, counts> of the top-k items."))
.def(
"Serialize",
[](MisraGriesSketch& sketch) {
std::string serialized;
{
// Release the GIL during the call to Serialize
py::gil_scoped_release release_gil;
serialized = sketch.Serialize();
}
return py::bytes(serialized);
},
py::doc("Serializes the sketch into a string."))
.def_static(
"Deserialize",
[](absl::string_view byte_string) {
std::unique_ptr<MisraGriesSketch> result;
absl::Status s =
MisraGriesSketch::Deserialize(byte_string, &result);
if (!s.ok()) throw std::runtime_error(s.ToString());
return result;
},
py::doc("Deserializes the string to a MisraGries object."),
py::call_guard<py::gil_scoped_release>())
// Pickle support
.def(py::pickle(
[](MisraGriesSketch& sketch) { // __getstate__
std::string serialized;
{
// Release the GIL during the call to Serialize
py::gil_scoped_release release_gil;
serialized = sketch.Serialize();
}
return py::bytes(serialized);
},
[](py::bytes byte_string) { // __setstate__
char* data;
Py_ssize_t size;
PyBytes_AsStringAndSize(byte_string.ptr(), &data, &size);
std::unique_ptr<MisraGriesSketch> result;
absl::Status s = MisraGriesSketch::Deserialize(
absl::string_view(data, size), &result);
if (!s.ok()) throw std::runtime_error(s.ToString());
return result;
}));
}