in backends/gpu/lib/ops/tf/dnn_ops.cu.cc [502:596]
llvm::Error operator()(wrapper::CurrentContext current,
const wrapper::Stream& stream,
ChannelOrder channel_order,
const DenseGpuTensor& input,
const DenseGpuTensor& scale,
const DenseGpuTensor& bias, const DenseGpuTensor& mean,
const DenseGpuTensor& variance,
const DenseGpuTensor* side_input, float epsilon,
FusedBatchNormActivationMode activation_mode,
const GpuBuffer& output_buffer) {
int32_t count = input.NumElements();
if (count == 0) return llvm::Error::success();
TFRT_ASSIGN_OR_RETURN(GpuLaunchConfig config,
GetGpuLaunchConfig(current, count));
const bool no_side_input = side_input == nullptr;
const bool add_side_input = !no_side_input;
const T* side_input_ptr =
no_side_input ? nullptr : GetRawPointer<T>(*side_input);
const bool no_activation =
activation_mode == FusedBatchNormActivationMode::kIdentity;
const bool relu_activation =
activation_mode == FusedBatchNormActivationMode::kRelu;
auto launch = [&](auto* kernel, int channel_size, int inner_dim_size) {
return wrapper::CudaLaunchKernel(
current, kernel, config.block_count, config.thread_per_block, 0,
stream, count, channel_size, inner_dim_size, GetRawPointer<T>(input),
GetRawPointer<U>(scale), GetRawPointer<U>(bias),
GetRawPointer<U>(mean), GetRawPointer<U>(variance), side_input_ptr,
epsilon, GetRawPointer<T>(output_buffer));
};
auto input_shape = GetDimensions(input.shape());
if (channel_order == ChannelOrder::ChannelFirst) {
const int channel_size = input_shape[1];
const int inner_dim_size = input_shape[2] * input_shape[3];
if (no_activation && no_side_input) {
return launch(&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelFirst,
/*add_side_input=*/false,
FusedBatchNormActivationMode::kIdentity>,
channel_size, inner_dim_size);
} else if (relu_activation && no_side_input) {
return launch(
&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelFirst,
/*add_side_input=*/false, FusedBatchNormActivationMode::kRelu>,
channel_size, inner_dim_size);
} else if (no_activation && add_side_input) {
return launch(&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelFirst,
/*add_side_input=*/true,
FusedBatchNormActivationMode::kIdentity>,
channel_size, inner_dim_size);
} else if (relu_activation && add_side_input) {
return launch(
&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelFirst,
/*add_side_input=*/true, FusedBatchNormActivationMode::kRelu>,
channel_size, inner_dim_size);
}
} else if (channel_order == ChannelOrder::ChannelLast) {
const int channel_size = input_shape[3];
const int inner_dim_size = 1;
if (no_activation && no_side_input) {
return launch(&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelLast,
/*add_side_input=*/false,
FusedBatchNormActivationMode::kIdentity>,
channel_size, inner_dim_size);
} else if (relu_activation && no_side_input) {
return launch(
&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelLast,
/*add_side_input=*/false, FusedBatchNormActivationMode::kRelu>,
channel_size, inner_dim_size);
} else if (no_activation && add_side_input) {
return launch(&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelLast,
/*add_side_input=*/true,
FusedBatchNormActivationMode::kIdentity>,
channel_size, inner_dim_size);
} else if (relu_activation && add_side_input) {
return launch(
&FusedBatchNormInferenceMetaKernel<
T, U, ChannelOrder::ChannelLast,
/*add_side_input=*/true, FusedBatchNormActivationMode::kRelu>,
channel_size, inner_dim_size);
}
}
return MakeStringError("no fused batch norm kernel was launched");
}