const Tensor CudnnBatchNorm::Forward()

in src/model/layer/cudnn_batchnorm.cc [64:141]


const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
  auto shape = input.shape();
  auto dtype = input.data_type();
  Tensor output;
  Tensor x;
  if (is_2d_)
    x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1});
  else
    x = input;
  shape = x.shape();
  if (!has_init_cudnn_) {
    InitCudnn(shape, dtype);
  } else {
    int n, c, h, w, s;
    cudnnDataType_t type;
    CUDNN_CHECK(cudnnGetTensor4dDescriptor(shape_desc_, &type,
          &n, &c, &h, &w, &s, &s, &s, &s));
    if (shape[0] != static_cast<size_t>(n))
      InitCudnn(shape, dtype);
    CHECK(shape[1] == static_cast<size_t>(c)
        && shape[2] == static_cast<size_t>(h)
        && shape[3] == static_cast<size_t>(w))
      << "input sample shape should not change"
      << "previous shape " << c << ", " << h << ", " << w
      << "current shape " << shape[1] << ", " << shape[2] << ", "
      << shape[3];
  }


  // TODO(wangji): check device id of input and params
  output.ResetLike(x);
  if ((flag & kTrain) == kTrain) {
    output.device()->Exec(
        [=](Context* ctx) {
          Block* inBlock = x.block(), * outBlock = output.block(),
                 * saveMeanBlock = resultSaveMean_.block(),
                 * saveVarBlock = resultSaveVariance_.block(),
                 * runningMeanBlock = runningMean_.block(),
                 * runningVarBlock = runningVariance_.block(),
                 * bnScaleBlock = bnScale_.block(),
                 * bnBiasBlock = bnBias_.block();
          const float alpha = 1.0f, beta = 0.0f;
          double epsilon = CUDNN_BN_MIN_EPSILON;
          CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
              ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
              inBlock->data(), shape_desc_, outBlock->mutable_data(),
              param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), factor_,
              runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
              epsilon, saveMeanBlock->mutable_data(),
              saveVarBlock->mutable_data()));
        },
        {x.block(), bnScale_.block(), bnBias_.block()},
        {output.block(), runningMean_.block(), runningVariance_.block(),
         resultSaveMean_.block(), resultSaveVariance_.block()});
    buf_.push(x);
  } else {
    output.device()->Exec(
        [=](Context* ctx) {
          Block* inBlock = x.block(), * outBlock = output.block(),
                 * runningMeanBlock = runningMean_.block(),
                 * runningVarBlock = runningVariance_.block(),
                 * bnScaleBlock = bnScale_.block(),
                 * bnBiasBlock = bnBias_.block();
          const float alpha = 1.0f, beta = 0.0f;
          double epsilon = CUDNN_BN_MIN_EPSILON;
          CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
              ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
              inBlock->data(), shape_desc_, outBlock->mutable_data(),
              param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
              runningMeanBlock->data(), runningVarBlock->data(), epsilon));
        },
        {x.block(), bnScale_.block(), bnBias_.block(), runningMean_.block(),
         runningVariance_.block()},
        {output.block()});
  }
  if (is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
  return output;
}