std::tuple nll_loss_forward_decomposition()

in functorch/csrc/BatchRulesLoss.cpp [160:242]


std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
    const Tensor & self,
    const Tensor & target,
    const c10::optional<Tensor> & weight,
    int64_t reduction, int64_t ignore_index) {

  // self can be [N, C, ...] or [C]
  // target can be [N, ...] or []

  int64_t channel_dim = 1;
  if (self.dim() < 2) {
    channel_dim = 0;
  }
  auto self_ = self;
  Tensor weight_;

  if (weight && weight->defined()) {
    // Here is a specific case with reduction mean and non-batched tensors
    // https://github.com/pytorch/pytorch/issues/61309
    // In this case weight is cancelled: w * x[t] / w -> x[t]
    if (!(reduction == Reduction::Mean && self_.dim() < 2)) {
      // reshape weights to [1, C, 1, ..., 1]
      auto shape = weight->sizes();
      VmapDimVector new_shape(self_.dim(), 1);
      new_shape[channel_dim] = shape[0];
      weight_ = weight->reshape(new_shape);
      self_ = self_ * weight_;
    }
  }
  auto target_ = target.unsqueeze(channel_dim);
  // target can be [N, 1, ...] or [1]

  auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim);
  auto total_weight = at::full(
      {}, result.numel(), self_.scalar_type(),
      self_.layout(), self_.device(), nullopt);

  bool has_ignore_index = ignore_index >= 0;
  Tensor ignore_index_mask;
  if (has_ignore_index) {
    ignore_index_mask = target != ignore_index;
    result = result * ignore_index_mask;
    total_weight = ignore_index_mask.sum().to(self_);
  }

  // Apply the reduction
  if (result.dim() > 0) {
    if (reduction == Reduction::Sum) {
      result = result.sum();
    } else if (reduction == Reduction::Mean) {
      if (!weight || !weight->defined()) {
        if (has_ignore_index) {
          TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
          // total_weight is ignore_index_mask.sum()
          result = result.sum() / total_weight;
        } else {
          result = result.mean();
        }
      } else {
        TORCH_INTERNAL_ASSERT(weight_.defined());
        weight_ = weight_.expand(self_.sizes());
        auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
        if (has_ignore_index) {
          TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
          wsum = wsum * ignore_index_mask;
        }
        wsum = wsum.sum();
        result = result.sum() / wsum;
        total_weight = wsum;
      }
    }
  } else if (reduction == Reduction::Mean && weight && weight->defined()) {
    // here weight is [C] and target is [1]
    auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
    if (has_ignore_index) {
      TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
      wsum = wsum * ignore_index_mask;
    }
    total_weight = wsum.sum();
  }

  return std::make_tuple(result, total_weight);
}