static NodeSupportLevels isNodeSupported()

in lib/Backends/NNPI/NNPI.cpp [121:843]


static NodeSupportLevels isNodeSupported(const NodeInfo &NI) {
  bool isNodePrecisionSupported = false;
  bool isNodeHasAnySupport = true;
  switch (NI.getKind()) {
  // General math fp32/fp16/i8/int32.
  case Kinded::Kind::AddNodeKind:
  case Kinded::Kind::SubNodeKind:
  case Kinded::Kind::MulNodeKind:
  case Kinded::Kind::MaxNodeKind:
  case Kinded::Kind::MinNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Float16Ty,
         ElemKind::Int8QTy, ElemKind::Int64ITy});
    break;
  case Kinded::Kind::DivNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int64ITy, ElemKind::Int32ITy});
    break;

  // General math fp32/fp16/i8.
  case Kinded::Kind::PowNodeKind:
  case Kinded::Kind::ReluNodeKind:
  case Kinded::Kind::ReplaceNaNNodeKind:
  case Kinded::Kind::MatMulNodeKind:
  case Kinded::Kind::BatchedReduceAddNodeKind:
  case Kinded::Kind::BatchedReduceMeanNodeKind:
  case Kinded::Kind::BatchedAddNodeKind:
  case Kinded::Kind::BatchedMulNodeKind:
  case Kinded::Kind::TanhNodeKind:
  case Kinded::Kind::LogNodeKind:
  case Kinded::Kind::SigmoidNodeKind:
  case Kinded::Kind::NegNodeKind:
  case Kinded::Kind::AbsNodeKind:
  case Kinded::Kind::ExpNodeKind:
  case Kinded::Kind::SoftPlusNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy});
    break;
  case Kinded::Kind::BatchedReduceMinNodeKind:
  case Kinded::Kind::BatchedReduceMaxNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy});
    break;
  case Kinded::Kind::SplatNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});
    break;
  case Kinded::Kind::LocalResponseNormalizationNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Float16Ty, ElemKind::Int8QTy});
    break;
  case Kinded::Kind::ModuloNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
    break;
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 1
  case Kinded::Kind::NNPILookupTableNodeKind:
  case Kinded::Kind::IntLookupTableNodeKind:
    isNodePrecisionSupported = true;
    break;
  case Kinded::Kind::BBoxTransformNodeKind:
    // RoiBatchSplits output should be FP16 in the Glow node and get
    // converted explicitly to FP32 in NNPI importer.
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
    break;
  case Kinded::Kind::ROIAlignNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty}, {ROIAlignNode::BatchIndicesIdx}) &&
        (NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) == ElemKind::Int32ITy ||
         NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) == ElemKind::Int64ITy);
    break;
  case Kinded::Kind::LSTMUnitNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
    break;
  case Kinded::Kind::ResizeNearestNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int32QTy,
         ElemKind::Int8QTy, ElemKind::UInt8QTy});
    break;
  case Kinded::Kind::SparseLabelSplitNodeKind: {
    auto valuesIdxDataType = NI.getInElemTy(SparseLabelSplitNode::ValuesIdx);
    isNodePrecisionSupported =
        (NI.getInElemTy(SparseLabelSplitNode::LengthsIdx) ==
         ElemKind::Int32ITy) &&
        (NI.getInElemTy(SparseLabelSplitNode::IndicesIdx) ==
         ElemKind::Int64ITy) &&
        (NI.getInElemTy(SparseLabelSplitNode::ValuesIdx) ==
         NI.getOutElemTy(SparseLabelSplitNode::LabelValuesIdx)) &&
        (NI.getOutElemTy(SparseLabelSplitNode::ExampleIdsIdx) ==
         ElemKind::Int32ITy) &&
        (NI.getOutElemTy(SparseLabelSplitNode::GradientOffsetMapIdx) ==
         ElemKind::Int32ITy) &&
        (valuesIdxDataType == ElemKind::FloatTy ||
         valuesIdxDataType == ElemKind::Float16Ty ||
         valuesIdxDataType == ElemKind::Int8QTy ||
         valuesIdxDataType == ElemKind::UInt8QTy);
    break;
  }
