at::Tensor get_item()

in nestedtensor/csrc/py_init.cpp [47:98]


at::Tensor get_item(Tensor tensor, std::vector<py::object> key) {
  if (key.size() == 0) {
    return tensor;
  }
  c10::optional<at::Tensor> first;
  if (!is_nested_tensor_impl(tensor)) {
    auto wrapped_key = py::tuple(py::cast(key));
    auto wrapped_tensor = THPVariable_Wrap(tensor);
    auto wrapped_result =
        torch::autograd::THPVariable_getitem(wrapped_tensor, wrapped_key.ptr());
    auto result = THPVariable_Unpack(wrapped_result);
    Py_DECREF(wrapped_tensor);
    Py_DECREF(wrapped_result);
    return result;
  }
  std::vector<py::object> rest;
  for (size_t i = 1; i < key.size(); i++) {
    rest.push_back(key[i]);
  }
  if (is_nested_tensor_impl(tensor) && py::isinstance<py::none>(key[0])) {
    first = get_item(tensor, py::cast<py::none>(key[0]));
  }
  if (is_nested_tensor_impl(tensor) && py::isinstance<py::int_>(key[0])) {
    first = get_item(tensor, py::cast<int64_t>(key[0]));
  }
  if (is_nested_tensor_impl(tensor) && py::isinstance<py::slice>(key[0])) {
    first = get_item(tensor, py::cast<py::slice>(key[0]));
  }
  TORCH_CHECK(
      first,
      "First entry of tuple doesn't have accepted type. ",
      py::str(key[0]));
  if (!is_nested_tensor_impl(*first)) {
    return get_item(*first, rest);
  }
  std::vector<at::Tensor> result;
  for (auto t : (*first).unbind()) {
    result.push_back(get_item(t, rest));
  }
  int64_t nested_dim = get_nested_tensor_impl(*first)->nested_dim();
  std::vector<TensorNode> result_nodes;
  if (nested_dim == 1) {
    for (auto t : result) {
      result_nodes.push_back(TensorNode(std::move(t)));
    }
  } else {
    for (auto t : result) {
      result_nodes.push_back(get_nested_tensor_structure(t));
    }
  }
  return wrap_tensor_node(TensorNode(std::move(result_nodes)));
}