llvm::Error operator()

in backends/gpu/lib/ops/tf/dnn_ops.cu.cc [502:596]


  llvm::Error operator()(wrapper::CurrentContext current,
                         const wrapper::Stream& stream,
                         ChannelOrder channel_order,
                         const DenseGpuTensor& input,
                         const DenseGpuTensor& scale,
                         const DenseGpuTensor& bias, const DenseGpuTensor& mean,
                         const DenseGpuTensor& variance,
                         const DenseGpuTensor* side_input, float epsilon,
                         FusedBatchNormActivationMode activation_mode,
                         const GpuBuffer& output_buffer) {
    int32_t count = input.NumElements();
    if (count == 0) return llvm::Error::success();

    TFRT_ASSIGN_OR_RETURN(GpuLaunchConfig config,
                          GetGpuLaunchConfig(current, count));

    const bool no_side_input = side_input == nullptr;
    const bool add_side_input = !no_side_input;
    const T* side_input_ptr =
        no_side_input ? nullptr : GetRawPointer<T>(*side_input);

    const bool no_activation =
        activation_mode == FusedBatchNormActivationMode::kIdentity;
    const bool relu_activation =
        activation_mode == FusedBatchNormActivationMode::kRelu;

    auto launch = [&](auto* kernel, int channel_size, int inner_dim_size) {
      return wrapper::CudaLaunchKernel(
          current, kernel, config.block_count, config.thread_per_block, 0,
          stream, count, channel_size, inner_dim_size, GetRawPointer<T>(input),
          GetRawPointer<U>(scale), GetRawPointer<U>(bias),
          GetRawPointer<U>(mean), GetRawPointer<U>(variance), side_input_ptr,
          epsilon, GetRawPointer<T>(output_buffer));
    };

    auto input_shape = GetDimensions(input.shape());
    if (channel_order == ChannelOrder::ChannelFirst) {
      const int channel_size = input_shape[1];
      const int inner_dim_size = input_shape[2] * input_shape[3];
      if (no_activation && no_side_input) {
        return launch(&FusedBatchNormInferenceMetaKernel<
                          T, U, ChannelOrder::ChannelFirst,
                          /*add_side_input=*/false,
                          FusedBatchNormActivationMode::kIdentity>,
                      channel_size, inner_dim_size);
      } else if (relu_activation && no_side_input) {
        return launch(
            &FusedBatchNormInferenceMetaKernel<
                T, U, ChannelOrder::ChannelFirst,
                /*add_side_input=*/false, FusedBatchNormActivationMode::kRelu>,
            channel_size, inner_dim_size);
      } else if (no_activation && add_side_input) {
        return launch(&FusedBatchNormInferenceMetaKernel<
                          T, U, ChannelOrder::ChannelFirst,
                          /*add_side_input=*/true,
                          FusedBatchNormActivationMode::kIdentity>,
                      channel_size, inner_dim_size);
      } else if (relu_activation && add_side_input) {
        return launch(
            &FusedBatchNormInferenceMetaKernel<
                T, U, ChannelOrder::ChannelFirst,
                /*add_side_input=*/true, FusedBatchNormActivationMode::kRelu>,
            channel_size, inner_dim_size);
      }
    } else if (channel_order == ChannelOrder::ChannelLast) {
      const int channel_size = input_shape[3];
      const int inner_dim_size = 1;
      if (no_activation && no_side_input) {
        return launch(&FusedBatchNormInferenceMetaKernel<
                          T, U, ChannelOrder::ChannelLast,
                          /*add_side_input=*/false,
                          FusedBatchNormActivationMode::kIdentity>,
                      channel_size, inner_dim_size);
      } else if (relu_activation && no_side_input) {
        return launch(
            &FusedBatchNormInferenceMetaKernel<
                T, U, ChannelOrder::ChannelLast,
                /*add_side_input=*/false, FusedBatchNormActivationMode::kRelu>,
            channel_size, inner_dim_size);
      } else if (no_activation && add_side_input) {
        return launch(&FusedBatchNormInferenceMetaKernel<
                          T, U, ChannelOrder::ChannelLast,
                          /*add_side_input=*/true,
                          FusedBatchNormActivationMode::kIdentity>,
                      channel_size, inner_dim_size);
      } else if (relu_activation && add_side_input) {
        return launch(
            &FusedBatchNormInferenceMetaKernel<
                T, U, ChannelOrder::ChannelLast,
                /*add_side_input=*/true, FusedBatchNormActivationMode::kRelu>,
            channel_size, inner_dim_size);
      }
    }
    return MakeStringError("no fused batch norm kernel was launched");
  }