in lib/Backends/OpenCL/OpenCL.cpp [1785:1942]
bool OCLBackend::isOpSupported(const NodeInfo &NI) const {
switch (NI.getKind()) {
case Kinded::Kind::SplatNodeKind:
case Kinded::Kind::TouchNodeKind:
case Kinded::Kind::TransposeNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy, ElemKind::Int64ITy});
case Kinded::Kind::PowNodeKind:
case Kinded::Kind::LocalResponseNormalizationNodeKind:
case Kinded::Kind::LocalResponseNormalizationGradNodeKind:
case Kinded::Kind::BatchedReduceAddNodeKind:
case Kinded::Kind::TanhNodeKind:
case Kinded::Kind::SigmoidNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
case Kinded::Kind::AddNodeKind:
case Kinded::Kind::SubNodeKind:
case Kinded::Kind::MulNodeKind:
case Kinded::Kind::DivNodeKind:
case Kinded::Kind::MaxNodeKind:
case Kinded::Kind::MinNodeKind:
case Kinded::Kind::MatMulNodeKind:
case Kinded::Kind::ConcatNodeKind:
case Kinded::Kind::SliceNodeKind:
case Kinded::Kind::InsertTensorNodeKind:
case Kinded::Kind::AvgPoolNodeKind:
case Kinded::Kind::ReluNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy});
case Kinded::Kind::MaxPoolNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy}, {},
{MaxPoolNode::ArgmaxIdx}) &&
(NI.getOutElemTy(MaxPoolNode::ArgmaxIdx) == ElemKind::Int64ITy);
case Kinded::Kind::ConvolutionNodeKind:
if (!NI.getInTy(ConvolutionNode::InputIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
}
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
{ConvolutionNode::BiasIdx}) &&
(NI.getInElemTy(ConvolutionNode::BiasIdx) == ElemKind::Int32QTy);
case Kinded::Kind::TopKNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy, ElemKind::Int8QTy}, {},
{TopKNode::IndicesIdx}) &&
(NI.getOutElemTy(TopKNode::IndicesIdx) == ElemKind::Int64ITy);
case Kinded::Kind::BatchedAddNodeKind:
if (!NI.getInTy(BatchedAddNode::BatchIdx)->isQuantizedType()) {
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy});
}
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy},
{BatchedAddNode::SliceIdx}) &&
((NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int8QTy) ||
(NI.getInElemTy(BatchedAddNode::SliceIdx) == ElemKind::Int32QTy));
case Kinded::Kind::GatherNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy},
{GatherNode::IndicesIdx}) &&
(NI.getInElemTy(GatherNode::IndicesIdx) == ElemKind::Int64ITy);
case Kinded::Kind::ScatterDataNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {ScatterDataNode::IndicesIdx}) &&
(NI.getInElemTy(ScatterDataNode::IndicesIdx) == ElemKind::Int64ITy);
case Kinded::Kind::SparseLengthsWeightedSumNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
{SparseLengthsWeightedSumNode::IndicesIdx,
SparseLengthsWeightedSumNode::LengthsIdx}) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLengthsWeightedSumNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::MaxPoolGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
{MaxPoolGradNode::OriginalOutputForArgmaxIdx,
MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx}) &&
(NI.getInElemTy(MaxPoolGradNode::OriginalOutputForArgmaxIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(
MaxPoolGradNode::GradOfOriginalOutputNamedArgmaxIdx) ==
ElemKind::Int64ITy);
case Kinded::Kind::SparseLengthsWeightedSumGradNodeKind:
// GradOfInputNamedIndicesIdx and GradOfInputNamedLengthsIdx do not need to
// be checked because they are not used.
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy},
{SparseLengthsWeightedSumGradNode::IndicesIdx,
SparseLengthsWeightedSumGradNode::LengthsIdx},
{SparseLengthsWeightedSumGradNode::GradOfInputNamedIndicesIdx,
SparseLengthsWeightedSumGradNode::
GradOfInputNamedLengthsIdx}) &&
(NI.getInElemTy(SparseLengthsWeightedSumGradNode::IndicesIdx) ==
ElemKind::Int64ITy) &&
(NI.getInElemTy(SparseLengthsWeightedSumGradNode::LengthsIdx) ==
ElemKind::Int32ITy);
case Kinded::Kind::QuantizeNodeKind:
return (NI.getInElemTy(QuantizeNode::InputIdx) == ElemKind::FloatTy) &&
((NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int8QTy) ||
(NI.getOutElemTy(QuantizeNode::ResultIdx) == ElemKind::Int32QTy));
case Kinded::Kind::DequantizeNodeKind:
return (NI.getInElemTy(DequantizeNode::InputIdx) == ElemKind::Int8QTy) &&
(NI.getOutElemTy(DequantizeNode::ResultIdx) == ElemKind::FloatTy);
case Kinded::Kind::RescaleQuantizedNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::Int8QTy});
case Kinded::Kind::SelectNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy},
{SelectNode::CondIdx}) &&
(NI.getInElemTy(SelectNode::CondIdx) == ElemKind::BoolTy);
case Kinded::Kind::CmpLTENodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy}, {},
{CmpLTENode::ResultIdx}) &&
(NI.getOutElemTy(CmpLTENode::ResultIdx) == ElemKind::BoolTy);
// We just clip 64 to 32 SelectedIdx silently with the SoftMax
// SelectedIdx in case dim_t is 32b.
case Kinded::Kind::SoftMaxNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind({ElemKind::FloatTy},
{SoftMaxNode::SelectedIdx}) &&
(NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int32ITy ||
NI.getInElemTy(SoftMaxNode::SelectedIdx) == ElemKind::Int64ITy);
case Kinded::Kind::ConvolutionGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {},
{ConvolutionGradNode::GradOfInputNamedInputIdx});
case Kinded::Kind::SoftMaxGradNodeKind:
return NI.allInputsAndOutputsHaveSameElemKind(
{ElemKind::FloatTy}, {SoftMaxGradNode::SelectedIdx},
{SoftMaxGradNode::GradOfInputNamedSelectedIdx}) &&
(NI.getInElemTy(SoftMaxGradNode::SelectedIdx) == ElemKind::Int64ITy);
case Kinded::Kind::SaveNodeKind:
case Kinded::Kind::ReshapeNodeKind:
case Kinded::Kind::OCLBatchedReduceAddNodeKind:
case Kinded::Kind::TraceEventNodeKind:
// These work regardless of the underlying type.
return true;
default:
return false;
}
}