bool Interpreter::isOpSupported()

in lib/Backends/Interpreter/Interpreter.cpp [79:966]


bool Interpreter::isOpSupported(const NodeInfo &NI) const {
  switch (NI.getKind()) {
  case Kinded::Kind::BatchedReduceProdNodeKind:
  case Kinded::Kind::FmodNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::AddNodeKind:
  case Kinded::Kind::SubNodeKind:
  case Kinded::Kind::MulNodeKind:
  case Kinded::Kind::DivNodeKind:
  case Kinded::Kind::MaxNodeKind:
  case Kinded::Kind::MinNodeKind:
  case Kinded::Kind::ClipNodeKind:
  case Kinded::Kind::BatchedReduceMinNodeKind:
  case Kinded::Kind::BatchedReduceMaxNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::ResizeNearestNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});
  case Kinded::Kind::ResizeBilinearNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy,
         ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::BatchNormalizationNodeKind: {
    auto elemType = NI.getInElemTy(BatchNormalizationNode::InputIdx);

    // input can be int8, float16 or float32
    bool isNodePrecisionSupported =
        (elemType == ElemKind::Int8QTy || elemType == ElemKind::FloatTy ||
         elemType == ElemKind::Float16Ty);

    // parameters have to be float16 or float
    isNodePrecisionSupported = isNodePrecisionSupported &&
                               NI.allInputsAndOutputsHaveSameElemKind(
                                   {ElemKind::FloatTy, ElemKind::Float16Ty},
                                   {BatchNormalizationNode::InputIdx},
                                   {BatchNormalizationNode::ResultIdx});

    // input and output element types have to match
    isNodePrecisionSupported =
        isNodePrecisionSupported &&
        NI.allInputsAndOutputsHaveSameElemKind(
            {elemType},
            {BatchNormalizationNode::ScaleIdx, BatchNormalizationNode::BiasIdx,
             BatchNormalizationNode::MeanIdx, BatchNormalizationNode::VarIdx});
    return isNodePrecisionSupported;
  }

  case Kinded::Kind::AvgPoolNodeKind:
  case Kinded::Kind::AdaptiveAvgPoolNodeKind:
  case Kinded::Kind::BatchedReduceAddNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int32ITy, ElemKind::FloatTy, ElemKind::Float16Ty,
         ElemKind::BFloat16Ty, ElemKind::Int8QTy});

  case Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind:
    return (NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
                ElemKind::Float16Ty ||
            NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
                ElemKind::FloatTy) &&
           NI.getInElemTy(DynamicQuantizedFullyConnectedNode::WeightsIdx) ==
               ElemKind::Int8QTy &&
           NI.getInElemTy(DynamicQuantizedFullyConnectedNode::BiasIdx) ==
               ElemKind::FloatTy;

  case Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind:
    return (NI.getInElemTy(
                DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
                ElemKind::Float16Ty ||
            NI.getInElemTy(
                DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
                ElemKind::FloatTy) &&
           NI.getInElemTy(
               DynamicRowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
               ElemKind::Int8QTy &&
           NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(
               DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(
               DynamicRowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
               ElemKind::Int32ITy;

  case Kinded::Kind::MatMulNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int16QTy});

  case Kinded::Kind::BatchMatMulNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});

  case Kinded::Kind::FullyConnectedNodeKind:
    if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
    }
    return (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
            (NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int8QTy ||
             NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
                 ElemKind::Int32QTy ||
             NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
                 ElemKind::FloatTy)) ||
           (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int16QTy}, {FullyConnectedNode::BiasIdx}) &&
            (NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
                 ElemKind::Int16QTy ||
             NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
                 ElemKind::Int32QTy));

  case Kinded::Kind::MaxPoolNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy},
               {}, {MaxPoolNode::ArgmaxIdx}) &&
           (NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy);

  case Kinded::Kind::ArgMaxNodeKind:
  case Kinded::Kind::ArgMinNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
               {ArgMaxNode::ResultIdx}) &&
           (NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
            NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::AcosNodeKind:
  case Kinded::Kind::AsinNodeKind:
  case Kinded::Kind::AtanNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});

  case Kinded::Kind::PowNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy});
  case Kinded::Kind::LocalResponseNormalizationNodeKind:
  case Kinded::Kind::LayerNormalizationNodeKind:
  case Kinded::Kind::LogNodeKind:
  case Kinded::Kind::TanhNodeKind:
  case Kinded::Kind::ExpNodeKind:
  case Kinded::Kind::SigmoidNodeKind:
  case Kinded::Kind::SoftPlusNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
  case Kinded::Kind::SliceNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int32QTy, ElemKind::Int64ITy,
         ElemKind::Int32ITy, ElemKind::BoolTy});
  case Kinded::Kind::SpaceToDepthNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int64ITy});

  case Kinded::Kind::SplatNodeKind:
  case Kinded::Kind::TouchNodeKind:
  case Kinded::Kind::InsertTensorNodeKind:
  case Kinded::Kind::ConcatNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32ITy,
         ElemKind::Int64ITy, ElemKind::BoolTy});
  case Kinded::Kind::NonZeroNodeKind:
    return NI.getInElemTy(NonZeroNode::CondIdx) == ElemKind::BoolTy &&
           NI.getOutElemTy(NonZeroNode::ResultIdx) == ElemKind::Int32ITy;
  case Kinded::Kind::SelectNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy},
               {SelectNode::CondIdx}) &&
           (NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy);

  case Kinded::Kind::NotNodeKind:
  case Kinded::Kind::AndNodeKind:
  case Kinded::Kind::OrNodeKind:
  case Kinded::Kind::XorNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::BoolTy});

  case Kinded::Kind::SignNodeKind:
  case Kinded::Kind::CeilNodeKind:
  case Kinded::Kind::RoundNodeKind:
  case Kinded::Kind::SqrtNodeKind:
  case Kinded::Kind::RsqrtNodeKind:
  case Kinded::Kind::ReciprocalNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int8QTy});

  case Kinded::Kind::SinNodeKind:
  case Kinded::Kind::CosNodeKind:
  case Kinded::Kind::ErfNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});

  case Kinded::Kind::AbsNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});

  case Kinded::Kind::NegNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int8QTy});

  case Kinded::Kind::FloorNodeKind:
  case Kinded::Kind::TruncateNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});

  case Kinded::Kind::CmpEQNodeKind:
  case Kinded::Kind::CmpNEQNodeKind:
  case Kinded::Kind::CmpLTNodeKind:
  case Kinded::Kind::CmpLTENodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32ITy,
                ElemKind::Int64ITy},
               {}, {CmpEQNode::ResultIdx}) &&
           (NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);

  case Kinded::Kind::IsNaNNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {}, {IsNaNNode::ResultIdx}) &&
           (NI.getOutElemTy(IsNaNNode::ResultIdx) == ElemKind::BoolTy);

  case Kinded::Kind::BitwiseAndNodeKind:
  case Kinded::Kind::BitwiseOrNodeKind:
  case Kinded::Kind::BitwiseXorNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::BoolTy, ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::BitwiseNotNodeKind:
  case Kinded::Kind::ModuloNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::ConvolutionNodeKind:
    if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
    }
    return (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int8QTy}, {ConvolutionNode::BiasIdx}) &&
            (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int8QTy ||
             NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy)) ||
           (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int16QTy}, {ConvolutionNode::BiasIdx}) &&
            (NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int16QTy ||
             NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
    return (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
            ElemKind::Int8QTy) &&
           ((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
             ElemKind::Int8QTy) ||
            (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
             ElemKind::Int32QTy)) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::BiasScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(
                ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
            ElemKind::Int8QTy);

  case Kinded::Kind::Convolution3DNodeKind:
    if (!NI.getInTy(Convolution3DNode::InputIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
    }
    return (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int8QTy}, {Convolution3DNode::BiasIdx}) &&
            (NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int8QTy ||
             NI.getInElemTy(Convolution3DNode::BiasIdx) ==
                 ElemKind::Int32QTy)) ||
           (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int16QTy}, {Convolution3DNode::BiasIdx}) &&
            (NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int16QTy ||
             NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::ConvTransposeNodeKind:
    // TODO - support other types.
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::BatchedAddNodeKind:
    if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) {
      return NI.allInputsAndOutputsHaveSameElemKind(
          {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
    }
    return (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int8QTy}, {BatchedAddNode::SliceIdx}) &&
            (NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy ||
             NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy)) ||
           (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::Int16QTy}, {BatchedAddNode::SliceIdx}) &&
            (NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int16QTy ||
             NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
    return (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
            ElemKind::Int8QTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
            ElemKind::FloatTy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
                ElemKind::Int8QTy ||
            NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
                ElemKind::Int32QTy) &&
           (NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
            ElemKind::Int8QTy);

  case Kinded::Kind::SparseLengthsSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy},
               {SparseLengthsSumNode::IndicesIdx,
                SparseLengthsSumNode::LengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy},
               {SparseLengthsWeightedSumNode::IndicesIdx,
                SparseLengthsWeightedSumNode::LengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {EmbeddingNode::IndicesIdx}) &&
           (NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int64ITy ||
            NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingBagNodeKind:
    return (NI.allInputsAndOutputsHaveSameElemKind(
                {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
                {EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
            (((NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
               ElemKind::Int64ITy) &&
              (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) ==
               ElemKind::Int64ITy)) ||
             ((NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
               ElemKind::Int32ITy) &&
              (NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) ==
               ElemKind::Int32ITy))));

  case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind:
    // GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to
    // be checked because they are not used.
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {SparseLengthsWeightedSumGradNode::IndicesIdx,
                SparseLengthsWeightedSumGradNode::LengthsIdx},
               {SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx,
                SparseLengthsWeightedSumGradNode::
                    GradOfInputNamedLengthsIdx}) &&
           (NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx,
                RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx,
                RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx}) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
            ElemKind::UInt8QTy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(
                RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
    if (!((NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
               ElemKind::Int32ITy) ||
          (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
               ElemKind::Int64ITy &&
           NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
               ElemKind::Int64ITy))) {
      return false;
    }

    switch (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx)) {
    case ElemKind::UInt4FusedFP16QTy:
    case ElemKind::UInt8FusedFP16QTy:
      return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
              ElemKind::Float16Ty) &&
             (NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
              ElemKind::Float16Ty);
    case ElemKind::UInt8FusedQTy:
      return (
          (((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                WeightsIdx) == ElemKind::FloatTy) &&
            (NI.getOutElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                 ResultIdx) == ElemKind::FloatTy))) ||
          ((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                               WeightsIdx) == ElemKind::Float16Ty) &&
           (NI.getOutElemTy(
                FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
            ElemKind::Float16Ty)));
    default:
      return false;
    }
  }

  case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
    if ((NI.getInElemTy(
             FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
             ElemKind::Int64ITy &&
         NI.getInElemTy(
             FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
             ElemKind::Int32ITy) ||
        NI.getInElemTy(
            FusedRowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) !=
            ElemKind::Int32ITy) {
      return false;
    }

    switch (NI.getInElemTy(
        FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx)) {
    case ElemKind::UInt4FusedFP16QTy:
    case ElemKind::UInt8FusedFP16QTy:
      return (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                 WeightsIdx) == ElemKind::Float16Ty ||
              NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                                 WeightsIdx) == ElemKind::FloatTy) &&
             (NI.getOutElemTy(
                  FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
                      ResultIdx) == ElemKind::Float16Ty);
    case ElemKind::UInt4FusedQTy:
    case ElemKind::UInt8FusedQTy:
      if ((NI.getInElemTy(
               FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
           ElemKind::FloatTy) &&
          (NI.getOutElemTy(
               FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
           ElemKind::FloatTy)) {
        return true;
      }
      return (
          (NI.getInElemTy(
               FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
           ElemKind::Float16Ty) &&
          (NI.getOutElemTy(
               FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
           ElemKind::Float16Ty));
    default:
      return false;
    }
  }

  case Kinded::Kind::LengthsRangeFillNodeKind:
  case Kinded::Kind::LengthsToRangesNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});

  case Kinded::Kind::GatherNodeKind:
    // Note: Data and Result can be any data type, but must match.
    return (NI.getInElemTy(GatherNode::DataIdx) ==
            NI.getOutElemTy(GatherNode::ResultIdx)) &&
           ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
            (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));

  case Kinded::Kind::GatherNDNodeKind:
    // Note: Data and Result can be any data type, but must match.
    return ((NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int32ITy) ||
            (NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int64ITy));

  case Kinded::Kind::GatherElementsNodeKind:
    // Note: Data and Result can be any data type, but must match.
    return (NI.getInElemTy(GatherNode::DataIdx) ==
            NI.getOutElemTy(GatherNode::ResultIdx)) &&
           ((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
            (NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));

  case Kinded::Kind::BatchOneHotNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy, ElemKind::Int32ITy, ElemKind::Int64ITy},
               {BatchOneHotNode::LengthsIdx}) &&
           (NI.getInElemTy(BatchOneHotNode::LengthsIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::QuantizationProfileNodeKind:
  case Kinded::Kind::AvgPoolGradNodeKind:
  case Kinded::Kind::AdaptiveAvgPoolGradNodeKind:
  case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::MaxPoolGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy},
               {MaxPoolGradNode::OriginalOutputForArgmaxIdx,
                MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) &&
           (NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getInElemTy(
                MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
            ElemKind::Int64ITy);

  case Kinded::Kind::QuantizeNodeKind:
    return ((NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) ||
            (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::Float16Ty) ||
            (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::BFloat16Ty)) &&
           ((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) ||
            (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::UInt8QTy) ||
            (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int16QTy) ||
            (NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy));

  case Kinded::Kind::DequantizeNodeKind:
    return ((NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::UInt8QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int16QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int32QTy) ||
            (NI.getInElemTy(DequantizeNode::InputIdx) ==
             ElemKind::UInt8FusedQTy)) &&
           ((NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy) ||
            (NI.getOutElemTy(DequantizeNode::ResultIdx) ==
             ElemKind::Float16Ty) ||
            (NI.getOutElemTy(DequantizeNode::ResultIdx) ==
             ElemKind::BFloat16Ty));

  case Kinded::Kind::RescaleQuantizedNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy});

  case Kinded::Kind::IntLookupTableNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::Int8QTy, ElemKind::Int16QTy});

  case Kinded::Kind::ConvertToNodeKind: {
    auto isConversionSupportedFor = [](ElemKind kind) {
      switch (kind) {
      case ElemKind::Float16Ty:
      case ElemKind::BFloat16Ty:
      case ElemKind::FloatTy:
      case ElemKind::Int32ITy:
      case ElemKind::Int64ITy:
      case ElemKind::BoolTy:
        return true;
      default:
        return false;
      }
    };
    return (isConversionSupportedFor(NI.getInElemTy(ConvertToNode::InputIdx)) &&
            isConversionSupportedFor(
                NI.getOutElemTy(ConvertToNode::ResultIdx))) ||
           (NI.getInElemTy(ConvertToNode::InputIdx) ==
                ElemKind::UInt8FusedQTy &&
            NI.getOutElemTy(ConvertToNode::ResultIdx) ==
                ElemKind::UInt8FusedFP16QTy);
  }

  case Kinded::Kind::TopKNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
                ElemKind::Int8QTy},
               {}, {TopKNode::IndicesIdx}) &&
           ((NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy) ||
            (NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy));

  case Kinded::Kind::ScatterDataNodeKind:
    return ((NI.getInElemTy(ScatterDataNode::IndicesIdx) ==
             ElemKind::Int32ITy) ||
            (NI.getInElemTy(ScatterDataNode::IndicesIdx) ==
             ElemKind::Int64ITy)) &&
           (NI.getOutElemTy(ScatterDataNode::ResultIdx) ==
            NI.getInElemTy(ScatterDataNode::DataIdx)) &&
           (NI.getOutElemTy(ScatterDataNode::ResultIdx) ==
            NI.getInElemTy(ScatterDataNode::SlicesIdx));

  // We just clip 64 to 32 SelectedIdx silently with the SoftMax
  // SelectedIdx in case dim_t is 32b.
  case Kinded::Kind::SoftMaxNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {SoftMaxNode::SelectedIdx}) &&
           (NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy ||
            NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy);

  case Kinded::Kind::LogSoftMaxNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {LogSoftMaxNode::SelectedIdx}) &&
           (NI.getInElemTy(LogSoftMaxNode::SelectedIdx) == ElemKind::Int32ITy ||
            NI.getInElemTy(LogSoftMaxNode::SelectedIdx) == ElemKind::Int64ITy);

  case Kinded::Kind::GatherRangesNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::Int32ITy, ElemKind::Int64ITy},
               {GatherRangesNode::DataIdx}, {GatherRangesNode::OutputIdx}) &&
           (NI.getOutElemTy(GatherRangesNode::OutputIdx) ==
            NI.getInElemTy(GatherRangesNode::DataIdx));

  case Kinded::Kind::CrossEntropyLossNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {CrossEntropyLossNode::LabelsIdx}) &&
           (NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
            ElemKind::Int64ITy);

  case Kinded::Kind::CumSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
         ElemKind::Int32ITy, ElemKind::Int64ITy});

  case Kinded::Kind::LengthsSumNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {LengthsSumNode::LengthsIdx}) &&
           (NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);

  case Kinded::Kind::BatchSparseToDenseNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
               {BatchSparseToDenseNode::LengthsIdx,
                BatchSparseToDenseNode::IndicesIdx}) &&
           (NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
                ElemKind::Int32ITy) &&
           (NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
                ElemKind::Int64ITy ||
            NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
                ElemKind::Int32ITy);

  case Kinded::Kind::FillExamplesWithIndicatorNodeKind:
    return (NI.getInElemTy(FillExamplesWithIndicatorNode::DataIdx) ==
            NI.getOutElemTy(FillExamplesWithIndicatorNode::ResultIdx)) &&
           ((NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
             ElemKind::Int32ITy) ||
            (NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
             ElemKind::Int64ITy) ||
            (NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
             ElemKind::BoolTy));

  case Kinded::Kind::SparseToDenseMaskNodeKind:
    return (NI.getInElemTy(SparseToDenseMaskNode::IndicesIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getInElemTy(SparseToDenseMaskNode::LengthsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseToDenseMaskNode::ValuesIdx) ==
            NI.getInElemTy(SparseToDenseMaskNode::DefaultValueIdx)) &&
           (NI.getInElemTy(SparseToDenseMaskNode::ValuesIdx) ==
            NI.getOutElemTy(SparseToDenseMaskNode::ResultIdx));

  case Kinded::Kind::SparseLabelSplitNodeKind:
    return (NI.getInElemTy(SparseLabelSplitNode::LengthsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getInElemTy(SparseLabelSplitNode::IndicesIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getInElemTy(SparseLabelSplitNode::ValuesIdx) ==
            NI.getOutElemTy(SparseLabelSplitNode::LabelValuesIdx)) &&
           (NI.getOutElemTy(SparseLabelSplitNode::ExampleIdsIdx) ==
            ElemKind::Int32ITy) &&
           (NI.getOutElemTy(SparseLabelSplitNode::GradientOffsetMapIdx) ==
            ElemKind::Int32ITy);

  case Kinded::Kind::TraceEventNodeKind:
    return NI.getInElemTy(TraceEventNode::DataIdx) == ElemKind::Int64ITy;

  case Kinded::Kind::TransposeNodeKind:
  case Kinded::Kind::ReshapeNodeKind:
  case Kinded::Kind::SaveNodeKind:
  case Kinded::Kind::FlipNodeKind:
    // These work regardless of the underlying type.
    return true;

  case Kinded::Kind::GaussianFillNodeKind:
    return NI.getOutElemTy(GaussianFillNode::ResultIdx) == ElemKind::Float16Ty;

  case Kinded::Kind::NonMaxSuppressionNodeKind:
    return NI.getInElemTy(NonMaxSuppressionNode::BoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(NonMaxSuppressionNode::ScoresIdx) ==
               ElemKind::FloatTy &&
           (NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
                ElemKind::Int64ITy) &&
           (NI.getOutElemTy(
                NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getOutElemTy(
                NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
                ElemKind::Int64ITy) &&
           (NI.getOutElemTy(
                NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
            NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx));

  case Kinded::Kind::TFLiteDetectionPostProcessNodeKind:
    return NI.getInElemTy(TFLiteDetectionPostProcessNode::BoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(TFLiteDetectionPostProcessNode::ScoresIdx) ==
               ElemKind::FloatTy &&
           NI.getInElemTy(TFLiteDetectionPostProcessNode::AnchorsIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(TFLiteDetectionPostProcessNode::DetectionBoxesIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(
               TFLiteDetectionPostProcessNode::DetectionClassesIdx) ==
               ElemKind::Int32ITy &&
           NI.getOutElemTy(
               TFLiteDetectionPostProcessNode::DetectionScoresIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(TFLiteDetectionPostProcessNode::NumDetectionsIdx) ==
               ElemKind::Int32ITy;

  case Kinded::Kind::AudioSpectrogramNodeKind:
    return NI.getInElemTy(AudioSpectrogramNode::InputIdx) ==
               ElemKind::FloatTy &&
           NI.getOutElemTy(AudioSpectrogramNode::SpectrogramIdx) ==
               ElemKind::FloatTy;

  case Kinded::Kind::MFCCNodeKind:
    return NI.getInElemTy(MFCCNode::SpectrogramIdx) == ElemKind::FloatTy &&
           NI.getOutElemTy(MFCCNode::CoefficientsIdx) == ElemKind::FloatTy;

  case Kinded::Kind::ROIAlignNodeKind:
    return (NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) ==
                ElemKind::Int32ITy ||
            NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) ==
                ElemKind::Int64ITy) &&
           NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy, ElemKind::Float16Ty},
               /*ignoreIn*/ {ROIAlignNode::BatchIndicesIdx});

  case Kinded::Kind::CollectRpnProposalsNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});

  case Kinded::Kind::BBoxTransformNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy, ElemKind::Float16Ty});

  case Kinded::Kind::SoftMaxGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx},
               {SoftMaxGradNode::GradOfInputNamedSelectedIdx}) &&
           (NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int64ITy);

  case Kinded::Kind::ConvolutionGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
        {ElemKind::FloatTy}, {},
        {ConvolutionGradNode::GradOfInputNamedInputIdx});

  case Kinded::Kind::CrossEntropyLossGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind(
               {ElemKind::FloatTy}, {CrossEntropyLossGradNode::LabelsIdx},
               {CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx}) &&
           (NI.getInElemTy(CrossEntropyLossGradNode::LabelsIdx) ==
            ElemKind::Int64ITy) &&
           (NI.getOutElemTy(
                CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx) ==
            ElemKind::Int64ITy);

  case Kinded::Kind::BatchedPairwiseDotProductNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::BatchedPairwiseDotProductGradNodeKind:
    return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});

  case Kinded::Kind::BucketizeNodeKind:
    return NI.getInElemTy(BucketizeNode::InputIdx) == ElemKind::FloatTy &&
           NI.getOutElemTy(BucketizeNode::ResultIdx) == ElemKind::Int32ITy;

  case Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind:
    return (
        NI.allInputsAndOutputsHaveSameElemKind(
            {ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
            {BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx,
             BatchedUnaryEmbeddingsBagsNode::IndicesIdx,
             BatchedUnaryEmbeddingsBagsNode::OffsetsIdx}) &&
        (((NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx) ==
           ElemKind::Int64ITy) &&
          (NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::IndicesIdx) ==
           ElemKind::Int64ITy) &&
          (NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::OffsetsIdx) ==
           ElemKind::Int64ITy)) ||
         ((NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx) ==
           ElemKind::Int32ITy) &&
          (NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::IndicesIdx) ==
           ElemKind::Int32ITy) &&
          (NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::OffsetsIdx) ==
           ElemKind::Int32ITy))));

  case Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind:
    return (((NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::IndicesIdx) ==
              ElemKind::Int64ITy) &&
             (NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::OffsetsIdx) ==
              ElemKind::Int64ITy)) ||
            ((NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::IndicesIdx) ==
              ElemKind::Int32ITy) &&
             (NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::OffsetsIdx) ==
              ElemKind::Int32ITy))) &&
           NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::DevWeightsIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::UvmWeightsIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::WeightsTysIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingBagsNode::WeightsPlacementsIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::WeightsOffsetsIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::DimOffsetsIdx) ==
               ElemKind::Int32ITy;

  case Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind:
    return (((NI.getInElemTy(
                  IntNBitSplitEmbeddingWeightedBagsNode::IndicesIdx) ==
              ElemKind::Int64ITy) &&
             (NI.getInElemTy(
                  IntNBitSplitEmbeddingWeightedBagsNode::OffsetsIdx) ==
              ElemKind::Int64ITy)) ||
            ((NI.getInElemTy(
                  IntNBitSplitEmbeddingWeightedBagsNode::IndicesIdx) ==
              ElemKind::Int32ITy) &&
             (NI.getInElemTy(
                  IntNBitSplitEmbeddingWeightedBagsNode::OffsetsIdx) ==
              ElemKind::Int32ITy))) &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::DevWeightsIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::UvmWeightsIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::WeightsTysIdx) ==
               ElemKind::UInt8ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::WeightsPlacementsIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::WeightsOffsetsIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::DimOffsetsIdx) ==
               ElemKind::Int32ITy &&
           NI.getInElemTy(
               IntNBitSplitEmbeddingWeightedBagsNode::IndiceWeightIdx) ==
               ElemKind::FloatTy;

  default:
    return false;
  }
}