const std::vector CpuBatchNormBackwardx()

in src/model/operation/batchnorm.cc [181:256]


const std::vector<Tensor> CpuBatchNormBackwardx(
    const BatchNormHandle& bnh, const Tensor& y, const Tensor& dy,
    const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
    const Tensor& mean, const Tensor& var) {
  CHECK_EQ(x.device()->lang(), kCpp);
  CHECK_EQ(y.device()->lang(), kCpp);
  CHECK_EQ(dy.device()->lang(), kCpp);
  CHECK_EQ(mean.device()->lang(), kCpp);
  CHECK_EQ(var.device()->lang(), kCpp);
  CHECK_EQ(bnScale.device()->lang(), kCpp);
  CHECK_EQ(bnBias.device()->lang(), kCpp);

  Tensor dx;
  dx.ResetLike(dy);

  // combine scale and bias to construct weight tensor in required format for
  // backward
  Tensor w = get_bn_weight_from(bnScale, bnBias);

  // Tensor dw(Shape{bnScale.Size(), 2});
  Tensor dw;
  dw.ResetLike(w);

  dx.device()->Exec(
      [w, dw, dx, dy, x, y, mean, var, &bnh](Context* ctx) mutable {
        auto eng = ctx->dnnl_engine;
        using namespace dnnl;

        auto x_mem = memory(bnh.x_md, eng, x.block()->mutable_data());
        auto dx_mem = memory(bnh.x_md, eng, dx.block()->mutable_data());
        auto y_mem = memory(bnh.x_md, eng, y.block()->mutable_data());
        auto dy_mem = memory(bnh.x_md, eng, dy.block()->mutable_data());

        auto m_mem = memory(bnh.bn_fwd_training_pd->mean_desc(), eng,
                            mean.block()->mutable_data());
        auto v_mem = memory(bnh.bn_fwd_training_pd->variance_desc(), eng,
                            var.block()->mutable_data());
        auto w_mem = memory(bnh.bn_fwd_training_pd->weights_desc(), eng,
                            w.block()->mutable_data());

        auto bn_bwd_d = batch_normalization_backward::desc(
            prop_kind::backward, bnh.x_md, bnh.x_md, bnh.epsilon,
            normalization_flags::use_scale_shift);
        auto bn_bwd_pd = batch_normalization_backward::primitive_desc(
            bn_bwd_d, eng, *bnh.bn_fwd_training_pd);

        auto dw_mem = memory(bn_bwd_pd.diff_weights_desc(), eng,
                             dw.block()->mutable_data());

        batch_normalization_backward(bn_bwd_pd).execute(
            ctx->dnnl_stream, {{DNNL_ARG_SRC, x_mem},
                               {DNNL_ARG_DIFF_SRC, dx_mem},
                               {DNNL_ARG_DIFF_DST, dy_mem},
                               {DNNL_ARG_MEAN, m_mem},
                               {DNNL_ARG_VARIANCE, v_mem},
                               {DNNL_ARG_DIFF_SCALE_SHIFT, dw_mem},
                               {DNNL_ARG_SCALE_SHIFT, w_mem}});
        ctx->dnnl_stream.wait();
      },
      {x.block(), dy.block(), mean.block(), var.block(), w.block(), y.block()},
      {dx.block(), dw.block()}, "CpuBatchNormBackwardx");

  singa::Tensor dbnScale(bnScale.shape());
  CopyDataToFrom(&dbnScale, dw, bnScale.Size(), 0, 0);
  singa::Tensor dbnBias(bnBias.shape());
  CopyDataToFrom(&dbnBias, dw, bnBias.Size(), 0, bnScale.Size());

  CHECK(dbnScale.nDim() == bnScale.nDim()) << "dbnScale ndim not match bnScale";
  CHECK(dbnBias.nDim() == bnBias.nDim()) << "dbnScale ndim not match bnScale";
  CHECK(dbnScale.shape()[0] == bnScale.shape()[0])
      << "dbnScale shape not match bnScale";
  CHECK(dbnBias.shape()[0] == bnBias.shape()[0])
      << "dbnBias shape not match bnBias";

  return {dx, dbnScale, dbnBias};
}