in functorch/csrc/BatchRulesLoss.cpp [112:158]
Tensor binary_cross_entropy_backward_plumbing(
const Tensor& grad, const Tensor& input, const Tensor& target,
const c10::optional<Tensor>& weight_opt, int64_t reduction) {
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
Tensor grad_value;
optional<int64_t> grad_bdim;
std::tie(grad_value, grad_bdim) = unwrapTensorAtLevel(
reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
Tensor input_value;
optional<int64_t> input_bdim;
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
Tensor target_value;
optional<int64_t> target_bdim;
std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);
Tensor grad_input;
if (grad_bdim || input_bdim || target_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto bdim_size = get_bdim_size3(
grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);
auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
auto input_ = moveBatchDimToFront(input_value, input_bdim);
auto target_ = moveBatchDimToFront(target_value, target_bdim);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
grad_input = at::binary_cross_entropy_backward(
grad_, input_, target_, nullopt, Reduction::None);
grad_input = makeBatched(grad_input, 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
grad_input = at::binary_cross_entropy_backward(
grad_value, input_value, target_value, nullopt, Reduction::None);
}
if (weight_opt.has_value() && weight_opt->defined()) {
grad_input = grad_input * weight_opt.value();
}
if (reduction == Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}