void DefineMisraGriesSketchClass()

in tfx_bsl/cc/sketches/sketches_submodule.cc [110:224]


void DefineMisraGriesSketchClass(py::module sketch_module) {
  py::class_<MisraGriesSketch>(sketch_module, "MisraGriesSketch")
      .def(py::init([](const int& num_buckets,
                       absl::optional<std::string> invalid_utf8_placeholder,
                       absl::optional<int> large_string_thresehold,
                       absl::optional<std::string> large_string_placeholder) {
             if (large_string_thresehold.has_value() !=
                 large_string_placeholder.has_value()) {
               throw std::runtime_error(
                   "Must provide both or neither large_string_threshold and "
                   "large_string_placeholder.");
             }
             return absl::make_unique<MisraGriesSketch>(
                 num_buckets,
                 std::move(invalid_utf8_placeholder),
                 std::move(large_string_thresehold),
                 std::move(large_string_placeholder));
           }),
           py::arg("num_buckets"),
           py::arg("invalid_utf8_placeholder") = absl::nullopt,
           py::arg("large_string_threshold") = absl::nullopt,
           py::arg("large_string_placeholder") = absl::nullopt)
      .def(
          "AddValues",
          [](MisraGriesSketch& sketch,
             const std::shared_ptr<arrow::Array>& items) {
            absl::Status s = sketch.AddValues(*items);
            if (!s.ok()) {
              throw std::runtime_error(s.ToString());
            }
          },
          py::doc("Adds an array of items."),
          py::call_guard<py::gil_scoped_release>())
      .def(
          "AddValues",
          [](MisraGriesSketch& sketch,
             const std::shared_ptr<arrow::Array>& items,
             const std::shared_ptr<arrow::Array>& weights) {
            absl::Status s = sketch.AddValues(*items, *weights);
            if (!s.ok()) {
              throw std::runtime_error(s.ToString());
            }
          },
          py::doc("Adds an array of items with their associated weights. "
                  "Raises an error if the weights are not a FloatArray."),
          py::call_guard<py::gil_scoped_release>())
      .def(
          "Merge",
          [](MisraGriesSketch& sketch, MisraGriesSketch& other) {
            absl::Status s = sketch.Merge(other);
            if (!s.ok()) {
              throw std::runtime_error(s.ToString());
            }
          },
          py::doc("Merges another MisraGriesSketch into this sketch. Raises "
                  "an error if the sketches do not have the same number of "
                  "buckets."),
          py::call_guard<py::gil_scoped_release>())
      .def(
          "Estimate",
          [](MisraGriesSketch& sketch) {
            std::shared_ptr<arrow::Array> result;
            absl::Status s = sketch.Estimate(&result);
            if (!s.ok()) {
              throw std::runtime_error(s.ToString());
            }
            return result;
          },
          py::doc(
              "Creates a struct array <values, counts> of the top-k items."))
      .def(
          "Serialize",
          [](MisraGriesSketch& sketch) {
            std::string serialized;
            {
              // Release the GIL during the call to Serialize
              py::gil_scoped_release release_gil;
              serialized = sketch.Serialize();
            }
            return py::bytes(serialized);
          },
          py::doc("Serializes the sketch into a string."))
      .def_static(
          "Deserialize",
          [](absl::string_view byte_string) {
            std::unique_ptr<MisraGriesSketch> result;
            absl::Status s =
                MisraGriesSketch::Deserialize(byte_string, &result);
            if (!s.ok()) throw std::runtime_error(s.ToString());
            return result;
          },
          py::doc("Deserializes the string to a MisraGries object."),
          py::call_guard<py::gil_scoped_release>())
      // Pickle support
      .def(py::pickle(
          [](MisraGriesSketch& sketch) {  // __getstate__
            std::string serialized;
            {
              // Release the GIL during the call to Serialize
              py::gil_scoped_release release_gil;
              serialized = sketch.Serialize();
            }
            return py::bytes(serialized);
          },
          [](py::bytes byte_string) {  // __setstate__
            char* data;
            Py_ssize_t size;
            PyBytes_AsStringAndSize(byte_string.ptr(), &data, &size);
            std::unique_ptr<MisraGriesSketch> result;
            absl::Status s = MisraGriesSketch::Deserialize(
                absl::string_view(data, size), &result);
            if (!s.ok()) throw std::runtime_error(s.ToString());
            return result;
          }));
}