void LLVMIRGen::generateLLVMIRForDataParallelInstr()

in lib/LLVMIRCodeGen/LLVMIRGen.cpp [1131:1893]


void LLVMIRGen::generateLLVMIRForDataParallelInstr(
    llvm::IRBuilder<> &builder, const glow::Instruction *I,
    llvm::Function *kernel, llvm::DenseMap<Value *, int> &bufferToArgNum,
    llvm::Value *loopCount) {
  setCurrentDebugLocation(builder, I);
  assert(canBePartOfDataParallelKernel(I) &&
         "Instruction cannot be part of a data parallel kernel");
  switch (I->getKind()) {

#define ARITHMETIC_UNARY_OP_WITH_IMM_CASE(INST_NAME_, FUN_NAME_, VALUE_)       \
  case Kinded::Kind::INST_NAME_##InstKind: {                                   \
    auto *AN = cast<INST_NAME_##Inst>(I);                                      \
    auto *dest = AN->getDest();                                                \
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);  \
    auto *elementTy = getElementType(builder, dest);                           \
    auto value = AN->get##VALUE_();                                            \
    auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType());        \
    auto *pointerNull =                                                        \
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());             \
    if (dest->getType()->isQuantizedType()) {                                  \
      auto *destTy = dest->getType();                                          \
      /* Quantize value based on the output type. */                           \
      /* Perform this early and let jit library to work */                     \
      /* with quantized number. */                                             \
      TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()};   \
      if (destTy->getElementType() == ElemKind::Int8QTy) {                     \
        auto quantizedValue = quantization::quantize<int8_t>(value, TQP);      \
        auto *val = emitConstI8(builder, quantizedValue);                      \
        auto *stackedOpCall = createUncheckedCall(                             \
            builder, F, {loopCount, val, pointerNull, pointerNull});           \
        auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,      \
                                           "buffer.element.addr");             \
        builder.CreateStore(stackedOpCall, destAddr);                          \
      } else if (destTy->getElementType() == ElemKind::Int16QTy) {             \
        auto quantizedValue = quantization::quantize<int16_t>(value, TQP);     \
        auto *val = emitConstI16(builder, quantizedValue);                     \
        auto *stackedOpCall = createUncheckedCall(                             \
            builder, F, {loopCount, val, pointerNull, pointerNull});           \
        auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,      \
                                           "buffer.element.addr");             \
        builder.CreateStore(stackedOpCall, destAddr);                          \
      } else {                                                                 \
        llvm_unreachable("Quantization precision not supported.");             \
      }                                                                        \
    } else {                                                                   \
      auto *val = emitConst(builder, value, dest->getElementType());           \
      auto *stackedOpCall = createUncheckedCall(                               \
          builder, F, {loopCount, val, pointerNull, pointerNull});             \
      auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,        \
                                         "buffer.element.addr");               \
      builder.CreateStore(stackedOpCall, destAddr);                            \
    }                                                                          \
    break;                                                                     \
  }
    ARITHMETIC_UNARY_OP_WITH_IMM_CASE(Splat, "splat", Value);
