in rlmeta/cc/nested_utils.cc [81:109]
py::object CollateNestedImpl(std::function<py::object(const Sequence&)> func,
const Sequence& src) {
const int64_t batch_size = src.size();
std::vector<py::tuple> flattened;
size_t index = 0;
for (int64_t i = 0; i < batch_size; ++i) {
index = 0;
VisitNestedImpl(
[batch_size, i, &flattened, &index](const py::object& obj) {
py::tuple& cur = index < flattened.size()
? flattened.at(index)
: flattened.emplace_back(batch_size);
cur[i] = obj;
++index;
},
py::reinterpret_borrow<py::object>(src[i]));
}
std::vector<py::object> collated;
collated.reserve(flattened.size());
for (const auto& x : flattened) {
collated.push_back(func(x));
}
index = 0;
return MapNestedImpl(
[&collated, &index](const py::object& /* obj */) {
return std::move(collated[index++]);
},
src[0]);
}