GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate()

in src/GenerateI8Depthwise.cc [176:576]


GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
    int D,
    std::array<int, 3> F,
    int oc_per_g,
    bool compute_a_sum,
    int remainder,
    int prev_skip,
    int next_skip,
    int top_skip,
    int bottom_skip,
    int left_skip,
    int right_skip) {
  std::tuple<int, int, int, int, int, bool, int, int, int, int, int, int, int>
      kernelSig = std::make_tuple(
          D,
          F[0],
          F[1],
          F[2],
          oc_per_g,
          compute_a_sum,
          remainder,
          prev_skip,
          next_skip,
          top_skip,
          bottom_skip,
          left_skip,
          right_skip);

  return codeCache_.getOrCreate(kernelSig, [&]() -> jit_kernel_signature {
    asmjit::CodeHolder code;
    code.init(runtime().environment());
    x86::Assembler assembler(&code);
    x86::Emitter* e = assembler.as<x86::Emitter>();
#ifdef FBGEMM_LOG_CODE
    std::string filename = "dwconv_" + std::to_string(D) + "d_";
    for (int i = 3 - D; i < 3; ++i) {
      filename += std::to_string(K[i]);
      if (i < 2) {
        filename += "x"
      }
    }
    filename += "_" + std::to_string(oc_per_g);
    if (compute_a_sum) {
      filename += "_asum";
    }
    if (remainder) {
      filename += "_remainder" + std::to_string(remainder);
    }
    if (prev_skip) {
      filename += "_prev_skip" + std::to_string(prev_skip);
    }
    if (next_skip) {
      filename += "_next_skip" + std::to_string(next_skip);
    }
    if (top_skip) {
      filename += "_top_skip" + std::to_string(top_skip);
    }
    if (bottom_skip) {
      filename += "_bottom_skip" + std::to_string(bottom_skip);
    }
    if (left_skip) {
      filename += "_left_skip" + std::to_string(left_skip);
    }
    if (right_skip) {
      filename += "_right_skip" + std::to_string(right_skip);
    }
    filename += ".txt";
    FILE* codeLogFile = fopen(filename.c_str(), "w");
    asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile);
    code.setLogger(codeLogger);
#endif

    x86::Gp a_addr = e->zdi();
    x86::Gp b_addr = e->zsi();
    x86::Gp c_addr = e->zdx();
    x86::Gp a_sum_addr = e->zcx();
    x86::Gp h = e->gpz(8);
    x86::Gp w = e->gpz(9);
    x86::Gp ic = e->gpz(10);
    x86::Gp mask_addr = e->gpz(11);
    x86::Gp a_zero_point = e->gpz(12);
    x86::Gp b_zero_point_addr = e->gpz(13);
    x86::Gp ic_loop_count = e->gpz(14);
    x86::Gp a_addr_save = e->gpz(15);

    asmjit::FuncDetail func;
    func.init(
        asmjit::FuncSignatureT<
            void,
            const std::uint8_t*,
            const std::int8_t*,
            std::int32_t*,
            std::int32_t*,
            int,
            int,
            int,
            const int*,
            int,
            const std::int32_t*>(asmjit::CallConv::kIdHost),
        e->environment());

    asmjit::FuncFrame frame;
    frame.init(func);

    frame.setDirtyRegs(
        x86::Reg::kGroupVec,
        asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
            asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
    frame.setDirtyRegs(
        x86::Reg::kGroupGp,
        asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

    asmjit::FuncArgsAssignment args(&func);
    args.assignAll(
        a_addr,
        b_addr,
        c_addr,
        a_sum_addr,
        h,
        w,
        ic,
        mask_addr,
        a_zero_point,
        b_zero_point_addr);

    args.updateFuncFrame(frame);
    frame.finalize();

    e->emitProlog(frame);
    e->emitArgsAssignment(frame, args);

    // Assign vector registers
    x86::Ymm a[4];
    x86::Ymm c[4];
    x86::Ymm a_sum[2];

    int vreg_id = 2; // reserve 2 for temp vreg
    for (int i = 0; i < 4; ++i, ++vreg_id) {
      a[i] = x86::Ymm(vreg_id);
    }
    for (int i = 0; i < 4; ++i, ++vreg_id) {
      c[i] = x86::Ymm(vreg_id);
    }
    if (compute_a_sum) {
      a_sum[0] = x86::Ymm(vreg_id);
      ++vreg_id;
      a_sum[1] = x86::Ymm(vreg_id);
      ++vreg_id;
    }
    x86::Ymm mask_vreg(vreg_id);
    constexpr int vlen = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
    if (remainder != simd_info<inst_set_t::avx2>::WIDTH_BYTES) {
      ++vreg_id;
      e->vmovups(
          mask_vreg,
          x86::ymmword_ptr(
              mask_addr,
              (vlen - remainder / 4 / oc_per_g) % vlen * sizeof(int32_t)));
    }
    x86::Ymm one_epi8(vreg_id);
    if (compute_a_sum) {
      ++vreg_id;
      gen8BitVectorOne(e, one_epi8);
    }

    int K = std::accumulate(F.begin(), F.end(), 1, std::multiplies<int>());
    x86::Ymm one_epi16(vreg_id);
    if (K > 2) {
      ++vreg_id;
      gen16BitVectorOne<inst_set_t::avx2, x86::Ymm>(e, one_epi16);
    }

    bool has_pad = prev_skip || next_skip || top_skip || bottom_skip ||
        left_skip || right_skip;
    bool need_zero = K % 4 == 3 || K % 4 == 1;
    // When out of registers, zero and A_zero_point_vreg need to share.
    bool recompute_zero = vreg_id == 15 && need_zero;

    x86::Ymm a_zero_point_vreg(vreg_id);
    if (!recompute_zero && has_pad) {
      e->movq(a_zero_point_vreg.half(), a_zero_point);
      e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half());
    }
    if (vreg_id < 15) {
      ++vreg_id;
    }
    x86::Ymm zero(vreg_id);
    if (need_zero && (!recompute_zero || !has_pad)) {
      e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
    }

    // Assign scalar registers
    e->imul(w, ic);
    e->imul(h, w);
    if (D >= 3) {
      e->mov(a_addr_save, w);
      e->imul(a_addr_save, F[1]);
      e->sub(h, a_addr_save); // h * w * ic - F[1] * w * ic
    }
    e->mov(a_addr_save, ic);
    e->imul(a_addr_save, F[2]);
    e->sub(w, a_addr_save); // w * ic - F[2] * ic

    e->mov(ic_loop_count, ic);
    e->add(ic_loop_count, asmjit::Imm(32 / oc_per_g - 1));
    e->sar(ic_loop_count, asmjit::Imm(oc_per_g == 1 ? 5 : 4));

    e->mov(a_addr_save, a_addr);
    asmjit::Label ic_loop_begin = e->newLabel(), ic_loop_end = e->newLabel();

    // main_loop == false: the last vector iteration across input channels
    for (bool main_loop : {true, false}) {
      if (main_loop) {
        e->bind(ic_loop_begin);
        e->dec(ic_loop_count);
        e->jle(ic_loop_end);
      }

      if (recompute_zero && has_pad) {
        e->movq(a_zero_point_vreg.half(), a_zero_point);
        e->vpbroadcastb(a_zero_point_vreg, a_zero_point_vreg.half());
      }

      int i = 0;
      // Iterate across the reduction (filter) dimension
      for (int f_t = 0; f_t < ((D == 2) ? 1 : F[0]); ++f_t) {
        for (int f_h = 0; f_h < F[1]; ++f_h) {
          for (int f_w = 0; f_w < F[2]; ++f_w, ++i) {
            bool pad = false;
            if (D > 2) {
              if (f_t < prev_skip || f_t >= F[0] - next_skip) {
                pad = true;
              }
            }
            if (f_h < top_skip || f_h >= F[1] - bottom_skip ||
                f_w < left_skip || f_w >= F[2] - right_skip) {
              pad = true;
            }

            // Load A
            if (pad) {
              e->vmovups(a[i % 4], a_zero_point_vreg);
            } else {
              if (oc_per_g == 1) {
                if (!main_loop && remainder != 32) {
                  e->vmaskmovps(a[i % 4], mask_vreg, x86::ymmword_ptr(a_addr));
                } else {
                  e->vmovups(a[i % 4], x86::ymmword_ptr(a_addr));
                }
              } else {
                assert(oc_per_g == 2);
                if (!main_loop && remainder != 32) {
                  e->vmaskmovps(
                      a[i % 4].half(),
                      mask_vreg.half(),
                      x86::xmmword_ptr(a_addr));
                } else {
                  e->vmovups(a[i % 4].half(), x86::xmmword_ptr(a_addr));
                }
                // Duplicate each byte.
                e->vpmovzxbw(a[i % 4], a[i % 4].half());
                e->vpsllw(x86::ymm(i % 2), a[i % 4], asmjit::Imm(8));
                e->vpaddw(a[i % 4], a[i % 4], x86::ymm(i % 2));
              }
            }

            // Compute when we have 4 inputs or this is the last iteration
            if (i % 4 == 3 || i == K - 1) {
              if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) {
                if (recompute_zero && has_pad) {
                  e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
                }
              }

              genMaddEpi16xNPacked(
                  e,
                  a,
                  b_addr,
                  c,
                  compute_a_sum ? a_sum : nullptr,
                  /*n=*/std::min(K - i / 4 * 4, 4),
                  main_loop ? 32 : remainder,
                  /*accumulation=*/i / 4 > 0,
                  one_epi8,
                  one_epi16,
                  zero);

              if (i != K - 1) {
                e->add(b_addr, asmjit::Imm(32 * 4));
              } else if (main_loop) {
                e->add(b_addr, asmjit::Imm(32 * (K - i / 4 * 4 + 1) / 2 * 2));
              }

              if (K - i / 4 * 4 >= 3 && K - i / 4 * 4 <= 6) {
                for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
                  // fix? output layout (see genMaddEpi16xNPacked for details)
                  e->vperm2f128(
                      a[r],
                      c[r % 2 * 2],
                      c[r % 2 * 2 + 1],
                      asmjit::Imm(r < 2 ? 0x20 : 0x31));
                }
                for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
                  e->vmovdqa(c[r], a[r]);
                }
              }
            }
            if (i != K - 1) {
              e->add(a_addr, ic); // advance to next pixel
            }
          }
          if (i != K - 1) {
            e->add(a_addr, w); // advance to next row
          }
        }
        if (D >= 3 && i != K - 1) {
          e->add(a_addr, h); // advance to next frame
        }
      }

      for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
        e->vmovups(x86::ymmword_ptr(c_addr, r * 32), c[r]);
      }

      if (compute_a_sum) {
        if (oc_per_g == 1) {
          e->vpmovsxwd(a[0], a_sum[0].half());
          e->vmovups(x86::ymmword_ptr(a_sum_addr), a[0]);
        } else {
          // Rollback duplication
          e->vpsrld(a_sum[0], a_sum[0], asmjit::Imm(16));
          e->vmovups(x86::xmmword_ptr(a_sum_addr), a_sum[0].half());
        }

        if (main_loop || remainder >= 8) {
          if (oc_per_g == 1) {
            e->vpmovsxwd(a[1], a_sum[1].half());
            e->vmovups(x86::ymmword_ptr(a_sum_addr, 32), a[1]);
          } else {
            // Rollback duplication
            e->vpsrld(a_sum[1], a_sum[1], asmjit::Imm(16));
            e->vmovups(x86::xmmword_ptr(a_sum_addr, 16), a_sum[1].half());
          }
        }

        if (main_loop || remainder >= 16) {
          e->vextracti128(a_sum[0].half(), a_sum[0], asmjit::Imm(1));
          if (oc_per_g == 1) {
            e->vpmovsxwd(a_sum[0], a_sum[0].half());
            e->vmovups(x86::ymmword_ptr(a_sum_addr, 64), a_sum[0]);
          } else {
            e->vmovups(x86::xmmword_ptr(a_sum_addr, 32), a_sum[0].half());
          }
        }

        if (main_loop || remainder >= 24) {
          e->vextracti128(a_sum[1].half(), a_sum[1], asmjit::Imm(1));
          if (oc_per_g == 1) {
            e->vpmovsxwd(a_sum[1], a_sum[1].half());
            e->vmovups(x86::ymmword_ptr(a_sum_addr, 96), a_sum[1]);
          } else {
            e->vmovups(x86::xmmword_ptr(a_sum_addr, 48), a_sum[1].half());
          }
        }

        if (main_loop) {
          e->add(a_sum_addr, asmjit::Imm(128 / oc_per_g));
        }
      }

      if (main_loop) {
        e->add(c_addr, asmjit::Imm(128));
        e->add(a_addr_save, asmjit::Imm(32 / oc_per_g));
        e->mov(a_addr, a_addr_save);
        e->jmp(ic_loop_begin);

        e->bind(ic_loop_end);
      }
    }

    e->emitEpilog(frame);

    jit_kernel_signature fn;
    asmjit::Error err;
    {
      std::unique_lock<std::mutex> lock(rtMutex_);
      err = runtime().add(&fn, &code);
    }
    if (err) {
      std::cout << "Error: in fn add" << std::endl;
      return nullptr;
    }

#ifdef FBGEMM_LOG_CODE
    fclose(codeLogFile);
    delete codeLogger;
#endif

    return fn;
  });
}