#undef ARITHMETIC_UNARY_OP_WITH_IMM_CASE

  case Kinded::Kind::TouchInstKind:
    // do nothing;
    break;

  case Kinded::Kind::ElementSelectInstKind: {
    auto *ES = cast<ElementSelectInst>(I);
    auto *dest = ES->getDest();
    auto *cond = ES->getCond();
    auto *lhs = ES->getLHS();
    auto *rhs = ES->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *condPtr = emitBufferAddress(builder, cond, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);

    // Need _kernel suffix since these operations are implemented as
    // "data-parallel" kernels in libjit.
    auto *F = getFunction("elementselect_kernel", lhs->getElementType());

    if (lhs->getType()->isQuantizedType()) {
      auto *destTy = dest->getType();
      auto *lhsTy = lhs->getType();
      auto *rhsTy = rhs->getType();

      auto *destOffset = emitConstI32(builder, destTy->getOffset());
      auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
      auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());

      // The selected value will be either lhs = s_l * (i_l - o_l) or
      // rhs = s_r * (i_r - o_r); the stored result that must be computed is
      // therefore one of:
      // (i)  i_d = (s_l / s_d) * (i_l - o_l) + o_d
      // (ii) i_d = (s_r / s_d) * (i_r - o_r) + o_d
      float destScale = destTy->getScale();
      auto lhsScaleParams = quantization::quantizeScaleOffset32To8(
          lhsTy->getScale() / destScale, lhsTy->getOffset());
      auto rhsScaleParams = quantization::quantizeScaleOffset32To8(
          rhsTy->getScale() / destScale, rhsTy->getOffset());

      auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre);
      auto *lhsPost = emitConstI32(builder, lhsScaleParams.post);
      auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale);
      auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre);
      auto *rhsPost = emitConstI32(builder, rhsScaleParams.post);
      auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale);

      auto *stackedOpCall = createUncheckedCall(
          builder, F,
          {loopCount, condPtr, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset,
           lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale});
      auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
                                         loopCount, "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    } else {
      auto *stackedOpCall =
          createUncheckedCall(builder, F, {loopCount, condPtr, lhsPtr, rhsPtr});
      auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr,
                                         loopCount, "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    }
    break;
  }
  case Kinded::Kind::IntLookupTableInstKind: {
    auto *lookupTable = cast<IntLookupTableInst>(I);
    auto *dest = lookupTable->getDest();
    auto *src = lookupTable->getSrc();
    auto *mapping = lookupTable->getMapping();

    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *mappingPtr =
        emitBufferAddress(builder, mapping, kernel, bufferToArgNum);

    auto *F = getFunction("intlookuptable_kernel", dest->getElementType());
    auto *stackedOpCall =
        builder.CreateCall(F, {loopCount, srcPtr, mappingPtr});
    auto *destType = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);

    break;
  }
#define ARITHMETIC_UNARY_OP_CASE(INST_NAME_, FUN_NAME_)                        \
  case Kinded::Kind::INST_NAME_##InstKind: {                                   \
    auto *AN = cast<INST_NAME_##Inst>(I);                                      \
    auto *dest = AN->getDest();                                                \
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);  \
    auto *srcPtr =                                                             \
        emitBufferAddress(builder, AN->getSrc(), kernel, bufferToArgNum);      \
    auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType());        \
    auto *elementTy = getElementType(builder, dest);                           \
    auto *pointerNull =                                                        \
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());             \
    auto *stackedOpCall = createUncheckedCall(                                 \
        builder, F, {loopCount, srcPtr, pointerNull, pointerNull});            \
    auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr,          \
                                       loopCount, "buffer.element.addr");      \
    builder.CreateStore(stackedOpCall, destAddr);                              \
    break;                                                                     \
  }

    ARITHMETIC_UNARY_OP_CASE(Sigmoid, "sigmoid");
    ARITHMETIC_UNARY_OP_CASE(Tanh, "tanh");
    ARITHMETIC_UNARY_OP_CASE(ElementLog, "element_log");
    ARITHMETIC_UNARY_OP_CASE(ElementExp, "element_exp");
    ARITHMETIC_UNARY_OP_CASE(ElementAbs, "element_abs");
    ARITHMETIC_UNARY_OP_CASE(ElementNeg, "element_neg");
    ARITHMETIC_UNARY_OP_CASE(ElementFloor, "element_floor");
    ARITHMETIC_UNARY_OP_CASE(ElementCeil, "element_ceil");
    ARITHMETIC_UNARY_OP_CASE(ElementRound, "element_round");
    ARITHMETIC_UNARY_OP_CASE(ElementSqrt, "element_sqrt");
    ARITHMETIC_UNARY_OP_CASE(ElementErf, "element_erf");
    ARITHMETIC_UNARY_OP_CASE(ElementRsqrt, "element_rsqrt");
    ARITHMETIC_UNARY_OP_CASE(ElementReciprocal, "element_reciprocal");
    ARITHMETIC_UNARY_OP_CASE(ElementSin, "element_sin");
    ARITHMETIC_UNARY_OP_CASE(ElementCos, "element_cos");
