c10::optional nt_from_tensor_mask()

in nestedtensor/csrc/masking.cpp [180:248]


c10::optional<Tensor> nt_from_tensor_mask(
    Tensor tensor,
    Tensor mask,
    int64_t nested_dim) {
  if (nested_dim == 0) {
    if ((get_numel(mask) == 0) || (get_numel(mask) == 1 && mask.item<bool>())) {
      return tensor;
    }

    if (get_dim(mask) == 1) {
      std::vector<Tensor> tensors;
      for (int64_t i = 0; i < mask.size(0); i++) {
        if (mask[i].item<bool>()) {
          tensors.push_back(tensor[i]);
        }
      }
      if (tensors.size() == 0) {
        return torch::tensor({}).to(tensor);
      }
      return at::stack(tensors);
    }

    if (get_dim(mask) > 1) {
      std::vector<Tensor> tensors;
      bool all_zero = true;
      for (int64_t i = 0; i < mask.size(0); i++) {
        Tensor tmp = *nt_from_tensor_mask(tensor[i], mask[i], nested_dim);
        if (get_numel(tmp) > 0) {
          all_zero = false;
          tensors.push_back(tmp);
        }
      }
      if (all_zero) {
        for (int64_t i = 0; i < mask.size(0); i++) {
          Tensor tmp = *nt_from_tensor_mask(tensor[i], mask[i], nested_dim);
          tensors.push_back(tmp);
        }
      }
      if (tensors.size() == 0) {
        return torch::tensor({}).to(tensor);
      }
      return at::stack(tensors);
    }
    return c10::nullopt;
  }
  TORCH_CHECK(nested_dim == 1, "Only nested_dim of 1 is currently supported.");
  std::vector<c10::optional<Tensor>> inner_tensors;
  if ((get_numel(mask) == 0) || (get_numel(mask) == 1 && mask.item<bool>())) {
    for (int64_t i = 0; i < tensor.size(0); i++) {
      inner_tensors.push_back(
          nt_from_tensor_mask(tensor[i], mask, nested_dim - 1));
    }
  } else if (get_numel(mask) == 1 && !mask.item<bool>()) {
    inner_tensors.push_back(c10::nullopt);
  } else {
    for (int64_t i = 0; i < tensor.size(0); i++) {
      inner_tensors.push_back(
          nt_from_tensor_mask(tensor[i], mask[i], nested_dim - 1));
    }
  }
  std::vector<TensorNode> inner_tensor_nodes;
  for (size_t i = 0; i < inner_tensors.size(); i++) {
    if (inner_tensors[i]) {
      TensorNode node = get_nested_tensor_structure(*inner_tensors[i]);
      inner_tensor_nodes.push_back(node);
    }
  }
  return wrap_tensor_node(TensorNode(std::move(inner_tensor_nodes)));
}