in rlmeta/cc/nested_utils.cc [125:165]
py::tuple UnbatchNestedImpl(std::function<py::tuple(const py::object&)> func,
const py::object& obj, int64_t batch_size) {
if (py::isinstance<py::tuple>(obj)) {
const py::tuple src = py::reinterpret_borrow<py::tuple>(obj);
const int64_t n = src.size();
std::vector<py::tuple> children(n);
for (int64_t i = 0; i < n; ++i) {
children[i] = UnbatchNestedImpl(func, src[i], batch_size);
}
return UnbatchSequence<py::tuple>(batch_size, children);
}
if (py::isinstance<py::list>(obj)) {
const py::list src = py::reinterpret_borrow<py::list>(obj);
const int64_t n = src.size();
std::vector<py::tuple> children(n);
for (int64_t i = 0; i < n; ++i) {
children[i] = UnbatchNestedImpl(func, src[i], batch_size);
}
return UnbatchSequence<py::list>(batch_size, children);
}
if (py::isinstance<py::dict>(obj)) {
const py::dict src = py::reinterpret_borrow<py::dict>(obj);
py::tuple dst(batch_size);
for (int64_t i = 0; i < batch_size; ++i) {
dst[i] = py::dict();
}
for (const auto [k, v] : src) {
py::tuple cur = UnbatchNestedImpl(
func, py::reinterpret_borrow<py::object>(v), batch_size);
for (int64_t i = 0; i < batch_size; ++i) {
py::dict y = py::reinterpret_borrow<py::dict>(dst[i]);
y[k] = cur[i];
}
}
return dst;
}
return func(obj);
}