bool LLVMBackend::isOpSupported()

in lib/LLVMIRCodeGen/LLVMBackend.cpp [52:580]


bool LLVMBackend::isOpSupported(const NodeInfo &NI) const {
  // Note: For brevity below, "X ==> Y, Z" signifes that Node X is IRGen'd into
  // Instructions Y and Z.
  switch (NI.getKind()) {
  case Kinded::Kind::BatchedReduceMaxNodeKind:
  case Kinded::Kind::BatchedReduceMinNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::BatchedReduceProdNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::AddNodeKind:
  case Kinded::Kind::MulNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy,
         ElemKind::Int64ITy});

  case Kinded::Kind::ReluNodeKind:
  case Kinded::Kind::ClipNodeKind:
  case Kinded::Kind::LeakyReluNodeKind:
  case Kinded::Kind::SubNodeKind:
  case Kinded::Kind::MaxNodeKind:
  case Kinded::Kind::MinNodeKind:
  case Kinded::Kind::BatchedReduceAddNodeKind:
  case Kinded::Kind::MatMulNodeKind:
  case Kinded::Kind::AvgPoolNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy});

  case Kinded::Kind::AdaptiveAvgPoolNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::MaxPoolNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
               {MaxPoolNode::ArgmaxIdx}) &&
           (NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy ||
            NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::ArgMaxNodeKind:
  case Kinded::Kind::ArgMinNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
               {ArgMaxNode::ResultIdx}) &&
           (NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
            NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::ResizeNearestNodeKind:
  case Kinded::Kind::ResizeBilinearNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::SaveNodeKind:
  case Kinded::Kind::ReshapeNodeKind:
    // These are implemented via a Copy Instruction.
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
         ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
         ElemKind::BoolTy});

    // InsertTensor ==> Copy + InsertTensor. Copy supports everything
    // ReshapeNode above supports, so InsertTensor is the limiting factor.
  case Kinded::Kind::InsertTensorNodeKind:
    // Concat ==> Splat + Insert. Both only support the following.
  case Kinded::Kind::ConcatNodeKind:
  case Kinded::Kind::SplatNodeKind:
  case Kinded::Kind::TouchNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
         ElemKind::Int32ITy, ElemKind::BoolTy});
  case Kinded::Kind::SliceNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});
  case Kinded::Kind::FmodNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int64ITy});
  case Kinded::Kind::SpaceToDepthNodeKind:
  case Kinded::Kind::DivNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
         ElemKind::Int32ITy});

  case Kinded::Kind::TransposeNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
         ElemKind::BoolTy});

  case Kinded::Kind::FlipNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int16QTy,
         ElemKind::Int32QTy, ElemKind::Int32ITy, ElemKind::Int64ITy,
         ElemKind::BoolTy});

  case Kinded::Kind::SparseLengthsSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {SparseLengthsSumNode::IndicesIdx,
                                     SparseLengthsSumNode::LengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {SparseLengthsWeightedSumNode::IndicesIdx,
                SparseLengthsWeightedSumNode::LengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {EmbeddingNode::IndicesIdx}) &&
           (NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingBagNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
           (NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind:
    // GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to
    // be checked because they are not used.
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {SparseLengthsWeightedSumGradNode::IndicesIdx,
                SparseLengthsWeightedSumGradNode::LengthsIdx},
               {SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx,
                SparseLengthsWeightedSumGradNode::
                    GradOfInputNamedLengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
    return (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
            ElemKind::UInt8QTy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::ScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::OffsetsIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getOutElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
            ElemKind::FloatTy);

  case Kinded::Kind::LengthsRangeFillNodeKind:
  case Kinded::Kind::LengthsToRangesNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});

  case Kinded::Kind::IntLookupTableNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int8QTy, ElemKind::Int16QTy});

  case Kinded::Kind::RescaleQuantizedNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy});

  case Kinded::Kind::PowNodeKind:
  case Kinded::Kind::AvgPoolGradNodeKind:
  case Kinded::Kind::QuantizationProfileNodeKind:
  case Kinded::Kind::LocalResponseNormalizationNodeKind:
  case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
  case Kinded::Kind::LogNodeKind:
  case Kinded::Kind::TanhNodeKind:
  case Kinded::Kind::SigmoidNodeKind:
  case Kinded::Kind::ExpNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::ModuloNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::MaxPoolGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {MaxPoolGradNode::OriginalOutputForArgmaxIdx,
                MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) &&
           (NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(
                MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(
                MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
                ElemKind::Int32ITy);

  case Kinded::Kind::ConvolutionNodeKind:
    if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
    }

    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
                                                  {ConvolutionNode::BiasIdx}) &&
           (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int8QTy ||
            NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy);

  case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
    return (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
            ElemKind::Int8QTy) &&
           ((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
             ElemKind::Int8QTy) ||
            (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
             ElemKind::Int32QTy)) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::BiasScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
            ElemKind::Int8QTy);

  case Kinded::Kind::ConvTransposeNodeKind:
    // TODO - not quantized support yet in libjit.
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::BatchedAddNodeKind:
    if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
    }
    // Allow for Int8QTy or Int32QTy for the Slice input.
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
                                                  {BatchedAddNode::SliceIdx}) &&
           ((NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy) ||
            (NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::GatherNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
                ElemKind::Int32ITy},
               {GatherNode::IndicesIdx}) &&
           ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
            (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));

  case Kinded::Kind::GatherNDNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
                ElemKind::Int32ITy},
               {GatherNDNode::IndicesIdx}) &&
           ((NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int32ITy) ||
            (NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int64ITy));

  case Kinded::Kind::GatherRangesNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy,
                ElemKind::Int32ITy},
               {GatherRangesNode::RangesIdx}, {GatherRangesNode::LengthsIdx}) &&
           ((NI.getInElemTy(GatherRangesNode::RangesIdx) ==
             NI.getOutElemTy(GatherRangesNode::LengthsIdx)) &&
            ((NI.getOutElemTy(GatherRangesNode::LengthsIdx) ==
              ElemKind::Int32ITy) ||
             (NI.getOutElemTy(GatherRangesNode::LengthsIdx) ==
              ElemKind::Int64ITy)));

  case Kinded::Kind::ScatterDataNodeKind:
    // ScatterData ==> Copy + ScatterData. Copy supports everything
    // ReshapeNode above supports, however ScatterData only supports the
    // following.
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy},
               {ScatterDataNode::IndicesIdx}) &&
           (NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy ||
            NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::SelectNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy},
               {SelectNode::CondIdx}) &&
           ((NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy));

  case Kinded::Kind::NotNodeKind:
  case Kinded::Kind::AndNodeKind:
  case Kinded::Kind::OrNodeKind:
  case Kinded::Kind::XorNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::BoolTy});

  case Kinded::Kind::AbsNodeKind:
  case Kinded::Kind::NegNodeKind:
  case Kinded::Kind::FloorNodeKind:
  case Kinded::Kind::CeilNodeKind:
  case Kinded::Kind::RoundNodeKind:
  case Kinded::Kind::SqrtNodeKind:
  case Kinded::Kind::ErfNodeKind:
  case Kinded::Kind::RsqrtNodeKind:
  case Kinded::Kind::ReciprocalNodeKind:
  case Kinded::Kind::SinNodeKind:
  case Kinded::Kind::CosNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::CmpEQNodeKind:
  case Kinded::Kind::CmpNEQNodeKind:
  case Kinded::Kind::CmpLTNodeKind:
  case Kinded::Kind::CmpLTENodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int32ITy,
                ElemKind::Int64ITy},
               {}, {CmpEQNode::ResultIdx}) &&
           (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);

  case Kinded::Kind::IsNaNNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}, {},
                                                  {IsNaNNode::ResultIdx}) &&
           (NI.getOutElemTy(IsNaNNode::ResultIdx) == ElemKind::BoolTy);

  case Kinded::Kind::TopKNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy}, {},
               {TopKNode::IndicesIdx}) &&
           (NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy ||
            NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::QuantizeNodeKind:
    return (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) &&
           ((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) ||
            (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int16QTy) ||
            (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::DequantizeNodeKind:
    return ((NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int16QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int32QTy)) &&
           (NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy);

  case Kinded::Kind::SoftMaxNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Int8QTy},
               {SoftMaxNode::SelectedIdx}) &&
           (NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy ||
            NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::CrossEntropyLossNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {CrossEntropyLossNode::LabelsIdx}) &&
           (NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
                ElemKind::Int32ITy);

  case Kinded::Kind::LengthsSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {LengthsSumNode::LengthsIdx}) &&
           (NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind:
    return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx) ==
            ElemKind::UInt8FusedQTy) &&
           (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
            ElemKind::FloatTy);

  case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind:
    return (NI.getInElemTy(
                FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
            ElemKind::UInt8FusedQTy) &&
           (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                               WeightsIdx) == ElemKind::FloatTy) &&
           ((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                IndicesIdx) == ElemKind::Int64ITy ||
             NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                IndicesIdx) == ElemKind::Int32ITy)) &&
           (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                               LengthsIdx) == ElemKind::Int32ITy) &&
           (NI.getOutElemTy(
                FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
            ElemKind::FloatTy);

  case Kinded::Kind::FullyConnectedNodeKind:
    if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
    }
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
           (NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int8QTy ||
            NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int32QTy);

  case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
    return (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
                ElemKind::Int8QTy ||
            NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
                ElemKind::Int32QTy) &&
           (NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
            ElemKind::Int8QTy);

  case Kinded::Kind::SoftMaxGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx},
               {SoftMaxGradNode::GradOfInputNamedSelectedIdx}) &&
           (NI.getInElemTy(SoftMaxGradNode::SelectedIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::ConvolutionGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy}, {},
        {ConvolutionGradNode::GradOfInputNamedInputIdx});

  case Kinded::Kind::CrossEntropyLossGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {CrossEntropyLossGradNode::LabelsIdx},
               {CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx}) &&
           (NI.getInElemTy(CrossEntropyLossGradNode::LabelsIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getOutElemTy(
                CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx) ==
            ElemKind::Int64ITy);

  case Kinded::Kind::TraceEventNodeKind:
    return NI.getInElemTy(TraceEventNode::DataIdx) == ElemKind::Int64ITy;

  case Kinded::Kind::NonMaxSuppressionNodeKind:
    return NI.getInElemTy(NonMaxSuppressionNode::BoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(NonMaxSuppressionNode::ScoresIdx) ==
               ElemKind::FloatTy &&
           (NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
                ElemKind::Int64ITy) &&
           (NI.getOutElemTy(
                NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getOutElemTy(
                NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
                ElemKind::Int64ITy);

  case Kinded::Kind::TFLiteDetectionPostProcessNodeKind:
    return NI.getInElemTy(TFLiteDetectionPostProcessNode::BoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(TFLiteDetectionPostProcessNode::ScoresIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(TFLiteDetectionPostProcessNode::AnchorsIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(TFLiteDetectionPostProcessNode::DetectionBoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(
               TFLiteDetectionPostProcessNode::DetectionClassesIdx) ==
               ElemKind::Int32ITy &&
           NI.getOutElemTy(
               TFLiteDetectionPostProcessNode::DetectionScoresIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(TFLiteDetectionPostProcessNode::NumDetectionsIdx) ==
               ElemKind::Int32ITy;

  case Kinded::Kind::AudioSpectrogramNodeKind:
    return NI.getInElemTy(AudioSpectrogramNode::InputIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(AudioSpectrogramNode::SpectrogramIdx) ==
               ElemKind::FloatTy;

  case Kinded::Kind::MFCCNodeKind:
    return NI.getInElemTy(MFCCNode::SpectrogramIdx) == ElemKind::FloatTy &&
           NI.getOutElemTy(MFCCNode::CoefficientsIdx) == ElemKind::FloatTy;

  case Kinded::Kind::ConvertToNodeKind:
    return ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::FloatTy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int64ITy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
             ElemKind::Int32ITy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::Int32ITy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
             ElemKind::Int64ITy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::BoolTy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::FloatTy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) ==
             ElemKind::Int32ITy)) ||
           ((NI.getInElemTy(ConvertToNode::InputIdx) == ElemKind::BoolTy) &&
            (NI.getOutElemTy(ConvertToNode::ResultIdx) == ElemKind::Int32ITy));

  default:
    return false;
  }
}