in lib/LLVMIRCodeGen/LLVMIRGen.cpp [1131:1893]
void LLVMIRGen::generateLLVMIRForDataParallelInstr(
llvm::IRBuilder<> &builder, const glow::Instruction *I,
llvm::Function *kernel, llvm::DenseMap<Value *, int> &bufferToArgNum,
llvm::Value *loopCount) {
setCurrentDebugLocation(builder, I);
assert(canBePartOfDataParallelKernel(I) &&
"Instruction cannot be part of a data parallel kernel");
switch (I->getKind()) {
#define ARITHMETIC_UNARY_OP_WITH_IMM_CASE(INST_NAME_, FUN_NAME_, VALUE_) \
case Kinded::Kind::INST_NAME_##InstKind: { \
auto *AN = cast<INST_NAME_##Inst>(I); \
auto *dest = AN->getDest(); \
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
auto *elementTy = getElementType(builder, dest); \
auto value = AN->get##VALUE_(); \
auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
auto *pointerNull = \
llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
if (dest->getType()->isQuantizedType()) { \
auto *destTy = dest->getType(); \
/* Quantize value based on the output type. */ \
/* Perform this early and let jit library to work */ \
/* with quantized number. */ \
TensorQuantizationParams TQP{destTy->getScale(), destTy->getOffset()}; \
if (destTy->getElementType() == ElemKind::Int8QTy) { \
auto quantizedValue = quantization::quantize<int8_t>(value, TQP); \
auto *val = emitConstI8(builder, quantizedValue); \
auto *stackedOpCall = createUncheckedCall( \
builder, F, {loopCount, val, pointerNull, pointerNull}); \
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
"buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
} else if (destTy->getElementType() == ElemKind::Int16QTy) { \
auto quantizedValue = quantization::quantize<int16_t>(value, TQP); \
auto *val = emitConstI16(builder, quantizedValue); \
auto *stackedOpCall = createUncheckedCall( \
builder, F, {loopCount, val, pointerNull, pointerNull}); \
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
"buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
} else { \
llvm_unreachable("Quantization precision not supported."); \
} \
} else { \
auto *val = emitConst(builder, value, dest->getElementType()); \
auto *stackedOpCall = createUncheckedCall( \
builder, F, {loopCount, val, pointerNull, pointerNull}); \
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
"buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
} \
break; \
}
ARITHMETIC_UNARY_OP_WITH_IMM_CASE(Splat, "splat", Value);
#undef ARITHMETIC_UNARY_OP_WITH_IMM_CASE
case Kinded::Kind::TouchInstKind:
// do nothing;
break;
case Kinded::Kind::ElementSelectInstKind: {
auto *ES = cast<ElementSelectInst>(I);
auto *dest = ES->getDest();
auto *cond = ES->getCond();
auto *lhs = ES->getLHS();
auto *rhs = ES->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *condPtr = emitBufferAddress(builder, cond, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
// Need _kernel suffix since these operations are implemented as
// "data-parallel" kernels in libjit.
auto *F = getFunction("elementselect_kernel", lhs->getElementType());
if (lhs->getType()->isQuantizedType()) {
auto *destTy = dest->getType();
auto *lhsTy = lhs->getType();
auto *rhsTy = rhs->getType();
auto *destOffset = emitConstI32(builder, destTy->getOffset());
auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
// The selected value will be either lhs = s_l * (i_l - o_l) or
// rhs = s_r * (i_r - o_r); the stored result that must be computed is
// therefore one of:
// (i) i_d = (s_l / s_d) * (i_l - o_l) + o_d
// (ii) i_d = (s_r / s_d) * (i_r - o_r) + o_d
float destScale = destTy->getScale();
auto lhsScaleParams = quantization::quantizeScaleOffset32To8(
lhsTy->getScale() / destScale, lhsTy->getOffset());
auto rhsScaleParams = quantization::quantizeScaleOffset32To8(
rhsTy->getScale() / destScale, rhsTy->getOffset());
auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre);
auto *lhsPost = emitConstI32(builder, lhsScaleParams.post);
auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale);
auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre);
auto *rhsPost = emitConstI32(builder, rhsScaleParams.post);
auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale);
auto *stackedOpCall = createUncheckedCall(
builder, F,
{loopCount, condPtr, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset,
lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
} else {
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, condPtr, lhsPtr, rhsPtr});
auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
}
break;
}
case Kinded::Kind::IntLookupTableInstKind: {
auto *lookupTable = cast<IntLookupTableInst>(I);
auto *dest = lookupTable->getDest();
auto *src = lookupTable->getSrc();
auto *mapping = lookupTable->getMapping();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *mappingPtr =
emitBufferAddress(builder, mapping, kernel, bufferToArgNum);
auto *F = getFunction("intlookuptable_kernel", dest->getElementType());
auto *stackedOpCall =
builder.CreateCall(F, {loopCount, srcPtr, mappingPtr});
auto *destType = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
#define ARITHMETIC_UNARY_OP_CASE(INST_NAME_, FUN_NAME_) \
case Kinded::Kind::INST_NAME_##InstKind: { \
auto *AN = cast<INST_NAME_##Inst>(I); \
auto *dest = AN->getDest(); \
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
auto *srcPtr = \
emitBufferAddress(builder, AN->getSrc(), kernel, bufferToArgNum); \
auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
auto *elementTy = getElementType(builder, dest); \
auto *pointerNull = \
llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
auto *stackedOpCall = createUncheckedCall( \
builder, F, {loopCount, srcPtr, pointerNull, pointerNull}); \
auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, \
loopCount, "buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
break; \
}
ARITHMETIC_UNARY_OP_CASE(Sigmoid, "sigmoid");
ARITHMETIC_UNARY_OP_CASE(Tanh, "tanh");
ARITHMETIC_UNARY_OP_CASE(ElementLog, "element_log");
ARITHMETIC_UNARY_OP_CASE(ElementExp, "element_exp");
ARITHMETIC_UNARY_OP_CASE(ElementAbs, "element_abs");
ARITHMETIC_UNARY_OP_CASE(ElementNeg, "element_neg");
ARITHMETIC_UNARY_OP_CASE(ElementFloor, "element_floor");
ARITHMETIC_UNARY_OP_CASE(ElementCeil, "element_ceil");
ARITHMETIC_UNARY_OP_CASE(ElementRound, "element_round");
ARITHMETIC_UNARY_OP_CASE(ElementSqrt, "element_sqrt");
ARITHMETIC_UNARY_OP_CASE(ElementErf, "element_erf");
ARITHMETIC_UNARY_OP_CASE(ElementRsqrt, "element_rsqrt");
ARITHMETIC_UNARY_OP_CASE(ElementReciprocal, "element_reciprocal");
ARITHMETIC_UNARY_OP_CASE(ElementSin, "element_sin");
ARITHMETIC_UNARY_OP_CASE(ElementCos, "element_cos");
#undef ARITHMETIC_UNARY_OP_CASE
case Kinded::Kind::ReluInstKind: {
auto *RI = cast<ReluInst>(I);
auto *src = RI->getSrc();
auto *dest = RI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto srcTy = src->getType();
auto destTy = dest->getType();
auto *F = getFunction("element_relu", dest->getElementType());
llvm::CallInst *stackedOpCall = nullptr;
if (dest->getElementType() == ElemKind::Int8QTy) {
auto *srcOffset =
emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
auto *destOffset =
emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
auto destScaleParams = quantization::quantizeScaleOffset32To8(
srcTy->getScale() / destTy->getScale(), 0);
auto *destPre = emitConstI32(builder, destScaleParams.pre);
auto *destPost = emitConstI32(builder, destScaleParams.post);
auto *destScale = emitConstI32(builder, destScaleParams.scale);
stackedOpCall = createCall(builder, F,
{loopCount, srcPtr, srcOffset, destOffset,
destPre, destPost, destScale});
} else if (dest->getElementType() == ElemKind::FloatTy) {
stackedOpCall = createCall(builder, F, {loopCount, srcPtr});
} else {
LOG(FATAL) << "Type is not supported";
}
auto *elementTy = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ClipInstKind: {
auto *CI = cast<ClipInst>(I);
auto *src = CI->getSrc();
auto *dest = CI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto srcTy = src->getType();
auto destTy = dest->getType();
float clipMinF = CI->getMin();
float clipMaxF = CI->getMax();
auto *F = getFunction("element_clip", dest->getElementType());
llvm::CallInst *stackedOpCall = nullptr;
if (dest->getElementType() == ElemKind::Int8QTy) {
TensorQuantizationParams srcTQP{src->getType()->getScale(),
src->getType()->getOffset()};
int8_t clipMinQ = quantization::quantize<int8_t>(clipMinF, srcTQP);
int8_t clipMaxQ = quantization::quantize<int8_t>(clipMaxF, srcTQP);
auto *clipMin = emitConstI8(builder, clipMinQ);
auto *clipMax = emitConstI8(builder, clipMaxQ);
auto *srcOffset =
emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
auto *destOffset =
emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
auto destScaleParams = quantization::quantizeScaleOffset32To8(
srcTy->getScale() / destTy->getScale(), 0);
auto *destPre = emitConstI32(builder, destScaleParams.pre);
auto *destPost = emitConstI32(builder, destScaleParams.post);
auto *destScale = emitConstI32(builder, destScaleParams.scale);
stackedOpCall =
createCall(builder, F,
{loopCount, srcPtr, clipMin, clipMax, srcOffset,
destOffset, destPre, destPost, destScale});
} else if (dest->getElementType() == ElemKind::FloatTy) {
auto *clipMin = emitConstF32(builder, clipMinF);
auto *clipMax = emitConstF32(builder, clipMaxF);
stackedOpCall =
createCall(builder, F, {loopCount, srcPtr, clipMin, clipMax});
} else {
LOG(FATAL) << "Type is not supported";
}
auto *elementTy = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::LeakyReluInstKind: {
auto *LI = cast<LeakyReluInst>(I);
auto *src = LI->getSrc();
auto *dest = LI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto srcTy = src->getType();
auto destTy = dest->getType();
auto *F = getFunction("element_leaky_relu", dest->getElementType());
llvm::CallInst *stackedOpCall = nullptr;
if (dest->getElementType() == ElemKind::Int8QTy) {
auto *srcOffset =
emitConstI8(builder, static_cast<int8_t>(srcTy->getOffset()));
auto *destOffset =
emitConstI8(builder, static_cast<int8_t>(destTy->getOffset()));
// Scale parameters for the positive input domain.
auto posParams = quantization::quantizeScaleOffset32To8(
srcTy->getScale() / destTy->getScale(), 0);
auto *posPre = emitConstI32(builder, posParams.pre);
auto *posPost = emitConstI32(builder, posParams.post);
auto *posScale = emitConstI32(builder, posParams.scale);
// Scale parameters for the negative input domain.
auto negParams = quantization::quantizeScaleOffset32To8(
srcTy->getScale() * LI->getAlpha() / destTy->getScale(), 0);
auto *negPre = emitConstI32(builder, negParams.pre);
auto *negPost = emitConstI32(builder, negParams.post);
auto *negScale = emitConstI32(builder, negParams.scale);
stackedOpCall =
createCall(builder, F,
{loopCount, srcPtr, srcOffset, destOffset, posPre, posPost,
posScale, negPre, negPost, negScale});
} else if (dest->getElementType() == ElemKind::FloatTy) {
auto *alpha = emitConstF32(builder, LI->getAlpha());
stackedOpCall = createCall(builder, F, {loopCount, srcPtr, alpha});
} else {
LOG(FATAL) << "Type is not supported";
}
auto *elementTy = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ElementIsNaNInstKind: {
auto *AN = cast<ElementIsNaNInst>(I);
auto *src = AN->getSrc();
auto *dest = AN->getDest();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *F = getFunction("element_is_nan_kernel", src->getElementType());
auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
auto *elementTy = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::QuantizeInstKind: {
auto *QI = cast<QuantizeInst>(I);
auto *src = QI->getSrc();
auto *dest = QI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *destTy = dest->getType();
auto *destScale = emitConstF32(builder, destTy->getScale());
auto *destOffset = emitConstI32(builder, destTy->getOffset());
auto *F = getFunction("element_quantize_kernel", dest->getElementType());
auto *stackedOpCall = createUncheckedCall(
builder, F, {loopCount, srcPtr, destScale, destOffset});
auto *destType = getElementType(builder, dest);
auto *destAddr =
builder.CreateGEP(destType, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::DequantizeInstKind: {
auto *DI = cast<DequantizeInst>(I);
auto *src = DI->getSrc();
auto *dest = DI->getDest();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcTy = src->getType();
auto *srcScale = emitConstF32(builder, srcTy->getScale());
auto *srcOffset = emitConstI32(builder, srcTy->getOffset());
auto *F = getFunction("element_dequantize_kernel", src->getElementType());
auto *stackedOpCall = createUncheckedCall(
builder, F, {loopCount, srcPtr, srcScale, srcOffset});
auto *destAddr = builder.CreateGEP(builder.getFloatTy(), destPtr, loopCount,
"buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::RescaleQuantizedInstKind: {
auto *RQI = cast<RescaleQuantizedInst>(I);
auto *dest = RQI->getDest();
auto *src = RQI->getSrc();
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *destType = dest->getType();
auto *srcType = src->getType();
auto rescaleParams = quantization::quantizeScaleOffset32To8(
srcType->getScale() / destType->getScale(), srcType->getOffset());
auto *destOffset = emitConstI32(builder, destType->getOffset());
auto *srcOffset = emitConstI32(builder, srcType->getOffset());
auto *preShift = emitConstI32(builder, rescaleParams.pre);
auto *postShift = emitConstI32(builder, rescaleParams.post);
auto *scale = emitConstI32(builder, rescaleParams.scale);
auto *F = getFunction("element_rescale_kernel", dest->getElementType());
auto *stackedOpCall = createUncheckedCall(
builder, F,
{loopCount, srcPtr, destOffset, srcOffset, preShift, postShift, scale});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, loopCount,
"buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::CopyInstKind: {
auto *CI = cast<CopyInst>(I);
auto *dest = CI->getDest();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcPtr =
emitBufferAddress(builder, CI->getSrc(), kernel, bufferToArgNum);
auto *F = getFunction("copy_kernel", dest->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *pointerNull =
llvm::ConstantPointerNull::get(elementTy->getPointerTo());
auto *stackedOpCall = createUncheckedCall(
builder, F, {loopCount, srcPtr, pointerNull, pointerNull});
auto *destAddr = builder.CreateGEP(getElementType(builder, dest), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
#define ARITHMETIC_BINARY_OP_CASE(INST_NAME_, FUN_NAME_, ...) \
case Kinded::Kind::INST_NAME_##InstKind: { \
auto *AN = cast<INST_NAME_##Inst>(I); \
auto *dest = AN->getDest(); \
auto *lhs = AN->getLHS(); \
auto *rhs = AN->getRHS(); \
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum); \
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum); \
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum); \
\
auto *F = getFunction(FUN_NAME_ "_kernel", dest->getElementType()); \
auto *elementTy = getElementType(builder, dest); \
auto *pointerNull = \
llvm::ConstantPointerNull::get(elementTy->getPointerTo()); \
bool typesMatched = matchPair(dest->getElementType(), __VA_ARGS__); \
if (lhs->getType()->isQuantizedType()) { \
auto *destTy = dest->getType(); \
auto *lhsTy = lhs->getType(); \
auto *rhsTy = rhs->getType(); \
\
auto *destOffset = emitConstI32(builder, destTy->getOffset()); \
auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset()); \
auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset()); \
\
float destScale = destTy->getScale(); \
\
auto lhsScaleParams = quantization::quantizeScaleOffset32To8( \
lhsTy->getScale() / destScale, lhsTy->getOffset()); \
auto rhsScaleParams = quantization::quantizeScaleOffset32To8( \
rhsTy->getScale() / destScale, rhsTy->getOffset()); \
\
auto *lhsPre = emitConstI32(builder, lhsScaleParams.pre); \
auto *lhsPost = emitConstI32(builder, lhsScaleParams.post); \
auto *lhsScale = emitConstI32(builder, lhsScaleParams.scale); \
auto *rhsPre = emitConstI32(builder, rhsScaleParams.pre); \
auto *rhsPost = emitConstI32(builder, rhsScaleParams.post); \
auto *rhsScale = emitConstI32(builder, rhsScaleParams.scale); \
\
auto *stackedOpCall = createUncheckedCall( \
builder, F, \
{loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset, rhsOffset, \
lhsPre, lhsPost, lhsScale, rhsPre, rhsPost, rhsScale}); \
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr, \
loopCount, "buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
} else if (typesMatched) { \
auto *stackedOpCall = createUncheckedCall( \
builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull}); \
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount, \
"buffer.element.addr"); \
builder.CreateStore(stackedOpCall, destAddr); \
} else { \
llvm_unreachable("Unsupported Type in " #INST_NAME_); \
} \
break; \
}
ARITHMETIC_BINARY_OP_CASE(ElementAdd, "element_add", ElemKind::FloatTy,
ElemKind::Int32ITy, ElemKind::Int64ITy);
ARITHMETIC_BINARY_OP_CASE(ElementSub, "element_sub", ElemKind::FloatTy);
ARITHMETIC_BINARY_OP_CASE(ElementMax, "element_max", ElemKind::FloatTy);
ARITHMETIC_BINARY_OP_CASE(ElementMin, "element_min", ElemKind::FloatTy);
ARITHMETIC_BINARY_OP_CASE(ElementPow, "element_pow", ElemKind::FloatTy);
#undef ARITHMETIC_BINARY_OP_CASE
case Kinded::Kind::ElementNotInstKind: {
auto *NI = cast<ElementNotInst>(I);
auto *dest = NI->getDest();
auto *src = NI->getSrc();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *F = getFunction("element_not_kernel", src->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *stackedOpCall = createUncheckedCall(builder, F, {loopCount, srcPtr});
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ElementAndInstKind: {
auto *AI = cast<ElementAndInst>(I);
auto *dest = AI->getDest();
auto *lhs = AI->getLHS();
auto *rhs = AI->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
auto *F = getFunction("element_and_kernel", lhs->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ElementOrInstKind: {
auto *OI = cast<ElementOrInst>(I);
auto *dest = OI->getDest();
auto *lhs = OI->getLHS();
auto *rhs = OI->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
auto *F = getFunction("element_or_kernel", lhs->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ElementXorInstKind: {
auto *XI = cast<ElementXorInst>(I);
auto *dest = XI->getDest();
auto *lhs = XI->getLHS();
auto *rhs = XI->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
auto *F = getFunction("element_xor_kernel", lhs->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
auto *destAddr =
builder.CreateGEP(elementTy, destPtr, loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
break;
}
case Kinded::Kind::ElementCmpEQInstKind:
case Kinded::Kind::ElementCmpNEQInstKind:
case Kinded::Kind::ElementCmpLTInstKind:
case Kinded::Kind::ElementCmpLTEInstKind: {
Value *dest = nullptr;
Value *lhs = nullptr;
Value *rhs = nullptr;
std::string kernelName;
if (auto *CEQI = dyn_cast<ElementCmpEQInst>(I)) {
dest = CEQI->getDest();
lhs = CEQI->getLHS();
rhs = CEQI->getRHS();
kernelName = "element_cmp_eq_kernel";
} else if (auto *CNEQI = dyn_cast<ElementCmpNEQInst>(I)) {
dest = CNEQI->getDest();
lhs = CNEQI->getLHS();
rhs = CNEQI->getRHS();
kernelName = "element_cmp_neq_kernel";
} else if (auto *CLTEI = dyn_cast<ElementCmpLTEInst>(I)) {
dest = CLTEI->getDest();
lhs = CLTEI->getLHS();
rhs = CLTEI->getRHS();
kernelName = "element_cmp_lte_kernel";
} else if (auto *CLTI = dyn_cast<ElementCmpLTInst>(I)) {
dest = CLTI->getDest();
lhs = CLTI->getLHS();
rhs = CLTI->getRHS();
kernelName = "element_cmp_lt_kernel";
} else {
llvm_unreachable(
"Missmatch between Instruction Kind and instruction instance.");
}
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
// Need _kernel suffix since these operations are implemented as
// "data-parallel" kernels in libjit.
auto *F = getFunction(kernelName.c_str(), lhs->getElementType());
if (lhs->getType()->isQuantizedType()) {
auto *lhsTy = lhs->getType();
auto *rhsTy = rhs->getType();
auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
// We can divide both sides of the comparison by the rhs scale since it is
// strictly positive; this saves one rescale within the backend. The
// inequalities are:
// s_l * (i_l - o_l) <= s_r * (i_r - o_r)
// <=> (s_l / s_r) * (i_l - o_l) <= i_r - o_r
float scale = lhsTy->getScale() / rhsTy->getScale();
auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
auto *cmpPre = emitConstI32(builder, scaleParams.pre);
auto *cmpPost = emitConstI32(builder, scaleParams.post);
auto *cmpScale = emitConstI32(builder, scaleParams.scale);
auto *stackedOpCall =
createUncheckedCall(builder, F,
{loopCount, lhsPtr, rhsPtr, lhsOffset, rhsOffset,
cmpPre, cmpPost, cmpScale});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
} else {
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, lhsPtr, rhsPtr});
auto *elementTy = getElementType(builder, dest);
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
"buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
}
break;
}
case Kinded::Kind::ElementMulInstKind: {
auto *MI = cast<ElementMulInst>(I);
auto *dest = MI->getDest();
auto *lhs = MI->getLHS();
auto *rhs = MI->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
// Need _kernel suffix since these operations are implemented as
// "data-parallel" kernels in libjit.
auto *F = getFunction("element_mul_kernel", dest->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *pointerNull =
llvm::ConstantPointerNull::get(elementTy->getPointerTo());
if (lhs->getType()->isQuantizedType()) {
auto *destTy = dest->getType();
auto *lhsTy = lhs->getType();
auto *rhsTy = rhs->getType();
auto *destOffset = emitConstI32(builder, destTy->getOffset());
auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
// The multiplicative scale factor is s_l * s_r / s_d due to the equation
// s_d * (i_d - o_d) = s_l * (i_l - o_l) * s_r * (i_r - o_r)
// => i_d = (s_l * s_r / s_d) * (i_l - o_l) * (i_r - o_r) + o_d
float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
auto *mulPre = emitConstI32(builder, scaleParams.pre);
auto *mulPost = emitConstI32(builder, scaleParams.post);
auto *mulScale = emitConstI32(builder, scaleParams.scale);
auto *stackedOpCall =
createUncheckedCall(builder, F,
{loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
rhsOffset, mulPre, mulPost, mulScale});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
} else if (lhs->getType()->getElementType() == ElemKind::Int64ITy ||
lhs->getType()->getElementType() == ElemKind::Int32ITy ||
lhs->getType()->getElementType() == ElemKind::FloatTy) {
auto *stackedOpCall = createUncheckedCall(
builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
"buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
} else {
LOG_ASSERT(false) << "Unsupported element type for Mul.";
}
break;
}
case Kinded::Kind::ElementDivInstKind: {
auto *MI = cast<ElementDivInst>(I);
auto *dest = MI->getDest();
auto *lhs = MI->getLHS();
auto *rhs = MI->getRHS();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *lhsPtr = emitBufferAddress(builder, lhs, kernel, bufferToArgNum);
auto *rhsPtr = emitBufferAddress(builder, rhs, kernel, bufferToArgNum);
// Need _kernel suffix since these operations are implemented as
// "data-parallel" kernels in libjit.
auto *F = getFunction("element_div_kernel", dest->getElementType());
auto *elementTy = getElementType(builder, dest);
auto *pointerNull =
llvm::ConstantPointerNull::get(elementTy->getPointerTo());
if (lhs->getType()->isQuantizedType()) {
auto *destTy = dest->getType();
auto *lhsTy = lhs->getType();
auto *rhsTy = rhs->getType();
auto *destOffset = emitConstI32(builder, destTy->getOffset());
auto *lhsOffset = emitConstI32(builder, lhsTy->getOffset());
auto *rhsOffset = emitConstI32(builder, rhsTy->getOffset());
// The division scale factor is s_l / (s_r * s_d) due to the equation
// s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
// => i_d = (s_l / (s_r * s_d)) * ((i_l - o_l) / (i_r - o_r)) + o_d
float scale =
lhsTy->getScale() / (rhsTy->getScale() * destTy->getScale());
auto scaleParams = quantization::quantizeScaleOffset32To8(scale, 0);
auto *divPre = emitConstI32(builder, scaleParams.pre);
auto *divPost = emitConstI32(builder, scaleParams.post);
auto *divScale = emitConstI32(builder, scaleParams.scale);
auto *stackedOpCall =
createUncheckedCall(builder, F,
{loopCount, lhsPtr, rhsPtr, destOffset, lhsOffset,
rhsOffset, divPre, divPost, divScale});
auto *destAddr = builder.CreateGEP(builder.getInt8Ty(), destPtr,
loopCount, "buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
} else {
auto *elementTy = getElementType(builder, dest);
auto *stackedOpCall = createUncheckedCall(
builder, F, {loopCount, lhsPtr, rhsPtr, pointerNull});
auto *destAddr = builder.CreateGEP(elementTy, destPtr, loopCount,
"buffer.element.addr");
builder.CreateStore(stackedOpCall, destAddr);
}
break;
}
case Kinded::Kind::ModuloInstKind: {
auto *MI = cast<ModuloInst>(I);
auto *dest = MI->getDest();
auto *src = MI->getSrc();
auto *destPtr = emitBufferAddress(builder, dest, kernel, bufferToArgNum);
auto *srcPtr = emitBufferAddress(builder, src, kernel, bufferToArgNum);
auto *divisor = emitConst(builder, MI->getDivisor(), ElemKind::Int64ITy);
llvm::Function *F = nullptr;
// Need _kernel suffix since these operations are implemented as
// "data-parallel" kernels in libjit.
if (MI->getSignFollowDivisor()) {
F = getFunction("element_modulo_kernel_sign_follow",
dest->getElementType());
} else {
F = getFunction("element_modulo_kernel_no_sign_follow",
dest->getElementType());
}
auto *stackedOpCall =
createUncheckedCall(builder, F, {loopCount, divisor, srcPtr});
llvm::Value *destAddr = nullptr;
if (dest->getElementType() == ElemKind::Int64ITy) {
destAddr = builder.CreateGEP(builder.getInt64Ty(), destPtr, loopCount,
"buffer.element.addr");
} else {
destAddr = builder.CreateGEP(builder.getInt32Ty(), destPtr, loopCount,
"buffer.element.addr");
}
builder.CreateStore(stackedOpCall, destAddr);
break;
}
default:
std::string sBuf;
llvm::raw_string_ostream s(sBuf);
I->dump(s);
LOG(FATAL) << "Cannot select the instruction: " << s.str();
}
}