#endif // NNPI > 1.1
  case Kinded::Kind::LayerNormalizationNodeKind: {
    auto scaleType = NI.getInElemTy(LayerNormalizationNode::ScaleIdx);
    auto biasType = NI.getInElemTy(LayerNormalizationNode::BiasIdx);
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty, ElemKind::Int8QTy},
            {LayerNormalizationNode::ScaleIdx,
             LayerNormalizationNode::BiasIdx}) &&
        scaleType == biasType &&
        (scaleType == ElemKind::Float16Ty || scaleType == ElemKind::Int8QTy);
    break;
  }
  case Kinded::Kind::SwishNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy});
    break;
  case Kinded::Kind::GeluNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
    break;
  case Kinded::Kind::BatchNormalizationNodeKind: {
    auto elemType = NI.getInElemTy(BatchNormalizationNode::InputIdx);
    isNodePrecisionSupported =
        (elemType == ElemKind::Int8QTy || elemType == ElemKind::FloatTy ||
         elemType == ElemKind::Float16Ty);

    isNodePrecisionSupported = isNodePrecisionSupported &&
                               NI.allInputsAndOutputsHaveSameElemKind(
                                   {ElemKind::FloatTy, ElemKind::Float16Ty},
                                   {BatchNormalizationNode::InputIdx},
                                   {BatchNormalizationNode::ResultIdx});

    isNodePrecisionSupported =
        isNodePrecisionSupported &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {elemType},
            {BatchNormalizationNode::ScaleIdx, BatchNormalizationNode::BiasIdx,
             BatchNormalizationNode::MeanIdx, BatchNormalizationNode::VarIdx});
    break;
  }
  case Kinded::Kind::VectorNormNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy});
    break;
  case Kinded::Kind::AvgPoolNodeKind:
  case Kinded::Kind::AdaptiveAvgPoolNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
    break;
  case Kinded::Kind::BatchMatMulNodeKind:
  case Kinded::Kind::PReluNodeKind:
  case Kinded::Kind::ClipNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int8QTy, ElemKind::Float16Ty});
    break;
  case Kinded::Kind::FmodNodeKind:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
