Error Caffe2ModelLoader::loadOperator()

in lib/Importer/Caffe2ModelLoader.cpp [812:2265]


Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
  ArgumentDictionaryTy dict = loadArgumentMap(op);
  const std::string &typeName = op.type();
  mod_.registerOriginalName(op.name());

  // Check if operator is supported in parent class, CommonOperatorLoader.
  bool loadCommonOperatorSuccess;
  ASSIGN_VALUE_OR_RETURN_ERR(loadCommonOperatorSuccess,
                             tryLoadCommonOperator(typeName, op, dict));
  if (loadCommonOperatorSuccess) {
    return Error::success();
  }
  const std::string &opName = loadOperatorName(op);

  if (typeName == "Gelu") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    Node *node = G_->createGelu(opName, in);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Conv" || typeName == "ConvRelu") {
    return loadConv(op, dict);
  }

  if (typeName == "Softmax") {
    return loadSoftmax(op, dict);
  }

  if (typeName == "PRelu") {
    return loadPRelu(op, dict);
  }

  if (typeName == "ConvTranspose") {
    return loadConvTranspose(op, dict);
  }

  if (typeName == "Int8Conv" || typeName == "Int8ConvRelu") {
    return loadConvQuantized(op, dict);
  }

  if (typeName == "LayerNorm") {
    return loadLayerNorm(op, dict);
  }

  if (typeName == "Int8SumRelu") {
    RETURN_ERR_IF_NOT(op.input_size() == 2,
                      opErrMsg(op, "Only Sum of 2 inputs is supported."));
    RETURN_ERR_IF_NOT(
        dict.count("Y_zero_point"),
        opErrMsg(op, "missing zero point for quantized outout type"));
    RETURN_ERR_IF_NOT(
        dict.count("Y_scale"),
        opErrMsg(op, "missing Y_scale for quantized output type"));
    NodeValue in0;
    ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
    NodeValue in1;
    ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
    auto outDims = in0.getType()->dims();
    TypeRef outTy;
    ASSIGN_VALUE_OR_RETURN_ERR(
        outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
    auto *add = G_->createAdd(opName + ".sum", outTy, in0, in1);
    auto *relu = G_->createRELU(opName + ".relu", add);
    RETURN_IF_ERR(addNodeAsOutput(op, relu));
    return Error::success();
  }

  if (typeName == "Int8Relu") {
    RETURN_ERR_IF_NOT(op.input_size() == 1,
                      opErrMsg(op, "Only one input is supported."));
    RETURN_ERR_IF_NOT(
        dict.count("Y_zero_point"),
        opErrMsg(op, "missing zero point for quantized outout type"));
    RETURN_ERR_IF_NOT(
        dict.count("Y_scale"),
        opErrMsg(op, "missing Y_scale for quantized output type"));
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto outDims = in.getType()->dims();
    TypeRef outTy;
    ASSIGN_VALUE_OR_RETURN_ERR(
        outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
    auto *relu = G_->createRELU(opName, in, outTy);
    RETURN_IF_ERR(addNodeAsOutput(op, relu));
    return Error::success();
  }

  if (typeName == "Int8Quantize") {
    RETURN_ERR_IF_NOT(
        op.input_size() == 1,
        opErrMsg(op, "Glow only supports Int8Quantize with 1 input"));
    RETURN_ERR_IF_NOT(
        dict.count("Y_zero_point"),
        opErrMsg(op, "missing zero point for quantized output type"));
    RETURN_ERR_IF_NOT(
        dict.count("Y_scale"),
        opErrMsg(op, "missing Y_scale for quantized output type"));
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto outDims = in.getType()->dims();
    TypeRef outTy;
    ASSIGN_VALUE_OR_RETURN_ERR(
        outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
    Node *N = G_->createQuantize(opName, in, outTy);
    RETURN_IF_ERR(addNodeAsOutput(op, N));
    return Error::success();
  }

  if (typeName == "Int8Dequantize") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto *node = G_->createDequantize(opName, in, ElemKind::FloatTy);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "MaxPool" || typeName == "AveragePool" ||
      typeName == "Int8MaxPool" || typeName == "Int8AveragePool") {
    // Load the inputs:
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    std::vector<unsigned_t> strides;
    ASSIGN_VALUE_OR_RETURN_ERR(strides, getSizeHW(dict, "stride", 1));
    std::vector<unsigned_t> kernels;
    ASSIGN_VALUE_OR_RETURN_ERR(kernels, getSizeHW(dict, "kernel", 0));
    std::vector<unsigned_t> pads;
    ASSIGN_VALUE_OR_RETURN_ERR(pads, getPads(dict));
    bool countIncludePads;
    ASSIGN_VALUE_OR_RETURN_ERR(
        countIncludePads, getCountIncludePads(dict, /* defaultValue */ true));
    std::string order = "NCHW";
    if (dict.count("order")) {
      ASSIGN_VALUE_OR_RETURN_ERR(order, loadStr(dict["order"]));
    }
    // We expect the input to be NHWC.
    NodeValue finalIn;
    if (order == "NCHW") {
      finalIn = G_->createTranspose(opName, in, NCHW2NHWC)->getResult();
    } else {
      finalIn = in;
    }

    // If 'global_pooling' is set then the operation will pool over the size
    // of the input by doing: kernels = {height, width}.
    if (dict.count("global_pooling")) {
      auto Ty = in.getType();
      kernels[0] = Ty->dims()[2];
      kernels[1] = Ty->dims()[3];
    }

    // Check the padding style.
    if (dict.count("legacy_pad")) {
      int mode;
      ASSIGN_VALUE_OR_RETURN_ERR(mode, loadInt(dict["legacy_pad"]));
      // Caffe1 (legacy) rounded-up and Caffe2 rounds down.
      // This style is deprecated according to caffe2's caffe2_legacy.proto
      // definition.
      if (static_cast<LegacyPaddingMode>(mode) ==
          LegacyPaddingMode::CAFFE_LEGACY_POOLING) {
        return MAKE_ERR(opErrMsg(op,
                                 "MaxPool nodes with legacy caffe padding are "
                                 "deprecated and not supported."));
      }
    }

    Node *node = nullptr;

    if (typeName == "Int8MaxPool" || typeName == "Int8AveragePool") {
      // Create the node with quantized type.
      RETURN_ERR_IF_NOT(
          dict.count("Y_zero_point"),
          opErrMsg(op, "missing zero point for quantized output type"));
      RETURN_ERR_IF_NOT(
          dict.count("Y_scale"),
          opErrMsg(op, "missing Y_scale for quantized output type"));

      TypeRef finalInType = finalIn.getType();
      ShapeNHWC idim = ShapeNHWC(finalInType->dims());
      auto outSz =
          calculateConvPoolOutputDims(idim.h, idim.w, kernels, strides, pads);
      std::array<dim_t, 4> outDims = {
          {idim.n, outSz.first, outSz.second, idim.c}};
      if (typeName == "Int8MaxPool") {
        // Int8Maxpool output quantization should be same as the input, so
        // just ignore the given params.
        node = G_->createMaxPool(opName, finalIn, kernels, strides, pads);
      } else {
        TypeRef outTy;
        ASSIGN_VALUE_OR_RETURN_ERR(
            outTy, loadQuantTy(opName, ElemKind::Int8QTy, outDims, dict));
        node = G_->createAvgPool(opName, finalIn, outTy, kernels, strides, pads,
                                 NHWC, countIncludePads);
      }
    } else if (typeName == "MaxPool") {
      node = G_->createMaxPool(opName, finalIn, kernels, strides, pads);
    } else {
      node = G_->createAvgPool(opName, finalIn, kernels, strides, pads, NHWC,
                               countIncludePads);
    }
    if (order == "NCHW") {
      unsigned resIdx = 0;
      if (llvm::isa<MaxPoolNode>(node)) {
        resIdx = MaxPoolNode::ResultIdx;
      } else if (llvm::isa<AvgPoolNode>(node)) {
        resIdx = AvgPoolNode::ResultIdx;
      } else {
        return MAKE_ERR("Expected either Max or Avg Pool.");
      }
      // Transpose the output back.
      node = G_->createTranspose(opName, node->getNthResult(resIdx), NHWC2NCHW);
    }
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "SpatialBN") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    Constant *scale;
    ASSIGN_VALUE_OR_RETURN_ERR(scale, getConstantByName(op.input(1)));
    Constant *bias;
    ASSIGN_VALUE_OR_RETURN_ERR(bias, getConstantByName(op.input(2)));
    Constant *mean;
    ASSIGN_VALUE_OR_RETURN_ERR(mean, getConstantByName(op.input(3)));
    Constant *var;
    ASSIGN_VALUE_OR_RETURN_ERR(var, getConstantByName(op.input(4)));
    float epsilon = 1e-5f; // default
    auto epsilonIt = dict.find("epsilon");
    if (epsilonIt != dict.end()) {
      ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(epsilonIt->second));
    }

    unsigned_t channel;
    ASSIGN_VALUE_OR_RETURN_ERR(channel, getChannel(dict));
    auto *node = G_->createBatchNormalization(
        opName, in.getType(), in, bias, scale, mean, var, channel, epsilon);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Bucketize") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(
        dict.count("boundaries"),
        opErrMsg(op, "Bucketize: Expected a boundaries member vector"));
    std::vector<float> boundaries;
    ASSIGN_VALUE_OR_RETURN_ERR(boundaries, getFloats(dict["boundaries"]));
    auto *node = G_->createBucketizeNode(opName, in, boundaries);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "ResizeNearest") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    std::string order = "NCHW";
    if (dict.count("order")) {
      ASSIGN_VALUE_OR_RETURN_ERR(order, loadStr(dict["order"]));
    }
    // We expect the input to be NHWC.
    NodeValue finalIn;
    if (order == "NCHW") {
      finalIn = G_->createTranspose(opName, in, NCHW2NHWC)->getResult();
    } else {
      finalIn = in;
    }

    float heightScale;
    ASSIGN_VALUE_OR_RETURN_ERR(heightScale, loadFloat(dict["height_scale"]));
    float widthScale;
    ASSIGN_VALUE_OR_RETURN_ERR(widthScale, loadFloat(dict["width_scale"]));

    std::vector<float> scales;
    scales.push_back(1.0f);
    scales.push_back(heightScale);
    scales.push_back(widthScale);
    scales.push_back(1.0f);

    auto *node = G_->createResizeNearest(opName, finalIn, scales);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Concat") {
    const unsigned numInputs = op.input_size();
    llvm::SmallVector<NodeValue, 4> inputs;
    inputs.reserve(numInputs);
    for (unsigned i = 0; i < numInputs; i++) {
      NodeValue in;
      ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
      inputs.push_back(std::move(in));
    }

    // If axis exists it takes priority over channel.
    unsigned_t channel;
    if (dict.count("axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(channel, loadInt(dict["axis"]));
    } else {
      ASSIGN_VALUE_OR_RETURN_ERR(channel, getChannel(dict));
    }

    unsigned_t addAxis = 0;
    if (dict.count("add_axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(addAxis, loadInt(dict["add_axis"]));
    }

    Node *node{nullptr};

    if (addAxis) {
      // When add axis is used, this means we have to add a new dimension
      // before the axis, instead of merging on the axis.
      std::vector<dim_t> outputDims = inputs[0].dims();

      if (channel < outputDims.size()) {
        unsigned i = 0;
        for (const auto &input : inputs) {
          RETURN_ERR_IF_NOT(
              outputDims[channel] == input.dims()[channel],
              opErrMsg(op,
                       strFormat("inputs need all to have the same dims for "
                                 "concat with add_axis: input 0 (%s) vs "
                                 "input %u (%s), %u vs %u, channel = %u",
                                 op.input(0).c_str(), i, op.input(i).c_str(),
                                 static_cast<unsigned>(outputDims[channel]),
                                 static_cast<unsigned>(input.dims()[channel]),
                                 channel)));
          ++i;
        }
        outputDims.insert(outputDims.begin() + channel, numInputs);
        node = G_->createConcat(opName, inputs, channel);
        node = G_->createReshape(opName, node, outputDims);
      } else if (channel == outputDims.size()) {
        // We convert inputs into 2D arrays with single columns, thus the
        // number of rows will be equal to the product of all original dims.
        // Every converted input will look like a vertical line of numbers.
        const auto flatVerticalShape = flattenCdr(inputs[0].dims(), channel);
        llvm::SmallVector<NodeValue, 4> verticalInputs;
        for (auto &input : inputs) {
          verticalInputs.push_back(G_->createReshape(
              opName, input,
              {flatVerticalShape.first, flatVerticalShape.second}));
        }

        // We glue together the vertical lines, so, the number of columns
        // becomes equal to the number of original inputs.
        node = G_->createConcat(opName, verticalInputs, 1);

        // Reshape to convert to desired shape.
        outputDims.push_back(numInputs);
        node = G_->createReshape(opName, node, outputDims);
      } else {
        return MAKE_ERR(opErrMsg(
            op, strFormat("Invalid input: channel (=%u) > number of dims (=%u)",
                          channel, static_cast<unsigned>(outputDims.size()))));
      }
    } else {
      // In normal case (i.e. when we are not adding a new dimension)
      // plain createConcat() would suffice.
      node = G_->createConcat(opName, inputs, channel);
    }

    // If we add the axis then node is a Reshape, otherwise it should be
    // Concat.
    RETURN_ERR_IF_NOT(
        llvm::isa<ConcatNode>(node) || llvm::isa<ReshapeNode>(node),
        opErrMsg(op,
                 "Internal error: Node should either be a Concat or Reshape."));
    NodeValue finalNode = llvm::isa<ConcatNode>(node)
                              ? NodeValue(node, ConcatNode::ResultIdx)
                              : NodeValue(node, ReshapeNode::ResultIdx);
    nodeValueByName_[op.output(0)] = finalNode;
    // Concat may have a second output in Caffe2 (split_info), but we don't
    // use it for inference
    return Error::success();
  }

  if (typeName == "FC" || typeName == "FCTransposed" || typeName == "Int8FC" ||
      typeName == "FbFCPacked") {
    RETURN_ERR_IF_NOT(op.input_size() == 3,
                      "Glow only suports FC with 3 inputs");
    // Load the inputs:
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    auto originalInputDims = in.getType()->dims();

    size_t axis = 1;
    if (dict.count("axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
    }

    // Load weights.
    unsigned_t axis_w = 1;
    if (dict.count("axis_w")) {
      ASSIGN_VALUE_OR_RETURN_ERR(axis_w, loadInt(dict["axis_w"]));
    }

    NodeValue W;
    if (hasConstantByName(op.input(1))) {
      ASSIGN_VALUE_OR_RETURN_ERR(W, getConstantByName(op.input(1)));
    } else {
      ASSIGN_VALUE_OR_RETURN_ERR(W, getNodeValueByName(op.input(1)));
    }

    // Caffe2 stores the transposed W matrix. In here we first coerce W to a
    // 2D matrix size if necessary and then transpose it back.
    auto wDims = flattenCdr(W.dims(), axis_w);
    if (W.dims().size() > 2) {
      W = G_->createReshape(W.getNode()->getName(), W,
                            {wDims.first, wDims.second});
    }

    if (typeName == "FC" || typeName == "Int8FC" || typeName == "FbFCPacked") {
      W = G_->createTranspose(W.getNode()->getName(), W, {1, 0});
    }

    NodeValue B;
    if (hasConstantByName(op.input(2))) {
      ASSIGN_VALUE_OR_RETURN_ERR(B, getConstantByName(op.input(2)));
    } else {
      ASSIGN_VALUE_OR_RETURN_ERR(B, getNodeValueByName(op.input(2)));
    }

    Node *node = nullptr;
    if (typeName == "Int8FC") {
      // Create a node with quantized type.
      auto outputDims = flattenCdr(in.dims(), axis);
      TypeRef outTy;
      ASSIGN_VALUE_OR_RETURN_ERR(
          outTy, loadQuantTy(opName, ElemKind::Int8QTy,
                             {outputDims.first, B.dims()[0]}, dict));
      int dequantizeOutput = 0;
      if (dict.count("dequantize_output")) {
        ASSIGN_VALUE_OR_RETURN_ERR(dequantizeOutput,
                                   loadInt(dict["dequantize_output"]));
      }
      if (dequantizeOutput == 1) {
        node = G_->createDynamicQuantizedFullyConnected(opName, in, W, B);
      } else {
        node = G_->createFullyConnected(opName, in, W, B, outTy, axis);
      }
    } else if (typeName == "FbFCPacked") {
      RETURN_ERR_IF_NOT(W.getElementType() == ElemKind::Float16Ty,
                        opErrMsg(op, "Expected float16 weights."));
      auto fp16InputType =
          mod_.uniqueType(ElemKind::Float16Ty, in.getType()->dims());
      in = G_->createConvertTo(opName + ".ConvertInput", in, fp16InputType);

      auto fp16BiasType = mod_.uniqueType(ElemKind::Float16Ty, B.dims());
      auto *fp16Bias =
          G_->createConvertTo(opName + ".ConvertBias", B, fp16BiasType);

      auto outputDims = flattenCdr(in.dims(), axis);
      TypeRef OT =
          mod_.uniqueType(ElemKind::Float16Ty, {outputDims.first, B.dims()[0]});
      auto fc = G_->createFullyConnected(opName, in, W, fp16Bias, OT, axis);
      auto outputType =
          mod_.uniqueType(ElemKind::FloatTy, fc->getResult().dims());
      node = G_->createConvertTo(opName + ".ConvertOutput", fc, outputType);
    } else {
      auto outputDims = flattenCdr(in.dims(), axis);
      TypeRef outputType =
          mod_.uniqueType(ElemKind::FloatTy, {outputDims.first, B.dims()[0]});
      node = G_->createFullyConnected(opName, in, W, B, outputType, axis);
    }

    // If number of original input dims is greater than 2, expand the output
    // dims back with the same axis.
    if (axis != 1) {
      llvm::SmallVector<dim_t, max_tensor_dimensions> reshapeDims;
      size_t totalReshapeSize = 1;
      for (size_t i = 0; i < axis; ++i) {
        auto d = originalInputDims[i];
        reshapeDims.push_back(d);
        totalReshapeSize *= static_cast<dim_t>(d);
      }

      size_t finalDim = typeName == "FCTransposed" ? wDims.second : wDims.first;

      reshapeDims.push_back(finalDim);
      totalReshapeSize *= finalDim;

      size_t totalOriginalOutputSize = node->getNthResult(0).getType()->size();
      RETURN_ERR_IF_NOT(
          totalReshapeSize == totalOriginalOutputSize,
          opErrMsg(op, strFormat("Cannot reshape from size %lu to size %lu",
                                 totalOriginalOutputSize, totalReshapeSize)));

      node = G_->createReshape(opName + ".fc.out", node, reshapeDims);
    }

    // Save the outputs:
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "ChannelShuffle") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    size_t group;
    ASSIGN_VALUE_OR_RETURN_ERR(group, loadInt(dict["group"]));
    size_t kernel;
    ASSIGN_VALUE_OR_RETURN_ERR(kernel, loadInt(dict["kernel"]));

    Node *node = G_->createChannelShuffle(opName, in, group, kernel);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Squeeze") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    std::vector<dim_t> dims;
    ASSIGN_VALUE_OR_RETURN_ERR(dims, getShape<dim_t>(dict["dims"]));
    Node *node = G_->createSqueeze(opName, in, dims);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Log") {
    // Load the inputs:
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    // Create the log:
    auto *R = G_->createLog(opName, in);
    RETURN_IF_ERR(addNodeAsOutput(op, R));
    return Error::success();
  }

  if (typeName == "Swish") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto *S = G_->createSwish(opName, in);
    RETURN_IF_ERR(addNodeAsOutput(op, S));
    return Error::success();
  }

  if (typeName == "Logit") {
    // Load the input and (optional) epsilon clamping value:
    NodeValue input;
    ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
    auto epsIt = dict.find("eps");
    // default: 1e-6 (as in Caffe2)
    float eps = 1E-6f;
    if (epsIt != dict.end()) {
      ASSIGN_VALUE_OR_RETURN_ERR(eps, loadFloat(epsIt->second));
    }

    auto *node = G_->createLogit(opName, input, eps);
    // Save the outputs:
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "EQ") {
    NodeValue in0;
    ASSIGN_VALUE_OR_RETURN_ERR(in0, getNodeValueByName(op.input(0)));
    NodeValue in1;
    ASSIGN_VALUE_OR_RETURN_ERR(in1, getNodeValueByName(op.input(1)));
    auto *node = G_->createCmpEQ(opName, in0, in1);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Tile") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    unsigned_t tiles;
    ASSIGN_VALUE_OR_RETURN_ERR(tiles, loadInt(dict["tiles"]));
    unsigned_t axis;
    ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));

    auto *node = G_->createTile(opName, in, tiles, axis);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Free") {
    // Glow frees memory automatically.
    return Error::success();
  }
  if (typeName == "StopGradient" || typeName == "ScaleGradient") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    // Currently Caffe2 importer only supports inference.
    RETURN_IF_ERR(addNodeAsOutput(op, in));
    return Error::success();
  }

  if (typeName == "Transpose") {
    RETURN_IF_ERR(loadTranspose(op, dict, "axes"));
    return Error::success();
  }

  if (typeName == "NCHW2NHWC") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto *node = G_->createTranspose(opName, in, NCHW2NHWC);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "CopyCPUToMKL" || typeName == "CopyMKLToCPU" ||
      typeName == "Copy" || typeName == "EnsureCPUOutput" ||
      typeName == "EnsureDense" || typeName == "Dropout") {
    // Glow does not support any of these ops now, so implement them as
    // no-ops. Note: Implement this as a no-op reshape because these ops may
    // have partition information, and we need a node to maintain the parent
    // Function partition it specified. This reshape will get eliminated later
    // on during graph optimizations.
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    ReshapeNode *RN = G_->createReshape(in.getNode()->getName(), in, in.dims());
    RETURN_IF_ERR(addNodeAsOutput(op, RN));
    return Error::success();
  }

  if (typeName == "Slice") {
    NodeValue data;
    ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));

    std::vector<ssize_t> starts;
    ASSIGN_VALUE_OR_RETURN_ERR(starts, getShape<ssize_t>(dict["starts"]));
    std::vector<ssize_t> ends;
    ASSIGN_VALUE_OR_RETURN_ERR(ends, getShape<ssize_t>(dict["ends"]));

    std::vector<dim_t> newStarts, newEnds;
    RETURN_ERR_IF_NOT(
        starts.size() == ends.size(),
        opErrMsg(op, strFormat(
                         "Slice starts %lu and %lu ends must be the same size.",
                         starts.size(), ends.size())));
    for (size_t i = 0; i < starts.size(); i++) {
      ssize_t newStart = starts[i];
      if (newStart == -1) {
        newStart = data.dims()[i];
      }
      RETURN_ERR_IF_NOT(
          newStart >= 0,
          opErrMsg(op,
                   strFormat("Indices should never be negative, but found %lu ",
                             newStart)));
      newStarts.push_back(newStart);

      ssize_t newEnd = ends[i];
      if (newEnd == -1) {
        newEnd = data.dims()[i];
      }
      RETURN_ERR_IF_NOT(
          newEnd >= 0,
          opErrMsg(op,
                   strFormat("Indices should never be negative, but found %lu ",
                             newEnd)));
      newEnds.push_back(newEnd);
    }

    Node *SN = G_->createSlice(opName, data, newStarts, newEnds);
    RETURN_IF_ERR(addNodeAsOutput(op, SN));
    return Error::success();
  }

  if (typeName == "Clip") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    float cmin = std::numeric_limits<float>::lowest();
    if (dict.count("min")) {
      ASSIGN_VALUE_OR_RETURN_ERR(cmin, loadFloat(dict.find("min")->second));
    }

    float cmax = std::numeric_limits<float>::max();
    if (dict.count("max")) {
      ASSIGN_VALUE_OR_RETURN_ERR(cmax, loadFloat(dict.find("max")->second));
    }

    auto *node = G_->createClip(loadOperatorName(op), in, cmin, cmax);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "MatMul") {
    RETURN_IF_ERR(loadMatMul(op, dict));
    return Error::success();
  }

  if (typeName == "Cast") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    int to;
    ASSIGN_VALUE_OR_RETURN_ERR(to, loadInt(dict["to"]));

    switch (to) {
    case caffe2::TensorProto_DataType_FLOAT: {
      RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::FloatTy,
                        opErrMsg(op, "Can only cast float to float."));
      break;
    }
    case caffe2::TensorProto_DataType_INT32: {
      RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::Int32ITy,
                        opErrMsg(op, "Can only cast int32 to int32."));
      break;
    }
    case caffe2::TensorProto_DataType_INT64: {
      RETURN_ERR_IF_NOT(in.getElementType() == ElemKind::Int64ITy,
                        opErrMsg(op, "Can only cast int64 to int64."));
      break;
    }
    default:
      return MAKE_ERR(opErrMsg(op, "Unsupported Cast type."));
    }

    RETURN_IF_ERR(addNodeAsOutput(op, in));
    return Error::success();
  }

  if (typeName == "HalfToFloat") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto convertedType =
        mod_.uniqueType(ElemKind::FloatTy, in.getType()->dims());
    auto *R = G_->createConvertTo(opName + ".ConvertInput", in, convertedType);
    RETURN_IF_ERR(addNodeAsOutput(op, R));
    return Error::success();
  }

  if (typeName == "ScatterAssign") {
    NodeValue data;
    ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
    NodeValue indices;
    ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
    NodeValue slices;
    ASSIGN_VALUE_OR_RETURN_ERR(slices, getNodeValueByName(op.input(2)));

    assert(indices.dims().size() == 1 && "Indices should be 1-dimensional!");
    NodeValue indices2D = G_->createReshape(opName + ".indices.2d", indices,
                                            {indices.dims()[0], 1});
    Node *SAN = G_->createScatterData(opName, data, indices2D, slices);
    RETURN_IF_ERR(addNodeAsOutput(op, SAN));
    return Error::success();
  }

  if (typeName == "ConstantFill" || typeName == "GivenTensorIntFill" ||
      typeName == "GivenTensorInt64Fill" || typeName == "GaussianFill" ||
      typeName == "UniformFill") {
    RETURN_IF_ERR(loadWeight(op));
    return Error::success();
  }

  if (typeName == "SigmoidCrossEntropyWithLogits") {
    NodeValue logits;
    ASSIGN_VALUE_OR_RETURN_ERR(logits, getNodeValueByName(op.input(0)));
    NodeValue targets;
    ASSIGN_VALUE_OR_RETURN_ERR(targets, getNodeValueByName(op.input(1)));
    Node *SCEL =
        G_->createSigmoidCrossEntropyWithLogits(opName, logits, targets);
    RETURN_IF_ERR(addNodeAsOutput(op, SCEL));
    return Error::success();
  }

  if (typeName == "ElementwiseLinear") {
    NodeValue X, w, b;

    // If the axis argument does not exist in the protobuf, the default
    // value should be 1.
    unsigned axis = 1;

    ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
    ASSIGN_VALUE_OR_RETURN_ERR(w, getNodeValueByName(op.input(1)));
    ASSIGN_VALUE_OR_RETURN_ERR(b, getNodeValueByName(op.input(2)));

    if (dict.count("axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
    }

    Node *EL = G_->createElementwiseLinear(opName, X, w, b, axis);
    RETURN_IF_ERR(addNodeAsOutput(op, EL));
    return Error::success();
  }

  if (typeName == "AveragedLoss") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    auto *node = G_->createBatchedReduceMean(opName, in, 0);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Mod") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    int64_t divisor;
    ASSIGN_VALUE_OR_RETURN_ERR(divisor, loadInt(dict["divisor"]));

    RETURN_ERR_IF_NOT(
        divisor >= 1,
        opErrMsg(op,
                 strFormat("Divisor must not be less than 1, but found %ld ",
                           divisor)));

    bool signFollowDivisor = false;
    if (dict.count("sign_follow_divisor")) {
      ASSIGN_VALUE_OR_RETURN_ERR(signFollowDivisor,
                                 loadInt(dict["sign_follow_divisor"]));
    }

    auto *node = G_->createModulo(opName, in, divisor, signFollowDivisor);
    RETURN_IF_ERR(addNodeAsOutput(op, node));

    return Error::success();
  }

  if (typeName == "Scale") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    float scale = 1.0;
    if (dict.count("scale")) {
      ASSIGN_VALUE_OR_RETURN_ERR(scale, loadFloat(dict["scale"]));
    }
    auto scaleType = mod_.uniqueType(ElemKind::FloatTy, {in.dims()});
    auto scales = G_->createSplat(opName + ".scales", scaleType, scale);
    Node *node = G_->createMul(opName, in, scales);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "SparseLengthsWeightedSum8BitsRowwise" ||
      typeName == "SparseLengthsSum8BitsRowwise" ||
      typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
      typeName == "SparseLengthsSumFused8BitRowwise" ||
      typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
      typeName == "SparseLengthsSumFused4BitRowwise") {
    const bool isWeighted =
        typeName == "SparseLengthsWeightedSum8BitsRowwise" ||
        typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
        typeName == "SparseLengthsWeightedSumFused4BitRowwise";
    const bool isFused =
        typeName == "SparseLengthsWeightedSumFused8BitRowwise" ||
        typeName == "SparseLengthsSumFused8BitRowwise" ||
        typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
        typeName == "SparseLengthsSumFused4BitRowwise";
    const bool is4Bit =
        typeName == "SparseLengthsWeightedSumFused4BitRowwise" ||
        typeName == "SparseLengthsSumFused4BitRowwise";
    // If weighted, then the weights are the second input and so we need to
    // shift indices/lengths/scalesBiases.
    size_t indicesIdx = 1;
    size_t lengthsIdx = 2;
    size_t scalesBiasesIdx = 3;
    if (isWeighted) {
      indicesIdx++;
      lengthsIdx++;
      scalesBiasesIdx++;
    }

    NodeValue data;
    ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
    NodeValue weights;
    if (isWeighted) {
      ASSIGN_VALUE_OR_RETURN_ERR(weights, getNodeValueByName(op.input(1)));
    }
    NodeValue indices;
    ASSIGN_VALUE_OR_RETURN_ERR(indices,
                               getNodeValueByName(op.input(indicesIdx)));
    NodeValue lengths;
    ASSIGN_VALUE_OR_RETURN_ERR(lengths,
                               getNodeValueByName(op.input(lengthsIdx)));
    Storage *dataS = llvm::dyn_cast<Storage>(data);

    const dim_t numRows = data.dims()[0];

    // Make sure all the shapes make sense.
    RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
                      opErrMsg(op, "lengths must be a vector."));
    RETURN_ERR_IF_NOT(indices.dims().size() == 1,
                      opErrMsg(op, "indices must be a vector."));

    LengthsMode lengthsMode;
    ASSIGN_VALUE_OR_RETURN_ERR(lengthsMode, getLengthsMode(dict));

    float avgLength;
    ASSIGN_VALUE_OR_RETURN_ERR(avgLength, getAvgLength(dict));

    Node *node;
    if (isFused) {
      RETURN_IF_ERR(setFusedTy(dataS, is4Bit ? ElemKind::UInt4FusedFP16QTy
                                             : ElemKind::UInt8FusedQTy));

      // No other work to do, since the data is already loaded fused, so just
      // create the new node with its inputs.
      if (isWeighted) {
        node = G_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
            opName, dataS, weights, indices, lengths,
            /* useFP16Accumulation */ false, lengthsMode, avgLength);
      } else {
        node = G_->createFusedRowwiseQuantizedSparseLengthsSum(
            opName, dataS, indices, lengths, /* useFP16Accumulation */ false,
            lengthsMode, avgLength);
      }

      if (is4Bit) {
        node = G_->createConvertTo(opName, node, ElemKind::FloatTy);
      }
    } else {
      NodeValue scalesBiases;
      ASSIGN_VALUE_OR_RETURN_ERR(scalesBiases,
                                 getNodeValueByName(op.input(scalesBiasesIdx)));

      Constant *scalesBiasesC = llvm::dyn_cast<Constant>(scalesBiases);
      RETURN_ERR_IF_NOT(scalesBiasesC,
                        opErrMsg(op, "scales_biases must be Constant."));
      RETURN_ERR_IF_NOT(scalesBiases.dims().size() == 2,
                        opErrMsg(op, "scale_bias has to be a matrix."));
      RETURN_ERR_IF_NOT(
          scalesBiases.dims()[0] == numRows,
          opErrMsg(
              op,
              strFormat("scale_bias must have the same number of rows as data, "
                        "but found scale_bias %d and rows %d ",
                        int(scalesBiases.dims()[0]), int(numRows))));
      RETURN_ERR_IF_NOT(
          scalesBiases.dims()[1] == 2,
          opErrMsg(op,
                   strFormat("Second dim of scale_bias has to be equal to 2 "
                             "but found %d ",
                             int(scalesBiases.dims()[1]))));

      // Now strip out the scales and biases into their own tensors.
      NodeValue sliceScales =
          G_->createSlice(scalesBiasesC->getName().str() + "_scale",
                          scalesBiasesC, {0, 0}, {numRows, 1});
      NodeValue sliceBiases =
          G_->createSlice(scalesBiasesC->getName().str() + "_bias",
                          scalesBiasesC, {0, 1}, {numRows, 2});
      sliceScales =
          G_->createReshape(sliceScales.getNode()->getName().str() + "_1D",
                            sliceScales, {numRows});
      sliceBiases =
          G_->createReshape(sliceBiases.getNode()->getName().str() + "_1D",
                            sliceBiases, {numRows});

      // Now create the actual node.
      if (isWeighted) {
        node = G_->createRowwiseQuantizedSparseLengthsWeightedSum(
            opName, dataS, sliceScales, sliceBiases, weights, indices, lengths,
            /* precision */ ElemKind::FloatTy,
            /* useFP16Accumulation */ false, lengthsMode, avgLength);
      } else {
        node = G_->createRowwiseQuantizedSparseLengthsSum(
            opName, dataS, sliceScales, sliceBiases, indices, lengths,
            /* precision */ ElemKind::FloatTy,
            /* useFP16Accumulation */ false, lengthsMode, avgLength);
      }
    }

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "LengthsRangeFill") {
    NodeValue lengths;
    ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
                      opErrMsg(op, "lengths must be a 1D vector."));

    auto maxOutputSizeIt = dict.find("maxOutputSize");
    RETURN_ERR_IF_NOT(
        maxOutputSizeIt != dict.end(),
        opErrMsg(op, "Require maxOutputSize when loading LengthsRangeFill."));
    unsigned_t maxOutputSize;
    ASSIGN_VALUE_OR_RETURN_ERR(maxOutputSize, loadInt(maxOutputSizeIt->second));

    auto *LRF = G_->createLengthsRangeFill(opName, lengths, maxOutputSize);
    RETURN_IF_ERR(addNodeAsOutput(op, LRF));

    return Error::success();
  }

  // TODO: add checks for number of inputs and argument values
  if (typeName == "ReduceBackSum") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(in.dims().size() >= 2,
                      opErrMsg(op, "Input should be at least 2D."));
    Node *node = G_->createBatchedReduceAdd(opName, in, in.dims().size() - 1);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "RMSNorm") {
    NodeValue X, gamma, beta;
    ASSIGN_VALUE_OR_RETURN_ERR(X, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(X.dims().size() == 2,
                      opErrMsg(op, "X should be a 2D tensor."));
    ASSIGN_VALUE_OR_RETURN_ERR(gamma, getNodeValueByName(op.input(1)));
    RETURN_ERR_IF_NOT(gamma.dims().size() == 1,
                      opErrMsg(op, "gamma should be a 1D tensor."));
    ASSIGN_VALUE_OR_RETURN_ERR(beta, getNodeValueByName(op.input(2)));
    RETURN_ERR_IF_NOT(beta.dims().size() == 1,
                      opErrMsg(op, "beta should be a 1D tensor."));

    float epsilon = .0f;
    if (dict.count("eps")) {
      ASSIGN_VALUE_OR_RETURN_ERR(epsilon, loadFloat(dict["eps"]));
    }

    auto nodes = G_->createRMSNorm(opName, X, gamma, beta, epsilon);
    nodeValueByName_[op.output(0)] = nodes[0];
    nodeValueByName_[op.output(1)] = nodes[1];
    return Error::success();
  }

  if (typeName == "Mean") {
    const unsigned numInputs = op.input_size();
    RETURN_ERR_IF_NOT(numInputs > 0,
                      opErrMsg(op, "Expect at least one input."));

    std::vector<NodeValue> inputs;
    inputs.reserve(numInputs);
    for (unsigned i = 0; i < numInputs; i++) {
      NodeValue in;
      ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
      inputs.push_back(std::move(in));
    }

    // Check that all inputs have the same shape
    const auto shape = inputs[0].dims();
    for (unsigned i = 1; i < numInputs; i++) {
      RETURN_ERR_IF_NOT(
          shape == inputs[i].dims(),
          opErrMsg(op,
                   "All inputs should have the same shape, violating input " +
                       op.input(i)));
    }

    if (numInputs == 1) {
      RETURN_IF_ERR(addNodeAsOutput(op, inputs[0]));
      return Error::success();
    }

    Node *node = G_->createConcat(opName + ".concat", inputs, 0);

    std::vector<dim_t> newShape{numInputs};
    newShape.insert(newShape.end(), shape.begin(), shape.end());
    node = G_->createReshape(opName + ".reshape", node, newShape);

    node = G_->createBatchedReduceMean(opName + ".reduceMean", node, 0);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Negative") {
    RETURN_IF_ERR(loadNeg(op, dict));
    return Error::success();
  }

  if (typeName == "LpNorm") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    int p = 2;
    if (dict.count("p")) {
      ASSIGN_VALUE_OR_RETURN_ERR(p, loadInt(dict["p"]));
      RETURN_ERR_IF_NOT(p == 1 || p == 2,
                        opErrMsg(op, "p should be either 1 or 2."));
    }
    bool average = false;
    if (dict.count("average")) {
      ASSIGN_VALUE_OR_RETURN_ERR(average, loadInt(dict["average"]));
    }
    RETURN_ERR_IF_NOT(!average, opErrMsg(op, "average is not supported."));

    Node *node = nullptr;
    if (p == 1) {
      node = G_->createAbs(opName, in);
    } else {
      node = G_->createPow(opName, in, 2);
    }

    const auto dims1D = flattenCdr(in.dims(), in.dims().size());
    node = G_->createReshape(opName + ".reshape1D", node, dims1D.first);

    auto outputType = mod_.uniqueType(in.getElementType(), {1});
    node = G_->createBatchedReduceAdd(opName + ".sum", outputType, node, 0);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "ArgMin") {
    NodeValue input;
    ASSIGN_VALUE_OR_RETURN_ERR(input, getNodeValueByName(op.input(0)));
    int axis = 0;
    if (dict.count("axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(axis, loadInt(dict["axis"]));
    }
    bool keepDims = true;
    if (dict.count("keepdims")) {
      ASSIGN_VALUE_OR_RETURN_ERR(keepDims, loadInt(dict.at("keepdims")));
    }

    auto node = G_->createArgMin(opName, input, axis, keepDims);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Sign") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    Node *zeroes = G_->createSplat(opName + ".zeroes", in.getType(), 0.f);

    Node *isPos = G_->createCmpLT(opName + ".isPos", zeroes, in);
    Node *isNeg = G_->createCmpLT(opName + ".isNeg", in, zeroes);

    Node *posOnes = G_->createSplat(opName + ".posOnes", in.getType(), 1);
    Node *negOnes = G_->createSplat(opName + ".negOnes", in.getType(), -1);

    Node *node = G_->createSelect(opName + ".fillPos", isPos, posOnes, zeroes);
    node = G_->createSelect(opName + ".fillNeg", isNeg, negOnes, node);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "Softplus") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    Node *node = G_->createSoftPlus(opName, in);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "TopK") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(
        op.input_size() <= 2,
        opErrMsg(
            op,
            strFormat(
                "TopK: Maximum number of inputs is 2, but found input size %d ",
                op.input_size())));
    unsigned_t k = 0;
    if (op.input_size() > 1) {
      Constant *kConst = getConstantByNameOrNull(op.input(1));
      RETURN_ERR_IF_NOT(
          kConst,
          opErrMsg(op, "TopK: Non-constant k is not supported by Glow."));
      RETURN_ERR_IF_NOT(
          kConst->getElementType() == ElemKind::Int64ITy,
          opErrMsg(op, strFormat(
                           "TopK: k input must be of type Int64, but found "
                           "input type '%s' ",
                           kConst->getType()->getElementName().str().c_str())));
      auto constH = kConst->getPayload().getHandle<int64_t>();
      k = constH.at({0});
    } else {
      ASSIGN_VALUE_OR_RETURN_ERR(k, loadInt(dict["k"]));
    }

    int lastDim = in.dims().size() - 1;
    int axis = lastDim;
    if (dict.count("axis")) {
      ASSIGN_VALUE_OR_RETURN_ERR(axis,
                                 loadAxis<int>(dict["axis"], in.dims().size()));
    }

    RETURN_ERR_IF_NOT(
        axis == lastDim,
        opErrMsg(
            op,
            strFormat(
                "TopK: Currently only support axis %d being last dimension %d ",
                axis, lastDim)));

    TopKNode *R = G_->createTopK(opName, in, k, ElemKind::Int32ITy);
    RETURN_IF_ERR(addNodeAsOutput(op, R));
    return Error::success();
  }

  if (typeName == "FillExamplesWithIndicator") {
    // Support FillExamplesWithIndicator
    NodeValue data;
    ASSIGN_VALUE_OR_RETURN_ERR(data, getNodeValueByName(op.input(0)));
    NodeValue indicator;
    ASSIGN_VALUE_OR_RETURN_ERR(indicator, getNodeValueByName(op.input(1)));
    // Validating input types and shapes
    RETURN_ERR_IF_NOT(
        indicator.getElementType() == ElemKind::Int32ITy ||
            indicator.getElementType() == ElemKind::Int64ITy,
        opErrMsg(op, "Indicator should be of int32 or int64 type."));
    RETURN_ERR_IF_NOT(indicator.dims().size() == 1,
                      opErrMsg(op, "Indicator should be 1D tensor."));
    dim_t dataReshapeDim = flattenCdr(data.dims()).second;
    ShapeVector outDims{indicator.dims()[0]};
    outDims.insert(outDims.end(), data.dims().begin() + 1, data.dims().end());
    auto outTy2D = mod_.uniqueTypeWithNewShape(
        data.getType(), {indicator.dims()[0], dataReshapeDim});

    auto data2D = G_->createReshape(opName + ".data2D", data,
                                    {data.dims()[0], dataReshapeDim});
    if (indicator.getElementType() == ElemKind::Int64ITy) {
      indicator = G_->createConvertTo(opName + ".int64ToInt32", indicator,
                                      ElemKind::Int32ITy);
    }
    // Select only takes boolean indicators, and converting from int to bool
    // must go from int -> float -> bool. Due to fp16 clipping, since only
    // int32 -> fp16 conversions are available, there is an initial conversion
    // from int64 to int32 if necessary.
    auto indicatorFloat = G_->createConvertTo(opName + ".intToFloat", indicator,
                                              ElemKind::FloatTy);
    auto indicatorBool = G_->createConvertTo(opName + ".floatToBool",
                                             indicatorFloat, ElemKind::BoolTy);
    auto nzIndices = G_->createNonZero(opName + ".nonzero", indicatorBool);

    auto nzIndicesFixed = fixNonZero(G_, mod_, opName, nzIndices);
    auto nonZeroCount = data.dims()[0];
    RETURN_ERR_IF_NOT(nonZeroCount <= nzIndicesFixed->getNthResult(0).dims()[0],
                      opErrMsg(op,
                               "The number of "
                               "non-zero elements in the indicator must be at "
                               "least that of the first dimension of data"));

    auto indices = G_->createSlice(opName + ".indices", nzIndicesFixed, {0, 0},
                                   {data.dims()[0], 1});

    auto zeros = G_->createSplat(opName + ".zeros", outTy2D, 0);

    auto res2D = G_->createScatterData(opName + ".scatterData", zeros, indices,
                                       data2D, true);
    auto node = G_->createReshape(opName + ".result", res2D, outDims);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "BatchSparseToDense") {
    // Support BatchSparseToDense for output second dim = 1 only
    NodeValue lengths;
    ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
    NodeValue indices;
    ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
    NodeValue values;
    ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));

    dim_t denseLastDim = 1;
    if (dict.count("dense_last_dim")) {
      ASSIGN_VALUE_OR_RETURN_ERR(denseLastDim,
                                 loadInt(dict.at("dense_last_dim")));
    }

    RETURN_ERR_IF_NOT(
        denseLastDim == 1,
        opErrMsg(op, "Only output second dimension = 1 supported"));
    // Validating input types and shapes
    RETURN_ERR_IF_NOT(
        lengths.getElementType() == ElemKind::Int32ITy ||
            lengths.getElementType() == ElemKind::Int64ITy,
        opErrMsg(op, "Lengths should be of int32 or int64 type."));
    RETURN_ERR_IF_NOT(lengths.dims().size() == 1,
                      opErrMsg(op, "Lengths should be 1D tensor."));
    RETURN_ERR_IF_NOT(
        indices.getElementType() == ElemKind::Int32ITy ||
            indices.getElementType() == ElemKind::Int64ITy,
        opErrMsg(op, "Indices should be of int32 or int64 type."));
    RETURN_ERR_IF_NOT(indices.dims().size() == 1,
                      opErrMsg(op, "Indices should be 1D tensor."));
    RETURN_ERR_IF_NOT(values.getElementType() == ElemKind::FloatTy,
                      opErrMsg(op, "Values should be of float type."));
    RETURN_ERR_IF_NOT(
        indices.dims()[0] == values.dims()[0],
        opErrMsg(op, "There should be the same number of values as indices."));

    float defaultValue = 0.0;
    if (dict.count("default_value")) {
      ASSIGN_VALUE_OR_RETURN_ERR(defaultValue,
                                 loadFloat(dict.at("default_value")));
    }
    // Select only takes boolean indicators, and converting from int to bool
    // must go from int -> float -> bool. Due to fp16 clipping, since only
    // int32 -> fp16 conversions are available, there is an initial conversion
    // from int64 to int32 if necessary.
    if (lengths.getElementType() == ElemKind::Int64ITy) {
      lengths = G_->createConvertTo(opName + ".int64ToInt32", lengths,
                                    ElemKind::Int32ITy);
    }
    auto lengthsIntToFloat =
        G_->createConvertTo(opName + ".intToFloat", lengths, ElemKind::FloatTy);
    auto lengthsFloatToBool = G_->createConvertTo(
        opName + ".floatToBool", lengthsIntToFloat, ElemKind::BoolTy);
    auto nonZeroIndices =
        G_->createNonZero(opName + ".nonzero", lengthsFloatToBool);
    auto nonZeroIndicesFixed = fixNonZero(G_, mod_, opName, nonZeroIndices);
    auto numIndices = indices.dims()[0];
    auto indicesSliced = G_->createSlice(
        opName + ".indicesSlice", nonZeroIndicesFixed, {0, 0}, {numIndices, 1});

    ShapeVector outDims{lengths.dims()[0], 1};
    auto dataTy = mod_.uniqueTypeWithNewShape(values.getType(), outDims);
    auto data = G_->createSplat(opName + ".data", dataTy, defaultValue);
    auto values2D =
        G_->createReshape(opName + ".reshape", values, {numIndices, 1});
    auto scatterData = G_->createScatterData(opName + ".scatterData", data,
                                             indicesSliced, values2D, true);

    RETURN_IF_ERR(addNodeAsOutput(op, scatterData));
    return Error::success();
  }

  if (typeName == "SparseLabelSplit") {
    NodeValue lengths;
    ASSIGN_VALUE_OR_RETURN_ERR(lengths, getNodeValueByName(op.input(0)));
    NodeValue indices;
    ASSIGN_VALUE_OR_RETURN_ERR(indices, getNodeValueByName(op.input(1)));
    NodeValue values;
    ASSIGN_VALUE_OR_RETURN_ERR(values, getNodeValueByName(op.input(2)));

    dim_t numLabels = 0;
    RETURN_ERR_IF_NOT(dict.count("num_labels"),
                      opErrMsg(op, "num_labels was not provided."));
    ASSIGN_VALUE_OR_RETURN_ERR(numLabels, loadInt(dict.at("num_labels")));

    bool keepGradientOffsetMap = false;
    if (dict.count("keep_gradient_offset_map")) {
      ASSIGN_VALUE_OR_RETURN_ERR(keepGradientOffsetMap,
                                 loadInt(dict.at("keep_gradient_offset_map")));
    }

    // Validating input types and shapes
    RETURN_ERR_IF_NOT(lengths.getElementType() == ElemKind::Int32ITy,
                      opErrMsg(op, "Lengths should be of int32 type."));
    RETURN_ERR_IF_NOT(lengths.dims().size() == 1 || lengths.dims().size() == 2,
                      opErrMsg(op, "Lengths should be 1D or 2D tensor."));
    RETURN_ERR_IF_NOT(indices.getElementType() == ElemKind::Int64ITy,
                      opErrMsg(op, "Indices should be of int64 type."));
    RETURN_ERR_IF_NOT(indices.dims().size() == 1 || indices.dims().size() == 2,
                      opErrMsg(op, "Indices should be 1D or 2D tensor."));
    RETURN_ERR_IF_NOT(values.getElementType() == ElemKind::FloatTy,
                      opErrMsg(op, "Values should be of float type."));
    RETURN_ERR_IF_NOT(values.dims().size() == 1 || values.dims().size() == 2,
                      opErrMsg(op, "Values should be 1D or 2D tensor."));
    RETURN_ERR_IF_NOT(
        indices.dims() == values.dims(),
        opErrMsg(op, "Indices and values should have the same shape."));

    // Optional conversion from 2D to 1D inputs
    if (lengths.dims().size() == 2) {
      RETURN_ERR_IF_NOT(
          lengths.dims()[1] == 1,
          opErrMsg(op, "Second dimension should be 1 in lengths."));
      lengths = G_->createReshape(opName + ".lengths1D", lengths,
                                  {lengths.dims()[0]});
    }
    if (indices.dims().size() == 2) {
      RETURN_ERR_IF_NOT(
          indices.dims()[1] == 1,
          opErrMsg(op, "Second dimension should be 1 in indices."));
      indices = G_->createReshape(opName + ".indices1D", indices,
                                  {indices.dims()[0]});
    }
    if (values.dims().size() == 2) {
      RETURN_ERR_IF_NOT(
          values.dims()[1] == 1,
          opErrMsg(op, "Second dimension should be 1 in values."));
      values =
          G_->createReshape(opName + ".values1D", values, {values.dims()[0]});
    }

    SparseLabelSplitNode *node =
        G_->createSparseLabelSplit(opName, lengths, indices, values, numLabels);

    std::vector<SliceNode *> labelValueSlices;
    G_->createSplit(opName + ".splitLabelValues",
                    node->getNthResult(SparseLabelSplitNode::LabelValuesIdx),
                    numLabels, 0, {}, labelValueSlices);

    std::vector<SliceNode *> exampleIdSlices;
    G_->createSplit(opName + ".splitExampleIds",
                    node->getNthResult(SparseLabelSplitNode::ExampleIdsIdx),
                    numLabels, 0, {}, exampleIdSlices);

    const auto numItems = indices.dims()[0] / numLabels;

    std::vector<Node *> labelValues;
    for (auto slice : labelValueSlices) {
      labelValues.push_back(
          G_->createReshape(opName + ".reshapeLabelValue", slice, {numItems}));
    }

    std::vector<Node *> exampleIds;
    for (auto slice : exampleIdSlices) {
      exampleIds.push_back(
          G_->createReshape(opName + ".reshapeExamplId", slice, {numItems}));
    }

    for (dim_t i = 0; i < numLabels; ++i) {
      nodeValueByName_[op.output(i)] = labelValues[i];
    }
    for (dim_t i = 0; i < numLabels; ++i) {
      nodeValueByName_[op.output(numLabels + i)] = exampleIds[i];
    }
    if (keepGradientOffsetMap) {
      nodeValueByName_[op.output(2 * numLabels)] =
          node->getNthResult(SparseLabelSplitNode::GradientOffsetMapIdx);
    }
    return Error::success();
  }

  if (typeName == "Log1p") {
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));

    Node *ones = G_->createSplat(opName + ".ones", in.getType(), 1.0f);
    Node *add = G_->createAdd(opName + ".add", in, ones);
    Node *node = G_->createLog(opName + ".log", add);

    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  if (typeName == "ReduceBackMean") {
    const unsigned numInputs = op.input_size();
    RETURN_ERR_IF_NOT(numInputs == 1,
                      opErrMsg(op, "Only single input is supported."));

    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    RETURN_ERR_IF_NOT(in.dims().size() >= 2,
                      opErrMsg(op, "Input should be at least 2D."));

    int numReduceDim = 1;
    if (dict.count("num_reduce_dim")) {
      ASSIGN_VALUE_OR_RETURN_ERR(numReduceDim, loadInt(dict["num_reduce_dim"]));
    }
    // TODO: check maybe we can support more dimensions to be reduced
    RETURN_ERR_IF_NOT(numReduceDim == 1,
                      opErrMsg(op, "Supporting reducing only one dimension."));

    Node *node = G_->createBatchedReduceMean(opName, in, in.dims().size() - 1);
    RETURN_IF_ERR(addNodeAsOutput(op, node));
    return Error::success();
  }

  return MAKE_ERR(unexpectedNodeErrorMessage(op, "Unsupported operator."));
}