py::tuple UnbatchNestedImpl()

in rlmeta/cc/nested_utils.cc [125:165]


py::tuple UnbatchNestedImpl(std::function<py::tuple(const py::object&)> func,
                            const py::object& obj, int64_t batch_size) {
  if (py::isinstance<py::tuple>(obj)) {
    const py::tuple src = py::reinterpret_borrow<py::tuple>(obj);
    const int64_t n = src.size();
    std::vector<py::tuple> children(n);
    for (int64_t i = 0; i < n; ++i) {
      children[i] = UnbatchNestedImpl(func, src[i], batch_size);
    }
    return UnbatchSequence<py::tuple>(batch_size, children);
  }

  if (py::isinstance<py::list>(obj)) {
    const py::list src = py::reinterpret_borrow<py::list>(obj);
    const int64_t n = src.size();
    std::vector<py::tuple> children(n);
    for (int64_t i = 0; i < n; ++i) {
      children[i] = UnbatchNestedImpl(func, src[i], batch_size);
    }
    return UnbatchSequence<py::list>(batch_size, children);
  }

  if (py::isinstance<py::dict>(obj)) {
    const py::dict src = py::reinterpret_borrow<py::dict>(obj);
    py::tuple dst(batch_size);
    for (int64_t i = 0; i < batch_size; ++i) {
      dst[i] = py::dict();
    }
    for (const auto [k, v] : src) {
      py::tuple cur = UnbatchNestedImpl(
          func, py::reinterpret_borrow<py::object>(v), batch_size);
      for (int64_t i = 0; i < batch_size; ++i) {
        py::dict y = py::reinterpret_borrow<py::dict>(dst[i]);
        y[k] = cur[i];
      }
    }
    return dst;
  }

  return func(obj);
}