#else
    // Supporting these two for now because for fp inputs NNPI returns result
    // with the same sign as the divisor instead of the dividend.
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int64ITy, ElemKind::Int32ITy});
#endif // NNPI >= 1.7
    break;
  // Data transfer fp32/fp16/i8/i32/i64/bool.
  case Kinded::Kind::SaveNodeKind:
  case Kinded::Kind::ConcatNodeKind:
  case Kinded::Kind::TileNodeKind:
  case Kinded::Kind::TransposeNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy});
    break;
  case Kinded::Kind::ConvolutionNodeKind: {
    if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
      isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty});
    } else {
      isNodePrecisionSupported =
          NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
                                                 {ConvolutionNode::BiasIdx}) &&
          ((NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy) ||
           (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::FloatTy));
    }
    break;
  }
  case Kinded::Kind::Convolution3DNodeKind:
    if (!NI.getInTy(Convolution3DNode::InputIdx)->isQuantizedType()) {
      isNodePrecisionSupported =
          NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
    } else {
      isNodePrecisionSupported =
          NI.allInputsAndOutputsHaveSameElemKind(
              {ElemKind::Int8QTy}, {Convolution3DNode::BiasIdx}) &&
          ((NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int32QTy) ||
           (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::FloatTy));
    }
    break;
  case Kinded::Kind::QuantizeNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy ||
         NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::Float16Ty) &&
        (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy);
    break;
  case Kinded::Kind::DequantizeNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) &&
        (NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy ||
         NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::Float16Ty);
    break;
  case Kinded::Kind::RescaleQuantizedNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy});
    break;
  case Kinded::Kind::ConvertToNodeKind: {
    auto isConversionSupportedFor = [](ElemKind kindFrom, ElemKind kindTo) {
      switch (kindFrom) {
      case ElemKind::Float16Ty:
        switch (kindTo) {
        case ElemKind::FloatTy:
        case ElemKind::Int8QTy:
        case ElemKind::UInt8QTy:
        case ElemKind::BoolTy:
          return true;
        case ElemKind::Int32ITy:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
          return true;
#else
          return glow::nnpi::flags::EnableCustomIAKernels;
#endif // NNPI >= 1.7
        default:
          return false;
        }
        return false;

      case ElemKind::FloatTy:
        switch (kindTo) {
        case ElemKind::Float16Ty:
        case ElemKind::Int8QTy:
        case ElemKind::UInt8QTy:
        case ElemKind::BoolTy:
          return true;
        case ElemKind::Int32ITy:
          return glow::nnpi::flags::EnableCustomIAKernels;
        default:
          return false;
        }
        return false;

      case ElemKind::Int64ITy:
        switch (kindTo) {
        case ElemKind::Int32ITy:
        case ElemKind::FloatTy:
        case ElemKind::Int8QTy:
          return true;
        default:
          return false;
        }
        return false;

      // NOTE: this is supported by a custom kernel
      case ElemKind::BoolTy:
        switch (kindTo) {
        case ElemKind::Int32ITy:
          return true;
        default:
          return false;
        }
        return false;

      case ElemKind::Int32ITy:
        switch (kindTo) {
        case ElemKind::Int64ITy:
        case ElemKind::Float16Ty:
        case ElemKind::FloatTy:
        case ElemKind::Int8QTy:
          return true;
        case ElemKind::BoolTy:
          return glow::nnpi::flags::EnableCustomIAKernels;
        default:
          return false;
        }
        return false;

      case ElemKind::Int32QTy:
        switch (kindTo) {
        case ElemKind::Float16Ty:
          return true;
        default:
          return false;
        }
        return false;

      case ElemKind::UInt8QTy:
      case ElemKind::Int8QTy:
        return true;

      case ElemKind::UInt8FusedQTy:
        return (kindTo == ElemKind::Float16Ty ||
                kindTo == ElemKind::UInt8FusedFP16QTy);
      case ElemKind::UInt8FusedFP16QTy:
        return (kindTo == ElemKind::Float16Ty);
      default:
        return false;
      }
      return false;
    };
    isNodePrecisionSupported =
        isConversionSupportedFor(NI.getInElemTy(ConvertToNode::InputIdx),
                                 NI.getOutElemTy(ConvertToNode::ResultIdx));
    break;
  }

  case Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
             ElemKind::Float16Ty ||
         NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
             ElemKind::FloatTy) &&
        NI.getInElemTy(DynamicQuantizedFullyConnectedNode::WeightsIdx) ==
            ElemKind::Int8QTy &&
        NI.getInElemTy(DynamicQuantizedFullyConnectedNode::BiasIdx) ==
            ElemKind::FloatTy;
    break;

  case Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
             ElemKind::Float16Ty ||
         NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
             ElemKind::FloatTy) &&
        NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
            ElemKind::Int8QTy &&
        NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx) ==
            ElemKind::FloatTy &&
        NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
            ElemKind::FloatTy &&
        NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
            ElemKind::Int32ITy;
    break;
  case Kinded::Kind::FullyConnectedNodeKind:
    if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
      isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty});
    } else {
      isNodePrecisionSupported =
          NI.allInputsAndOutputsHaveSameElemKind(
              {ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
          ((NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
            ElemKind::Int32QTy) ||
           (NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::FloatTy));
    }
    break;
  case Kinded::Kind::MaxPoolNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
            {MaxPoolNode::ArgmaxIdx}) &&
        (NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy);
    break;
  case Kinded::Kind::TopKNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
            {TopKNode::IndicesIdx}) &&
        (NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy ||
         NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::GatherNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int64ITy,
             ElemKind::Int8QTy},
            {GatherNode::IndicesIdx}) &&
        ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
         (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));
    break;
  case Kinded::Kind::GatherRangesNodeKind:

    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Int32ITy, ElemKind::Int64ITy},
            {GatherRangesNode::DataIdx}, {GatherRangesNode::OutputIdx}) &&
        ((NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::FloatTy) ||
         (NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Float16Ty) ||
         (NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int8QTy) ||
         (NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int32ITy) ||
         (NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int64ITy)) &&
        (NI.getOutElemTy(GatherRangesNode::OutputIdx) ==
         NI.getInElemTy(GatherRangesNode::DataIdx));
    break;
  case Kinded::Kind::SliceNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy});
    break;
  case Kinded::Kind::ReshapeNodeKind:

    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});
    break;
  case Kinded::Kind::CmpLTENodeKind:
  case Kinded::Kind::CmpLTNodeKind:
  case Kinded::Kind::CmpEQNodeKind:
  case Kinded::Kind::CmpNEQNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
             ElemKind::Int32ITy},
            {}, {CmpEQNode::ResultIdx}) &&
        (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);
    break;
  case Kinded::Kind::NonZeroNodeKind:
    isNodePrecisionSupported =
        (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::SelectNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {SelectNode::CondIdx}) &&
        (NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy);
    break;
  case Kinded::Kind::GaussianFillNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
             ElemKind::Int32ITy, ElemKind::Int64ITy},
            {}, {GaussianFillNode::ResultIdx}) &&
        (NI.getOutElemTy(GaussianFillNode::ResultIdx)) == ElemKind::Float16Ty;
    break;
  case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
         ElemKind::Int8QTy) &&
        (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
         ElemKind::Int8QTy) &&
        (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
         ElemKind::FloatTy) &&
        (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
         ElemKind::Int32ITy) &&
        ((NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
          ElemKind::Int32QTy) ||
         (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
          ElemKind::FloatTy)) &&
        (NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
         ElemKind::Int8QTy);
    break;
  case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
         ElemKind::Int8QTy) &&
        (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
         ElemKind::Int8QTy) &&
        ((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
          ElemKind::Int32QTy) ||
         (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
          ElemKind::FloatTy)) &&
        (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
         ElemKind::FloatTy) &&
        (NI.getInElemTy(
             ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==

         ElemKind::Int32ITy) &&
        (NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
         ElemKind::Int8QTy);
    break;
  case Kinded::Kind::SparseLengthsSumNodeKind:
    isNodePrecisionSupported =
        isSLSIndicesValid(NI.getInTy(SparseLengthsSumNode::IndicesIdx)) &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {SparseLengthsSumNode::IndicesIdx,
             SparseLengthsSumNode::LengthsIdx}) &&
        (NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
             ElemKind::Int64ITy ||
         NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
             ElemKind::Int32ITy) &&
        (NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
         ElemKind::Int32ITy);
    break;
  case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
    isNodePrecisionSupported =
        isSLSIndicesValid(
            NI.getInTy(SparseLengthsWeightedSumNode::IndicesIdx)) &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {SparseLengthsWeightedSumNode::IndicesIdx,
             SparseLengthsWeightedSumNode::LengthsIdx}) &&
        (NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
             ElemKind::Int64ITy ||
         NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
             ElemKind::Int32ITy) &&
        (NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
         ElemKind::Int32ITy);
    break;
  case Kinded::Kind::EmbeddingNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {EmbeddingNode::IndicesIdx}) &&
        (NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int64ITy ||
         NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::EmbeddingBagNodeKind:
    isNodePrecisionSupported =
        isSLSIndicesValid(NI.getInTy(EmbeddingBagNode::IndicesIdx)) &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
        (NI.getInElemTy(EmbeddingBagNode::IndicesIdx) == ElemKind::Int64ITy ||
         NI.getInElemTy(EmbeddingBagNode::IndicesIdx) == ElemKind::Int32ITy) &&
        (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int64ITy ||
         NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
    auto dataK = NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx);
    auto offsetsK =
        NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx);
    auto indicesK =
        NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx);
    auto resultK =
        NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx);
    isNodePrecisionSupported =
        isSLSIndicesValid(
            NI.getInTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx)) &&
        (dataK == ElemKind::UInt8FusedQTy ||
         dataK == ElemKind::UInt8FusedFP16QTy ||
         dataK == ElemKind::UInt4FusedFP16QTy) &&
        (resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
        (offsetsK == ElemKind::Int64ITy || offsetsK == ElemKind::Int32ITy) &&
        (indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy);

    break;
  }
  case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind: {
    auto dataK =
        NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx);
    auto lengthsK =
        NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::LengthsIdx);
    auto indicesK =
        NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::IndicesIdx);
    auto resultK =
        NI.getOutElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::ResultIdx);
    isNodePrecisionSupported =
        isSLSIndicesValid(NI.getInTy(
            FusedRowwiseQuantizedSparseLengthsSumNode::IndicesIdx)) &&
        (dataK == ElemKind::UInt8FusedQTy ||
         dataK == ElemKind::UInt8FusedFP16QTy ||
         dataK == ElemKind::UInt4FusedFP16QTy) &&
        (resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
        (indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy) &&
        (lengthsK == ElemKind::Int32ITy);
    break;
  }
  case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
    auto dataK = NI.getInElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx);
    auto weightsK = NI.getInElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx);
    auto lengthsK = NI.getInElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx);
    auto indicesK = NI.getInElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx);
    auto resultK = NI.getOutElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx);
    isNodePrecisionSupported =
        isSLSIndicesValid(NI.getInTy(
            FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx)) &&
        (dataK == ElemKind::UInt8FusedQTy ||
         dataK == ElemKind::UInt8FusedFP16QTy ||
         dataK == ElemKind::UInt4FusedFP16QTy) &&
        (weightsK == ElemKind::FloatTy || weightsK == ElemKind::Float16Ty) &&
        (resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
        (indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy) &&
        (lengthsK == ElemKind::Int32ITy);
  } break;
  case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
    isNodePrecisionSupported =
        isSLSIndicesValid(NI.getInTy(
            RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx)) &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty},
            {RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx,
             RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx,
             RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx}) &&
        (NI.getInElemTy(
             RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
         ElemKind::UInt8QTy) &&
        (NI.getInElemTy(
             RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
             ElemKind::Int64ITy ||
         NI.getInElemTy(
             RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
             ElemKind::Int32ITy) &&
        (NI.getInElemTy(
             RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
         ElemKind::Int32ITy);
    break;
  case Kinded::Kind::ScatterDataNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
             ElemKind::UInt8QTy},
            {ScatterDataNode::IndicesIdx}) &&
        (NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int32ITy ||
         NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy);
    break;
  case Kinded::Kind::BucketizeNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy}, {},
            {BucketizeNode::ResultIdx}) &&
        (NI.getOutElemTy(BucketizeNode::ResultIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::SoftMaxNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
            {SoftMaxNode::SelectedIdx}) &&
        (NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy ||
         NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::LengthsRangeFillNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
    break;
  case Kinded::Kind::BatchOneHotNodeKind:

    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
             ElemKind::Int32ITy, ElemKind::Int64ITy},
            {BatchOneHotNode::LengthsIdx}) &&
        (NI.getInElemTy(BatchOneHotNode::LengthsIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::NNPICustomDSPNodeKind:
  case Kinded::Kind::NNPICustomIANodeKind:
    isNodePrecisionSupported = true;
    break;
  case Kinded::Kind::SpaceToDepthNodeKind:
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});
    break;
  case Kinded::Kind::ArgMaxNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::Int32ITy,
             ElemKind::Int64ITy, ElemKind::BoolTy},
            {}, {ArgMaxNode::ResultIdx}) &&
        (NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
         NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::ArgMinNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty, ElemKind::FloatTy, ElemKind::Int8QTy,
             ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy},
            {}, {ArgMinNode::ResultIdx}) &&
        (NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int64ITy ||
         NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int32ITy);
    break;
  case Kinded::Kind::LogitNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
    break;
  case Kinded::Kind::CumSumNodeKind:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
    isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int32ITy, ElemKind::Int8QTy, ElemKind::UInt8QTy});
