void register_tensors()

in src/mlio-py/mlio/core/tensor.cc [88:164]


void register_tensors(py::module &m)
{
    py::enum_<Data_type>(m, "DataType")
        .value("SIZE", Data_type::size)
        .value("FLOAT16", Data_type::float16)
        .value("FLOAT32", Data_type::float32)
        .value("FLOAT64", Data_type::float64)
        .value("INT8", Data_type::int8)
        .value("INT16", Data_type::int16)
        .value("INT32", Data_type::int32)
        .value("INT64", Data_type::int64)
        .value("UINT8", Data_type::uint8)
        .value("UINT16", Data_type::uint16)
        .value("UINT32", Data_type::uint32)
        .value("UINT64", Data_type::uint64)
        .value("STRING", Data_type::string);

    py::class_<Tensor, Intrusive_ptr<Tensor>>(m,
                                              "Tensor",
                                              R"(
        Represents a multi-dimensional array.

        This is an abstract class that only defines the data type and shape
        of a Tensor. Derived types specify how the Tensor data is laid out
        in memory.
        )")
        .def("__repr__", &Tensor::repr)
        .def_property_readonly("data_type", &Tensor::data_type, "Gets the data type of the Tensor.")
        .def_property_readonly(
            "shape",
            [](Tensor &self) -> py::tuple {
                return py::cast(self.shape());
            },
            "Gets the shape of the Tensor.")
        .def_property_readonly(
            "strides",
            [](Tensor &self) -> py::tuple {
                return py::cast(self.strides());
            },
            "Gets the strides of the Tensor.");

    py::class_<Dense_tensor, Tensor, Intrusive_ptr<Dense_tensor>>(
        m,
        "DenseTensor",
        py::buffer_protocol(),
        "Represents a Tensor that stores its data in a contiguous memory "
        "block.")
        .def(py::init<>(&make_dense_tensor),
             "shape"_a,
             "data"_a,
             "strides"_a = std::nullopt,
             "copy"_a = true)
        .def_property_readonly(
            "data",
            [](Dense_tensor &self) {
                return Py_device_array{wrap_intrusive(&self), self.data()};
            },
            "Gets the data of the Tensor.")
        .def_buffer(&to_py_buffer);

    py::class_<Coo_tensor, Tensor, Intrusive_ptr<Coo_tensor>>(
        m, "CooTensor", "Represents a Tensor that stores its data in coordinate format.")
        .def(py::init<>(&make_coo_tensor), "shape"_a, "data"_a, "coords"_a, "copy"_a = true)
        .def_property_readonly(
            "data",
            [](Coo_tensor &self) {
                return Py_device_array{wrap_intrusive(&self), self.data()};
            },
            "Gets the data of the Tensor.")
        .def(
            "indices",
            [](Coo_tensor &self, std::size_t dim) {
                return Py_device_array{wrap_intrusive(&self), self.indices(dim)};
            },
            "dim"_a,
            "Gets the indices for the specified dimension.");
}