Tensor NestedTensor_var_dim()

in nestedtensor/csrc/ReduceOps.cpp [253:317]


Tensor NestedTensor_var_dim(
    const Tensor& self,
    IntArrayRef dims,
    bool unbiased,
    bool keepdims) {
  std::vector<int64_t> tensordims;
  std::vector<int64_t> nesteddims;
  std::tie(tensordims, nesteddims) = make_split_dims(self, dims);

  auto nested_size = get_nested_size(self);
  int64_t nested_dim = get_nested_tensor_impl(self)->nested_dim();
  auto new_nested_size = map(
      [&tensordims](std::vector<int64_t> sizes) {
        std::vector<int64_t> new_sizes;
        for (size_t i = 0; i < sizes.size(); i++) {
          if (std::find(tensordims.begin(), tensordims.end(), i) ==
              tensordims.end()) {
            new_sizes.push_back(sizes[i]);
          }
        }
        return new_sizes;
      },
      nested_size);
  if (nesteddims.size() > 0) {
    TORCH_CHECK(
        nesteddims.size() == 1 && nesteddims[0] == 0,
        "Can only reduce across nested dimension 0.");
    TORCH_CHECK(
        nested_dim == 1,
        "Can only reduce across nested dimensions if given nested tensor is of nested dimension 1.");
    auto opt_sizes = construct_size(new_nested_size);
    for (size_t i = 1; i < opt_sizes.size(); i++) {
      TORCH_CHECK(
          opt_sizes[i],
          "Can only reduce across nested dimensions of Tensor compliant shapes.")
    }
    new_nested_size = squeeze(new_nested_size, 0, keepdims);
  }
  if (tensordims.size() == 0) {
    return wrap_buffer(
        at::var(
            NestedTensor_to_tensor(self, c10::nullopt), 0, unbiased, keepdims)
            .reshape({-1}),
        new_nested_size);
  }
  if (nesteddims.size() == 0) {
    return map_nested_tensor(
        [tensordims, unbiased, keepdims](at::Tensor t) {
          return at::var(t, tensordims, unbiased, keepdims);
        },
        self);
  }

  at::Tensor m2_tensor, mean_tensor, numel;
  std::vector<at::Tensor> tensors = flatten(get_nested_tensor_structure(self));
  std::tie(m2_tensor, mean_tensor, numel) =
      _make_m2(tensors, IntArrayRef(tensordims));
  std::tie(m2_tensor, mean_tensor, numel) =
      _merge_m2(m2_tensor, mean_tensor, numel);
  if (unbiased) {
    return wrap_buffer(
        (m2_tensor / (numel - 1)).reshape({-1}), new_nested_size);
  }
  return wrap_buffer((m2_tensor / numel).reshape({-1}), new_nested_size);
}