Expected DnnBuildFusedConvolution()

in backends/gpu/lib/kernels/dnn_kernels.cc [479:679]


Expected<cudnn_frontend::ExecutionPlan> DnnBuildFusedConvolution(
    const GpuDnnHandle& handle, double alpha, double alpha2,
    // Needs to be sorted alphabetically by attribute name!
    Attribute<int32_t> activation_mode_attr,
    Attribute<int32_t> backend_type_attr, ArrayAttr bias_dims_attr,
    ArrayAttr bias_strides, Attribute<int32_t> bias_type_attr,
    ArrayAttr conv_dilations, Attribute<int32_t> conv_dim_attr,
    Attribute<int32_t> conv_mode_attr, ArrayAttr conv_padding,
    ArrayAttr conv_strides, Attribute<int32_t> engine_id,
    ArrayAttr filter_dims_attr, ArrayAttr filter_strides,
    ArrayAttr input_dims_attr, ArrayAttr input_strides,
    Attribute<int32_t> input_type_attr, ArrayAttr output_dims_attr,
    ArrayAttr output_strides_attr, Attribute<int32_t> output_type_attr,
    ArrayAttr tuning_knob_ids, ArrayAttr tuning_knob_values) {
  auto input_type = wrapper::DnnDataType::FromOpaqueValue(*input_type_attr);
  auto output_type = wrapper::DnnDataType::FromOpaqueValue(*output_type_attr);
  auto bias_type = wrapper::DnnDataType::FromOpaqueValue(*bias_type_attr);

  auto unvectorized_input_type =
      wrapper::GetUnvectorizedDnnDataType(input_type);
  auto unvectorized_output_type =
      wrapper::GetUnvectorizedDnnDataType(output_type);
  auto unvectorized_bias_type = wrapper::GetUnvectorizedDnnDataType(bias_type);

  auto accumulator_type = GetConvAccumulatorType(
      unvectorized_input_type,
      DnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled());
  auto activation_type = GetConvActivationType(
      unvectorized_input_type,
      DnnEnvVar<ConvDoFP32ComputationFP16Input>::IsEnabled());

  // CUDNN fused operation supports the pattern in the form of
  // Conv + Add + BiasAdd + Act. Therefore, we need to build a graph of the
  // four ops with their input/output tensor edges:
  // Conv   : input: tensor_x, tensor_w;    output: tensor_conv (virtual)
  // Add    : input: tensor_conv, tensor_z; output: tensor_add (virtual)
  // BiasAdd: input: tensor_add, tensor_b;  output: tensor_bias (virtual)
  // Act    : input: tensor_bias;           output: tensor_y
  auto input_dims = input_dims_attr.GetValue<int64_t>();
  auto tensor_x = BuildTensor(input_type, input_dims.size(), input_dims.data(),
                              input_strides.GetValue<int64_t>().data(), 'x',
                              unvectorized_input_type);
  if (!tensor_x) return tensor_x.takeError();

  auto output_dims = output_dims_attr.GetValue<int64_t>();
  auto output_strides = output_strides_attr.GetValue<int64_t>();
  auto tensor_y =
      BuildTensor(output_type, output_dims.size(), output_dims.data(),
                  output_strides.data(), 'y', unvectorized_output_type);
  if (!tensor_y) return tensor_y.takeError();

  auto tensor_z =
      BuildTensor(output_type, output_dims.size(), &output_dims[0],
                  &output_strides[0], 'z', unvectorized_output_type);
  if (!tensor_z) return tensor_z.takeError();

  auto filter_dims = filter_dims_attr.GetValue<int64_t>();
  auto tensor_w = BuildTensor(
      input_type, filter_dims.size(), filter_dims.data(),
      filter_strides.GetValue<int64_t>().data(), 'w', unvectorized_input_type);
  if (!tensor_w) return tensor_w.takeError();

  auto bias_dims = bias_dims_attr.GetValue<int64_t>();
  auto tensor_b = BuildTensor(input_type, bias_dims.size(), bias_dims.data(),
                              bias_strides.GetValue<int64_t>().data(), 'b',
                              unvectorized_bias_type);
  if (!tensor_b) return tensor_b.takeError();

  auto tensor_conv =
      BuildTensor(output_type, output_dims.size(), &output_dims[0],
                  &output_strides[0], 'C', accumulator_type,
                  /*set_virtual=*/true);
  if (!tensor_conv) return tensor_conv.takeError();

  auto tensor_add =
      BuildTensor(output_type, output_dims.size(), &output_dims[0],
                  &output_strides[0], 'A', activation_type,
                  /*set_virtual=*/true);
  if (!tensor_add) return tensor_add.takeError();

  auto tensor_bias =
      BuildTensor(output_type, output_dims.size(), &output_dims[0],
                  &output_strides[0], 'B', activation_type,
                  /*set_virtual=*/true);
  if (!tensor_bias) return tensor_bias.takeError();

  // conv_desc.
  cudnnConvolutionMode_t conv_mode =
      wrapper::DnnConvolutionMode::FromOpaqueValue(*conv_mode_attr);
  int conv_dim = *conv_dim_attr;
  auto conv_desc =
      cudnn_frontend::ConvDescBuilder()
          .setComputePrecision(accumulator_type)
          .setMathMode(conv_mode)
          .setNDims(conv_dim)
          .setStrides(conv_dim, conv_strides.GetValue<int64_t>().data())
          .setPrePadding(conv_dim, conv_padding.GetValue<int64_t>().data())
          .setPostPadding(conv_dim, conv_padding.GetValue<int64_t>().data())
          .setDilation(conv_dim, conv_dilations.GetValue<int64_t>().data())
          .build();
  if (conv_desc.get_status())
    return wrapper::MakeError(conv_desc.get_status(),
                              conv_desc.describe().c_str());

  // CUDNN Operation
  auto backend_type =
      static_cast<cudnnBackendDescriptorType_t>(*backend_type_attr);
  auto conv_op = cudnn_frontend::OperationBuilder(backend_type)
                     .setxDesc(*tensor_x)
                     .setyDesc(*tensor_conv)
                     .setwDesc(*tensor_w)
                     .setcDesc(conv_desc)
                     .setAlpha(1.0f)
                     .setBeta(0.0f)
                     .build();
  if (conv_op.get_status())
    return wrapper::MakeError(conv_op.get_status(), conv_op.describe().c_str());

  auto add_desc = cudnn_frontend::PointWiseDescBuilder()
                      .setMode(CUDNN_POINTWISE_ADD)
                      .setMathPrecision(activation_type)
                      .build();
  auto add_op = cudnn_frontend::OperationBuilder(
                    CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
                    .setxDesc(conv_op.getOutputTensor())
                    .setbDesc(*tensor_z)
                    .setyDesc(*tensor_add)
                    .setpwDesc(add_desc)
                    .setAlpha(alpha)
                    .setAlpha2(alpha2)
                    .build();
  if (add_op.get_status())
    return wrapper::MakeError(add_op.get_status(), add_op.describe().c_str());

  auto bias_add_desc = cudnn_frontend::PointWiseDescBuilder()
                           .setMode(CUDNN_POINTWISE_ADD)
                           .setMathPrecision(activation_type)
                           .build();

  // If the activation is the identity function, then the bias-add is the last
  // op, and it writes to the output, tensor_y.  Otherwise, it writes to the
  // "virtual tensor" (temp buffer) tensor_bias, to which we apply the
  // activation.
  auto activation_mode =
      static_cast<cudnnActivationMode_t>(*activation_mode_attr);
  auto& bias_out_desc =
      activation_mode == CUDNN_ACTIVATION_IDENTITY ? *tensor_y : *tensor_bias;
  auto bias_add_op = cudnn_frontend::OperationBuilder(
                         CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
                         .setxDesc(add_op.getOutputTensor())
                         .setbDesc(*tensor_b)
                         .setyDesc(bias_out_desc)
                         .setpwDesc(bias_add_desc)
                         .build();
  if (bias_add_op.get_status())
    return wrapper::MakeError(bias_add_op.get_status(),
                              bias_add_op.describe().c_str());

  // CUDNN OperationGraph
  llvm::SmallVector<cudnn_frontend::Operation const*, 4> ops = {
      &conv_op, &add_op, &bias_add_op};

  llvm::Optional<cudnn_frontend::PointWiseDesc_v8> act_desc;
  llvm::Optional<cudnn_frontend::Operation_v8> act_op;
  switch (activation_mode) {
    case CUDNN_ACTIVATION_IDENTITY:
      break;
    case CUDNN_ACTIVATION_RELU:
      act_desc.emplace(cudnn_frontend::PointWiseDescBuilder()
                           .setMode(CUDNN_POINTWISE_RELU_FWD)
                           .setMathPrecision(activation_type)
                           .build());
      if (act_desc->get_status())
        return wrapper::MakeError(act_desc->get_status(),
                                  act_desc->describe().c_str());
      act_op.emplace(cudnn_frontend::OperationBuilder(
                         CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
                         .setxDesc(bias_add_op.getOutputTensor())
                         .setyDesc(*tensor_y)
                         .setpwDesc(*act_desc)
                         .build());
      if (act_op->get_status())
        return wrapper::MakeError(act_op->get_status(),
                                  act_op->describe().c_str());
      ops.push_back(&*act_op);
      break;
    default:
      return MakeStringError("Unimplemented activation mode");
  }

  auto op_graph = cudnn_frontend::OperationGraphBuilder()
                      .setHandle(handle.get())
                      .setOperationGraph(ops.size(), ops.data())
                      .build();
  if (op_graph.get_status())
    return wrapper::MakeError(op_graph.get_status(),
                              op_graph.describe().c_str());
  return BuildExecutionPlan(handle, std::move(op_graph), *engine_id,
                            tuning_knob_ids.GetValue<int64_t>(),
                            tuning_knob_values.GetValue<int64_t>());
}