int main()

in src/codegen_fp16fp32.cc [83:609]


int main(int argc, const char* argv[]) {
  bool iaca = false;
  bool disable = false;
  unordered_set<string> enabledDataType;

  // Always generate FP16
  enabledDataType.insert("FP16");
  if (parseArgumentBool(argc, argv, "--fp32", false)) {
    enabledDataType.insert("FP32");
  }

  // Frefetch 8 cache lines ahead
  const int prefetch_a_len =
      parseArgumentInt(argc, argv, "--prefetch-a", 0, 128);

  // Frefetch 8 cache lines ahead
  const int prefetch_b_len =
      parseArgumentInt(argc, argv, "--prefetch-b", 0, 256);

  const int prefetch_c_len =
      parseArgumentInt(argc, argv, "--prefetch-c", 0, 1024);

  bool fixedA = true, fixedB = true, fixedC = true;

  int eax, ebx, ecx, edx;
  __cpuid(1 /* ecx = vendor string */, eax, ebx, ecx, edx);
  printf("FC16 is %s supported\n", ((ecx & bit_F16C) ? " " : "not"));

  static const string license =
      "/*\n"
      " * Copyright (c) Meta Platforms, Inc. and affiliates.\n"
      " * All rights reserved.\n"
      " * This source code is licensed under the BSD-style license found in the\n"
      " * LICENSE file in the root directory of this source tree.\n"
      " */\n";

  string comma = ",";

  enum class mult_type { fma, mul };

  vector<ISA> isa = {
      // {1, "AVX", {{4, 1, 0}, {4, 2, 0}, {4, 3, 0}, {3, 1, 0}, {3, 2, 0}, {3,
      // 3, 0}}},
      {ISA::isaType::avx2,
       "Avx2",
       {
           // 4x3 register layout
           // {1, 3, 0},
           // {2, 3, 0},
           // {3, 3, 0},
           // {4, 3, 0},

           // 6x2 register layout
           {1, 2, 0},
           {2, 2, 0},
           {3, 2, 0},
           {4, 2, 0},
           {5, 2, 0},
           {6, 2, 0},

           // 14x1 register layout
           // {1, 1, 0},
           // {2, 1, 0},
           // {3, 1, 0},
           // {4, 1, 0},
           // {5, 1, 0},
           // {6, 1, 0},
           // {7, 1, 0},
           // {8, 1, 0},
           // {9, 1, 0},
           // {10, 1, 0},
           // {11, 1, 0},
           // {12, 1, 0},
           // {13, 1, 0},
           // {14, 1, 0},
       }},
      {ISA::isaType::avx512,
       "Avx512",
       {
           // 14x2 register layout
           {1, 2, 0},
           {2, 2, 0},
           {3, 2, 0},
           {4, 2, 0},
           {5, 2, 0},
           {6, 2, 0},
           {7, 2, 0},
           {8, 2, 0},
           {9, 2, 0},
           {10, 2, 0},
           {11, 2, 0},
           {12, 2, 0},
           {13, 2, 0},
           {14, 2, 0},
       }},
      {ISA::isaType::avx512_256,
       "Avx512_256",
       {
           // 14x2 register layout
           // Implemented by AVX2
           //{1, 2, 0},
           //{2, 2, 0},
           //{3, 2, 0},
           //{4, 2, 0},
           //{5, 2, 0},
           //{6, 2, 0},
           {7, 2, 0},
           {8, 2, 0},
           {9, 2, 0},
           {10, 2, 0},
           {11, 2, 0},
           {12, 2, 0},
           {13, 2, 0},
           {14, 2, 0},
       }}};

  // Labels
  const string label_outer = "loop_outter%=";
  const string label_next_inner = "next_inner%=";
  const string label_inner = "loop_inner%=";
  const string label_zero = "zero_regs%=";
  const string label_dump_C = "dump_C%=";

  for (auto& d_type : types_to_gen) {
    if (enabledDataType.count(d_type.second) == 0) {
      continue;
    }
    for (auto s : isa) {
      bool const isFp16 = d_type.first != DataType::Float32;
      string const B_type = [&]() {
        if (d_type.first == DataType::Float32)
          return "fp32";
        if (d_type.first == DataType::Float16)
          return "fp16";
        throw std::runtime_error("Unknow DataType");
      }();

      string isa_file_name = "Fbgemm" + d_type.second + "UKernels" + s.name;

      // open all files
      ofstream srcfile;
      srcfile.open(isa_file_name + ".cc");
      srcfile << license;
      srcfile << "#include \"./" + isa_file_name + ".h\"\n\n";
      srcfile << "namespace fbgemm {\n\n";
      if (iaca) {
        srcfile << "#include \"iacaMarks.h\"\n";
      }

      ofstream hdrfile;
      hdrfile.open(isa_file_name + ".h");
      hdrfile << license;

      hdrfile << "#pragma once\n";
      hdrfile << "#include <cstdint>\n";
      hdrfile << "#include \"fbgemm/Types.h\"\n";
      hdrfile << "#include \"fbgemm/FbgemmBuild.h\"\n";
      hdrfile << "#include \"fbgemm/FbgemmFPCommon.h\"\n\n";
      hdrfile << "namespace fbgemm {\n\n";
      hdrfile << "using GemmParams" << d_type.second << " = GemmParams<float"
              << (isFp16 ? "16" : "") << ">;\n\n";

      unsigned labelId = 0;

      bool fixedA = false, fixedB = false, fixedC = false;

      vector<vector<unsigned>>& ukernel_shape = s.shapes;

      vector<string> funcname(ukernel_shape.size()),
          fheader(ukernel_shape.size());
      string fargs;

      string prefix = s.name + "_" + B_type + "_" + "fA" + to_string(fixedA) +
          "fB" + to_string(fixedB) + "fC" + to_string(fixedC);
      cout << "Generating code for " << s.name << " " << B_type << "\n";

      string vec_reg_prefix = s.iset == ISA::isaType::avx512 ? "zmm" : "ymm";
      int num_vec_regs = s.iset == ISA::isaType::avx2 ? 16 : 32;
      int vec_len_in_bytes = s.iset == ISA::isaType::avx512 ? 64 : 32;

      for (unsigned k = 0; k < ukernel_shape.size(); k++) {
        printf(
            "shape: %d x %d * 32\n", ukernel_shape[k][0], ukernel_shape[k][1]);

        const string A_stride = to_string(4 * ukernel_shape[k][0]);
        const string B_stride =
            to_string((vec_len_in_bytes >> (int)isFp16) * ukernel_shape[k][1]);

        const string p1 = "GemmParams" + d_type.second + "* gp";

        funcname[k] = "gemmkernel_" + to_string(ukernel_shape[k][0]) + "x" +
            to_string(ukernel_shape[k][1]) + "_";
        funcname[k] += prefix;

        fargs = "(" + p1 + ")";

        fheader[k] = "void NOINLINE " + funcname[k] + fargs;
        srcfile << fheader[k] << " {\n";

        unsigned last_free_vecreg = 0;
        // produce register block of C
        vector<vector<string>> vCtile(ukernel_shape[k][0]);
        for (auto r = 0; r < ukernel_shape[k][0]; r++)
          for (auto c = 0; c < ukernel_shape[k][1]; c++) {
            vCtile[r].push_back(vec_reg_prefix + to_string(last_free_vecreg));
            last_free_vecreg++;
          }
        assert(last_free_vecreg <= num_vec_regs - 2);

        string vAtmp = vec_reg_prefix + to_string(last_free_vecreg++);
        // produce register block of B col
        vector<string> vBcol(ukernel_shape[k][1]);

        for (auto c = 0; c < ukernel_shape[k][1]; c++) {
          vBcol[c] = (vec_reg_prefix + to_string(last_free_vecreg));
          last_free_vecreg++;
        }
        assert(last_free_vecreg <= num_vec_regs);
        string r_spare = vec_reg_prefix +
            to_string(num_vec_regs - (s.iset == ISA::isaType::avx ? 2 : 1));

        auto const A_load_mult = [&](int r, mult_type m_type) {
          if (prefetch_a_len && ((4 * r) % cache_line_size == 0)) {
            addi(
                srcfile,
                "prefetcht0 [r9 + " + to_string(prefetch_a_len) + "]",
                fixedC);
          }
          string mul = m_type == mult_type::mul ? "vmulps" : "vfmadd231ps";
          addi(
              srcfile,
              "vbroadcastss " + vAtmp + ",DWORD PTR [r9+" + to_string(4 * r) +
                  "]");
          for (int c = 0; c < vCtile[0].size(); c++) {
            addi(
                srcfile,
                mul + " " + vCtile[r][c] + "," + vBcol[c] + "," + vAtmp);
          }
        };

        // Generate Loads from Matrix B
        auto const B_load = [&](int c, const string& vBcol, int prefetch_len) {
          if (d_type.first == DataType::Float32) {
            addi(
                srcfile,
                "vmovups " + vBcol + "," +
                    (s.iset == ISA::isaType::avx512 ? "ZMM" : "YMM") +
                    "WORD PTR [r10 + " + to_string(vec_len_in_bytes * c) + "]");
          } else if (d_type.first == DataType::Float16) {
            addi(
                srcfile,
                "vcvtph2ps " + vBcol + "," +
                    (s.iset == ISA::isaType::avx512 ? "YMM" : "XMM") +
                    "WORD PTR [r10 + " + to_string(vec_len_in_bytes / 2 * c) +
                    "]");
          }
          if (prefetch_len && ((vec_len_in_bytes * c) % cache_line_size == 0)) {
            addi(
                srcfile,
                "prefetcht0 [r10 + " +
                    to_string(vec_len_in_bytes * c + prefetch_len) + "]",
                fixedC);
          }
        };

        auto const C_prefetch = [&](int r) {
          for (auto c = 0; prefetch_c_len && (c < vCtile[r].size()); c++) {
            if ((vec_len_in_bytes * c) % cache_line_size == 0) {
              addi(
                  srcfile,
                  "prefetcht1 [r12 + " +
                      to_string(
                          /*vec_len_in_bytes * ukernel_shape[k][1] +*/
                          c * cache_line_size + prefetch_c_len) +
                      "]",
                  fixedC);
            }
          }
        };

        // Generate Loads from Matrix C
        auto const C_load = [&](int r) {
          for (auto c = 0; c < vCtile[r].size(); ++c) {
            switch (s.iset) {
              case ISA::isaType::avx:
              case ISA::isaType::avx2:
              case ISA::isaType::avx512:
              case ISA::isaType::avx512_256:
                if (prefetch_c_len &&
                    ((vec_len_in_bytes * c) % cache_line_size == 0)) {
                  addi(
                      srcfile,
                      "prefetcht1 [r12 + " +
                          to_string(
                              /*vec_len_in_bytes * ukernel_shape[k][1] +*/
                              c * cache_line_size + prefetch_c_len) +
                          "]",
                      fixedC);
                }
                addi(
                    srcfile,
                    "vmulps " + vCtile[r][c] + ", " + r_spare + ", " +
                        "[r12 + " + to_string(vec_len_in_bytes * c) + "]",
                    fixedC);
                break;
              default:
                assert(0);
            }
          }
        };

        srcfile << "  asm volatile(\n";

        srcfile << "#if !defined(__clang__) || __clang_major__ >= 14"
                << "\n";
        addi(srcfile, "mov r14, %[gp]");
        srcfile << "#else\n";
        addi(srcfile, "mov %[gp], %%r14");
        addi(srcfile, ".intel_syntax noprefix");
        srcfile << "#endif\n";

        srcfile << "\n";
        srcfile << "      // Copy parameters\n";
        srcfile << "      // k\n";
        addi(srcfile, "mov r8, [r14 + 0]");
        // Assuming k >= 1
        addi(srcfile, "dec r8");
        srcfile << "      // A\n";
        addi(srcfile, "mov r9, [r14 + 8]");
        srcfile << "      // B\n";
        addi(srcfile, "mov r10, [r14 + 16]");
        srcfile << "      // beta\n";
        addi(srcfile, "lea r15, [r14 + 24]");
        srcfile << "      // C\n";
        addi(srcfile, "mov r12, [r14 + 32]");
        srcfile << "      // ldc\n";
        addi(srcfile, "mov r13, [r14 + 40]");
        srcfile << "      // b_block_cols\n";
        addi(srcfile, "mov rdi, [r14 + 48]");
        srcfile << "      // b_block_size\n";
        addi(srcfile, "mov rsi, [r14 + 56]");
        srcfile << "\n";
        srcfile << "      // Make copies of A and C\n";
        addi(srcfile, "mov rax, r9");
        addi(srcfile, "mov rcx, r12");
        srcfile << "\n";

        addi(srcfile, "xor ebx, ebx");
        addi(srcfile, label_outer + ":");
        addi(srcfile, "mov r14, r8");

        string r_spare_cmp = "xmm" +
            to_string(num_vec_regs - (s.iset == ISA::isaType::avx ? 2 : 1));

        addi(
            srcfile,
            "vbroadcastss " + r_spare + string(",DWORD PTR [r15]"),
            fixedC);
        // Generate first iteration which loads values from C  and interleavs
        // With loads from B and multiplication
        for (auto c = 0; c < vCtile[0].size(); ++c) {
          B_load(c, vBcol[c], prefetch_b_len);
        }
        addi(srcfile, "vxorps xmm0, xmm0, xmm0");
        addi(srcfile, "vcomiss " + r_spare_cmp + ", xmm0");
        addi(srcfile, "jz " + label_zero);

        srcfile << "\n";
        srcfile << "      // Setup values with beta multiplication\n";
        string r_last = vec_reg_prefix + to_string(num_vec_regs - 1);
        for (auto r = 0; r < vCtile.size(); r++) {
          if (r > 0) {
            addi(srcfile, "add r12, r13", fixedC); // move C ptr
          }
          C_load(r);
        }
        // Skip matrix B preload if k == 1 (may OutOfBound access)
        addi(srcfile, "test r14,r14");
        addi(srcfile, "jz skip_preload%=");
        // Preload B index and prefetch with the next iteration
        B_load(vCtile[0].size(), r_spare, prefetch_b_len);
        addi(srcfile, "skip_preload%=:");
        for (auto r = 0; r < vCtile.size(); r++) {
          A_load_mult(r, mult_type::fma);
        }
        if (vCtile.size() > 1) {
          addi(srcfile, "mov r12, rcx");
        }
        addi(srcfile, "test r14,r14"); // Decrease iterations
        addi(srcfile, "jnz " + label_next_inner);
        addi(srcfile, "add r10," + B_stride, fixedA); // B stride
        addi(srcfile, "jmp " + label_dump_C);

        //
        // Handle non-accumulate case, the values can be directly stored
        //
        srcfile << "\n";
        addi(srcfile, label_zero + ":");
        srcfile << "\n";
        // Skip matrix B preload if k == 1 (may OutOfBound access)
        addi(srcfile, "test r14,r14");
        addi(srcfile, "jz skip_preload_b_zero%=");
        // Preload B index and with the next iteration
        B_load(vCtile[0].size(), r_spare, prefetch_b_len);
        addi(srcfile, "skip_preload_b_zero%=:");
        // Consider all vCtile regs as zeros, do direct MUL into
        for (auto r = 0; r < vCtile.size(); r++) {
          if (r > 0) {
            addi(srcfile, "add r12, r13", fixedC); // move C ptr
          }
          C_prefetch(r);
          A_load_mult(r, mult_type::mul);
        }
        if (vCtile.size() > 1) {
          addi(srcfile, "mov r12, rcx");
        }
        addi(srcfile, "test r14,r14"); // Decrease iterations
        addi(srcfile, "jnz " + label_next_inner);
        addi(srcfile, "add r10," + B_stride, fixedA); // B stride
        addi(srcfile, "jmp " + label_dump_C);

        // start marker
        if (iaca) {
          addi(srcfile, "mov ebx, 111");
          addi(srcfile, ".byte 0x64, 0x67, 0x90");
        }

        //
        //  Inner iteration begin
        //
        srcfile << "\n";
        addi(srcfile, label_inner + ":");
        srcfile << "\n";

        // Store preloaded value
        addi(srcfile, "vmovaps " + vBcol[0] + "," + r_spare);
        for (int c = 1; c < vCtile[0].size(); c++) {
          B_load(c, vBcol[c], prefetch_b_len);
        }
        // Preload for next iteration
        B_load(vCtile[0].size(), r_spare, prefetch_b_len);
        for (int r = 0; r < vCtile.size(); r++) {
          A_load_mult(r, mult_type::fma);
        }

        // Finish inner iteration
        srcfile << "\n";
        addi(srcfile, label_next_inner + ":");
        addi(srcfile, "add r9," + A_stride, fixedA); // A stride
        addi(srcfile, "add r10," + B_stride, fixedA); // B stride
        addi(srcfile, "dec r14"); // Decrease iterations
        addi(srcfile, "jnz " + label_inner);
        srcfile << "\n";

        // end marker
        if (iaca) {
          addi(srcfile, "mov ebx, 222");
          addi(srcfile, ".byte 0x64, 0x67, 0x90");
        }

        // Perform last iteration without preloading B values
        // Store preloaded value
        addi(srcfile, "vmovaps " + vBcol[0] + "," + r_spare);
        for (int c = 1; c < vCtile[0].size(); c++) {
          B_load(c, vBcol[c], 0); // no prefetch
        }
        for (int r = 0; r < vCtile.size(); r++) {
          A_load_mult(r, mult_type::fma);
        }
        addi(srcfile, "add r9," + A_stride, fixedA); // A stride
        addi(srcfile, "add r10," + B_stride, fixedA); // B stride

        srcfile << "      // Dump C\n";
        addi(srcfile, label_dump_C + ":");
        for (auto r = 0; r < vCtile.size(); r++) {
          if (r > 0) {
            addi(srcfile, "add r12, r13", fixedC); // move C ptr
          }
          for (auto c = 0; c < vCtile[r].size(); c++) {
            addi(
                srcfile,
                "vmovups " + vec_reg_prefix + "word PTR [r12 + " +
                    to_string(vec_len_in_bytes * c) + "], " + vCtile[r][c],
                fixedC);
          }
        }

        srcfile << "\n      // next outer iteration\n";
        // C
        addi(
            srcfile,
            "add rcx, " + to_string(vec_len_in_bytes * ukernel_shape[k][1]),
            fixedC);
        addi(srcfile, "mov r12, rcx", fixedC);
        // A
        addi(srcfile, "mov r9, rax");

        addi(srcfile, "inc rbx");
        addi(srcfile, "cmp rbx, rdi");
        addi(srcfile, "jl " + label_outer);

        // output
        srcfile << "      :\n";
        // input
        srcfile << "      : [gp] \"rm\"(gp)\n";

        // clobbered
        srcfile << "      : \"r8\",\n        \"r9\",\n        \"r10\",\n"
                   "        \"r11\",\n        \"r13\",\n"
                   "        \"r14\",\n        \"rax\",\n        \"rcx\",\n"
                   "        \"rsi\",\n        \"rdi\",\n"
                   "        \"rbx\",\n        \"r12\",\n        \"r15\",\n"
                   "        \"memory\");\n";
        srcfile << "}\n";
      }

      for (unsigned k = 0; k < ukernel_shape.size(); k++) {
        hdrfile << fheader[k] << ";\n";
      }

      srcfile << "\n} // namespace fbgemm\n";
      srcfile.close();
      hdrfile << "\n} // namespace fbgemm\n";
      hdrfile.close();
    } // isa
  }
}