const Tensor BatchNorm::Forward()

in src/model/layer/batchnorm.cc [68:156]


const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
  Tensor x = input.Clone();
  x.Reshape(Shape{input.shape(0), input.Size() / input.shape(0)});
  Tensor output;
  output.ResetLike(x);
  // TODO(wangwei) input sample shape check
  if ((flag & kTrain) == kTrain) {  // forward for train
    if (is_2d_) {                   // batchnorm_per_activation mode
      auto mean = Average(x, 0);
      runningMean_ *= 1.0f - factor_;
      Axpy(factor_, mean, &runningMean_);
      auto xnorm = x.Clone();
      SubRow(mean, &xnorm);
      xnorm = Square(xnorm);
      auto var = Average(xnorm, 0);
      runningVariance_ *= 1.0f - factor_;
      Axpy(factor_, var, &runningVariance_);
      Tensor tmp = var.Clone();
      tmp = Sqrt(tmp);
      tmp += 1e-6f;
      xnorm = x.Clone();
      SubRow(mean, &xnorm);
      DivRow(tmp, &xnorm);
      output = xnorm.Clone();
      MultRow(bnScale_, &output);
      AddRow(bnBias_, &output);
      buf_.push(x);
      buf_.push(mean);
      buf_.push(var);
      buf_.push(xnorm);
    } else {  // batchnorm_spatial mode
      LOG(FATAL) << "Trainning SpatialBatchNormalization has not been "
                    "implemented yet...";
    }
  } else {         // forward for test
    if (is_2d_) {  // batchnorm_per_activation mode
      auto xnorm = x.Clone();
      SubRow(runningMean_, &xnorm);
      Tensor tmp = runningVariance_.Clone();
      tmp = Sqrt(tmp);
      tmp += 1e-6f;
      DivRow(tmp, &xnorm);
      output = xnorm.Clone();
      MultRow(bnScale_, &output);
      AddRow(bnBias_, &output);
    } else {  // batchnorm_spatial mode
      runningMean_.Reshape(Shape{channels_, 1});
      runningVariance_.Reshape(Shape{channels_, 1});
      bnScale_.Reshape(Shape{channels_, 1});
      bnBias_.Reshape(Shape{channels_, 1});

      std::vector<Tensor> mean_stack, var_stack, scale_stack, bias_stack;
      for (unsigned i = 0; i < height_ * width_; ++i) {
        mean_stack.push_back(runningMean_);
        var_stack.push_back(runningVariance_);
        scale_stack.push_back(bnScale_);
        bias_stack.push_back(bnBias_);
      }
      auto mean = ConcatenateColumns(mean_stack);
      auto var = ConcatenateColumns(var_stack);
      auto scale = ConcatenateColumns(scale_stack);
      auto bias = ConcatenateColumns(bias_stack);

      mean.Reshape(Shape{channels_ * height_ * width_});
      var.Reshape(Shape{channels_ * height_ * width_});
      scale.Reshape(Shape{channels_ * height_ * width_});
      bias.Reshape(Shape{channels_ * height_ * width_});

      auto xnorm = x.Clone();
      SubRow(mean, &xnorm);
      var = Sqrt(var);
      var += 1e-6f;
      DivRow(var, &xnorm);
      output = xnorm.Clone();

      MultRow(scale, &output);
      AddRow(bias, &output);

      runningMean_.Reshape(Shape{channels_});
      runningVariance_.Reshape(Shape{channels_});
      bnScale_.Reshape(Shape{channels_});
      bnBias_.Reshape(Shape{channels_});
    }
  }

  if (!is_2d_)
    output.Reshape(Shape{output.shape(0), channels_, height_, width_});
  return output;
}