in src/model/operation/batchnorm.cc [181:256]
const std::vector<Tensor> CpuBatchNormBackwardx(
const BatchNormHandle& bnh, const Tensor& y, const Tensor& dy,
const Tensor& x, const Tensor& bnScale, const Tensor& bnBias,
const Tensor& mean, const Tensor& var) {
CHECK_EQ(x.device()->lang(), kCpp);
CHECK_EQ(y.device()->lang(), kCpp);
CHECK_EQ(dy.device()->lang(), kCpp);
CHECK_EQ(mean.device()->lang(), kCpp);
CHECK_EQ(var.device()->lang(), kCpp);
CHECK_EQ(bnScale.device()->lang(), kCpp);
CHECK_EQ(bnBias.device()->lang(), kCpp);
Tensor dx;
dx.ResetLike(dy);
// combine scale and bias to construct weight tensor in required format for
// backward
Tensor w = get_bn_weight_from(bnScale, bnBias);
// Tensor dw(Shape{bnScale.Size(), 2});
Tensor dw;
dw.ResetLike(w);
dx.device()->Exec(
[w, dw, dx, dy, x, y, mean, var, &bnh](Context* ctx) mutable {
auto eng = ctx->dnnl_engine;
using namespace dnnl;
auto x_mem = memory(bnh.x_md, eng, x.block()->mutable_data());
auto dx_mem = memory(bnh.x_md, eng, dx.block()->mutable_data());
auto y_mem = memory(bnh.x_md, eng, y.block()->mutable_data());
auto dy_mem = memory(bnh.x_md, eng, dy.block()->mutable_data());
auto m_mem = memory(bnh.bn_fwd_training_pd->mean_desc(), eng,
mean.block()->mutable_data());
auto v_mem = memory(bnh.bn_fwd_training_pd->variance_desc(), eng,
var.block()->mutable_data());
auto w_mem = memory(bnh.bn_fwd_training_pd->weights_desc(), eng,
w.block()->mutable_data());
auto bn_bwd_d = batch_normalization_backward::desc(
prop_kind::backward, bnh.x_md, bnh.x_md, bnh.epsilon,
normalization_flags::use_scale_shift);
auto bn_bwd_pd = batch_normalization_backward::primitive_desc(
bn_bwd_d, eng, *bnh.bn_fwd_training_pd);
auto dw_mem = memory(bn_bwd_pd.diff_weights_desc(), eng,
dw.block()->mutable_data());
batch_normalization_backward(bn_bwd_pd).execute(
ctx->dnnl_stream, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_DIFF_SRC, dx_mem},
{DNNL_ARG_DIFF_DST, dy_mem},
{DNNL_ARG_MEAN, m_mem},
{DNNL_ARG_VARIANCE, v_mem},
{DNNL_ARG_DIFF_SCALE_SHIFT, dw_mem},
{DNNL_ARG_SCALE_SHIFT, w_mem}});
ctx->dnnl_stream.wait();
},
{x.block(), dy.block(), mean.block(), var.block(), w.block(), y.block()},
{dx.block(), dw.block()}, "CpuBatchNormBackwardx");
singa::Tensor dbnScale(bnScale.shape());
CopyDataToFrom(&dbnScale, dw, bnScale.Size(), 0, 0);
singa::Tensor dbnBias(bnBias.shape());
CopyDataToFrom(&dbnBias, dw, bnBias.Size(), 0, bnScale.Size());
CHECK(dbnScale.nDim() == bnScale.nDim()) << "dbnScale ndim not match bnScale";
CHECK(dbnBias.nDim() == bnBias.nDim()) << "dbnScale ndim not match bnScale";
CHECK(dbnScale.shape()[0] == bnScale.shape()[0])
<< "dbnScale shape not match bnScale";
CHECK(dbnBias.shape()[0] == bnBias.shape()[0])
<< "dbnBias shape not match bnBias";
return {dx, dbnScale, dbnBias};
}