virtual void Forward()

in src/operator/cudnn_batch_norm-inl.h [48:154]


  virtual void Forward(const OpContext &ctx,
                       const std::vector<TBlob> &in_data,
                       const std::vector<OpReqType> &req,
                       const std::vector<TBlob> &out_data,
                       const std::vector<TBlob> &aux_states) {
    using namespace mshadow;
    using namespace mshadow::expr;
    CHECK_EQ(in_data.size(), 3U);
    CHECK_EQ(aux_states.size(), 2U);
    if (ctx.is_train) {
      CHECK_EQ(out_data.size(), 3U);
      CHECK_EQ(req.size(), 3U);
    } else {
      CHECK_GE(out_data.size(), 1U);
      CHECK_GE(req.size(), 1U);
    }
    CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
    CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
    CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4);

    if (!init_cudnn_) {
      for (int i = 0; i < 4; ++i) {
        if (i < in_data[cudnnbatchnorm::kData].ndim()) {
          shape_[i] = in_data[cudnnbatchnorm::kData].shape_[i];
        } else {
          shape_[i] = 1;
        }
      }
      CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_));
      CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_));
      CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
                                            CUDNN_TENSOR_NCHW,
                                            dtype_,
                                            shape_[0],
                                            shape_[1],
                                            shape_[2],
                                            shape_[3]));
      CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_,
                                               io_desc_,
                                               CUDNN_BATCHNORM_SPATIAL));
      init_cudnn_  = true;
    }

    Stream<gpu> *s = ctx.get_stream<gpu>();
    Tensor<gpu, 4, DType> x =
      in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);

    Tensor<gpu, 4, DType> y =
      out_data[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, DType>(shape_, s);

    MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
      Tensor<gpu, 1, DTypeParam> gamma =
        in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
      Tensor<gpu, 1, DTypeParam> beta =
        in_data[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
      Tensor<gpu, 1, DTypeParam> moving_mean =
        aux_states[cudnnbatchnorm::kMovingMean]
        .get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
      Tensor<gpu, 1, DTypeParam> moving_inv_var =
        aux_states[cudnnbatchnorm::kMovingInvVar]
        .get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
      typename DataType<DType>::ScaleType a = 1.0f;
      typename DataType<DType>::ScaleType b = 0.0f;

      if (param_.fix_gamma) gamma = 1.f;

      if (ctx.is_train) {
        Tensor<gpu, 1, DTypeParam> save_mean =
          out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
        Tensor<gpu, 1, DTypeParam> save_inv_var =
          out_data[cudnnbatchnorm::kInvVar]
          .get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
        CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_,
                                                          CUDNN_BATCHNORM_SPATIAL,
                                                          &a,
                                                          &b,
                                                          io_desc_,
                                                          x.dptr_,
                                                          io_desc_,
                                                          y.dptr_,
                                                          mean_desc_,
                                                          gamma.dptr_,
                                                          beta.dptr_,
                                                          1 - param_.momentum,
                                                          moving_mean.dptr_,
                                                          moving_inv_var.dptr_,
                                                          param_.eps,
                                                          save_mean.dptr_,
                                                          save_inv_var.dptr_));
      } else {
        CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_,
                                                           CUDNN_BATCHNORM_SPATIAL,
                                                           &a,
                                                           &b,
                                                           io_desc_,
                                                           x.dptr_,
                                                           io_desc_,
                                                           y.dptr_,
                                                           mean_desc_,
                                                           gamma.dptr_,
                                                           beta.dptr_,
                                                           moving_mean.dptr_,
                                                           moving_inv_var.dptr_,
                                                           param_.eps));
      }
    })
  }