#undef ARITHMETIC_UNARY_OP_CASE

  case Kinded::Kind::ReluInstKind: {
    auto *RI = cast<ReluInst>(I);
    auto *src = RI->getSrc();
    auto *dest = RI->getDest();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto srcTy = src->getType();
    auto destTy = dest->getType();

    auto *F = getFunction("element_relu", dest->getElementType());
    llvm::CallInst *stackedOpCall = nullptr;
    if (dest->getElementType() == ElemKind::Int8QTy) {
      auto *srcOffset =
          emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
      auto *destOffset =
          emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
      auto destScaleParams = quantization::quantizeScaleOffset32To8(
          srcTy->getScale() / destTy->getScale(), 0);
      auto *destPre = emitConstI32(builder, destScaleParams.pre);
      auto *destPost = emitConstI32(builder, destScaleParams.post);
      auto *destScale = emitConstI32(builder, destScaleParams.scale);
      stackedOpCall = createCall(builder, F,
                                 {loopCount, srcPtr, srcOffset, destOffset,
                                  destPre, destPost, destScale});
    } else if (dest->getElementType() == ElemKind::FloatTy) {
      stackedOpCall = createCall(builder, F, {loopCount, srcPtr});
    } else {
      LOG(FATAL) << "Type is not supported";
    }
    auto *elementTy = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ClipInstKind: {
    auto *CI = cast<ClipInst>(I);
    auto *src = CI->getSrc();
    auto *dest = CI->getDest();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto srcTy = src->getType();
    auto destTy = dest->getType();
    float clipMinF = CI->getMin();
    float clipMaxF = CI->getMax();

    auto *F = getFunction("element_clip", dest->getElementType());
    llvm::CallInst *stackedOpCall = nullptr;
    if (dest->getElementType() == ElemKind::Int8QTy) {
      TensorQuantizationParams srcTQP{src->getType()->getScale(),
                                      src->getType()->getOffset()};
      int8_t clipMinQ = quantization::quantize<int8_t>(clipMinF, srcTQP);
      int8_t clipMaxQ = quantization::quantize<int8_t>(clipMaxF, srcTQP);
      auto *clipMin = emitConstI8(builder, clipMinQ);
      auto *clipMax = emitConstI8(builder, clipMaxQ);
      auto *srcOffset =
          emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
      auto *destOffset =
          emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
      auto destScaleParams = quantization::quantizeScaleOffset32To8(
          srcTy->getScale() / destTy->getScale(), 0);
      auto *destPre = emitConstI32(builder, destScaleParams.pre);
      auto *destPost = emitConstI32(builder, destScaleParams.post);
      auto *destScale = emitConstI32(builder, destScaleParams.scale);
      stackedOpCall =
          createCall(builder, F,
                     {loopCount, srcPtr, clipMin, clipMax, srcOffset,
                      destOffset, destPre, destPost, destScale});
    } else if (dest->getElementType() == ElemKind::FloatTy) {
      auto *clipMin = emitConstF32(builder, clipMinF);
      auto *clipMax = emitConstF32(builder, clipMaxF);
      stackedOpCall =
          createCall(builder, F, {loopCount, srcPtr, clipMin, clipMax});
    } else {
      LOG(FATAL) << "Type is not supported";
    }
    auto *elementTy = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::LeakyReluInstKind: {
    auto *LI = cast<LeakyReluInst>(I);
    auto *src = LI->getSrc();
    auto *dest = LI->getDest();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto srcTy = src->getType();
    auto destTy = dest->getType();

    auto *F = getFunction("element_leaky_relu", dest->getElementType());
    llvm::CallInst *stackedOpCall = nullptr;
    if (dest->getElementType() == ElemKind::Int8QTy) {
      auto *srcOffset =
          emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
      auto *destOffset =
          emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
      // Scale parameters for the positive input domain.
      auto posParams = quantization::quantizeScaleOffset32To8(
          srcTy->getScale() / destTy->getScale(), 0);
      auto *posPre = emitConstI32(builder, posParams.pre);
      auto *posPost = emitConstI32(builder, posParams.post);
      auto *posScale = emitConstI32(builder, posParams.scale);
      // Scale parameters for the negative input domain.
      auto negParams = quantization::quantizeScaleOffset32To8(
          srcTy->getScale() * LI->getAlpha() / destTy->getScale(), 0);
      auto *negPre = emitConstI32(builder, negParams.pre);
      auto *negPost = emitConstI32(builder, negParams.post);
      auto *negScale = emitConstI32(builder, negParams.scale);
      stackedOpCall =
          createCall(builder, F,
                     {loopCount, srcPtr, srcOffset, destOffset, posPre, posPost,
                      posScale, negPre, negPost, negScale});
    } else if (dest->getElementType() == ElemKind::FloatTy) {
      auto *alpha = emitConstF32(builder, LI->getAlpha());
      stackedOpCall = createCall(builder, F, {loopCount, srcPtr, alpha});
    } else {
      LOG(FATAL) << "Type is not supported";
    }
    auto *elementTy = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ElementIsNaNInstKind: {
    auto *AN = cast<ElementIsNaNInst>(I);
    auto *src = AN->getSrc();
    auto *dest = AN->getDest();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *F = getFunction("element_is_nan_kernel", src->getElementType());
    auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
    auto *elementTy = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::QuantizeInstKind: {
    auto *QI = cast<QuantizeInst>(I);
    auto *src = QI->getSrc();
    auto *dest = QI->getDest();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *destTy = dest->getType();
    auto *destScale = emitConstF32(builder, destTy->getScale());
    auto *destOffset = emitConstI32(builder, destTy->getOffset());
    auto *F = getFunction("element_quantize_kernel", dest->getElementType());

    auto *stackedOpCall = createUncheckedCall(
        builder, F, {loopCount, srcPtr, destScale, destOffset});
    auto *destType = getElementType(builder, dest);
    auto *destAddr =
        builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::DequantizeInstKind: {
    auto *DI = cast<DequantizeInst>(I);
    auto *src = DI->getSrc();
    auto *dest = DI->getDest();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcTy = src->getType();
    auto *srcScale = emitConstF32(builder, srcTy->getScale());
    auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
    auto *F = getFunction("element_dequantize_kernel", src->getElementType());

    auto *stackedOpCall = createUncheckedCall(
        builder, F, {loopCount, srcPtr, srcScale, srcOffset});
    auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, loopCount,
                                       "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::RescaleQuantizedInstKind: {
    auto *RQI = cast<RescaleQuantizedInst>(I);
    auto *dest = RQI->getDest();
    auto *src = RQI->getSrc();
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);

    auto *destType = dest->getType();
    auto *srcType = src->getType();

    auto rescaleParams = quantization::quantizeScaleOffset32To8(
        srcType->getScale() / destType->getScale(), srcType->getOffset());

    auto *destOffset = emitConstI32(builder, destType->getOffset());
    auto *srcOffset = emitConstI32(builder, srcType->getOffset());
    auto *preShift = emitConstI32(builder, rescaleParams.pre);
    auto *postShift = emitConstI32(builder, rescaleParams.post);
    auto *scale = emitConstI32(builder, rescaleParams.scale);
    auto *F = getFunction("element_rescale_kernel", dest->getElementType());

    auto *stackedOpCall = createUncheckedCall(
        builder, F,
        {loopCount, srcPtr, destOffset, srcOffset, preShift, postShift, scale});
    auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount,
                                       "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::CopyInstKind: {
    auto *CI = cast<CopyInst>(I);
    auto *dest = CI->getDest();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcPtr =
        emitBufferAddress(builder, CI->getSrc(), kernel, bufferToArgNum);
    auto *F = getFunction("copy_kernel", dest->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *pointerNull =
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());
    auto *stackedOpCall = createUncheckedCall(
        builder, F, {loopCount, srcPtr, pointerNull, pointerNull});
    auto *destAddr = builder.CreateGEP(getElementType(builder, dest), destPtr,
                                       loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

#define ARITHMETIC_BINARY_OP_CASE(INST_NAME_, FUN_NAME_, ...)                  \
  case Kinded::Kind::INST_NAME_##InstKind: {                                   \
    auto *AN = cast<INST_NAME_##Inst>(I);                                      \
    auto *dest = AN->getDest();                                                \
    auto *lhs = AN->getLHS();                                                  \
    auto *rhs = AN->getRHS();                                                  \
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);  \
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);    \
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);    \
                                                                               \
    auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType());        \
    auto *elementTy = getElementType(builder, dest);                           \
    auto *pointerNull =                                                        \
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());             \
    bool typesMatched = matchPair(dest->getElementType(), __VA_ARGS__);        \
    if (lhs->getType()->isQuantizedType()) {                                   \
      auto *destTy = dest->getType();                                          \
      auto *lhsTy = lhs->getType();                                            \
      auto *rhsTy = rhs->getType();                                            \
                                                                               \
      auto *destOffset = emitConstI32(builder, destTy->getOffset());           \
      auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());             \
      auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());             \
                                                                               \
      float destScale = destTy->getScale();                                    \
                                                                               \
      auto lhsScaleParams = quantization::quantizeScaleOffset32To8(            \
          lhsTy->getScale() / destScale, lhsTy->getOffset());                  \
      auto rhsScaleParams = quantization::quantizeScaleOffset32To8(            \
          rhsTy->getScale() / destScale, rhsTy->getOffset());                  \
                                                                               \
      auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre);                \
      auto *lhsPost = emitConstI32(builder, lhsScaleParams.post);              \
      auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale);            \
      auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre);                \
      auto *rhsPost = emitConstI32(builder, rhsScaleParams.post);              \
      auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale);            \
                                                                               \
      auto *stackedOpCall = createUncheckedCall(                               \
          builder, F,                                                          \
          {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset,        \
           lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale});             \
      auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,         \
                                         loopCount, "buffer.element.addr");    \
      builder.CreateStore(stackedOpCall, destAddr);                            \
    } else if (typesMatched) {                                                 \
      auto *stackedOpCall = createUncheckedCall(                               \
          builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});               \
      auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,        \
                                         "buffer.element.addr");               \
      builder.CreateStore(stackedOpCall, destAddr);                            \
    } else {                                                                   \
      llvm_unreachable("Unsupported Type in " #INST_NAME_);                    \
    }                                                                          \
    break;                                                                     \
  }
    ARITHMETIC_BINARY_OP_CASE(ElementAdd, "element_add", ElemKind::FloatTy,
                              ElemKind::Int32ITy, ElemKind::Int64ITy);
    ARITHMETIC_BINARY_OP_CASE(ElementSub, "element_sub", ElemKind::FloatTy);
    ARITHMETIC_BINARY_OP_CASE(ElementMax, "element_max", ElemKind::FloatTy);
    ARITHMETIC_BINARY_OP_CASE(ElementMin, "element_min", ElemKind::FloatTy);
    ARITHMETIC_BINARY_OP_CASE(ElementPow, "element_pow", ElemKind::FloatTy);
