void cblas_gemm_i64_i64acc()

in src/FbgemmI64.cc [403:540]


void cblas_gemm_i64_i64acc(
    matrix_op_t transa,
    matrix_op_t transb,
    int M,
    int N,
    int K,
    const int64_t* A,
    int lda,
    const int64_t* B,
    int ldb,
    bool accumulate,
    int64_t* C,
    int ldc) {
  cpuinfo_initialize();
  if (!fbgemmHasAvx512Support()) {
    cblas_gemm_i64_i64acc_ref(
        transa, transb, M, N, K, A, lda, B, ldb, accumulate, C, ldc);
    return;
  }
  constexpr int MCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MCB;
  constexpr int NCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NCB;
  constexpr int KCB = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::KCB;
  constexpr int MR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::MR;
  constexpr int NR = PackingTraits<int64_t, int64_t, inst_set_t::avx512>::NR;
  static_assert(MCB % MR == 0, "MR must divide MCB");
  static_assert(NCB % NR == 0, "NR must divide NCB");
  constexpr int VLEN =
      simd_info<inst_set_t::avx512>::WIDTH_BYTES / sizeof(int64_t);
  static_assert(NR % VLEN == 0, "VLEN must divide NR");

  using CodeGenType = CodeGenBase<int64_t, int64_t, int64_t, int64_t>;
  CodeGenType codeObj;
  CodeGenType::jit_micro_kernel_fp fn =
      codeObj.getOrCreate<inst_set_t::avx512>(true /* accum */, MCB, NCB, KCB);
  CodeGenType::jit_micro_kernel_fp fn_noacc;
  if (!accumulate) {
    fn_noacc = codeObj.getOrCreate<inst_set_t::avx512>(
        false /* accum */, MCB, NCB, KCB);
  }

  vector<int64_t> At, Bt;
  // TODO: handle transpose during packing
  if (transa == matrix_op_t::Transpose) {
    At.resize(M * K);
    for (int i = 0; i < M; ++i) {
      for (int k = 0; k < K; ++k) {
        At.at(i * K + k) = A[i + k * lda];
      }
    }
    A = At.data();
    lda = K;
  }
  if (transb == matrix_op_t::Transpose) {
    Bt.resize(K * N);
    for (int k = 0; k < K; ++k) {
      for (int j = 0; j < N; ++j) {
        Bt.at(k * N + j) = B[k + j * ldb];
      }
    }
    B = Bt.data();
    ldb = N;
  }

  alignas(64) array<int64_t, MCB * KCB> packA;
  alignas(64) array<int64_t, KCB * NCB> packB;
  alignas(64) array<int64_t, MCB * NCB> packC;

  for (int ic = 0; ic < M; ic += MCB) {
    for (int kc = 0; kc < K; kc += KCB) {
      // pack A
      for (int i = 0; i < std::min(MCB, M - ic); ++i) {
        memcpy(
            &packA[i * KCB],
            A + (ic + i) * lda + kc,
            std::min(K - kc, KCB) * sizeof(int64_t));
      }

      for (int jc = 0; jc < N; jc += NCB) {
        // pack B
        for (int i = 0; i < std::min(KCB, K - kc); ++i) {
          memcpy(
              &packB[i * NCB],
              B + (kc + i) * ldb + jc,
              std::min(NCB, N - jc) * sizeof(int64_t));
        }

        if (M - ic >= MCB && N - jc >= NCB) {
          if (kc == 0 && !accumulate) {
            fn_noacc(
                packA.data(),
                packB.data(),
                packB.data(),
                C + ic * ldc + jc,
                std::min(KCB, K - kc),
                ldc);
          } else {
            fn(packA.data(),
               packB.data(),
               packB.data(),
               C + ic * ldc + jc,
               std::min(KCB, K - kc),
               ldc);
          }
        } else {
          // remainder
          if (kc == 0 && !accumulate) {
            fn_noacc(
                packA.data(),
                packB.data(),
                packB.data(),
                packC.data(),
                std::min(KCB, K - kc),
                NCB);
          } else {
            for (int i = 0; i < std::min(MCB, M - ic); ++i) {
              memcpy(
                  &packC[i * NCB],
                  C + (ic + i) * ldc + jc,
                  std::min(NCB, N - jc) * sizeof(int64_t));
            }
            fn(packA.data(),
               packB.data(),
               packB.data(),
               packC.data(),
               std::min(KCB, K - kc),
               NCB);
          }
          for (int i = 0; i < std::min(MCB, M - ic); ++i) {
            memcpy(
                C + (ic + i) * ldc + jc,
                &packC[i * NCB],
                std::min(NCB, N - jc) * sizeof(int64_t));
          }
        }
      } // jc
    } // kc
  } // ic
}