in lib/Backends/NNPI/NNPI.cpp [121:843]
static NodeSupportLevels isNodeSupported(const NodeInfo &NI) {
bool isNodePrecisionSupported = false;
bool isNodeHasAnySupport = true;
switch (NI.getKind()) {
// General math fp32/fp16/i8/int32.
case Kinded::Kind::AddNodeKind:
case Kinded::Kind::SubNodeKind:
case Kinded::Kind::MulNodeKind:
case Kinded::Kind::MaxNodeKind:
case Kinded::Kind::MinNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int32ITy, ElemKind::Float16Ty,
ElemKind::Int8QTy, ElemKind::Int64ITy});
break;
case Kinded::Kind::DivNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int64ITy, ElemKind::Int32ITy});
break;
// General math fp32/fp16/i8.
case Kinded::Kind::PowNodeKind:
case Kinded::Kind::ReluNodeKind:
case Kinded::Kind::ReplaceNaNNodeKind:
case Kinded::Kind::MatMulNodeKind:
case Kinded::Kind::BatchedReduceAddNodeKind:
case Kinded::Kind::BatchedReduceMeanNodeKind:
case Kinded::Kind::BatchedAddNodeKind:
case Kinded::Kind::BatchedMulNodeKind:
case Kinded::Kind::TanhNodeKind:
case Kinded::Kind::LogNodeKind:
case Kinded::Kind::SigmoidNodeKind:
case Kinded::Kind::NegNodeKind:
case Kinded::Kind::AbsNodeKind:
case Kinded::Kind::ExpNodeKind:
case Kinded::Kind::SoftPlusNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy});
break;
case Kinded::Kind::BatchedReduceMinNodeKind:
case Kinded::Kind::BatchedReduceMaxNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy});
break;
case Kinded::Kind::SplatNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy});
break;
case Kinded::Kind::LocalResponseNormalizationNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy});
break;
case Kinded::Kind::ModuloNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
break;
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 1
case Kinded::Kind::NNPILookupTableNodeKind:
case Kinded::Kind::IntLookupTableNodeKind:
isNodePrecisionSupported = true;
break;
case Kinded::Kind::BBoxTransformNodeKind:
// RoiBatchSplits output should be FP16 in the Glow node and get
// converted explicitly to FP32 in NNPI importer.
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
break;
case Kinded::Kind::ROIAlignNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty}, {ROIAlignNode::BatchIndicesIdx}) &&
(NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) == ElemKind::Int32ITy ||
NI.getInElemTy(ROIAlignNode::BatchIndicesIdx) == ElemKind::Int64ITy);
break;
case Kinded::Kind::LSTMUnitNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
break;
case Kinded::Kind::ResizeNearestNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int32QTy,
ElemKind::Int8QTy, ElemKind::UInt8QTy});
break;
case Kinded::Kind::SparseLabelSplitNodeKind: {
auto valuesIdxDataType = NI.getInElemTy(SparseLabelSplitNode::ValuesIdx);
isNodePrecisionSupported =
(NI.getInElemTy(SparseLabelSplitNode::LengthsIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLabelSplitNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLabelSplitNode::ValuesIdx) ==
NI.getOutElemTy(SparseLabelSplitNode::LabelValuesIdx)) &&
(NI.getOutElemTy(SparseLabelSplitNode::ExampleIdsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(SparseLabelSplitNode::GradientOffsetMapIdx) ==
ElemKind::Int32ITy) &&
(valuesIdxDataType == ElemKind::FloatTy ||
valuesIdxDataType == ElemKind::Float16Ty ||
valuesIdxDataType == ElemKind::Int8QTy ||
valuesIdxDataType == ElemKind::UInt8QTy);
break;
}
#endif // NNPI > 1.1
case Kinded::Kind::LayerNormalizationNodeKind: {
auto scaleType = NI.getInElemTy(LayerNormalizationNode::ScaleIdx);
auto biasType = NI.getInElemTy(LayerNormalizationNode::BiasIdx);
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy},
{LayerNormalizationNode::ScaleIdx,
LayerNormalizationNode::BiasIdx}) &&
scaleType == biasType &&
(scaleType == ElemKind::Float16Ty || scaleType == ElemKind::Int8QTy);
break;
}
case Kinded::Kind::SwishNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy});
break;
case Kinded::Kind::GeluNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
break;
case Kinded::Kind::BatchNormalizationNodeKind: {
auto elemType = NI.getInElemTy(BatchNormalizationNode::InputIdx);
isNodePrecisionSupported =
(elemType == ElemKind::Int8QTy || elemType == ElemKind::FloatTy ||
elemType == ElemKind::Float16Ty);
isNodePrecisionSupported = isNodePrecisionSupported &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty},
{BatchNormalizationNode::InputIdx},
{BatchNormalizationNode::ResultIdx});
isNodePrecisionSupported =
isNodePrecisionSupported &&
NI.allInputsAndOutputsHaveSameElemKind(
{elemType},
{BatchNormalizationNode::ScaleIdx, BatchNormalizationNode::BiasIdx,
BatchNormalizationNode::MeanIdx, BatchNormalizationNode::VarIdx});
break;
}
case Kinded::Kind::VectorNormNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy});
break;
case Kinded::Kind::AvgPoolNodeKind:
case Kinded::Kind::AdaptiveAvgPoolNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy});
break;
case Kinded::Kind::BatchMatMulNodeKind:
case Kinded::Kind::PReluNodeKind:
case Kinded::Kind::ClipNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy, ElemKind::Float16Ty});
break;
case Kinded::Kind::FmodNodeKind:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
#else
// Supporting these two for now because for fp inputs NNPI returns result
// with the same sign as the divisor instead of the dividend.
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int64ITy, ElemKind::Int32ITy});
#endif // NNPI >= 1.7
break;
// Data transfer fp32/fp16/i8/i32/i64/bool.
case Kinded::Kind::SaveNodeKind:
case Kinded::Kind::ConcatNodeKind:
case Kinded::Kind::TileNodeKind:
case Kinded::Kind::TransposeNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy});
break;
case Kinded::Kind::ConvolutionNodeKind: {
if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty});
} else {
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
{ConvolutionNode::BiasIdx}) &&
((NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy) ||
(NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::FloatTy));
}
break;
}
case Kinded::Kind::Convolution3DNodeKind:
if (!NI.getInTy(Convolution3DNode::InputIdx)->isQuantizedType()) {
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
} else {
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {Convolution3DNode::BiasIdx}) &&
((NI.getInElemTy(Convolution3DNode::BiasIdx) == ElemKind::Int32QTy) ||
(NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::FloatTy));
}
break;
case Kinded::Kind::QuantizeNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy ||
NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::Float16Ty) &&
(NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy);
break;
case Kinded::Kind::DequantizeNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) &&
(NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy ||
NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::Float16Ty);
break;
case Kinded::Kind::RescaleQuantizedNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy});
break;
case Kinded::Kind::ConvertToNodeKind: {
auto isConversionSupportedFor = [](ElemKind kindFrom, ElemKind kindTo) {
switch (kindFrom) {
case ElemKind::Float16Ty:
switch (kindTo) {
case ElemKind::FloatTy:
case ElemKind::Int8QTy:
case ElemKind::UInt8QTy:
case ElemKind::BoolTy:
return true;
case ElemKind::Int32ITy:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
return true;
#else
return glow::nnpi::flags::EnableCustomIAKernels;
#endif // NNPI >= 1.7
default:
return false;
}
return false;
case ElemKind::FloatTy:
switch (kindTo) {
case ElemKind::Float16Ty:
case ElemKind::Int8QTy:
case ElemKind::UInt8QTy:
case ElemKind::BoolTy:
return true;
case ElemKind::Int32ITy:
return glow::nnpi::flags::EnableCustomIAKernels;
default:
return false;
}
return false;
case ElemKind::Int64ITy:
switch (kindTo) {
case ElemKind::Int32ITy:
case ElemKind::FloatTy:
case ElemKind::Int8QTy:
return true;
default:
return false;
}
return false;
// NOTE: this is supported by a custom kernel
case ElemKind::BoolTy:
switch (kindTo) {
case ElemKind::Int32ITy:
return true;
default:
return false;
}
return false;
case ElemKind::Int32ITy:
switch (kindTo) {
case ElemKind::Int64ITy:
case ElemKind::Float16Ty:
case ElemKind::FloatTy:
case ElemKind::Int8QTy:
return true;
case ElemKind::BoolTy:
return glow::nnpi::flags::EnableCustomIAKernels;
default:
return false;
}
return false;
case ElemKind::Int32QTy:
switch (kindTo) {
case ElemKind::Float16Ty:
return true;
default:
return false;
}
return false;
case ElemKind::UInt8QTy:
case ElemKind::Int8QTy:
return true;
case ElemKind::UInt8FusedQTy:
return (kindTo == ElemKind::Float16Ty ||
kindTo == ElemKind::UInt8FusedFP16QTy);
case ElemKind::UInt8FusedFP16QTy:
return (kindTo == ElemKind::Float16Ty);
default:
return false;
}
return false;
};
isNodePrecisionSupported =
isConversionSupportedFor(NI.getInElemTy(ConvertToNode::InputIdx),
NI.getOutElemTy(ConvertToNode::ResultIdx));
break;
}
case Kinded::Kind::DynamicQuantizedFullyConnectedNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Float16Ty ||
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::FloatTy) &&
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy &&
NI.getInElemTy(DynamicQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy;
break;
case Kinded::Kind::DynamicRowwiseQuantizedFullyConnectedNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Float16Ty ||
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::FloatTy) &&
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy &&
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
ElemKind::FloatTy &&
NI.getInElemTy(DynamicRowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
ElemKind::Int32ITy;
break;
case Kinded::Kind::FullyConnectedNodeKind:
if (!NI.getInTy(FullyConnectedNode::InputIdx)->isQuantizedType()) {
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty});
} else {
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int8QTy}, {FullyConnectedNode::BiasIdx}) &&
((NI.getInElemTy(FullyConnectedNode::BiasIdx) ==
ElemKind::Int32QTy) ||
(NI.getInElemTy(FullyConnectedNode::BiasIdx) == ElemKind::FloatTy));
}
break;
case Kinded::Kind::MaxPoolNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
{MaxPoolNode::ArgmaxIdx}) &&
(NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy);
break;
case Kinded::Kind::TopKNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy}, {},
{TopKNode::IndicesIdx}) &&
(NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy ||
NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::GatherNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int64ITy,
ElemKind::Int8QTy},
{GatherNode::IndicesIdx}) &&
((NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int32ITy) ||
(NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy));
break;
case Kinded::Kind::GatherRangesNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int32ITy, ElemKind::Int64ITy},
{GatherRangesNode::DataIdx}, {GatherRangesNode::OutputIdx}) &&
((NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::FloatTy) ||
(NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Float16Ty) ||
(NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int8QTy) ||
(NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int32ITy) ||
(NI.getInElemTy(GatherRangesNode::DataIdx) == ElemKind::Int64ITy)) &&
(NI.getOutElemTy(GatherRangesNode::OutputIdx) ==
NI.getInElemTy(GatherRangesNode::DataIdx));
break;
case Kinded::Kind::SliceNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy});
break;
case Kinded::Kind::ReshapeNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy});
break;
case Kinded::Kind::CmpLTENodeKind:
case Kinded::Kind::CmpLTNodeKind:
case Kinded::Kind::CmpEQNodeKind:
case Kinded::Kind::CmpNEQNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy},
{}, {CmpEQNode::ResultIdx}) &&
(NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::BoolTy);
break;
case Kinded::Kind::NonZeroNodeKind:
isNodePrecisionSupported =
(NI.getOutElemTy(CmpEQNode::ResultIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::SelectNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{SelectNode::CondIdx}) &&
(NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy);
break;
case Kinded::Kind::GaussianFillNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy},
{}, {GaussianFillNode::ResultIdx}) &&
(NI.getOutElemTy(GaussianFillNode::ResultIdx)) == ElemKind::Float16Ty;
break;
case Kinded::Kind::RowwiseQuantizedFullyConnectedNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::InputIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::WeightsIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::ScalesIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::OffsetsIdx) ==
ElemKind::Int32ITy) &&
((NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::Int32QTy) ||
(NI.getInElemTy(RowwiseQuantizedFullyConnectedNode::BiasIdx) ==
ElemKind::FloatTy)) &&
(NI.getOutElemTy(RowwiseQuantizedFullyConnectedNode::ResultIdx) ==
ElemKind::Int8QTy);
break;
case Kinded::Kind::ChannelwiseQuantizedConvolutionNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::InputIdx) ==
ElemKind::Int8QTy) &&
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterIdx) ==
ElemKind::Int8QTy) &&
((NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
ElemKind::Int32QTy) ||
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::BiasIdx) ==
ElemKind::FloatTy)) &&
(NI.getInElemTy(ChannelwiseQuantizedConvolutionNode::FilterScalesIdx) ==
ElemKind::FloatTy) &&
(NI.getInElemTy(
ChannelwiseQuantizedConvolutionNode::FilterOffsetsIdx) ==
ElemKind::Int32ITy) &&
(NI.getOutElemTy(ChannelwiseQuantizedConvolutionNode::ResultIdx) ==
ElemKind::Int8QTy);
break;
case Kinded::Kind::SparseLengthsSumNodeKind:
isNodePrecisionSupported =
isSLSIndicesValid(NI.getInTy(SparseLengthsSumNode::IndicesIdx)) &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{SparseLengthsSumNode::IndicesIdx,
SparseLengthsSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(SparseLengthsSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLengthsSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
break;
case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
isNodePrecisionSupported =
isSLSIndicesValid(
NI.getInTy(SparseLengthsWeightedSumNode::IndicesIdx)) &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{SparseLengthsWeightedSumNode::IndicesIdx,
SparseLengthsWeightedSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
break;
case Kinded::Kind::EmbeddingNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{EmbeddingNode::IndicesIdx}) &&
(NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingNode::IndicesIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::EmbeddingBagNodeKind:
isNodePrecisionSupported =
isSLSIndicesValid(NI.getInTy(EmbeddingBagNode::IndicesIdx)) &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{EmbeddingBagNode::IndicesIdx, EmbeddingBagNode::OffsetsIdx}) &&
(NI.getInElemTy(EmbeddingBagNode::IndicesIdx) == ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingBagNode::IndicesIdx) == ElemKind::Int32ITy) &&
(NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int64ITy ||
NI.getInElemTy(EmbeddingBagNode::OffsetsIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::EmbeddingBagByteRowwiseOffsetsNodeKind: {
auto dataK = NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::DataIdx);
auto offsetsK =
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::OffsetsIdx);
auto indicesK =
NI.getInElemTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx);
auto resultK =
NI.getOutElemTy(EmbeddingBagByteRowwiseOffsetsNode::ResultIdx);
isNodePrecisionSupported =
isSLSIndicesValid(
NI.getInTy(EmbeddingBagByteRowwiseOffsetsNode::IndicesIdx)) &&
(dataK == ElemKind::UInt8FusedQTy ||
dataK == ElemKind::UInt8FusedFP16QTy ||
dataK == ElemKind::UInt4FusedFP16QTy) &&
(resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
(offsetsK == ElemKind::Int64ITy || offsetsK == ElemKind::Int32ITy) &&
(indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy);
break;
}
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsSumNodeKind: {
auto dataK =
NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::DataIdx);
auto lengthsK =
NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::LengthsIdx);
auto indicesK =
NI.getInElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::IndicesIdx);
auto resultK =
NI.getOutElemTy(FusedRowwiseQuantizedSparseLengthsSumNode::ResultIdx);
isNodePrecisionSupported =
isSLSIndicesValid(NI.getInTy(
FusedRowwiseQuantizedSparseLengthsSumNode::IndicesIdx)) &&
(dataK == ElemKind::UInt8FusedQTy ||
dataK == ElemKind::UInt8FusedFP16QTy ||
dataK == ElemKind::UInt4FusedFP16QTy) &&
(resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
(indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy) &&
(lengthsK == ElemKind::Int32ITy);
break;
}
case Kinded::Kind::FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind: {
auto dataK = NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx);
auto weightsK = NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::WeightsIdx);
auto lengthsK = NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx);
auto indicesK = NI.getInElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx);
auto resultK = NI.getOutElemTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::ResultIdx);
isNodePrecisionSupported =
isSLSIndicesValid(NI.getInTy(
FusedRowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx)) &&
(dataK == ElemKind::UInt8FusedQTy ||
dataK == ElemKind::UInt8FusedFP16QTy ||
dataK == ElemKind::UInt4FusedFP16QTy) &&
(weightsK == ElemKind::FloatTy || weightsK == ElemKind::Float16Ty) &&
(resultK == ElemKind::FloatTy || resultK == ElemKind::Float16Ty) &&
(indicesK == ElemKind::Int64ITy || indicesK == ElemKind::Int32ITy) &&
(lengthsK == ElemKind::Int32ITy);
} break;
case Kinded::Kind::RowwiseQuantizedSparseLengthsWeightedSumNodeKind:
isNodePrecisionSupported =
isSLSIndicesValid(NI.getInTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx)) &&
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty},
{RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx,
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx,
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx}) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::DataIdx) ==
ElemKind::UInt8QTy) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int32ITy) &&
(NI.getInElemTy(
RowwiseQuantizedSparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
break;
case Kinded::Kind::ScatterDataNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::UInt8QTy},
{ScatterDataNode::IndicesIdx}) &&
(NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int32ITy ||
NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy);
break;
case Kinded::Kind::BucketizeNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::UInt8QTy}, {},
{BucketizeNode::ResultIdx}) &&
(NI.getOutElemTy(BucketizeNode::ResultIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::SoftMaxNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy},
{SoftMaxNode::SelectedIdx}) &&
(NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy ||
NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::LengthsRangeFillNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
break;
case Kinded::Kind::BatchOneHotNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy},
{BatchOneHotNode::LengthsIdx}) &&
(NI.getInElemTy(BatchOneHotNode::LengthsIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::NNPICustomDSPNodeKind:
case Kinded::Kind::NNPICustomIANodeKind:
isNodePrecisionSupported = true;
break;
case Kinded::Kind::SpaceToDepthNodeKind:
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Float16Ty, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy});
break;
case Kinded::Kind::ArgMaxNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::Int8QTy, ElemKind::Int32ITy,
ElemKind::Int64ITy, ElemKind::BoolTy},
{}, {ArgMaxNode::ResultIdx}) &&
(NI.getOutElemTy(ArgMaxNode::ResultIdx) == ElemKind::Int64ITy ||
NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::ArgMinNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::FloatTy, ElemKind::Int8QTy,
ElemKind::Int32ITy, ElemKind::Int64ITy, ElemKind::BoolTy},
{}, {ArgMinNode::ResultIdx}) &&
(NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int64ITy ||
NI.getOutElemTy(ArgMinNode::ResultIdx) == ElemKind::Int32ITy);
break;
case Kinded::Kind::LogitNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Float16Ty});
break;
case Kinded::Kind::CumSumNodeKind:
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 7
isNodePrecisionSupported = NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Int32ITy, ElemKind::Int8QTy, ElemKind::UInt8QTy});
#else
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int32ITy});
#endif // NNPI >= 1.7
break;
#if NNPI_MAJOR_VERSION >= 1 && NNPI_MINOR_VERSION >= 9
case Kinded::Kind::BatchSparseToDenseNodeKind:
isNodePrecisionSupported =
NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::Float16Ty, ElemKind::UInt8QTy, ElemKind::Int8QTy},
{BatchSparseToDenseNode::LengthsIdx,
BatchSparseToDenseNode::IndicesIdx}) &&
((NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(BatchSparseToDenseNode::LengthsIdx) ==
ElemKind::Int32ITy)) &&
((NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
ElemKind::Int64ITy ||
NI.getInElemTy(BatchSparseToDenseNode::IndicesIdx) ==
ElemKind::Int32ITy));
break;
case Kinded::Kind::FillExamplesWithIndicatorNodeKind:
isNodePrecisionSupported =
(NI.getInElemTy(FillExamplesWithIndicatorNode::DataIdx) ==
NI.getOutElemTy(FillExamplesWithIndicatorNode::ResultIdx)) &&
((NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
ElemKind::Int32ITy) ||
(NI.getInElemTy(FillExamplesWithIndicatorNode::IndicatorIdx) ==
ElemKind::Int64ITy));
break;
#endif // NNPI >= 1.9
default:
isNodeHasAnySupport = false;
isNodePrecisionSupported = false;
}
if (isNodePrecisionSupported) {
return NodeSupportLevels::PRECISION_SUPPORTED;
} else if (isNodeHasAnySupport) {
return NodeSupportLevels::SUPPORTED;
}
return NodeSupportLevels::NOT_SUPPORTED;
}