in src/model/layer/batchnorm.cc [68:156]
const Tensor BatchNorm::Forward(int flag, const Tensor& input) {
Tensor x = input.Clone();
x.Reshape(Shape{input.shape(0), input.Size() / input.shape(0)});
Tensor output;
output.ResetLike(x);
// TODO(wangwei) input sample shape check
if ((flag & kTrain) == kTrain) { // forward for train
if (is_2d_) { // batchnorm_per_activation mode
auto mean = Average(x, 0);
runningMean_ *= 1.0f - factor_;
Axpy(factor_, mean, &runningMean_);
auto xnorm = x.Clone();
SubRow(mean, &xnorm);
xnorm = Square(xnorm);
auto var = Average(xnorm, 0);
runningVariance_ *= 1.0f - factor_;
Axpy(factor_, var, &runningVariance_);
Tensor tmp = var.Clone();
tmp = Sqrt(tmp);
tmp += 1e-6f;
xnorm = x.Clone();
SubRow(mean, &xnorm);
DivRow(tmp, &xnorm);
output = xnorm.Clone();
MultRow(bnScale_, &output);
AddRow(bnBias_, &output);
buf_.push(x);
buf_.push(mean);
buf_.push(var);
buf_.push(xnorm);
} else { // batchnorm_spatial mode
LOG(FATAL) << "Trainning SpatialBatchNormalization has not been "
"implemented yet...";
}
} else { // forward for test
if (is_2d_) { // batchnorm_per_activation mode
auto xnorm = x.Clone();
SubRow(runningMean_, &xnorm);
Tensor tmp = runningVariance_.Clone();
tmp = Sqrt(tmp);
tmp += 1e-6f;
DivRow(tmp, &xnorm);
output = xnorm.Clone();
MultRow(bnScale_, &output);
AddRow(bnBias_, &output);
} else { // batchnorm_spatial mode
runningMean_.Reshape(Shape{channels_, 1});
runningVariance_.Reshape(Shape{channels_, 1});
bnScale_.Reshape(Shape{channels_, 1});
bnBias_.Reshape(Shape{channels_, 1});
std::vector<Tensor> mean_stack, var_stack, scale_stack, bias_stack;
for (unsigned i = 0; i < height_ * width_; ++i) {
mean_stack.push_back(runningMean_);
var_stack.push_back(runningVariance_);
scale_stack.push_back(bnScale_);
bias_stack.push_back(bnBias_);
}
auto mean = ConcatenateColumns(mean_stack);
auto var = ConcatenateColumns(var_stack);
auto scale = ConcatenateColumns(scale_stack);
auto bias = ConcatenateColumns(bias_stack);
mean.Reshape(Shape{channels_ * height_ * width_});
var.Reshape(Shape{channels_ * height_ * width_});
scale.Reshape(Shape{channels_ * height_ * width_});
bias.Reshape(Shape{channels_ * height_ * width_});
auto xnorm = x.Clone();
SubRow(mean, &xnorm);
var = Sqrt(var);
var += 1e-6f;
DivRow(var, &xnorm);
output = xnorm.Clone();
MultRow(scale, &output);
AddRow(bias, &output);
runningMean_.Reshape(Shape{channels_});
runningVariance_.Reshape(Shape{channels_});
bnScale_.Reshape(Shape{channels_});
bnBias_.Reshape(Shape{channels_});
}
}
if (!is_2d_)
output.Reshape(Shape{output.shape(0), channels_, height_, width_});
return output;
}