in src/model/layer/cudnn_batchnorm.cc [64:141]
const Tensor CudnnBatchNorm::Forward(int flag, const Tensor& input) {
auto shape = input.shape();
auto dtype = input.data_type();
Tensor output;
Tensor x;
if (is_2d_)
x = Reshape(input, Shape{shape.at(0), shape.at(1), 1, 1});
else
x = input;
shape = x.shape();
if (!has_init_cudnn_) {
InitCudnn(shape, dtype);
} else {
int n, c, h, w, s;
cudnnDataType_t type;
CUDNN_CHECK(cudnnGetTensor4dDescriptor(shape_desc_, &type,
&n, &c, &h, &w, &s, &s, &s, &s));
if (shape[0] != static_cast<size_t>(n))
InitCudnn(shape, dtype);
CHECK(shape[1] == static_cast<size_t>(c)
&& shape[2] == static_cast<size_t>(h)
&& shape[3] == static_cast<size_t>(w))
<< "input sample shape should not change"
<< "previous shape " << c << ", " << h << ", " << w
<< "current shape " << shape[1] << ", " << shape[2] << ", "
<< shape[3];
}
// TODO(wangji): check device id of input and params
output.ResetLike(x);
if ((flag & kTrain) == kTrain) {
output.device()->Exec(
[=](Context* ctx) {
Block* inBlock = x.block(), * outBlock = output.block(),
* saveMeanBlock = resultSaveMean_.block(),
* saveVarBlock = resultSaveVariance_.block(),
* runningMeanBlock = runningMean_.block(),
* runningVarBlock = runningVariance_.block(),
* bnScaleBlock = bnScale_.block(),
* bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
inBlock->data(), shape_desc_, outBlock->mutable_data(),
param_desc_, bnScaleBlock->data(), bnBiasBlock->data(), factor_,
runningMeanBlock->mutable_data(), runningVarBlock->mutable_data(),
epsilon, saveMeanBlock->mutable_data(),
saveVarBlock->mutable_data()));
},
{x.block(), bnScale_.block(), bnBias_.block()},
{output.block(), runningMean_.block(), runningVariance_.block(),
resultSaveMean_.block(), resultSaveVariance_.block()});
buf_.push(x);
} else {
output.device()->Exec(
[=](Context* ctx) {
Block* inBlock = x.block(), * outBlock = output.block(),
* runningMeanBlock = runningMean_.block(),
* runningVarBlock = runningVariance_.block(),
* bnScaleBlock = bnScale_.block(),
* bnBiasBlock = bnBias_.block();
const float alpha = 1.0f, beta = 0.0f;
double epsilon = CUDNN_BN_MIN_EPSILON;
CUDNN_CHECK(cudnnBatchNormalizationForwardInference(
ctx->cudnn_handle, this->mode_, &alpha, &beta, shape_desc_,
inBlock->data(), shape_desc_, outBlock->mutable_data(),
param_desc_, bnScaleBlock->data(), bnBiasBlock->data(),
runningMeanBlock->data(), runningVarBlock->data(), epsilon));
},
{x.block(), bnScale_.block(), bnBias_.block(), runningMean_.block(),
runningVariance_.block()},
{output.block()});
}
if (is_2d_) output.Reshape(Shape{shape.at(0), shape.at(1)});
return output;
}