in src/operator/mkl/mkl_batch_norm-inl.h [116:259]
virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 3);
CHECK_EQ(aux_states.size(), 2);
if (ctx.is_train) {
CHECK_EQ(out_data.size(), 3);
CHECK_EQ(req.size(), 3);
} else {
CHECK_GE(out_data.size(), 1);
CHECK_GE(req.size(), 1);
CHECK_EQ(req[batchnorm::kOut], kWriteTo);
}
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4, DType> data;
Tensor<xpu, 4, DType> out;
if (in_data[batchnorm::kData].ndim() == 2) {
Shape<4> dshape = Shape4(in_data[batchnorm::kData].shape_[0],
in_data[batchnorm::kData].shape_[1], 1, 1);
data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
in_data[batchnorm::kData], dshape, s);
out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
out_data[batchnorm::kOut], dshape, s);
} else {
data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[batchnorm::kOut], s);
}
// const real_t scale = static_cast<real_t>(in_data[batchnorm::kData].shape_[1]) /
// static_cast<real_t>(in_data[batchnorm::kData].shape_.Size());
Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> bias = in_data[batchnorm::kBeta].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);
if (param_.fix_gamma)
slope = 1.f;
dnnError_t e;
if (!init_mkldnn_) {
LayerSetUp(data, out);
init_mkldnn_ = true;
}
void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
bottom_data =
reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
#endif
int bwd_flags = dnnUseScaleShift;
if (param_.use_global_stats)
bwd_flags = dnnUseScaleShift | dnnUseInputMeanVariance;
#if MKL_EXPERIMENTAL == 1
if (NULL != bottom_data) {
// Is it the first pass? Create a primitive.
if (batchNormFwdInference == NULL) {
std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[batchnorm::kData].Mkl_mem_;
std::shared_ptr<PrvMemDescr> bottom_prv_desc = bottom_data_mem->get_prv_descriptor();
CHECK(bottom_prv_desc->get_descr_type() == PrvMemDescr::PRV_DESCR_MKL2017);
std::shared_ptr<MKLData<DType> > mem_descr
= std::static_pointer_cast<MKLData<DType>>(bottom_prv_desc);
CHECK(mem_descr != NULL);
fwd_bottom_data = mem_descr;
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdInference, NULL, mem_descr->layout_int, eps_,
dnnUseInputMeanVariance | dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdTraining, NULL, mem_descr->layout_int, eps_,
dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
fwd_top_data->create_internal_layout(batchNormFwdInference, dnnResourceDst);
bwd_top_diff->create_internal_layout(batchNormFwdInference, dnnResourceDst);
bwd_bottom_diff->create_internal_layout(batchNormFwdInference, dnnResourceSrc);
e = dnnBatchNormalizationCreateBackward_v2<DType>(
&batchNormBwdScaleShift, NULL, mem_descr->layout_int, eps_, bwd_flags);
CHECK_EQ(e, E_SUCCESS);
}
}
#endif
if (NULL == bottom_data) {
if (batchNormFwdInference == NULL) {
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdInference, NULL, layout_usr_, eps_,
dnnUseInputMeanVariance | dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateForward_v2<DType>(
&batchNormFwdTraining, NULL, layout_usr_, eps_, dnnUseScaleShift);
CHECK_EQ(e, E_SUCCESS);
e = dnnBatchNormalizationCreateBackward_v2<DType>(
&batchNormBwdScaleShift, NULL, layout_usr_, eps_, bwd_flags);
CHECK_EQ(e, E_SUCCESS);
}
bottom_data = reinterpret_cast<void *>(data.dptr_);
}
DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
// use_weight_bias_
for (int i = 0; i < channels_; i++) {
scaleShift_buf[i] = (slope.dptr_)[i];
}
for (int i = 0; i < channels_; i++) {
scaleShift_buf[channels_ + i] = (bias.dptr_)[i];
}
void* BatchNorm_res[dnnResourceNumber];
BatchNorm_res[dnnResourceSrc] = bottom_data;
BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;
BatchNorm_res[dnnResourceDst] = fwd_top_data->get_output_ptr(out.dptr_,
fwd_top_data, out_data[batchnorm::kOut]);
if (ctx.is_train && !param_.use_global_stats) {
Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
CHECK(req[batchnorm::kMean] == kNullOp || req[batchnorm::kMean] == kWriteTo);
CHECK(req[batchnorm::kVar] == kNullOp || req[batchnorm::kVar] == kWriteTo);
BatchNorm_res[dnnResourceMean] = mean.dptr_;
BatchNorm_res[dnnResourceVariance] = var.dptr_;
e = dnnExecute<DType>(batchNormFwdTraining, BatchNorm_res);
CHECK_EQ(e, E_SUCCESS);
} else {
BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
e = dnnExecute<DType>(batchNormFwdInference, BatchNorm_res);
CHECK_EQ(e, E_SUCCESS);
}
#if MKL_EXPERIMENTAL == 0
if (fwd_top_data->conversion_needed()) {
fwd_top_data->convert_from_prv(out.dptr_);
}
#endif
}