Tensor binary_cross_entropy_backward_plumbing()

in functorch/csrc/BatchRulesLoss.cpp [112:158]


Tensor binary_cross_entropy_backward_plumbing(
    const Tensor& grad, const Tensor& input, const Tensor& target,
    const c10::optional<Tensor>& weight_opt, int64_t reduction) {
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();
  Tensor grad_value;
  optional<int64_t> grad_bdim;
  std::tie(grad_value, grad_bdim) = unwrapTensorAtLevel(
      reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
  Tensor input_value;
  optional<int64_t> input_bdim;
  std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
  Tensor target_value;
  optional<int64_t> target_bdim;
  std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);

  Tensor grad_input;
  if (grad_bdim || input_bdim || target_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    const auto bdim_size = get_bdim_size3(
        grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);

    auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
    auto input_ = moveBatchDimToFront(input_value, input_bdim);
    auto target_ = moveBatchDimToFront(target_value, target_bdim);

    grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
    input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
    target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);

    grad_input = at::binary_cross_entropy_backward(
        grad_, input_, target_, nullopt, Reduction::None);
    grad_input = makeBatched(grad_input, 0, cur_level);
  } else {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    grad_input = at::binary_cross_entropy_backward(
        grad_value, input_value, target_value, nullopt, Reduction::None);
  }
  if (weight_opt.has_value() && weight_opt->defined()) {
    grad_input = grad_input * weight_opt.value();
  }
  if (reduction == Reduction::Mean) {
    grad_input.div_(input.numel());
  }
  return grad_input;
}