in functorch/csrc/BatchRulesLoss.cpp [160:242]
std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
const Tensor & self,
const Tensor & target,
const c10::optional<Tensor> & weight,
int64_t reduction, int64_t ignore_index) {
// self can be [N, C, ...] or [C]
// target can be [N, ...] or []
int64_t channel_dim = 1;
if (self.dim() < 2) {
channel_dim = 0;
}
auto self_ = self;
Tensor weight_;
if (weight && weight->defined()) {
// Here is a specific case with reduction mean and non-batched tensors
// https://github.com/pytorch/pytorch/issues/61309
// In this case weight is cancelled: w * x[t] / w -> x[t]
if (!(reduction == Reduction::Mean && self_.dim() < 2)) {
// reshape weights to [1, C, 1, ..., 1]
auto shape = weight->sizes();
VmapDimVector new_shape(self_.dim(), 1);
new_shape[channel_dim] = shape[0];
weight_ = weight->reshape(new_shape);
self_ = self_ * weight_;
}
}
auto target_ = target.unsqueeze(channel_dim);
// target can be [N, 1, ...] or [1]
auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim);
auto total_weight = at::full(
{}, result.numel(), self_.scalar_type(),
self_.layout(), self_.device(), nullopt);
bool has_ignore_index = ignore_index >= 0;
Tensor ignore_index_mask;
if (has_ignore_index) {
ignore_index_mask = target != ignore_index;
result = result * ignore_index_mask;
total_weight = ignore_index_mask.sum().to(self_);
}
// Apply the reduction
if (result.dim() > 0) {
if (reduction == Reduction::Sum) {
result = result.sum();
} else if (reduction == Reduction::Mean) {
if (!weight || !weight->defined()) {
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
// total_weight is ignore_index_mask.sum()
result = result.sum() / total_weight;
} else {
result = result.mean();
}
} else {
TORCH_INTERNAL_ASSERT(weight_.defined());
weight_ = weight_.expand(self_.sizes());
auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
wsum = wsum * ignore_index_mask;
}
wsum = wsum.sum();
result = result.sum() / wsum;
total_weight = wsum;
}
}
} else if (reduction == Reduction::Mean && weight && weight->defined()) {
// here weight is [C] and target is [1]
auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
wsum = wsum * ignore_index_mask;
}
total_weight = wsum.sum();
}
return std::make_tuple(result, total_weight);
}