Error OpenCLFunction::execute()

in lib/Backends/OpenCL/OpenCL.cpp [748:1674]


Error OpenCLFunction::execute(ExecutionContext *context) {
  auto clBindings = static_cast<runtime::OpenCLDeviceBindings *>(
      context->getDeviceBindings());

  auto deviceBuffer = clBindings->deviceBuffer;
  auto deviceId = clBindings->deviceId;
  auto commands = clBindings->commandQueue;
  auto program = clBindings->program;
  std::vector<KernelLaunch> &kernelLaunches = clBindings->kernelLaunches;

  kernelProfiling_ = clDoProfile || getTraceInfo().autoInstrumented;

  TRACE_EVENT_SCOPE_NAMED(context, TraceLevel::RUNTIME, "enqueueKernels",
                          enqueueEvent);
  for (const auto &I : F_->getInstrs()) {
    // Skip memory allocation instructions as they are NOPs.
    if (isa<AllocActivationInst>(I) || isa<DeallocActivationInst>(I) ||
        isa<TensorViewInst>(I) || isa<TouchInst>(I)) {
      continue;
    }
    // The kernels are named after the name of the instruction, plus the "W"
    // suffix to prevent name colissions for functions like 'tanh' that are also
    // a part of the OpenCL runtime.
    auto elemTy = I.getNumOperands() ? I.getOperand(0).first->getElementType()
                                     : ElemKind::FloatTy;

    // If ElementCmpLTEInst then the first operand is always bool, so instead
    // set the element type based on the LHS input.
    if (auto *LTE = dyn_cast<ElementCmpLTEInst>(&I)) {
      elemTy = LTE->getLHS()->getElementType();
    }

    std::string kernelName = getKernelName(I.getKindName(), elemTy);

    //  Check if the instruction is quantized. Consider an instruction to be
    //  quantized if its destination or source operands are quantized.
    bool isQuantized = I.getNumOperands() &&
                       (I.getOperand(0).first->getType()->isQuantizedType() ||
                        I.getOperand(I.getNumOperands() - 1)
                            .first->getType()
                            ->isQuantizedType());

    // Element-wise operations, except the copy instruction.
    if (I.isDataParallel() && !isa<CopyInst>(I)) {
      // Figure out how many element-wise elements are there to process:
      size_t global;
      if (I.isDataParallel()) {
        global = I.getOperand(0).first->getType()->size();
        // The check for quantization below is a temporary workaround until the
        // corresponding kernels are implemented for the quantized operations.
        if (!isQuantized) {
          if (getPreferredVectorWidth(deviceId, elemTy) == 1) {
            // If the device prefers not to use vector data types, let's not.
          } else if (global % 16 == 0) {
            // Start less kernels and let each kernel do more work using vector
            // instructions.
            global /= 16;
            kernelName += "16";
          } else if (global % 8 == 0) {
            // Start less kernels and let each kernel do more work using vector
            // instructions.
            global /= 8;
            kernelName += "8";
          }
        }
      } else {
        LOG(FATAL) << "Invalid instruction: " << I.getName().str();
      }

      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
      auto numMandatoryArgs = numArgs;
      (void)numMandatoryArgs;

      if (auto *SI = dyn_cast<SplatInst>(&I)) {
        // Pass the splat as a parameter.
        if (!isQuantized) {
          setKernelArg(kernel, ++numArgs, SI->getValue());
        } else {
          auto *destTy = SI->getDest()->getType();
          TensorQuantizationParams destQ{destTy->getScale(),
                                         destTy->getOffset()};
          float val = SI->getValue();
          int8_t int8Val = quantization::quantize(val, destQ);
          setKernelArg<float>(kernel, ++numArgs, static_cast<float>(int8Val));
        }
      }

      if (isQuantized) {
        if (isa<ElementAddInst>(I) || isa<ElementSubInst>(I) ||
            isa<ElementMulInst>(I) || isa<ElementDivInst>(I) ||
            isa<ElementMinInst>(I) || isa<ElementMaxInst>(I)) {
          int32_t destOffset = I.getOperand(0).first->getType()->getOffset();
          float destScale = I.getOperand(0).first->getType()->getScale();

          auto LHSTy = I.getOperand(1).first->getType();
          auto RHSTy = I.getOperand(2).first->getType();

          auto lhsScaleParams = quantization::quantizeScaleOffset32To8(
              LHSTy->getScale() / destScale, LHSTy->getOffset());
          auto rhsScaleParams = quantization::quantizeScaleOffset32To8(
              RHSTy->getScale() / destScale, RHSTy->getOffset());
          setKernelArg(kernel, ++numArgs, destOffset);
          setKernelArg(kernel, ++numArgs, lhsScaleParams);
          setKernelArg(kernel, ++numArgs, rhsScaleParams);
          if (isa<ElementMulInst>(I) || isa<ElementDivInst>(I)) {
            float resultScale =
                isa<ElementMulInst>(I)
                    ? LHSTy->getScale() * RHSTy->getScale() / destScale
                    : LHSTy->getScale() / (RHSTy->getScale() * destScale);
            auto resultScaleParams =
                quantization::quantizeScaleOffset32To8(resultScale, 0);
            setKernelArg(kernel, ++numArgs, resultScaleParams);
          }
        } else if (auto *RI = dyn_cast<ReluInst>(&I)) {
          int32_t destOffset = RI->getDest()->getType()->getOffset();
          float destScale = RI->getDest()->getType()->getScale();

          auto srcTy = RI->getSrc()->getType();

          auto srcScaleParams = quantization::quantizeScaleOffset32To8(
              srcTy->getScale() / destScale, srcTy->getOffset());
          setKernelArg(kernel, ++numArgs, destOffset);
          setKernelArg(kernel, ++numArgs, srcScaleParams);
        }
        // Quantize floating point tensor. Scale and Offset are based on return
        // type of the instruction \p I.
        if (auto *QI = dyn_cast<QuantizeInst>(&I)) {
          float destTensorQuantizationScale =
              QI->getDest()->getType()->getScale();
          int32_t destTensorQuantizationOffset =
              QI->getDest()->getType()->getOffset();
          setKernelArg(kernel, ++numArgs, destTensorQuantizationScale);
          setKernelArg(kernel, ++numArgs, destTensorQuantizationOffset);
        }
        // Rescale quantized tensor. Scale and Offset are based on return type
        // of the instruction \p I.
        if (auto *RQI = dyn_cast<RescaleQuantizedInst>(&I)) {
          auto *dest = RQI->getDest();
          auto *src = RQI->getSrc();
          auto *destType = dest->getType();
          auto *srcType = src->getType();
          auto rescaleParams = quantization::quantizeScaleOffset32To8(
              srcType->getScale() / destType->getScale(), srcType->getOffset());

          setKernelArg(kernel, ++numArgs, destType->getOffset());
          setKernelArg(kernel, ++numArgs, srcType->getOffset());
          setKernelArg(kernel, ++numArgs, rescaleParams);
        }
        // Dequantize integer tensor. Scale and Offset are based
        // on the source tensor type.
        if (auto *QI = dyn_cast<DequantizeInst>(&I)) {
          float srcTensorQuantizationScale =
              QI->getSrc()->getType()->getScale();
          int32_t srcTensorQuantizationOffset =
              QI->getSrc()->getType()->getOffset();
          setKernelArg(kernel, ++numArgs, srcTensorQuantizationScale);
          setKernelArg(kernel, ++numArgs, srcTensorQuantizationOffset);
        }
      }

      if (isQuantized) {
        DCHECK_GT(numArgs, numMandatoryArgs) << "Not enough kernel arguments";
      }
      enqueueKernel(I.getName(), commands, kernel, deviceId, {global},
                    kernelLaunches);
      continue;
    }

    if (auto *SM = dyn_cast<SoftMaxInst>(&I)) {
      // Implement Softmax by parallelizing the batch dimension. Each sample in
      // the batch is processed by a different parallel 'thread'.
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // This is the number of elements for each slice. There are N slices in
      // our batch.
      auto inputDims = SM->getSrc()->getType()->dims();
      size_t numSlices = inputDims[0];

      // Pass the slice size (size of each sample in the batch) as a parameter.
      setKernelArg<cl_uint>(kernel, numArgs + 1, flattenCdr(inputDims).second);

      enqueueKernel(I.getName(), commands, kernel, deviceId, {numSlices},
                    kernelLaunches);
      continue;
    }

    if (auto *SM = dyn_cast<SoftMaxGradInst>(&I)) {
      // Implement Softmax by parallelizing the batch dimension. Each sample in
      // the batch is processed by a different parallel 'thread'.
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // This is the number of elements for each slice. There are N slices in
      // our batch.
      auto inputDims = SM->getSrcGrad()->getType()->dims();
      size_t numSlices = inputDims[0];

      // Pass the slice size (size of each sample in the batch) as a parameter.
      setKernelArg<cl_uint>(kernel, numArgs + 1, flattenCdr(inputDims).second);

      enqueueKernel(I.getName(), commands, kernel, deviceId, {numSlices},
                    kernelLaunches);
      continue;
    }

    if (auto *ET = dyn_cast<ExtractTensorInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // Currently support tensors up to 4 dimensions.
      // TODO: Handle other dimensions.
      assert(ET->getDest()->getType()->dims().size() <= 4);

      ShapeNHWC odim = shapeFromDims(ET->getDest()->getType()->dims());
      ShapeNHWC idim = shapeFromDims(ET->getSrc()->getType()->dims());
      ShapeNHWC offset = shapeFromDims(ET->getOffsets());

      setKernelArg(kernel, numArgs + 1, odim);
      setKernelArg(kernel, numArgs + 2, idim);
      setKernelArg(kernel, numArgs + 3, offset);
      enqueueKernel(I.getName(), commands, kernel, deviceId, {odim.n, odim.h},
                    kernelLaunches);
      continue;
    }

    if (auto *IT = dyn_cast<InsertTensorInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // Currently support tensors of up to 4 dimensions.
      // TODO: Handle other dimensions.
      assert(IT->getDest()->getType()->dims().size() <= 4);

      ShapeNHWC odim = shapeFromDims(IT->getDest()->getType()->dims());
      ShapeNHWC idim = shapeFromDims(IT->getSrc()->getType()->dims());
      ShapeNHWC offset = shapeFromDims(IT->getOffsets());

      setKernelArg(kernel, numArgs + 1, odim);
      setKernelArg(kernel, numArgs + 2, idim);
      setKernelArg(kernel, numArgs + 3, offset);
      setKernelArg<cl_uint>(kernel, numArgs + 4, IT->getCount());
      setKernelArg<cl_uint>(kernel, numArgs + 5, IT->getAxis());
      enqueueKernel(I.getName(), commands, kernel, deviceId, {idim.n, idim.h},
                    kernelLaunches);
      continue;
    }

    if (auto *BMM = dyn_cast<MatMulInst>(&I)) {
      // Size of the tile to be used for matrix multiplication.
      constexpr size_t TILE_DIM = 8;

      // Determine max work groups sizes.
      size_t WIS[3];
      cl_int err = clGetDeviceInfo(deviceId, CL_DEVICE_MAX_WORK_ITEM_SIZES,
                                   sizeof(WIS), &WIS, nullptr);
      CHECK_EQ(err, CL_SUCCESS) << "Could not execute clGetDeviceInfo";
      // True if the tiled matrix multiplication kernel can be used. This is
      // only possible if the device allows workgroups with sizes which are at
      // least as big as a tile.
      bool useTiledMatMul = (WIS[0] >= TILE_DIM && WIS[1] >= TILE_DIM);
      auto tiledKernelName = isQuantized ? "matmul_tiled_i8" : "matmul_tiled";
      cl_kernel kernel =
          createKernel(useTiledMatMul ? tiledKernelName : kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      ShapeNHWC ddim = shapeFromDims(BMM->getDest()->getType()->dims());
      ShapeNHWC ldim = shapeFromDims(BMM->getLHS()->getType()->dims());
      ShapeNHWC rdim = shapeFromDims(BMM->getRHS()->getType()->dims());

      setKernelArg(kernel, numArgs + 1, ddim);
      setKernelArg(kernel, numArgs + 2, ldim);
      setKernelArg(kernel, numArgs + 3, rdim);
      if (isQuantized) {
        auto lhsTy = BMM->getLHS()->getType();
        auto rhsTy = BMM->getRHS()->getType();
        auto destTy = BMM->getDest()->getType();
        auto destScaleParams = quantization::quantizeScaleOffset32To8(
            lhsTy->getScale() * rhsTy->getScale() / destTy->getScale(), 0);
        setKernelArg(kernel, numArgs + 4, lhsTy->getOffset());
        setKernelArg(kernel, numArgs + 5, rhsTy->getOffset());
        setKernelArg(kernel, numArgs + 6, destTy->getOffset());
        setKernelArg(kernel, numArgs + 7, destScaleParams);
      }

      if (useTiledMatMul) {
        std::vector<size_t> local{TILE_DIM, TILE_DIM};
        std::vector<size_t> global{(ddim.n / local[0] + 1) * local[0],
                                   (ddim.h / local[1] + 1) * local[1]};

        enqueueKernel(I.getName(), commands, kernel, deviceId, global, local,
                      kernelLaunches);
      } else {
        enqueueKernel(I.getName(), commands, kernel, deviceId,
                      {ddim.n, ddim.h, ddim.w}, kernelLaunches);
      }
      continue;
    }

    if (auto *BA = dyn_cast<BatchedAddInst>(&I)) {
      if (isQuantized &&
          BA->getSlice()->getType()->getElementType() == ElemKind::Int32QTy) {
        kernelName += "_32";
      }
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      auto bdim = flattenCdr(BA->getBatch()->dims());
      setKernelArg<cl_uint>(kernel, numArgs + 1, bdim.first);
      setKernelArg<cl_uint>(kernel, numArgs + 2, bdim.second);

      if (isQuantized) {
        auto *destTy = BA->getDest()->getType();
        auto *batchTy = BA->getBatch()->getType();
        auto *sliceTy = BA->getSlice()->getType();

        setKernelArg(kernel, numArgs + 3, destTy->getOffset());

        float destScale = destTy->getScale();
        auto batchScaleParams = quantization::quantizeScaleOffset32To8(
            batchTy->getScale() / destScale, batchTy->getOffset());
        auto sliceScaleParams = quantization::quantizeScaleOffset32To8(
            sliceTy->getScale() / destScale, sliceTy->getOffset());

        setKernelArg(kernel, numArgs + 4, batchScaleParams);
        setKernelArg(kernel, numArgs + 5, sliceScaleParams);
      }

      // Parallelize on each element in the slice.
      enqueueKernel(I.getName(), commands, kernel, deviceId, {bdim.second},
                    kernelLaunches);
      continue;
    }

    if (auto *BRA = dyn_cast<OCLBatchedReduceAddInst>(&I)) {
      auto axis = BRA->getAxis();
      auto axisSrcSliceSize = BRA->getAxisSrcSliceSize();

      // Determine and store the slice sizes of each input dimension excluding
      // the reduce axis into batchSliceSizes. Determine also the slice size on
      // the reduce axis and store that separately. These are used by the kernel
      // to index correctly into the input buffer. If the input has one
      // dimension (that is also the reduce axis), store one slice of size 1
      // into batchSliceSizes.
      auto batchDims = BRA->getSrc()->getType()->dims();

      // Determine and store the slice sizes of each output dimension excluding
      // the reduce axis into destSliceSizes. These are used by the kernel to
      // index correctly into the output buffer. If the output has zero
      // dimensions store one slice of size 1 into destSliceSizes.
      auto destDims = BRA->getDest()->getType()->dims();
      std::vector<size_t> destDimsVec(destDims.begin(), destDims.end());
      if (destDims.empty()) {
        destDimsVec.emplace_back(1);
      }

      // Create kernel and set arguments.
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      setKernelArg<cl_uint>(kernel, numArgs + 1, batchDims[axis]);
      setKernelArg<cl_uint>(kernel, numArgs + 2, axisSrcSliceSize);

      // Parallelize on each element in the slice.
      enqueueKernel(I.getName(), commands, kernel, deviceId, destDimsVec,
                    kernelLaunches);
      continue;
    }

    if (auto *LRN = dyn_cast<LocalResponseNormalizationGradInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);

      size_t numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
      ShapeNHWC dim(LRN->getDest()->getType()->dims());

      uint32_t halfWindowSize = LRN->getHalfWindowSize();
      uint32_t windowSize = 2 * halfWindowSize + 1;
      setKernelArg(kernel, ++numArgs, dim);
      setKernelArg(kernel, ++numArgs, halfWindowSize);
      setKernelArg(kernel, ++numArgs, LRN->getK());
      setKernelArg(kernel, ++numArgs, LRN->getBeta());
      setKernelArg(kernel, ++numArgs, LRN->getAlpha() / windowSize);

      enqueueKernel(I.getName(), commands, kernel, deviceId,
                    {dim.n, dim.h, dim.w}, kernelLaunches);
      continue;
    }

    if (auto *LRN = dyn_cast<LocalResponseNormalizationInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);

      size_t numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
      ShapeNHWC dim(LRN->getDest()->getType()->dims());

      uint32_t halfWindowSize = LRN->getHalfWindowSize();
      uint32_t windowSize = 2 * halfWindowSize + 1;
      setKernelArg(kernel, ++numArgs, dim);
      setKernelArg(kernel, ++numArgs, halfWindowSize);
      setKernelArg(kernel, ++numArgs, LRN->getK());
      setKernelArg(kernel, ++numArgs, LRN->getBeta());
      setKernelArg(kernel, ++numArgs, LRN->getAlpha() / windowSize);

      enqueueKernel(I.getName(), commands, kernel, deviceId,
                    {dim.h, dim.w, dim.c}, kernelLaunches);
      continue;
    }

    if (auto *CC = dyn_cast<ConvolutionInst>(&I)) {
      // For OpenCL backend, only NCHW convolution support non-square dilation
      if (CC->getLayout() == NCHW) {
        executeNCHWConvolution(CC, context, clBindings);
        continue;
      }

      if (CC->getFusedActivation() == FusedActivation::RELU) {
        kernelName += "_ReLU";
      }

      // This is a naive implementation that parallelizes using three dims:
      // the X and the Y in the output filter.
      cl_program prog = program;
      auto idim = ShapeNHWC(CC->getSrc()->getType()->dims());
      ShapeHW kdim(CC->getKernels());
      ShapeHW sdim(CC->getStrides());
      auto odim = ShapeNHWC(CC->getDest()->getType()->dims());
      ShapeNHWC kernelSize(CC->getFilter()->getType()->dims());
      auto pads = PaddingTLBR(CC->getPads());

      CHECK_EQ(CC->getDilation()[0], CC->getDilation()[1])
          << "Currently not support non-square dilation here";

      const bool specialize = clSpecializeConvolution && !isQuantized;
      std::string src;
      if (specialize) {
        // Specialize the kernel related to Conv node parameters to enable
        // aggressive constant propagation and other optimizations.
        std::vector<std::string> options;
        addIntOption(options, "CONVK_GROUP", CC->getGroup());
        addIntOption(options, "CONVK_BATCHES", idim.n);
        addIntOption(options, "CONVK_DILATION", CC->getDilation()[0]);
        addIntOption(options, "CONVK_KERNEL_W", kdim.width);
        addIntOption(options, "CONVK_KERNEL_H", kdim.height);
        addIntOption(options, "CONVK_STRIDES_W", sdim.width);
        addIntOption(options, "CONVK_STRIDES_H", sdim.height);
        addIntOption(options, "CONVK_IDIM_W", idim.w);
        addIntOption(options, "CONVK_IDIM_H", idim.h);
        addIntOption(options, "CONVK_IDIM_C", idim.c);
        addIntOption(options, "CONVK_ODIM_W", odim.w);
        addIntOption(options, "CONVK_ODIM_H", odim.h);
        addIntOption(options, "CONVK_ODIM_C", odim.c);
        addIntOption(options, "CONVK_PADS_TOP", pads.top);
        addIntOption(options, "CONVK_PADS_LEFT", pads.left);
        addIntOption(options, "CONVK_FILTER_W", kernelSize.w);
        addIntOption(options, "CONVK_FILTER_H", kernelSize.h);
        addIntOption(options, "CONVK_FILTER_C", kernelSize.c);
        src.append(reinterpret_cast<const char *>(
                       kernels_specialized_no_local_mem_conv_cl_src),
                   kernels_specialized_no_local_mem_conv_cl_src_size);
        prog = createProgram(src, options, commands);
      }

      cl_kernel kernel = createKernel(kernelName, prog);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      if (!specialize) {
        setKernelArg(kernel, numArgs + 1, kdim);
        setKernelArg(kernel, numArgs + 2, sdim);
        setKernelArg(kernel, numArgs + 3, pads);
        setKernelArg(kernel, numArgs + 4, CC->getGroup());
        setKernelArg(kernel, numArgs + 5, CC->getDilation()[0]);
        setKernelArg(kernel, numArgs + 6, odim);
        setKernelArg(kernel, numArgs + 7, idim);
        setKernelArg(kernel, numArgs + 8, kernelSize);

        if (isQuantized) {
          auto srcTy = CC->getSrc()->getType();
          auto destTy = CC->getDest()->getType();
          auto filterTy = CC->getFilter()->getType();
          auto biasTy = CC->getBias()->getType();
          setKernelArg(kernel, numArgs + 9, destTy->getOffset());
          setKernelArg(kernel, numArgs + 10, destTy->getScale());
          setKernelArg(kernel, numArgs + 11, srcTy->getOffset());
          setKernelArg(kernel, numArgs + 12, srcTy->getScale());
          setKernelArg(kernel, numArgs + 13, filterTy->getOffset());
          setKernelArg(kernel, numArgs + 14, filterTy->getScale());
          setKernelArg(kernel, numArgs + 15, biasTy->getOffset());
          setKernelArg(kernel, numArgs + 16, biasTy->getScale());
        }
      }

      // Use a 3D grid where the first dimension is the depth and the second
      // dimension is the slice index in the batch.
      enqueueKernel(I.getName(), commands, kernel, deviceId,
                    {odim.h, odim.w, odim.c}, kernelLaunches);
      continue;
    }

    if (auto *CG = dyn_cast<ConvolutionGradInst>(&I)) {
      auto *src = CG->getSrc();
      auto *destGrad = CG->getDestGrad();
      auto *srcGrad = CG->getSrcGrad();
      auto *filterGrad = CG->getFilterGrad();
      auto *biasGrad = CG->getBiasGrad();
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      auto destGradDim = ShapeNHWC(destGrad->dims());
      auto srcDim = ShapeNHWC(src->dims());
      auto filterGradDim = ShapeNHWC(filterGrad->dims());
      auto pads = PaddingTLBR(CG->getPads());

      CHECK_EQ(CG->getDilation()[0], CG->getDilation()[1])
          << "Currently not support non-square dilation.";

      ShapeHW kdim(CG->getKernels());
      ShapeHW sdim(CG->getStrides());
      setKernelArg(kernel, numArgs + 1, kdim);
      setKernelArg(kernel, numArgs + 2, sdim);
      setKernelArg(kernel, numArgs + 3, pads);
      setKernelArg(kernel, numArgs + 4, CG->getGroup());
      setKernelArg(kernel, numArgs + 5, CG->getDilation()[0]);
      setKernelArg(kernel, numArgs + 6, srcDim);
      setKernelArg(kernel, numArgs + 7, destGradDim);
      setKernelArg(kernel, numArgs + 8, filterGradDim);
      // Zero memory for the output buffers.
      fillBuffer(deviceBuffer, runtimeBundle_.getValueOffset(srcGrad),
                 srcGrad->size(), 0, srcGrad->getElementType(), clBindings);
      fillBuffer(deviceBuffer, runtimeBundle_.getValueOffset(filterGrad),
                 filterGrad->size(), 0, filterGrad->getElementType(),
                 clBindings);
      fillBuffer(deviceBuffer, runtimeBundle_.getValueOffset(biasGrad),
                 biasGrad->size(), 0, biasGrad->getElementType(), clBindings);

      enqueueKernel(I.getName(), commands, kernel, deviceId,
                    {destGradDim.h, destGradDim.w, destGradDim.c},
                    kernelLaunches);
      continue;
    }

    if (auto *PM = dyn_cast<MaxPoolInst>(&I)) {
      bool isNCHW = PM->getLayout() == NCHW;

      if (isNCHW) {
        kernelName = "ocl" + kernelName;
      }

      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      ShapeHW kdim(PM->getKernels());
      ShapeHW sdim(PM->getStrides());
      setKernelArg<cl_uint>(kernel, numArgs + 1, kdim.height);
      setKernelArg<cl_uint>(kernel, numArgs + 2, sdim.height);
      auto pads = PaddingTLBR(PM->getPads());
      setKernelArg(kernel, numArgs + 3, pads);

      std::array<size_t, 3> global;
      if (isNCHW) {
        ShapeNCHW odim(PM->getDest()->getType()->dims());
        ShapeNCHW idim(PM->getSrc()->getType()->dims());

        setKernelArg(kernel, numArgs + 4, odim);
        setKernelArg(kernel, numArgs + 5, idim);
        global = {{odim.h, odim.w, odim.c}};
      } else {
        ShapeNHWC odim(PM->getDest()->getType()->dims());
        ShapeNHWC idim(PM->getSrc()->getType()->dims());
        setKernelArg(kernel, numArgs + 4, odim);
        setKernelArg(kernel, numArgs + 5, idim);
        global = {{odim.h, odim.w, odim.c}};
      }

      enqueueKernel(I.getName(), commands, kernel, deviceId, global,
                    kernelLaunches);
      continue;
    }

    if (auto *PM = dyn_cast<MaxPoolWithArgmaxInst>(&I)) {
      // This is a naive implementation that parallelizes using three dims:
      // the X and the Y in the output filter.
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      auto odim = ShapeNHWC(PM->getDest()->getType()->dims());
      auto idim = ShapeNHWC(PM->getSrc()->getType()->dims());
      auto pads = PaddingTLBR(PM->getPads());
      ShapeHW kdim(PM->getKernels());
      ShapeHW sdim(PM->getStrides());
      setKernelArg<cl_uint>(kernel, numArgs + 1, kdim.height);
      setKernelArg<cl_uint>(kernel, numArgs + 2, sdim.height);
      setKernelArg(kernel, numArgs + 3, pads);
      setKernelArg(kernel, numArgs + 4, odim);
      setKernelArg(kernel, numArgs + 5, idim);

      enqueueKernel(I.getName(), commands, kernel, deviceId,
                    {odim.h, odim.w, odim.c}, kernelLaunches);
      continue;
    }

    if (auto *PMG = dyn_cast<MaxPoolWithArgmaxGradInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      auto destGradDim = ShapeNHWC(PMG->getDestGrad()->dims());
      auto srcGradDim = ShapeNHWC(PMG->getSrcGrad()->dims());
      auto pads = PaddingTLBR(PMG->getPads());
      ShapeHW kdim(PMG->getKernels());
      ShapeHW sdim(PMG->getStrides());
      setKernelArg<cl_uint>(kernel, numArgs + 1, kdim.height);
      setKernelArg<cl_uint>(kernel, numArgs + 2, sdim.height);
      setKernelArg(kernel, numArgs + 3, pads);
      setKernelArg(kernel, numArgs + 4, srcGradDim);
      setKernelArg(kernel, numArgs + 5, destGradDim);

      enqueueKernel(I.getName(), commands, kernel, deviceId, {srcGradDim.n},
                    kernelLaunches);
      continue;
    }

    if (auto *PA = dyn_cast<AvgPoolInst>(&I)) {
      bool isNCHW = PA->getLayout() == NCHW;

      if (isNCHW) {
        kernelName = "ocl" + kernelName;
      }

      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      ShapeHW kdim(PA->getKernels());
      ShapeHW sdim(PA->getStrides());
      setKernelArg<cl_uint>(kernel, numArgs + 1, kdim.height);
      setKernelArg<cl_uint>(kernel, numArgs + 2, sdim.height);
      auto pads = PaddingTLBR(PA->getPads());
      setKernelArg(kernel, numArgs + 3, pads);

      std::array<size_t, 3> global;
      if (isNCHW) {
        ShapeNCHW odim(PA->getDest()->getType()->dims());
        ShapeNCHW idim(PA->getSrc()->getType()->dims());

        setKernelArg(kernel, numArgs + 4, odim);
        setKernelArg(kernel, numArgs + 5, idim);
        global = {{odim.h, odim.w, odim.c}};
      } else {
        ShapeNHWC odim(PA->getDest()->getType()->dims());
        ShapeNHWC idim(PA->getSrc()->getType()->dims());
        setKernelArg(kernel, numArgs + 4, odim);
        setKernelArg(kernel, numArgs + 5, idim);
        global = {{odim.h, odim.w, odim.c}};
      }

      if (isNCHW && isQuantized) {
        auto srcTy = PA->getSrc()->getType();
        auto destTy = PA->getDest()->getType();
        auto destScaleParam = quantization::quantizeScaleOffset32To8(
            srcTy->getScale() / destTy->getScale() /
                (PA->getKernels()[0] * PA->getKernels()[0]),
            destTy->getOffset());
        setKernelArg(kernel, numArgs + 6, srcTy->getOffset());
        setKernelArg(kernel, numArgs + 7, destScaleParam);
      }

      enqueueKernel(I.getName(), commands, kernel, deviceId, global,
                    kernelLaunches);
      continue;
    }

    if (auto *TR = dyn_cast<TransposeInst>(&I)) {
      // This is a naive implementation that parallelizes using one dimension,
      // the N (batch size).
      CHECK_LE(TR->getShuffle().size(), 4)
          << "This code supports only 4 and lower dimensional transposes";

      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // Temporary hack to support 3-dim transposes.
      // TODO: support any dimensional transposes.
      std::vector<dim_t> odim_vec = TR->getDest()->getType()->dims();
      std::vector<dim_t> idim_vec = TR->getSrc()->getType()->dims();
      std::vector<unsigned_t> mask = TR->getShuffle();
      while (mask.size() < 4) {
        odim_vec.push_back(1);
        idim_vec.push_back(1);
        mask.push_back(mask.size());
        continue;
      }

      auto odim = ShapeNHWC(llvm::makeArrayRef(odim_vec));
      auto idim = ShapeNHWC(llvm::makeArrayRef(idim_vec));

      setKernelArg(kernel, numArgs + 1, odim);
      setKernelArg(kernel, numArgs + 2, idim);

      ShapeNHWC shuff(mask[0], mask[1], mask[2], mask[3]);
      setKernelArg(kernel, numArgs + 3, shuff);
      enqueueKernel(I.getName(), commands, kernel, deviceId, {idim.n, idim.h},
                    kernelLaunches);
      continue;
    }

    if (auto *C = dyn_cast<CopyInst>(&I)) {
      Value *dest, *src;
      dest = C->getDest();
      src = C->getSrc();
      if (src == dest) {
        continue;
      }
      size_t destOff = runtimeBundle_.getValueOffset(dest);
      size_t srcOff = runtimeBundle_.getValueOffset(src);
      size_t sizeInBytes = dest->getSizeInBytes();
      cl_event event{nullptr};
      cl_int err = clEnqueueCopyBuffer(commands, deviceBuffer, deviceBuffer,
                                       srcOff, destOff, sizeInBytes, 0, nullptr,
                                       kernelProfiling_ ? &event : nullptr);
      if (kernelProfiling_) {
        kernelLaunches.emplace_back(
            KernelLaunch(I.getName().str(), "copy", event));
      }
      CHECK_EQ(err, CL_SUCCESS) << "Error in clEnqueueCopyBuffer.";
      continue;
    }

    if (auto *GI = dyn_cast<GatherInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);
      unsigned_t axis = GI->getBatchDims();

      auto *data = GI->getData();

      TypeRef dataType = data->getType();
      size_t numIndices = GI->getIndices()->size();

      // The size of the sample in the batch.
      size_t sliceSize = dataType->getSliceSize(axis + 1);
      // The size of the slices that we gather.
      size_t srcSampleSize = dataType->getSliceSize(axis);
      // The size of the slices that we pack.
      size_t destSampleSize = numIndices * sliceSize;
      // The size of each sample in the batch.
      size_t numSamples = dataType->size() / srcSampleSize;

      setKernelArg<cl_uint>(kernel, numArgs + 1, numIndices);
      setKernelArg<cl_uint>(kernel, numArgs + 2, sliceSize);

      // Batch arguments:
      setKernelArg<cl_uint>(kernel, numArgs + 3, numSamples);
      setKernelArg<cl_uint>(kernel, numArgs + 4, destSampleSize);
      setKernelArg<cl_uint>(kernel, numArgs + 5, srcSampleSize);

      enqueueKernel(I.getName(), commands, kernel, deviceId, {numIndices},
                    kernelLaunches);
      continue;
    }

    if (auto *SDI = dyn_cast<ScatterDataInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      setKernelArg(kernel, 0, deviceBuffer);
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      auto *data = SDI->getData();
      size_t dataSliceSize = data->size() / data->dims()[0];
      size_t numIndices = SDI->getIndices()->size();
      setKernelArg<cl_uint>(kernel, numArgs + 1, dataSliceSize);

      enqueueKernel(I.getName(), commands, kernel, deviceId, {numIndices},
                    kernelLaunches);
      continue;
    }

    if (auto *SLWS = dyn_cast<SparseLengthsWeightedSumInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      // Set the device buffer as the first argument.
      setKernelArg(kernel, 0, deviceBuffer);
      // Set all buffer arguments from the instruction (data, dest, weights,
      // indices, lengths) as subsequent arguments.
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // Set the size of one slice of data as the last argument.
      auto *data = SLWS->getData();
      size_t dataSliceSize = data->size() / data->dims()[0];
      setKernelArg<cl_uint>(kernel, numArgs + 1, dataSliceSize);

      // Zero the destination buffer so that the kernel can accumulate (+=) into
      // it.
      auto *dest = SLWS->getDest();
      fillBuffer(deviceBuffer, runtimeBundle_.getValueOffset(dest),
                 dest->size(), 0, dest->getElementType(), clBindings);

      // Get the number of segments. The output for each segment will be
      // computed in parallel by setting the global size equal to the number of
      // segments.
      size_t segments = SLWS->getLengths()->size();

      // Enqueue the kernel.
      enqueueKernel(I.getName(), commands, kernel, deviceId, {segments},
                    kernelLaunches);
      continue;
    }

    if (auto *SLWSG = dyn_cast<SparseLengthsWeightedSumGradInst>(&I)) {
      cl_kernel kernel = createKernel(kernelName, program);
      // Set the device buffer as the first argument.
      setKernelArg(kernel, 0, deviceBuffer);
      // Set all buffer arguments from the instruction (dataGrad, destGrad,
      // weights, indices, lengths) as subsequent arguments.
      auto numArgs = setKernelArgsForBuffers(kernel, I, 1, runtimeBundle_);

      // Set the number of segments as the second last argument.
      auto *lengths = SLWSG->getLengths();
      size_t segments = lengths->size();
      setKernelArg<cl_uint>(kernel, numArgs + 1, segments);

      // Set the size of one slice of destGrad as the last argument.
      auto *destGrad = SLWSG->getDestGrad();
      size_t destGradSliceSize = destGrad->size() / destGrad->dims()[0];
      setKernelArg<cl_uint>(kernel, numArgs + 2, destGradSliceSize);

      // Zero the data gradient buffer so that the kernel can accumulate (+=)
      // into it.
      auto *dataGrad = SLWSG->getDataGrad();
      fillBuffer(deviceBuffer, runtimeBundle_.getValueOffset(dataGrad),
                 dataGrad->size(), 0, dataGrad->getElementType(), clBindings);

      // Enqueue the kernel. Set the global size to 1 so that all segments are
      // processed sequentially to avoid two kernel instances accumulating into
      // the same data gradient slice. This could potentially be relaxed by
      // using an atomic add in the kernel.
      enqueueKernel(I.getName(), commands, kernel, deviceId, {1},
                    kernelLaunches);
      continue;
    }

    if (auto *DP = dyn_cast<DebugPrintInst>(&I)) {
      clFinish(commands);
      auto *V = DP->getSrc();
      // Allocate a temporary tensor to hold the value.
      Tensor T(V->getType());
      // Load the current value of the variable into host memory.
      copyValueFromDevice(V, clBindings, T.getUnsafePtr());
      clFinish(commands);
      llvm::outs() << I.getName() << ": ";
      // Dump the content of a value.
      V->dump();
      llvm::outs() << "\n";
      dumpImpl(&T);
      llvm::outs() << "\n";
      llvm::outs().flush();
      continue;
    }

    if (auto *TE = dyn_cast<TraceEventInst>(&I)) {
      cl_kernel kernel = createKernel("checkpoint", program);
      setKernelArg(kernel, 0, deviceBuffer);

      llvm::SmallVector<size_t, 1> global = {1};
      llvm::SmallVector<size_t, 4> local(global.size(), 0);
      getMaxLocalWorkgroupSize(kernel, deviceId, global, local);

      cl_event event;
      cl_int err =
          clEnqueueNDRangeKernel(commands, kernel, global.size(), nullptr,
                                 &global[0], &local[0], 0, nullptr, &event);
      CHECK_EQ(err, CL_SUCCESS) << "Error in clEnqueueNDRangeKernel.";
      kernelLaunches.push_back(
          KernelLaunch(kernel, TE->getName().str(), "checkpoint", event));
      continue;
    }

    // For TopKInst, we perform the computation on the host side, as sorting on
    // GPU is complex and we may not get too much benefit from it. We copy the
    // tensor from GPU memory to host memory, perform the computation, and then
    // copy the results back to GPU memory.
    if (auto *TK = dyn_cast<TopKInst>(&I)) {
      clFinish(commands);
      auto *destDev = TK->getValues();
      auto *indDev = TK->getIndices();
      auto *srcDev = TK->getInput();
      Tensor destT(destDev->getType());
      Tensor indT(indDev->getType());
      Tensor srcT(srcDev->getType());
      size_t k = TK->getK();

      copyValueFromDevice(srcDev, clBindings, srcT.getUnsafePtr());
      clFinish(commands);

      if (srcDev->getType()->isQuantizedType() ||
          destDev->getType()->isQuantizedType()) {
        topK<int8_t>(destT, indT, srcT, k);
      } else {
        topK<float>(destT, indT, srcT, k);
      }
      copyValueToDevice(destDev, clBindings, destT.getUnsafePtr());
      copyValueToDevice(indDev, clBindings, indT.getUnsafePtr());
      clFinish(commands);
      continue;
    }

    LOG(FATAL) << "Compilation failed, cannot select: " << I.getKindName();
  }

  enqueueEvent.end();

  clFinish(commands);

  return Error::success();
}