void CallCublasLt()

in src/runtime/contrib/cublas/cublas.cc [139:306]


void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
                  cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
                  const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
                  const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
                  size_t workspace_size, cublasLtEpilogue_t epilogue,
                  std::optional<float> dq_scale) {
  ICHECK(TypeEqual(A->dtype, B->dtype));
  // Reversed strides indicates an in-place transpose operation.
  transa = IsInPlaceTransposed(A) ? !transa : transa;
  transb = IsInPlaceTransposed(B) ? !transb : transb;

  auto compute_type = CUBLAS_COMPUTE_32F;
  auto scale_type = CUDA_R_32F;
  cudaDataType_t ab_type = CUDA_R_32F;
  cudaDataType_t c_type = CUDA_R_32F;
  float one_fp32 = 1.0;
  float zero_fp32 = 0.0;
  int32_t one_i32 = 1;
  int32_t zero_i32 = 0;
  // Pass dequantization scale through the "alpha" parameter. If there is no dequantization after
  // matmul, then alpha == 1.0
  float alpha_value = dq_scale.value_or(one_fp32);
  void* alpha = &alpha_value;
  void* beta = &zero_fp32;

  if (TypeMatch(A->dtype, kDLFloat, 16)) {
    ab_type = CUDA_R_16F;
  } else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
    ab_type = CUDA_R_16BF;
  } else if (TypeMatch(A->dtype, kDLInt, 8)) {
    ab_type = CUDA_R_8I;
  } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
    ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8));
    ab_type = CUDA_R_8F_E4M3;
  }

  if (TypeMatch(C->dtype, kDLFloat, 16)) {
    c_type = CUDA_R_16F;
  } else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
    c_type = CUDA_R_16BF;
  } else if (TypeMatch(C->dtype, kDLInt, 32)) {
    c_type = CUDA_R_32I;
    compute_type = CUBLAS_COMPUTE_32I;
    scale_type = CUDA_R_32I;
    alpha = &one_i32;
    beta = &zero_i32;
  }

  cublasLtMatmulDesc_t op_desc;
  cublasOperation_t op_transa = CUBLASBooleanToTranspose(transa);
  cublasOperation_t op_transb = CUBLASBooleanToTranspose(transb);

  CHECK_CUBLAS_ERROR(cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type));
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSA,
                                                    &op_transb, sizeof(op_transb)));
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_TRANSB,
                                                    &op_transa, sizeof(op_transa)));

  if (bias != nullptr) {
    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
                                                      &bias->data, sizeof(float*)));
  }

  if (scaleA != nullptr) {
    auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
                                                      &scaleA_data, sizeof(float*)));
  }
  if (scaleB != nullptr) {
    auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                      &scaleB_data, sizeof(float*)));
  }

  if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
    CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
                                                      &epilogue, sizeof(epilogue)));
  }

  int batch_offset_A = A->ndim - 2;
  int batch_offset_B = B->ndim - 2;

  int M = ColumnCount(B, transb, batch_offset_B);
  int N = RowCount(A, transa, batch_offset_A);
  int K = ColumnCount(A, transa, batch_offset_A);
  bool use_batched_gemm = A->ndim > 2 || B->ndim > 2;
  // If A is batched but B is not, flatten all non-reduction axes of A to use the regular GEMM.
  // This trick is only applicable if batch axes and the other spatial axis (M or N) are
  // adjacent in both the input and the output matrix. In particular, if A is of shape (M, K)
  // and B matrix is of shape (Batch, N, K) with transb = true, the output shape
  // is (Batch, M, N). Since the Batch and the N axes are not adjacent in the output, we cannot
  // use the regular GEMM if only B is batched.
  if (A->ndim > 2 && B->ndim == 2 && transa == false) {
    N = 1;
    for (int i = 0; i < A->ndim - 1; ++i) {
      N *= A->shape[i];
    }
    use_batched_gemm = false;
  }

  int lda = transb ? K : M;
  int ldb = transa ? N : K;
  int ldc = M;

  cublasLtMatrixLayout_t A_desc, B_desc, C_desc;
  CHECK_CUBLAS_ERROR(
      cublasLtMatrixLayoutCreate(&A_desc, ab_type, !transb ? M : K, !transb ? K : M, lda));
  CHECK_CUBLAS_ERROR(
      cublasLtMatrixLayoutCreate(&B_desc, ab_type, !transa ? K : N, !transa ? N : K, ldb));
  CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&C_desc, c_type, M, N, ldc));

  if (use_batched_gemm) {
    auto get_batch_count = [](int64_t* shape, int batch_offset) {
      int64_t count = 1;
      for (int i = 0; i < batch_offset; ++i) {
        count *= shape[i];
      }
      return count;
    };
    auto set_batch = [](cublasLtMatrixLayout_t mat_desc, int batch_count, int64_t batch_stride) {
      CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
          mat_desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count, sizeof(batch_count)));
      CHECK_CUBLAS_ERROR(
          cublasLtMatrixLayoutSetAttribute(mat_desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
                                           &batch_stride, sizeof(batch_stride)));
    };

    int batch_count_A = get_batch_count(A->shape, batch_offset_A);
    int batch_count_B = get_batch_count(B->shape, batch_offset_B);
    int batch_count_C = get_batch_count(C->shape, C->ndim - 2);
    int64_t batch_stride_A = M * K;
    int64_t batch_stride_B = K * N;
    int64_t batch_stride_C = M * N;

    // cuBLASLt does not seem to support batched GEMM with one of matrices having
    // one batch (with batch_stride 0).
    ICHECK_EQ(batch_count_A, batch_count_B);

    set_batch(A_desc, batch_count_A, batch_stride_A);
    set_batch(B_desc, batch_count_B, batch_stride_B);
    set_batch(C_desc, batch_count_C, batch_stride_C);
  }

  auto A_data = static_cast<char*>(A->data) + A->byte_offset;
  auto B_data = static_cast<char*>(B->data) + B->byte_offset;
  auto C_data = static_cast<char*>(C->data) + C->byte_offset;

  cublasLtMatmulPreferenceSetAttribute(matmul_pref_desc, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
                                       &workspace_size, sizeof(size_t));

  cublasLtMatmulHeuristicResult_t heuristic_result = {};
  int returned_result = 0;
  CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(hdl, op_desc, A_desc, B_desc, C_desc, C_desc,
                                                    matmul_pref_desc, 1, &heuristic_result,
                                                    &returned_result));
  if (returned_result == 0) {
    CHECK_CUBLAS_ERROR(CUBLAS_STATUS_NOT_SUPPORTED);
  }

  CHECK_CUBLAS_ERROR(cublasLtMatmul(hdl, op_desc, alpha, B_data, A_desc, A_data, B_desc, beta,
                                    C_data, C_desc, C_data, C_desc, &heuristic_result.algo,
                                    workspace_ptr, workspace_size, stream));

  cublasLtMatmulDescDestroy(op_desc);
  cublasLtMatrixLayoutDestroy(A_desc);
  cublasLtMatrixLayoutDestroy(B_desc);
  cublasLtMatrixLayoutDestroy(C_desc);
}