Error ONNXModelLoader::loadOperator()

in lib/Importer/ONNXModelLoader.cpp [5700:6000]


Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
  ArgumentDictionaryTy dict = loadArgumentMap(op);
  const std::string &typeName = op.op_type();
  mod_.registerOriginalName(op.name());

  if (useGlowCustomOps_) {
    Node *loadedNode;
    ASSIGN_VALUE_OR_RETURN_ERR(loadedNode,
                               tryLoadGlowCustomOp(typeName, op, dict));
    if (loadedNode) {
      if (!perNodeOpts_) {
        return Error::success();
      }
      return loadPerNodeOptions(loadedNode, *perNodeOpts_, dict);
    }

    // These are handled earlier when loading initializers and inputs and so can
    // be safely ignored here.
    if (typeName == constFoldSubgraphNodeName ||
        typeName == staticPHDummyNodeName) {
      return Error::success();
    }

    // Identity is the only official ONNX op used with useGlowCustomOps. Let it
    // fall through to logic to handle below, otherwise return error.
    if (typeName != "Identity") {
      return MAKE_ERR("Failed to load operator " + typeName + " .",
                      ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
    }
  }

  // Check if operator is supported in parent class, CommonOperatorLoader.
  bool tryLoadCommonOperatorResult;
  ASSIGN_VALUE_OR_RETURN_ERR(tryLoadCommonOperatorResult,
                             tryLoadCommonOperator(typeName, op, dict));
  if (tryLoadCommonOperatorResult) {
    return Error::success();
  }
  if (typeName == "Loop") {
    return loadLoop(op, dict);
  }

  if (typeName == "Constant") {
    return loadConstant(op, dict);
  }
  if (typeName == "Range") {
    return loadRange(op, dict);
  }
  if (typeName == "PRelu") {
    return loadPRelu(op, dict);
  }
  if (typeName == "Slice") {
    return loadSlice(op, dict);
  }
  if (typeName == "Sin" || typeName == "Cos") {
    return loadTrigonometricOps(typeName, op, dict);
  }
  if (typeName == "Erf") {
    return loadErf(op, dict);
  }
  if (typeName == "Conv") {
    // If the Conv operator has quantized inputs and
    // dict contains the scale and offset params, use
    // loadTensorwiseQuantizedConvolution.
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    return in.getType()->isQuantizedType() && dict.count("out_scale") &&
                   dict.count("out_offset")
               ? loadTensorwiseQuantizedConvolution(op, dict)
               : loadConv(op, dict);
  }
  if (typeName == "ChannelwiseQuantizedConvolution") {
    return loadChannelwiseQuantizedConvolution(op, dict);
  }
  if (typeName == "MaxPool" || typeName == "AveragePool") {
    // If the pool operator has quantized inputs, use
    // loadTensorwiseQuantizedPool.
    NodeValue in;
    ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(0)));
    return in.getType()->isQuantizedType() && dict.count("out_scale") &&
                   dict.count("out_offset")
               ? loadTensorwiseQuantizedPool(op, dict, typeName)
               : loadPool(op, dict, typeName);
  }
  if (typeName == "GlobalAveragePool") {
    return loadGlobalAveragePool(op, dict);
  }
  if (typeName == "Squeeze") {
    return loadSqueeze(op, dict);
  }
  if (typeName == "Unsqueeze") {
    return loadUnsqueeze(op, dict);
  }
  if (typeName == "BatchNormalization") {
    return loadBatchNormalization(op, dict);
  }
  if (typeName == "InstanceNormalization") {
    return loadInstanceNormalization(op, dict);
  }
  if (typeName == "Concat") {
    return loadConcat(op, dict);
  }
  if (typeName == "FCTransposed") {
    return loadFCTransposed(op, dict);
  }
  if (typeName == "Gemm") {
    return loadGemm(op, dict);
  }
  if (typeName == "Transpose") {
    return loadTranspose(op, dict, "perm");
  }
  if (typeName == "ReduceSumSquare") {
    return loadReduceOp(typeName, op, dict);
  }
  if (typeName == "MatMul") {
    return loadMatMul(op, dict);
  }
  if (typeName == "Pad") {
    return loadPad(op, dict);
  }
  if (typeName == "Cast") {
    return loadCast(op, dict);
  }
  if (typeName == "HardSigmoid") {
    return loadHardSigmoid(op, dict);
  }
  if (typeName == "LeakyRelu") {
    return loadLeakyRelu(op, dict);
  }
  if (typeName == "SpaceToDepth") {
    return loadSpaceToDepth(op, dict);
  }
  if (typeName == "DepthToSpace") {
    return loadDepthToSpace(op, dict);
  }
  if (typeName == "ReduceL2") {
    return loadReduceL2(op, dict);
  }
  if (typeName == "ConstantOfShape") {
    return loadConstantOfShape(op, dict, false /* isSplat */);
  }
  if (typeName == "Tile") {
    return loadTile(op, dict);
  }
  if (typeName == "Expand") {
    return loadExpand(op, dict);
  }
  if (typeName == "Where") {
    return loadWhere(op, dict);
  }
  if (typeName == "RNN") {
    return loadRNN(op, dict);
  }
  if (typeName == "GRU") {
    return loadGRU(op, dict);
  }
  if (typeName == "LSTM") {
    return loadLSTM(op, dict);
  }
  if (typeName == "Clip") {
    return loadClip(op, dict);
  }
  if (typeName == "Equal") {
    return loadCmpEQ(op, dict);
  }
  if (typeName == "CmpLTE") {
    return loadCmpLTE(op, dict);
  }
  if (typeName == "Mean") {
    return loadMean(op, dict);
  }
  if (typeName == "Select") {
    return loadSelect(op, dict);
  }
  if (typeName == "Quantize") {
    return loadQuantize(op, dict);
  }
  if (typeName == "QuantizeLinear") {
    return loadQuantizeLinear(op, dict);
  }
  if (typeName == "ConvertTo") {
    return loadConvertTo(op, dict);
  }
  if ((typeName == "Dequantize") || (typeName == "DequantizeLinear")) {
    return loadDequantize(op, dict);
  }
  if (typeName == "Regression") {
    return loadRegression(op, dict);
  }
  if (typeName == "BatchedAdd") {
    return loadBatchedAdd(op, dict);
  }
  if (typeName == "CumSum") {
    return loadCumSum(op, dict);
  }
  if ((typeName == "ScatterAssign") || (typeName == "ScatterND")) {
    return loadScatterAssign(op, dict);
  }
  if (typeName == "IntLookupTable") {
    return loadIntLookupTable(op, dict);
  }
  if (typeName == "LengthsRangeFill") {
    return loadLengthsRangeFill(op, dict);
  }
  if (typeName == "RescaleQuantized") {
    return loadRescaleQuantized(op, dict);
  }
  if (typeName == "RowwiseQuantizedSparseLengthsWeightedSum") {
    return loadRowwiseQuantizedSparseLengthsWeightedSum(op, dict);
  }
  if (typeName == "FusedRowwiseQuantizedSparseLengthsWeightedSum") {
    return loadFusedRowwiseQuantizedSparseLengthsWeightedSum(op, dict);
  }
  if (typeName == "FusedRowwiseQuantizedSparseLengthsSum") {
    return loadFusedRowwiseQuantizedSparseLengthsSum(op, dict);
  }
  if (typeName == "FullyConnected") {
    return loadFullyConnected(op, dict);
  }
  if (typeName == "RowwiseQuantizedFullyConnected") {
    return loadRowwiseQuantizedFullyConnected(op, dict);
  }
  if (typeName == "Splat") {
    return loadSplat(op, dict);
  }
  if (typeName == "InsertTensor") {
    return loadInsertTensor(op, dict);
  }
  if (typeName == "ArgMin") {
    return loadArgMinMax(op, dict, true);
  }
  if (typeName == "ArgMax") {
    return loadArgMinMax(op, dict, false);
  }
  if (typeName == "NonMaxSuppressionV4") {
    return loadNonMaxSuppression(op, dict, true);
  }
  if (typeName == "NonMaxSuppression") {
    return loadNonMaxSuppression(op, dict, false);
  }
  if (typeName == "ConvTranspose") {
    return loadConvTranspose(op, dict);
  }
  if (typeName == "If") {
    return loadIf(op, dict);
  }
  if (typeName == "AdaptiveAvgPool") {
    return loadAdaptiveAvgPool(op, dict);
  }
  if (typeName == "Flip") {
    return loadFlip(op, dict);
  }
  if (typeName == "AudioSpectrogram") {
    return loadAudioSpectrogram(op, dict);
  }
  if (typeName == "RoiAlign") {
    return loadROIAlign(op, dict);
  }
  if (typeName == "MFCC") {
    return loadMFCC(op, dict);
  }
  if (typeName == "Identity") {
    return loadIdentity(op, dict);
  }
  if (typeName == "Upsample") {
    return loadUpsample(op, dict);
  }
  if (typeName == "Resize") {
    return loadResize(op, dict);
  }
  if (typeName == "NonZero") {
    return loadNonZero(op, dict);
  }
  if (typeName == "Acos") {
    return loadAcos(op, dict);
  }
  if (typeName == "Asin") {
    return loadAsin(op, dict);
  }
  if (typeName == "Atan") {
    return loadAtan(op, dict);
  }
  if (typeName == "Sign") {
    return loadSign(op, dict);
  }
  if (typeName == "Softmax") {
    return loadSoftmax(op, dict);
  }
  if (typeName == "LogSoftmax") {
    return loadLogSoftmax(op, dict);
  }
  if (typeName == "ScatterData") {
    return loadScatterData(op, dict);
  }
  if (typeName == "TopK") {
    return loadTopK(op, dict);
  }

  return MAKE_ERR("Failed to load operator " + typeName + " .",
                  ErrorValue::ErrorCode::MODEL_LOADER_UNSUPPORTED_OPERATOR);
}