py::object CollateNestedImpl()

in rlmeta/cc/nested_utils.cc [81:109]


py::object CollateNestedImpl(std::function<py::object(const Sequence&)> func,
                             const Sequence& src) {
  const int64_t batch_size = src.size();
  std::vector<py::tuple> flattened;
  size_t index = 0;
  for (int64_t i = 0; i < batch_size; ++i) {
    index = 0;
    VisitNestedImpl(
        [batch_size, i, &flattened, &index](const py::object& obj) {
          py::tuple& cur = index < flattened.size()
                               ? flattened.at(index)
                               : flattened.emplace_back(batch_size);
          cur[i] = obj;
          ++index;
        },
        py::reinterpret_borrow<py::object>(src[i]));
  }
  std::vector<py::object> collated;
  collated.reserve(flattened.size());
  for (const auto& x : flattened) {
    collated.push_back(func(x));
  }
  index = 0;
  return MapNestedImpl(
      [&collated, &index](const py::object& /* obj */) {
        return std::move(collated[index++]);
      },
      src[0]);
}