in lib/Backends/Interpreter/Interpreter.cpp [79:966]
bool Interpreter::isOpSupported(const NodeInfo &NI) const {
switch (NI.getKind()) {
case Kinded::Kind::BatchedReduceProdNodeKind:
case Kinded::Kind::FmodNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::AddNodeKind:
case Kinded::Kind::SubNodeKind:
case Kinded::Kind::MulNodeKind:
case Kinded::Kind::DivNodeKind:
case Kinded::Kind::MaxNodeKind:
case Kinded::Kind::MinNodeKind:
case Kinded::Kind::ClipNodeKind:
case Kinded::Kind::BatchedReduceMinNodeKind:
case Kinded::Kind::BatchedReduceMaxNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::ResizeNearestNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::ResizeBilinearNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::BatchNormalizationNodeKind: {
auto elemType = NI.getInElemTy(BatchNormalizationNode::InputIdx);
// input can be int8, float16 or float32
bool isNodePrecisionSupported =
(elemType == ElemKind::Int8QTy || elemType == ElemKind::FloatTy ||
elemType == ElemKind::Float16Ty);
// parameters have to be float16 or float
isNodePrecisionSupported = isNodePrecisionSupported &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty},
{BatchNormalizationNode::InputIdx},
{BatchNormalizationNode::ResultIdx});
// input and output element types have to match
isNodePrecisionSupported =
isNodePrecisionSupported &&
NI.allInputsAndOutputsHaveSameElemKind(
{elemType},
{BatchNormalizationNode::ScaleIdx, BatchNormalizationNode::BiasIdx,
BatchNormalizationNode::MeanIdx, BatchNormalizationNode::VarIdx});
return isNodePrecisionSupported;
}
case Kinded::Kind::AvgPoolNodeKind:
case Kinded::Kind::AdaptiveAvgPoolNodeKind:
case Kinded::Kind::BatchedReduceAddNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int32ITy, ElemKind::FloatTy, ElemKind::Float16Ty,
ElemKind::BFloat16Ty, ElemKind::Int8QTy});
case Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind:
return (NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Float16Ty ||
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::FloatTy) &&
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy &&
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy;
case Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind:
return (NI.getInElemTy(
DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Float16Ty ||
NI.getInElemTy(
DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::FloatTy) &&
NI.getInElemTy(
DynamicRowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy &&
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(
DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(
DynamicRowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
ElemKind::Int32ITy;
case Kinded::Kind::MatMulNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int16QTy});
case Kinded::Kind::BatchMatMulNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
case Kinded::Kind::FullyConnectedNodeKind:
if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
}
return (NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
(NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::Int8QTy ||
NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
ElemKind::Int32QTy ||
NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy)) ||
(NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int16QTy}, {FullyConnectedNode::BiasIdx}) &&
(NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
ElemKind::Int16QTy ||
NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
ElemKind::Int32QTy));
case Kinded::Kind::MaxPoolNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy},
{}, {MaxPoolNode::ArgmaxIdx}) &&
(NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy);
case Kinded::Kind::ArgMaxNodeKind:
case Kinded::Kind::ArgMinNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
{ArgMaxNode::ResultIdx}) &&
(NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int32ITy);
case Kinded::Kind::AcosNodeKind:
case Kinded::Kind::AsinNodeKind:
case Kinded::Kind::AtanNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
case Kinded::Kind::PowNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy});
case Kinded::Kind::LocalResponseNormalizationNodeKind:
case Kinded::Kind::LayerNormalizationNodeKind:
case Kinded::Kind::LogNodeKind:
case Kinded::Kind::TanhNodeKind:
case Kinded::Kind::ExpNodeKind:
case Kinded::Kind::SigmoidNodeKind:
case Kinded::Kind::SoftPlusNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
case Kinded::Kind::SliceNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int32QTy, ElemKind::Int64ITy,
ElemKind::Int32ITy, ElemKind::BoolTy});
case Kinded::Kind::SpaceToDepthNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int64ITy});
case Kinded::Kind::SplatNodeKind:
case Kinded::Kind::TouchNodeKind:
case Kinded::Kind::InsertTensorNodeKind:
case Kinded::Kind::ConcatNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32ITy,
ElemKind::Int64ITy, ElemKind::BoolTy});
case Kinded::Kind::NonZeroNodeKind:
return NI.getInElemTy(NonZeroNode::CondIdx) == ElemKind::BoolTy &&
NI.getOutElemTy(NonZeroNode::ResultIdx) == ElemKind::Int32ITy;
case Kinded::Kind::SelectNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy},
{SelectNode::CondIdx}) &&
(NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy);
case Kinded::Kind::NotNodeKind:
case Kinded::Kind::AndNodeKind:
case Kinded::Kind::OrNodeKind:
case Kinded::Kind::XorNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::BoolTy});
case Kinded::Kind::SignNodeKind:
case Kinded::Kind::CeilNodeKind:
case Kinded::Kind::RoundNodeKind:
case Kinded::Kind::SqrtNodeKind:
case Kinded::Kind::RsqrtNodeKind:
case Kinded::Kind::ReciprocalNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy});
case Kinded::Kind::SinNodeKind:
case Kinded::Kind::CosNodeKind:
case Kinded::Kind::ErfNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
case Kinded::Kind::AbsNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
case Kinded::Kind::NegNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Int8QTy});
case Kinded::Kind::FloorNodeKind:
case Kinded::Kind::TruncateNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
case Kinded::Kind::CmpEQNodeKind:
case Kinded::Kind::CmpNEQNodeKind:
case Kinded::Kind::CmpLTNodeKind:
case Kinded::Kind::CmpLTENodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32ITy,
ElemKind::Int64ITy},
{}, {CmpEQNode::ResultIdx}) &&
(NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);
case Kinded::Kind::IsNaNNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{}, {IsNaNNode::ResultIdx}) &&
(NI.getOutElemTy(IsNaNNode::ResultIdx) == ElemKind::BoolTy);
case Kinded::Kind::BitwiseAndNodeKind:
case Kinded::Kind::BitwiseOrNodeKind:
case Kinded::Kind::BitwiseXorNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::BoolTy, ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::BitwiseNotNodeKind:
case Kinded::Kind::ModuloNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::ConvolutionNodeKind:
if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
}
return (NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {ConvolutionNode::BiasIdx}) &&
(NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int8QTy ||
NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy)) ||
(NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int16QTy}, {ConvolutionNode::BiasIdx}) &&
(NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int16QTy ||
NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy));
case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
return (NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
ElemKind::Int8QTy) &&
((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
ElemKind::Int8QTy) ||
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
ElemKind::Int32QTy)) &&
(NI.getInElemTy(
ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(
ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(
ChannelwiseQuantizedConvolutionNode::BiasScalesIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(
ChannelwiseQuantizedConvolutionNode::BiasOffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
ElemKind::Int8QTy);
case Kinded::Kind::Convolution3DNodeKind:
if (!NI.getInTy(Convolution3DNode::InputIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
}
return (NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {Convolution3DNode::BiasIdx}) &&
(NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int8QTy ||
NI.getInElemTy(Convolution3DNode::BiasIdx) ==
ElemKind::Int32QTy)) ||
(NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int16QTy}, {Convolution3DNode::BiasIdx}) &&
(NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int16QTy ||
NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int32QTy));
case Kinded::Kind::ConvTransposeNodeKind:
// TODO - support other types.
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
case Kinded::Kind::BatchedAddNodeKind:
if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
}
return (NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {BatchedAddNode::SliceIdx}) &&
(NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy ||
NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy)) ||
(NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int16QTy}, {BatchedAddNode::SliceIdx}) &&
(NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int16QTy ||
NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy));
case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
return (NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::Int8QTy ||
NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::Int32QTy) &&
(NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
ElemKind::Int8QTy);
case Kinded::Kind::SparseLengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy},
{SparseLengthsSumNode::IndicesIdx,
SparseLengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy},
{SparseLengthsWeightedSumNode::IndicesIdx,
SparseLengthsWeightedSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::EmbeddingNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{EmbeddingNode::IndicesIdx}) &&
(NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);
case Kinded::Kind::EmbeddingBagNodeKind:
return (NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
(((NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) ==
ElemKind::Int64ITy)) ||
((NI.getInElemTy(EmbeddingBagNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) ==
ElemKind::Int32ITy))));
case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind:
// GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to
// be checked because they are not used.
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
{SparseLengthsWeightedSumGradNode::IndicesIdx,
SparseLengthsWeightedSumGradNode::LengthsIdx},
{SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx,
SparseLengthsWeightedSumGradNode::
GradOfInputNamedLengthsIdx}) &&
(NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx,
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx,
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx}) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
ElemKind::UInt8QTy) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
if (!((NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
ElemKind::Int32ITy) ||
(NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx) ==
ElemKind::Int64ITy &&
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx) ==
ElemKind::Int64ITy))) {
return false;
}
switch (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx)) {
case ElemKind::UInt4FusedFP16QTy:
case ElemKind::UInt8FusedFP16QTy:
return (NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::WeightsIdx) ==
ElemKind::Float16Ty) &&
(NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx) ==
ElemKind::Float16Ty);
case ElemKind::UInt8FusedQTy:
return (
(((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
WeightsIdx) == ElemKind::FloatTy) &&
(NI.getOutElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
ResultIdx) == ElemKind::FloatTy))) ||
((NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
WeightsIdx) == ElemKind::Float16Ty) &&
(NI.getOutElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
ElemKind::Float16Ty)));
default:
return false;
}
}
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
if ((NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
ElemKind::Int64ITy &&
NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) !=
ElemKind::Int32ITy) ||
NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) !=
ElemKind::Int32ITy) {
return false;
}
switch (NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx)) {
case ElemKind::UInt4FusedFP16QTy:
case ElemKind::UInt8FusedFP16QTy:
return (NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
WeightsIdx) == ElemKind::Float16Ty ||
NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
WeightsIdx) == ElemKind::FloatTy) &&
(NI.getOutElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::
ResultIdx) == ElemKind::Float16Ty);
case ElemKind::UInt4FusedQTy:
case ElemKind::UInt8FusedQTy:
if ((NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
ElemKind::FloatTy) &&
(NI.getOutElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
ElemKind::FloatTy)) {
return true;
}
return (
(NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx) ==
ElemKind::Float16Ty) &&
(NI.getOutElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx) ==
ElemKind::Float16Ty));
default:
return false;
}
}
case Kinded::Kind::LengthsRangeFillNodeKind:
case Kinded::Kind::LengthsToRangesNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
case Kinded::Kind::GatherNodeKind:
// Note: Data and Result can be any data type, but must match.
return (NI.getInElemTy(GatherNode::DataIdx) ==
NI.getOutElemTy(GatherNode::ResultIdx)) &&
((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
(NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));
case Kinded::Kind::GatherNDNodeKind:
// Note: Data and Result can be any data type, but must match.
return ((NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int32ITy) ||
(NI.getInElemTy(GatherNDNode::IndicesIdx) == ElemKind::Int64ITy));
case Kinded::Kind::GatherElementsNodeKind:
// Note: Data and Result can be any data type, but must match.
return (NI.getInElemTy(GatherNode::DataIdx) ==
NI.getOutElemTy(GatherNode::ResultIdx)) &&
((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
(NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));
case Kinded::Kind::BatchOneHotNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy, ElemKind::Int32ITy, ElemKind::Int64ITy},
{BatchOneHotNode::LengthsIdx}) &&
(NI.getInElemTy(BatchOneHotNode::LengthsIdx) == ElemKind::Int32ITy);
case Kinded::Kind::QuantizationProfileNodeKind:
case Kinded::Kind::AvgPoolGradNodeKind:
case Kinded::Kind::AdaptiveAvgPoolGradNodeKind:
case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
case Kinded::Kind::MaxPoolGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
{MaxPoolGradNode::OriginalOutputForArgmaxIdx,
MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) &&
(NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(
MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
ElemKind::Int64ITy);
case Kinded::Kind::QuantizeNodeKind:
return ((NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) ||
(NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::Float16Ty) ||
(NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::BFloat16Ty)) &&
((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) ||
(NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::UInt8QTy) ||
(NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int16QTy) ||
(NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy));
case Kinded::Kind::DequantizeNodeKind:
return ((NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) ||
(NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::UInt8QTy) ||
(NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int16QTy) ||
(NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int32QTy) ||
(NI.getInElemTy(DequantizeNode::InputIdx) ==
ElemKind::UInt8FusedQTy)) &&
((NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy) ||
(NI.getOutElemTy(DequantizeNode::ResultIdx) ==
ElemKind::Float16Ty) ||
(NI.getOutElemTy(DequantizeNode::ResultIdx) ==
ElemKind::BFloat16Ty));
case Kinded::Kind::RescaleQuantizedNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy, ElemKind::Int16QTy, ElemKind::Int32QTy});
case Kinded::Kind::IntLookupTableNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy, ElemKind::Int16QTy});
case Kinded::Kind::ConvertToNodeKind: {
auto isConversionSupportedFor = [](ElemKind kind) {
switch (kind) {
case ElemKind::Float16Ty:
case ElemKind::BFloat16Ty:
case ElemKind::FloatTy:
case ElemKind::Int32ITy:
case ElemKind::Int64ITy:
case ElemKind::BoolTy:
return true;
default:
return false;
}
};
return (isConversionSupportedFor(NI.getInElemTy(ConvertToNode::InputIdx)) &&
isConversionSupportedFor(
NI.getOutElemTy(ConvertToNode::ResultIdx))) ||
(NI.getInElemTy(ConvertToNode::InputIdx) ==
ElemKind::UInt8FusedQTy &&
NI.getOutElemTy(ConvertToNode::ResultIdx) ==
ElemKind::UInt8FusedFP16QTy);
}
case Kinded::Kind::TopKNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int8QTy},
{}, {TopKNode::IndicesIdx}) &&
((NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy) ||
(NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy));
case Kinded::Kind::ScatterDataNodeKind:
return ((NI.getInElemTy(ScatterDataNode::IndicesIdx) ==
ElemKind::Int32ITy) ||
(NI.getInElemTy(ScatterDataNode::IndicesIdx) ==
ElemKind::Int64ITy)) &&
(NI.getOutElemTy(ScatterDataNode::ResultIdx) ==
NI.getInElemTy(ScatterDataNode::DataIdx)) &&
(NI.getOutElemTy(ScatterDataNode::ResultIdx) ==
NI.getInElemTy(ScatterDataNode::SlicesIdx));
// We just clip 64 to 32 SelectedIdx silently with the SoftMax
// SelectedIdx in case dim_t is 32b.
case Kinded::Kind::SoftMaxNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{SoftMaxNode::SelectedIdx}) &&
(NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy ||
NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy);
case Kinded::Kind::LogSoftMaxNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{LogSoftMaxNode::SelectedIdx}) &&
(NI.getInElemTy(LogSoftMaxNode::SelectedIdx) == ElemKind::Int32ITy ||
NI.getInElemTy(LogSoftMaxNode::SelectedIdx) == ElemKind::Int64ITy);
case Kinded::Kind::GatherRangesNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int32ITy, ElemKind::Int64ITy},
{GatherRangesNode::DataIdx}, {GatherRangesNode::OutputIdx}) &&
(NI.getOutElemTy(GatherRangesNode::OutputIdx) ==
NI.getInElemTy(GatherRangesNode::DataIdx));
case Kinded::Kind::CrossEntropyLossNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{CrossEntropyLossNode::LabelsIdx}) &&
(NI.getInElemTy(CrossEntropyLossNode::LabelsIdx) ==
ElemKind::Int64ITy);
case Kinded::Kind::CumSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty,
ElemKind::Int32ITy, ElemKind::Int64ITy});
case Kinded::Kind::LengthsSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{LengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(LengthsSumNode::LengthsIdx) == ElemKind::Int32ITy);
case Kinded::Kind::BatchSparseToDenseNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{BatchSparseToDenseNode::LengthsIdx,
BatchSparseToDenseNode::IndicesIdx}) &&
(NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::FillExamplesWithIndicatorNodeKind:
return (NI.getInElemTy(FillExamplesWithIndicatorNode::DataIdx) ==
NI.getOutElemTy(FillExamplesWithIndicatorNode::ResultIdx)) &&
((NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
ElemKind::Int32ITy) ||
(NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
ElemKind::Int64ITy) ||
(NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
ElemKind::BoolTy));
case Kinded::Kind::SparseToDenseMaskNodeKind:
return (NI.getInElemTy(SparseToDenseMaskNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseToDenseMaskNode::LengthsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseToDenseMaskNode::ValuesIdx) ==
NI.getInElemTy(SparseToDenseMaskNode::DefaultValueIdx)) &&
(NI.getInElemTy(SparseToDenseMaskNode::ValuesIdx) ==
NI.getOutElemTy(SparseToDenseMaskNode::ResultIdx));
case Kinded::Kind::SparseLabelSplitNodeKind:
return (NI.getInElemTy(SparseLabelSplitNode::LengthsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLabelSplitNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLabelSplitNode::ValuesIdx) ==
NI.getOutElemTy(SparseLabelSplitNode::LabelValuesIdx)) &&
(NI.getOutElemTy(SparseLabelSplitNode::ExampleIdsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(SparseLabelSplitNode::GradientOffsetMapIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::TraceEventNodeKind:
return NI.getInElemTy(TraceEventNode::DataIdx) == ElemKind::Int64ITy;
case Kinded::Kind::TransposeNodeKind:
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::SaveNodeKind:
case Kinded::Kind::FlipNodeKind:
// These work regardless of the underlying type.
return true;
case Kinded::Kind::GaussianFillNodeKind:
return NI.getOutElemTy(GaussianFillNode::ResultIdx) == ElemKind::Float16Ty;
case Kinded::Kind::NonMaxSuppressionNodeKind:
return NI.getInElemTy(NonMaxSuppressionNode::BoxesIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(NonMaxSuppressionNode::ScoresIdx) ==
ElemKind::FloatTy &&
(NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
ElemKind::Int32ITy ||
NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getOutElemTy(
NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
ElemKind::Int32ITy ||
NI.getOutElemTy(
NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getOutElemTy(
NonMaxSuppressionNode::NumberOfSelectedIndicesIdx) ==
NI.getOutElemTy(NonMaxSuppressionNode::IndicesIdx));
case Kinded::Kind::TFLiteDetectionPostProcessNodeKind:
return NI.getInElemTy(TFLiteDetectionPostProcessNode::BoxesIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(TFLiteDetectionPostProcessNode::ScoresIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(TFLiteDetectionPostProcessNode::AnchorsIdx) ==
ElemKind::FloatTy &&
NI.getOutElemTy(TFLiteDetectionPostProcessNode::DetectionBoxesIdx) ==
ElemKind::FloatTy &&
NI.getOutElemTy(
TFLiteDetectionPostProcessNode::DetectionClassesIdx) ==
ElemKind::Int32ITy &&
NI.getOutElemTy(
TFLiteDetectionPostProcessNode::DetectionScoresIdx) ==
ElemKind::FloatTy &&
NI.getOutElemTy(TFLiteDetectionPostProcessNode::NumDetectionsIdx) ==
ElemKind::Int32ITy;
case Kinded::Kind::AudioSpectrogramNodeKind:
return NI.getInElemTy(AudioSpectrogramNode::InputIdx) ==
ElemKind::FloatTy &&
NI.getOutElemTy(AudioSpectrogramNode::SpectrogramIdx) ==
ElemKind::FloatTy;
case Kinded::Kind::MFCCNodeKind:
return NI.getInElemTy(MFCCNode::SpectrogramIdx) == ElemKind::FloatTy &&
NI.getOutElemTy(MFCCNode::CoefficientsIdx) == ElemKind::FloatTy;
case Kinded::Kind::ROIAlignNodeKind:
return (NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) ==
ElemKind::Int32ITy ||
NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) ==
ElemKind::Int64ITy) &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty},
/*ignoreIn*/ {ROIAlignNode::BatchIndicesIdx});
case Kinded::Kind::CollectRpnProposalsNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty});
case Kinded::Kind::BBoxTransformNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty});
case Kinded::Kind::SoftMaxGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx},
{SoftMaxGradNode::GradOfInputNamedSelectedIdx}) &&
(NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int64ITy);
case Kinded::Kind::ConvolutionGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {},
{ConvolutionGradNode::GradOfInputNamedInputIdx});
case Kinded::Kind::CrossEntropyLossGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {CrossEntropyLossGradNode::LabelsIdx},
{CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx}) &&
(NI.getInElemTy(CrossEntropyLossGradNode::LabelsIdx) ==
ElemKind::Int64ITy) &&
(NI.getOutElemTy(
CrossEntropyLossGradNode::GradOfInputNamedLabelsIdx) ==
ElemKind::Int64ITy);
case Kinded::Kind::BatchedPairwiseDotProductNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
case Kinded::Kind::BatchedPairwiseDotProductGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
case Kinded::Kind::BucketizeNodeKind:
return NI.getInElemTy(BucketizeNode::InputIdx) == ElemKind::FloatTy &&
NI.getOutElemTy(BucketizeNode::ResultIdx) == ElemKind::Int32ITy;
case Kinded::Kind::BatchedUnaryEmbeddingsBagsNodeKind:
return (
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::BFloat16Ty},
{BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx,
BatchedUnaryEmbeddingsBagsNode::IndicesIdx,
BatchedUnaryEmbeddingsBagsNode::OffsetsIdx}) &&
(((NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::OffsetsIdx) ==
ElemKind::Int64ITy)) ||
((NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::TableOffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(BatchedUnaryEmbeddingsBagsNode::OffsetsIdx) ==
ElemKind::Int32ITy))));
case Kinded::Kind::IntNBitSplitEmbeddingBagsNodeKind:
return (((NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::OffsetsIdx) ==
ElemKind::Int64ITy)) ||
((NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::OffsetsIdx) ==
ElemKind::Int32ITy))) &&
NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::DevWeightsIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::UvmWeightsIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::WeightsTysIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingBagsNode::WeightsPlacementsIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::WeightsOffsetsIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(IntNBitSplitEmbeddingBagsNode::DimOffsetsIdx) ==
ElemKind::Int32ITy;
case Kinded::Kind::IntNBitSplitEmbeddingWeightedBagsNodeKind:
return (((NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::OffsetsIdx) ==
ElemKind::Int64ITy)) ||
((NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::OffsetsIdx) ==
ElemKind::Int32ITy))) &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::DevWeightsIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::UvmWeightsIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::WeightsTysIdx) ==
ElemKind::UInt8ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::WeightsPlacementsIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::WeightsOffsetsIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::DimOffsetsIdx) ==
ElemKind::Int32ITy &&
NI.getInElemTy(
IntNBitSplitEmbeddingWeightedBagsNode::IndiceWeightIdx) ==
ElemKind::FloatTy;
default:
return false;
}
}