std::pair OnnxConverter::IsNodeSupported()

in onnxruntime/core/providers/rknpu/onnx_converter.cc [322:543]


std::pair<bool, std::string> OnnxConverter::IsNodeSupported(
    const ONNX_NAMESPACE::ModelProto& model_proto,
    const ONNX_NAMESPACE::NodeProto& node) const {
  NodeAttrHelper helper(node);
  const auto& op = node.op_type();
  const std::vector<std::string> supported_types{
      "Conv", "Relu", "Clip", "LeakyRelu",
      "MaxPool", "AveragePool", "GlobalAveragePool",
      "Concat", "Softmax", "BatchNormalization", "Gemm",
      "Add", "Mul", "Sub",
      "Reshape", "Squeeze", "Unsqueeze",
      "Flatten", "Transpose", /*"Gather", "Slice",*/
      "QLinearConv", /*"QuantizeLinear",*/ "DequantizeLinear"};
  if (std::find(supported_types.begin(), supported_types.end(), op) ==
      supported_types.end()) {
    return {false, "Unsupported operator"};
  }

  if (!TypeSupport(model_proto, node.input(0))) {
    return {false, "Type of input(" + node.input(0) + ") is unsupported"};
  }

  if (op == "Conv") {
    const auto strides = helper.get("strides", vector<int>{1, 1});
    const auto pads = helper.get("pads", vector<int>{0, 0, 0, 0});
    const auto dilations = helper.get("dilations", vector<int>{1, 1});
    const auto group = helper.get("group", 1);
    if (dilations != vector<int>{1, 1} && strides != vector<int>{1, 1}) {
      return {false, "Both dilations and strides > 1 is not supported for now"};
    }
    const auto weight = m(node.input(1));
    if (HAS(tensor_dims_, weight)) {
      const auto& dims = tensor_dims_.at(weight);
      if (group != 1 && dims[1] != 1) {
        return {false, "group != 1 is not supported"};
      }
      if (dims.size() != 4) {
        return {false, "Only conv 2d is supported."};
      }
    } else {
      return {false, "The weight of convolution must be known"};
    }
  } else if (op == "AveragePool" || op == "MaxPool") {
    const auto count_include_pad = helper.get("count_include_pad", 0);
    if (count_include_pad == 1) {
      return {false, "count_include_pad == 1 is not supported"};
    }
    const auto storage_order = helper.get("storage_order", 0);
    if (storage_order == 1) {
      return {false, "storage_order == 1 is not supported"};
    }
    if (helper.get("auto_pad", "NOTSET") != "NOTSET") {
      return {false, "auto_pad is not supported"};
    }
    if (helper.get("kernel_shape", std::vector<int>{1, 1}).size() != 2) {
      return {false, "Only pooling 2d is supported"};
    }
    if (helper.get("ceil_mode", 0) == 1) {
      return {false, "ceil_mode == 1 is not supported for pooling"};
    }
    if (helper.get("dilations", std::vector<int>{1, 1}) !=
        std::vector<int>{1, 1}) {
      return {false, "Dilations of pooling is not supported"};
    }
    if (node.output_size() != 1) {
      return {false, "Argmax in maxpooling is not supported"};
    }
  } else if (op == "GlobalAveragePool" || op == "GlobalMaxPool") {
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    if (input_shape.size() == 0 || input_shape.size() != 4) {
      return {false, "Only rank-4 tensor is supported"};
    }
  } else if (op == "PRelu") {
    const auto slope = m(node.input(1));
    if (HAS(tensor_dims_, slope)) {
      if (tensor_dims_.at(slope) != Shaper::Shape{1}) {
        // TODO: support it
        return {false, "PRelu only support one element slope."};
      }
    } else {
      return {false, "PRelu slope must be known"};
    }
  } else if (op == "Gemm") {
    const auto transA = helper.get("transA", 0);
    const auto transB = helper.get("transB", 0);
    const auto alpha = helper.get("alpha", 1.0f);
    const auto beta = helper.get("beta", 1.0f);
    if (!(transA == 0 && transB == 1 && alpha == 1.f && beta == 1.f)) {
      return {false,
              "Only transA == 0, transB == 1, alpha == 1.0 and beta == "
              "1.0 is supported."};
    }
  } else if (op == "BatchNormalization") {
    if (node.output_size() != 1) {
      return {false,
              "Your onnx model may be in training mode, please export "
              "it in test mode."};
    }
    const auto scale = m(node.input(1));
    const auto b = m(node.input(2));
    const auto mean = m(node.input(3));
    const auto var = m(node.input(4));
    if (!HAS(tensor_dims_, scale)) {
      return {false, "Scale of BN must be known"};
    }
    if (!HAS(tensor_dims_, b)) {
      return {false, "B of BN must be known"};
    }
    if (!HAS(tensor_dims_, mean)) {
      return {false, "Mean of BN must be known"};
    }
    if (!HAS(tensor_dims_, var)) {
      return {false, "Var of BN must be known"};
    }
  } else if (op == "LRN") {
    const auto size = helper.get("size", 1);
    if (size % 2 == 0) {
      return {false, "NNAPI only support odd size for LRN"};
    }
  } else if (op == "Reshape") {
    const auto output = node.output(0);
    for (const auto& another_node : model_proto_.graph().node()) {
      for (const auto& input : another_node.input()) {
        if (input == output &&
            another_node.op_type() != "Gemm") {
          return {false,
                  "Reshape can only be the last layer or precede a "
                  "gemm layer for now"};
        }
      }
    }
    const auto& shape = GetShape(model_proto, tensor_dims_, node.input(1));
    int rank = shape.size();
    if (shape.size() != 1) {
      return {false, "Wrong shape rank"};
    }
    if (shape[0] <= 1) {  // rknpu doesn't support dims of shape equal 1,
                          // but simulator support, why? (reproduce on
                          // "Reshape_699" of ssd.onnx)
      return {false, "Only shape dims > 1 is supported"};
    }
  } else if (op == "Softmax") {
    const auto axis = helper.get("axis", 1);
    if (axis != 1) {
      return {false, "Only axis == 1 is supported"};
    }
  } else if (op == "Flatten") {
    const auto axis = helper.get("axis", 1);
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    int rank = input_shape.size();
    if (rank != 0) {
      if (axis < 0 || axis > (int64_t)rank) {
        return {false, "Only axis <= rank of input is supported"};
      }
    }
  } else if (op == "Squeeze") {
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    const auto axes = helper.get("axes", std::vector<int>{});
    int rank = input_shape.size();
    if (rank != 0) {
      for (auto axis : axes) {
        if (axis >= rank || axis < -1) {
          return {false, "Only axes <= rank of input is supported"};
        } else {
          axis = (axis < 0) ? (axis + rank) : (axis);
          if (input_shape[axis] != 1) {
            return {false, "the input_shape[axis] must equal one"};
          }
        }
      }
    }
  } else if (op == "Unsqueeze") {
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    const auto axes = helper.get("axes", std::vector<int>{});
    int rank = input_shape.size();
    if (rank != 0) {
      for (const auto axis : axes) {
        if (axis < 0) {
          return {false, "Only axes >= 0 is supported"};
        } else if (axis >= (int)(rank + axes.size())) {
          return {false, "Only axes <= rank of (input + axes) is supported"};
        }
      }
    }
  } else if (op == "Gather") {
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    const auto axis = helper.get("axis", 1);
    int rank = input_shape.size();
    if (rank != 0 && (axis >= rank || axis < -rank)) {
      return {false, "Only axis <= rank of input is supported"};
    }
  } else if (op == "Concat") {
    const auto& input_shape = GetShape(model_proto,
                                       tensor_dims_, node.input(0));
    const auto axis = helper.get("axis", 1);
    int rank = input_shape.size();
    if (axis >= 4) {
      return {false, "Only axis <= 4 of input is supported"};
    }
    if (rank != 0 && rank > 4 && (axis >= rank || axis < -rank)) {
      if (rank > 4)
        return {false, "Only rank <= 4 of input is supported"};
      else
        return {false, "Only axis <= rank of input is supported"};
    }
  } else if (op == "Add") {
    if (!TypeSupport(model_proto, node.input(1))) {
      return {false, "Type of input1(" + node.input(1) + ") is unsupported"};
    }
  } else if (op == "Mul") {
    if (!TypeSupport(model_proto, node.input(1))) {
      return {false, "Type of input1(" + node.input(1) + ") is unsupported"};
    }
  }

  return {true, ""};
}