#else
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
#endif // NNPI >= 1.7
    break;
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 9
  case Kinded::Kind::BatchSparseToDenseNodeKind:
    isNodePrecisionSupported =
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::Float16Ty, ElemKind::UInt8QTy, ElemKind::Int8QTy},
            {BatchSparseToDenseNode::LengthsIdx,
             BatchSparseToDenseNode::IndicesIdx}) &&
        ((NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
              ElemKind::Int64ITy ||
          NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
              ElemKind::Int32ITy)) &&
        ((NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
              ElemKind::Int64ITy ||
          NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
              ElemKind::Int32ITy));
    break;
  case Kinded::Kind::FillExamplesWithIndicatorNodeKind:
    isNodePrecisionSupported =
        (NI.getInElemTy(FillExamplesWithIndicatorNode::DataIdx) ==
         NI.getOutElemTy(FillExamplesWithIndicatorNode::ResultIdx)) &&
        ((NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
          ElemKind::Int32ITy) ||
         (NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
          ElemKind::Int64ITy));
    break;
#endif // NNPI >= 1.9
  default:
    isNodeHasAnySupport = false;
    isNodePrecisionSupported = false;
  }

  if (isNodePrecisionSupported) {
    return NodeSupportLevels::PRECISION_SUPPORTED;
  } else if (isNodeHasAnySupport) {
    return NodeSupportLevels::SUPPORTED;
  }
  return NodeSupportLevels::NOT_SUPPORTED;
}