in nestedtensor/csrc/masking.cpp [180:248]
c10::optional<Tensor> nt_from_tensor_mask(
Tensor tensor,
Tensor mask,
int64_t nested_dim) {
if (nested_dim == 0) {
if ((get_numel(mask) == 0) || (get_numel(mask) == 1 && mask.item<bool>())) {
return tensor;
}
if (get_dim(mask) == 1) {
std::vector<Tensor> tensors;
for (int64_t i = 0; i < mask.size(0); i++) {
if (mask[i].item<bool>()) {
tensors.push_back(tensor[i]);
}
}
if (tensors.size() == 0) {
return torch::tensor({}).to(tensor);
}
return at::stack(tensors);
}
if (get_dim(mask) > 1) {
std::vector<Tensor> tensors;
bool all_zero = true;
for (int64_t i = 0; i < mask.size(0); i++) {
Tensor tmp = *nt_from_tensor_mask(tensor[i], mask[i], nested_dim);
if (get_numel(tmp) > 0) {
all_zero = false;
tensors.push_back(tmp);
}
}
if (all_zero) {
for (int64_t i = 0; i < mask.size(0); i++) {
Tensor tmp = *nt_from_tensor_mask(tensor[i], mask[i], nested_dim);
tensors.push_back(tmp);
}
}
if (tensors.size() == 0) {
return torch::tensor({}).to(tensor);
}
return at::stack(tensors);
}
return c10::nullopt;
}
TORCH_CHECK(nested_dim == 1, "Only nested_dim of 1 is currently supported.");
std::vector<c10::optional<Tensor>> inner_tensors;
if ((get_numel(mask) == 0) || (get_numel(mask) == 1 && mask.item<bool>())) {
for (int64_t i = 0; i < tensor.size(0); i++) {
inner_tensors.push_back(
nt_from_tensor_mask(tensor[i], mask, nested_dim - 1));
}
} else if (get_numel(mask) == 1 && !mask.item<bool>()) {
inner_tensors.push_back(c10::nullopt);
} else {
for (int64_t i = 0; i < tensor.size(0); i++) {
inner_tensors.push_back(
nt_from_tensor_mask(tensor[i], mask[i], nested_dim - 1));
}
}
std::vector<TensorNode> inner_tensor_nodes;
for (size_t i = 0; i < inner_tensors.size(); i++) {
if (inner_tensors[i]) {
TensorNode node = get_nested_tensor_structure(*inner_tensors[i]);
inner_tensor_nodes.push_back(node);
}
}
return wrap_tensor_node(TensorNode(std::move(inner_tensor_nodes)));
}