static bool IsUnsupportedOpMode()

in onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc [406:668]


static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) {
  std::vector<NodeIndex> input_nodes;
  const auto& optype = node->OpType();
  if (optype == "ArgMax" or optype == "ArgMin") {
    const auto& attributes = node->GetAttributes();
    // we do not support select_last_index = 1 for now
    auto sli_attr = attributes.find("select_last_index");
    if (sli_attr != attributes.end() && (*sli_attr).second.i() != 0) {
      return true;
    }
  } else if (optype == "ConstantOfShape") {
    if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "ConvInteger") {
    if (node->InputDefs()[0]->Shape()->dim_size() != 4) {
      return true;
    }

    // migraphx can handle only two inputs
    if (node->InputDefs().size() != 2) {
      return true;
    }

    // only support int8 type
    const auto& input_type = node->InputDefs()[0]->TypeAsProto();
    if (input_type == nullptr) {
      return true;
    }

    if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
      return true;
    }
  } else if (optype == "Expand") {
    // MIGraphX only supports constant shape input values
    if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
    {
      return true;
    }
  }
  else if (optype == "MaxPool") {
    //MaxPool "indices" output is not currently supported.
    if (node->OutputDefs().size() > 1) {
      return true;
    }

    // ceil_mode and dilations attrs are not supported in MIGraphX
    const auto& attributes = node->GetAttributes();
    auto dila_attr = attributes.find("dilations");
    if (dila_attr != attributes.end()) {
      auto dilas = to_vector((*dila_attr).second.ints());
      bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1; });
      if (ret == false) {
        return true;
      }
    }

    // storage order 1 (column major format) is not supported
    auto storage_order_attr = attributes.find("storage_order");
    if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) {
      return true;
    }

    // do not support int8 and uint8 type
    const auto& input_type = node->InputDefs()[0]->TypeAsProto();
    if (input_type == nullptr) {
      return true;
    }
    auto data_type = input_type->tensor_type().elem_type();
    if (data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8 or
        data_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8) {
      return true;
    }
  } else if (optype == "MatMulInteger") {
    // migraphx can handle only two inputs
    if (node->InputDefs().size() != 2) {
      return true;
    }

    // only support int8 type
    const auto& input_type = node->InputDefs()[0]->TypeAsProto();
    if (input_type == nullptr) {
      return true;
    }

    if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) {
      return true;
    }
  } else if (optype == "NonZero") {
    if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "OneHot") {
    if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "Pad") {
    const auto& args = node->InputDefs();
    // if pad size is not constant, migraphx cannot support
    if (args.size() >= 2) {
      if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
      {
        return true;
      }
    }

    const auto& attributes = node->GetAttributes();
    // Pad only support constant mode
    auto mode_attr = attributes.find("mode");
    std::string mode = "constant";
    if (mode_attr != attributes.end()) {
      mode = (*mode_attr).second.s();
    }
    static const std::set<std::string> allowed_modes = {"constant", "reflect"};
    if (allowed_modes.count(mode) == 0) {
      return true;
    }

    // input value only applied to constant mode
    if (mode == "constant") {
      if (args.size() == 3) {
        if (!can_eval_node_argument(graph_viewer, node, {2}, logger, input_nodes))
        {
          return true;
        }
      }
    }
  } else if (optype == "Range") {
    auto arg_num = node->InputDefs().size();
    std::vector<std::size_t> vec(arg_num);
    std::iota(vec.begin(), vec.end(), 0);
    if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "Reshape") {
    const auto& args = node->InputDefs();
    if (args.size() == 2) {
      if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
      {
        return false;
      }
      return true;
    }
  } else if (optype == "Resize") {
    const auto& attributes = node->GetAttributes();
    auto ct_attr = attributes.find("coordinate_transformation_mode");
    if (ct_attr != attributes.end()) {
      auto ct = (*ct_attr).second.s();
      if (ct == "tf_crop_and_resize")
      {
        return true;
      }
    }

    auto mode_attr = attributes.find("mode");
    if (mode_attr != attributes.end()) {
      auto mode = (*mode_attr).second.s();
      if (mode == "cubic")
      {
        return true;
      }
    }

    const auto& args = node->InputDefs();
    if (args.size() > 1)
    {
      std::vector<std::size_t> indices(args.size() - 1);
      std::iota(indices.begin(), indices.end(), 1);
      if (can_eval_node_argument(graph_viewer, node, indices, logger, input_nodes))
      {
        return false;
      }
      return true;
    }
  } else if (optype == "ReduceSum") {
    const auto& args = node->InputDefs();
    if (args.size() == 2) {
      if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
      {
        return false;
      }
      return true;
    }
  } else if (optype == "Slice") {
    // MIGraphX does not properly handle the situation where any
    // value of the "starts" attribute is higher than a corresponding
    // value in the "ends"
    auto arg_num = node->InputDefs().size();
    std::vector<std::size_t> vec(arg_num);
    std::iota(vec.begin(), vec.end(), 0);
    vec.erase(vec.begin());
    if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes))
    {
      return true;
    }

    const auto& attributes = node->GetAttributes();
    if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
      auto starts = to_vector((*attributes.find("starts")).second.ints());
      auto ends = to_vector((*attributes.find("ends")).second.ints());
      for (std::size_t i = 0; i < starts.size(); ++i) {
        if (starts.at(i) > ends.at(i)) {
          return true;
        }
      }
    }
  } else if (optype == "Split") {
    // cannot process input dim of 0 size
    const auto arg_s = node->InputDefs()[0]->Shape();
    if (arg_s != nullptr) {
      const auto& tensor_dims = arg_s->dim();
      std::vector<std::size_t> dims;
      std::transform(tensor_dims.begin(),
                     tensor_dims.end(),
                     std::back_inserter(dims),
                     [&](auto&& d) -> std::size_t {
                       if (d.has_dim_value()) {
                         return d.dim_value();
                       } else {
                         return 0;
                       }
                     });
      if (dims == std::vector<std::size_t>{0}) {
        return true;
      }
    }

    const auto& args = node->InputDefs();
    if (args.size() == 2) {
      if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
      {
        return false;
      }
      return true;
    }
  } else if (optype == "Tile") {
    if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "TopK") {
    if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
    {
      return true;
    }
  } else if (optype == "Unsqueeze" or optype == "Squeeze") {
    const auto& args = node->InputDefs();
    if (args.size() == 2) {
      if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes))
      {
        return false;
      }
      return true;
    }
  }

  //Op doesn't fall into known any of unsupported modes.
  return false;
}