in nestedtensor/csrc/py_init.cpp [47:98]
at::Tensor get_item(Tensor tensor, std::vector<py::object> key) {
if (key.size() == 0) {
return tensor;
}
c10::optional<at::Tensor> first;
if (!is_nested_tensor_impl(tensor)) {
auto wrapped_key = py::tuple(py::cast(key));
auto wrapped_tensor = THPVariable_Wrap(tensor);
auto wrapped_result =
torch::autograd::THPVariable_getitem(wrapped_tensor, wrapped_key.ptr());
auto result = THPVariable_Unpack(wrapped_result);
Py_DECREF(wrapped_tensor);
Py_DECREF(wrapped_result);
return result;
}
std::vector<py::object> rest;
for (size_t i = 1; i < key.size(); i++) {
rest.push_back(key[i]);
}
if (is_nested_tensor_impl(tensor) && py::isinstance<py::none>(key[0])) {
first = get_item(tensor, py::cast<py::none>(key[0]));
}
if (is_nested_tensor_impl(tensor) && py::isinstance<py::int_>(key[0])) {
first = get_item(tensor, py::cast<int64_t>(key[0]));
}
if (is_nested_tensor_impl(tensor) && py::isinstance<py::slice>(key[0])) {
first = get_item(tensor, py::cast<py::slice>(key[0]));
}
TORCH_CHECK(
first,
"First entry of tuple doesn't have accepted type. ",
py::str(key[0]));
if (!is_nested_tensor_impl(*first)) {
return get_item(*first, rest);
}
std::vector<at::Tensor> result;
for (auto t : (*first).unbind()) {
result.push_back(get_item(t, rest));
}
int64_t nested_dim = get_nested_tensor_impl(*first)->nested_dim();
std::vector<TensorNode> result_nodes;
if (nested_dim == 1) {
for (auto t : result) {
result_nodes.push_back(TensorNode(std::move(t)));
}
} else {
for (auto t : result) {
result_nodes.push_back(get_nested_tensor_structure(t));
}
}
return wrap_tensor_node(TensorNode(std::move(result_nodes)));
}