virtual void Forward()

in src/operator/mkl/mkl_batch_norm-inl.h [116:259]


  virtual void Forward(const OpContext &ctx,
                       const std::vector<TBlob> &in_data,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &out_data,
                       const std::vector<TBlob> &aux_states) {
    using namespace mshadow;
    using namespace mshadow::expr;
    CHECK_EQ(in_data.size(), 3);
    CHECK_EQ(aux_states.size(), 2);
    if (ctx.is_train) {
      CHECK_EQ(out_data.size(), 3);
      CHECK_EQ(req.size(), 3);
    } else {
      CHECK_GE(out_data.size(), 1);
      CHECK_GE(req.size(), 1);
      CHECK_EQ(req[batchnorm::kOut], kWriteTo);
    }

    Stream<xpu> *s = ctx.get_stream<xpu>();
    Tensor<xpu, 4, DType>  data;
    Tensor<xpu, 4, DType>  out;
    if (in_data[batchnorm::kData].ndim() == 2) {
      Shape<4> dshape = Shape4(in_data[batchnorm::kData].shape_[0],
                               in_data[batchnorm::kData].shape_[1], 1, 1);
      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
        in_data[batchnorm::kData], dshape, s);
      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
        out_data[batchnorm::kOut], dshape, s);
    } else {
      data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
      out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[batchnorm::kOut], s);
    }

    // const real_t scale = static_cast<real_t>(in_data[batchnorm::kData].shape_[1]) /
    //   static_cast<real_t>(in_data[batchnorm::kData].shape_.Size());

    Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
    Tensor<xpu, 1, DType> bias = in_data[batchnorm::kBeta].get<xpu, 1, DType>(s);
    Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
    Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);

    if (param_.fix_gamma)
      slope = 1.f;

    dnnError_t e;
    if (!init_mkldnn_) {
      LayerSetUp(data, out);
      init_mkldnn_ = true;
    }
    void* bottom_data = NULL;
#if MKL_EXPERIMENTAL == 1
    bottom_data =
          reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
#endif
    int bwd_flags = dnnUseScaleShift;
    if (param_.use_global_stats)
      bwd_flags = dnnUseScaleShift | dnnUseInputMeanVariance;
#if MKL_EXPERIMENTAL == 1
    if (NULL != bottom_data) {
      // Is it the first pass? Create a primitive.
      if (batchNormFwdInference == NULL) {
        std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[batchnorm::kData].Mkl_mem_;
        std::shared_ptr<PrvMemDescr> bottom_prv_desc = bottom_data_mem->get_prv_descriptor();
        CHECK(bottom_prv_desc->get_descr_type() == PrvMemDescr::PRV_DESCR_MKL2017);
        std::shared_ptr<MKLData<DType> > mem_descr
          = std::static_pointer_cast<MKLData<DType>>(bottom_prv_desc);
        CHECK(mem_descr != NULL);
        fwd_bottom_data = mem_descr;

        e = dnnBatchNormalizationCreateForward_v2<DType>(
             &batchNormFwdInference, NULL, mem_descr->layout_int, eps_,
             dnnUseInputMeanVariance | dnnUseScaleShift);
        CHECK_EQ(e, E_SUCCESS);

        e = dnnBatchNormalizationCreateForward_v2<DType>(
              &batchNormFwdTraining, NULL, mem_descr->layout_int, eps_,
              dnnUseScaleShift);
        CHECK_EQ(e, E_SUCCESS);

        fwd_top_data->create_internal_layout(batchNormFwdInference, dnnResourceDst);
        bwd_top_diff->create_internal_layout(batchNormFwdInference, dnnResourceDst);
        bwd_bottom_diff->create_internal_layout(batchNormFwdInference, dnnResourceSrc);

        e = dnnBatchNormalizationCreateBackward_v2<DType>(
                &batchNormBwdScaleShift, NULL, mem_descr->layout_int, eps_, bwd_flags);
        CHECK_EQ(e, E_SUCCESS);
      }
    }
#endif
    if (NULL == bottom_data) {
      if (batchNormFwdInference == NULL) {
        e = dnnBatchNormalizationCreateForward_v2<DType>(
          &batchNormFwdInference, NULL, layout_usr_, eps_,
          dnnUseInputMeanVariance | dnnUseScaleShift);
        CHECK_EQ(e, E_SUCCESS);

        e = dnnBatchNormalizationCreateForward_v2<DType>(
              &batchNormFwdTraining, NULL, layout_usr_, eps_, dnnUseScaleShift);
        CHECK_EQ(e, E_SUCCESS);

        e = dnnBatchNormalizationCreateBackward_v2<DType>(
              &batchNormBwdScaleShift, NULL, layout_usr_, eps_, bwd_flags);
        CHECK_EQ(e, E_SUCCESS);
      }
      bottom_data = reinterpret_cast<void *>(data.dptr_);
    }

    DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
     // use_weight_bias_
    for (int i = 0; i < channels_; i++) {
        scaleShift_buf[i] = (slope.dptr_)[i];
    }
    for (int i = 0; i < channels_; i++) {
      scaleShift_buf[channels_ + i] = (bias.dptr_)[i];
    }

    void* BatchNorm_res[dnnResourceNumber];
    BatchNorm_res[dnnResourceSrc] = bottom_data;
    BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;

    BatchNorm_res[dnnResourceDst] = fwd_top_data->get_output_ptr(out.dptr_,
      fwd_top_data, out_data[batchnorm::kOut]);
    if (ctx.is_train && !param_.use_global_stats) {
      Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
      Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
      CHECK(req[batchnorm::kMean] == kNullOp || req[batchnorm::kMean] == kWriteTo);
      CHECK(req[batchnorm::kVar] == kNullOp || req[batchnorm::kVar] == kWriteTo);
      BatchNorm_res[dnnResourceMean] = mean.dptr_;
      BatchNorm_res[dnnResourceVariance] = var.dptr_;
      e = dnnExecute<DType>(batchNormFwdTraining, BatchNorm_res);
      CHECK_EQ(e, E_SUCCESS);
    } else {
      BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
      BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
      e = dnnExecute<DType>(batchNormFwdInference, BatchNorm_res);
      CHECK_EQ(e, E_SUCCESS);
    }

#if MKL_EXPERIMENTAL == 0
    if (fwd_top_data->conversion_needed()) {
      fwd_top_data->convert_from_prv(out.dptr_);
    }
#endif
  }