#undef ARITHMETIC_BINARY_OP_CASE

  case Kinded::Kind::ElementNotInstKind: {
    auto *NI = cast<ElementNotInst>(I);
    auto *dest = NI->getDest();
    auto *src = NI->getSrc();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *F = getFunction("element_not_kernel", src->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ElementAndInstKind: {
    auto *AI = cast<ElementAndInst>(I);
    auto *dest = AI->getDest();
    auto *lhs = AI->getLHS();
    auto *rhs = AI->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
    auto *F = getFunction("element_and_kernel", lhs->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *stackedOpCall =
        createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ElementOrInstKind: {
    auto *OI = cast<ElementOrInst>(I);
    auto *dest = OI->getDest();
    auto *lhs = OI->getLHS();
    auto *rhs = OI->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
    auto *F = getFunction("element_or_kernel", lhs->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *stackedOpCall =
        createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ElementXorInstKind: {
    auto *XI = cast<ElementXorInst>(I);
    auto *dest = XI->getDest();
    auto *lhs = XI->getLHS();
    auto *rhs = XI->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
    auto *F = getFunction("element_xor_kernel", lhs->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *stackedOpCall =
        createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
    auto *destAddr =
        builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  case Kinded::Kind::ElementCmpEQInstKind:
  case Kinded::Kind::ElementCmpNEQInstKind:
  case Kinded::Kind::ElementCmpLTInstKind:
  case Kinded::Kind::ElementCmpLTEInstKind: {
    Value *dest = nullptr;
    Value *lhs = nullptr;
    Value *rhs = nullptr;
    std::string kernelName;

    if (auto *CEQI = dyn_cast<ElementCmpEQInst>(I)) {
      dest = CEQI->getDest();
      lhs = CEQI->getLHS();
      rhs = CEQI->getRHS();
      kernelName = "element_cmp_eq_kernel";
    } else if (auto *CNEQI = dyn_cast<ElementCmpNEQInst>(I)) {
      dest = CNEQI->getDest();
      lhs = CNEQI->getLHS();
      rhs = CNEQI->getRHS();
      kernelName = "element_cmp_neq_kernel";
    } else if (auto *CLTEI = dyn_cast<ElementCmpLTEInst>(I)) {
      dest = CLTEI->getDest();
      lhs = CLTEI->getLHS();
      rhs = CLTEI->getRHS();
      kernelName = "element_cmp_lte_kernel";
    } else if (auto *CLTI = dyn_cast<ElementCmpLTInst>(I)) {
      dest = CLTI->getDest();
      lhs = CLTI->getLHS();
      rhs = CLTI->getRHS();
      kernelName = "element_cmp_lt_kernel";
    } else {
      llvm_unreachable(
          "Missmatch between Instruction Kind and instruction instance.");
    }

    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);

    // Need _kernel suffix since these operations are implemented as
    // "data-parallel" kernels in libjit.
    auto *F = getFunction(kernelName.c_str(), lhs->getElementType());

    if (lhs->getType()->isQuantizedType()) {
      auto *lhsTy = lhs->getType();
      auto *rhsTy = rhs->getType();

      auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
      auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());

      // We can divide both sides of the comparison by the rhs scale since it is
      // strictly positive; this saves one rescale within the backend. The
      // inequalities are:
      //     s_l * (i_l - o_l) <= s_r * (i_r - o_r)
      // <=> (s_l / s_r) * (i_l - o_l) <= i_r - o_r
      float scale = lhsTy->getScale() / rhsTy->getScale();
      auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
      auto *cmpPre = emitConstI32(builder, scaleParams.pre);
      auto *cmpPost = emitConstI32(builder, scaleParams.post);
      auto *cmpScale = emitConstI32(builder, scaleParams.scale);

      auto *stackedOpCall =
          createUncheckedCall(builder, F,
                              {loopCount, lhsPtr, rhsPtr, lhsOffset, rhsOffset,
                               cmpPre, cmpPost, cmpScale});
      auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
                                         loopCount, "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    } else {
      auto *stackedOpCall =
          createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
      auto *elementTy = getElementType(builder, dest);
      auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
                                         "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    }
    break;
  }

  case Kinded::Kind::ElementMulInstKind: {
    auto *MI = cast<ElementMulInst>(I);
    auto *dest = MI->getDest();
    auto *lhs = MI->getLHS();
    auto *rhs = MI->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);

    // Need _kernel suffix since these operations are implemented as
    // "data-parallel" kernels in libjit.
    auto *F = getFunction("element_mul_kernel", dest->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *pointerNull =
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());

    if (lhs->getType()->isQuantizedType()) {
      auto *destTy = dest->getType();
      auto *lhsTy = lhs->getType();
      auto *rhsTy = rhs->getType();

      auto *destOffset = emitConstI32(builder, destTy->getOffset());
      auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
      auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());

      // The multiplicative scale factor is s_l * s_r / s_d due to the equation
      //    s_d * (i_d - o_d) = s_l * (i_l - o_l) * s_r * (i_r - o_r)
      // => i_d = (s_l * s_r / s_d) * (i_l - o_l) * (i_r - o_r) + o_d
      float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
      auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
      auto *mulPre = emitConstI32(builder, scaleParams.pre);
      auto *mulPost = emitConstI32(builder, scaleParams.post);
      auto *mulScale = emitConstI32(builder, scaleParams.scale);

      auto *stackedOpCall =
          createUncheckedCall(builder, F,
                              {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
                               rhsOffset, mulPre, mulPost, mulScale});
      auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
                                         loopCount, "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    } else if (lhs->getType()->getElementType() == ElemKind::Int64ITy ||
               lhs->getType()->getElementType() == ElemKind::Int32ITy ||
               lhs->getType()->getElementType() == ElemKind::FloatTy) {
      auto *stackedOpCall = createUncheckedCall(
          builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
      auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
                                         "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    } else {
      LOG_ASSERT(false) << "Unsupported element type for Mul.";
    }
    break;
  }

  case Kinded::Kind::ElementDivInstKind: {
    auto *MI = cast<ElementDivInst>(I);
    auto *dest = MI->getDest();
    auto *lhs = MI->getLHS();
    auto *rhs = MI->getRHS();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
    auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);

    // Need _kernel suffix since these operations are implemented as
    // "data-parallel" kernels in libjit.
    auto *F = getFunction("element_div_kernel", dest->getElementType());
    auto *elementTy = getElementType(builder, dest);
    auto *pointerNull =
        llvm::ConstantPointerNull::get(elementTy->getPointerTo());

    if (lhs->getType()->isQuantizedType()) {
      auto *destTy = dest->getType();
      auto *lhsTy = lhs->getType();
      auto *rhsTy = rhs->getType();

      auto *destOffset = emitConstI32(builder, destTy->getOffset());
      auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
      auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());

      // The division scale factor is s_l / (s_r * s_d) due to the equation
      //    s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
      // => i_d = (s_l / (s_r * s_d)) * ((i_l - o_l) / (i_r - o_r)) + o_d
      float scale =
          lhsTy->getScale() / (rhsTy->getScale() * destTy->getScale());
      auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
      auto *divPre = emitConstI32(builder, scaleParams.pre);
      auto *divPost = emitConstI32(builder, scaleParams.post);
      auto *divScale = emitConstI32(builder, scaleParams.scale);

      auto *stackedOpCall =
          createUncheckedCall(builder, F,
                              {loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
                               rhsOffset, divPre, divPost, divScale});
      auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
                                         loopCount, "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    } else {
      auto *elementTy = getElementType(builder, dest);
      auto *stackedOpCall = createUncheckedCall(
          builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
      auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
                                         "buffer.element.addr");
      builder.CreateStore(stackedOpCall, destAddr);
    }
    break;
  }

  case Kinded::Kind::ModuloInstKind: {
    auto *MI = cast<ModuloInst>(I);
    auto *dest = MI->getDest();
    auto *src = MI->getSrc();
    auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
    auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
    auto *divisor = emitConst(builder, MI->getDivisor(), ElemKind::Int64ITy);
    llvm::Function *F = nullptr;
    // Need _kernel suffix since these operations are implemented as
    // "data-parallel" kernels in libjit.
    if (MI->getSignFollowDivisor()) {
      F = getFunction("element_modulo_kernel_sign_follow",
                      dest->getElementType());
    } else {
      F = getFunction("element_modulo_kernel_no_sign_follow",
                      dest->getElementType());
    }
    auto *stackedOpCall =
        createUncheckedCall(builder, F, {loopCount, divisor, srcPtr});
    llvm::Value *destAddr = nullptr;
    if (dest->getElementType() == ElemKind::Int64ITy) {
      destAddr = builder.CreateGEP(builder.getInt64Ty(), destPtr, loopCount,
                                   "buffer.element.addr");
    } else {
      destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount,
                                   "buffer.element.addr");
    }
    builder.CreateStore(stackedOpCall, destAddr);
    break;
  }

  default:
    std::string sBuf;
    llvm::raw_string_ostream s(sBuf);
    I->dump(s);
    LOG(FATAL) << "Cannot select the instruction: " << s.str();
  }
}