in lib/Dialect/mhlo/IR/hlo_ops.cc [3339:3502]
static LogicalResult verifyReducerShape(
ReduceOp op, Block& block, ArrayRef<TensorType> inputArgTypes,
ArrayRef<TensorType> initValueTypes, int64_t numInputs,
ArrayRef<int64_t> outputShape, bool allInputsUnranked,
SmallVectorImpl<TensorType>& accumulatorSubShapes) {
// Check that the number of reduction-region arguments matches with that of
// reduce-op's arguments.
if (block.getArguments().size() != numInputs * 2)
return op.emitError() << "Reduction-region must take " << numInputs * 2
<< " parameters, but takes "
<< block.getArguments().size() << " parameter(s)";
// Check if the reduction-region produces non-zero outputs.
if (block.getTerminator()->getOperands().empty())
return op.emitError()
<< "The reduction-region expected to return some value(s)";
// Check that the reduction-region returns a tuple- OR list- of tensors.
// The number of result-tensors must match the `numInputs`.
// TODO(b/171261845): Remove tuples from MHLO dialect.
auto tupleT =
block.getTerminator()->getOperand(0).getType().dyn_cast<TupleType>();
if (tupleT && block.getTerminator()->getOperands().size() == 1) {
if (tupleT.size() != numInputs)
return op.emitError()
<< "Reduction-region here must produce a tuple with " << numInputs
<< " tensors, but produces " << tupleT.size() << " instead";
for (Type elementType : tupleT.getTypes()) {
auto tensorTy = elementType.dyn_cast<TensorType>();
if (!tensorTy)
return op.emitError() << "Reduction-region here must produce tuple "
"of tensor-typed results, but "
"produces "
<< elementType << " instead";
accumulatorSubShapes.push_back(tensorTy);
}
} else {
if (block.getTerminator()->getOperands().size() != numInputs)
return op.emitError()
<< "Reduction-region here must produce " << numInputs
<< " tensors, but produces "
<< block.getTerminator()->getOperands().size() << " instead";
for (Value retOperand : block.getTerminator()->getOperands()) {
auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
if (!tensorTy)
return op.emitError() << "Reduction-region here must produce "
"tensor-typed result(s), but "
"produces "
<< retOperand.getType() << " instead";
accumulatorSubShapes.push_back(tensorTy);
}
}
// Consider typical reduce-op syntax:
//
// reduce(I(i), V(j)):
// block(BI(i), BV(j)):
// ... some computation ...
// return(R(i))
//
// where
// I(i) : i-th input of reduce-op
// V(j) : j-th init-value of reduce-op
// BI(i) : i-th input of reducer-function
// BV(j) : j-th init-value of reducer-function
// R(i) : i-th return-type
//
// Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
//
// Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
// C1 : Check that BI(i) and R(i) have same shape and element-type.
// C2 : Check that BV(j) and R(i) have same shape and element-type.
// C3 : Check that V(j) and R(i) have same shape and element-type.
//
// From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
// have compatible shapes and element-types.
// The next check, C4, adds constraints on how the type if I(i) is related
// to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
//
// C4.1 : Check that I(i) and BV(j) have same element-type.
// C4.2 : Check that shape of BV(j) is a 'sub-sequence' of the shape of
// output-array. The shape of output-array is determined from that
// of I(i) after removing the "dimensions-to-reduce" (as specified by
// the dimensions attribute of reduce-op).
for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
// Check C1.
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
block.getArgument(inputIdx).getType()))
return op.emitError()
<< "The type of reduction-region's parameter at index " << inputIdx
<< " is different than the corresponding result type: "
<< block.getArgument(inputIdx).getType() << " vs "
<< accumulatorSubShapes[inputIdx];
// Check C2.
if (!compatibleShapeAndElementType(
accumulatorSubShapes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(),
/*ignoreFpPrecision=*/true))
return op.emitError()
<< "The type of reduction-region's parameter at index "
<< numInputs + inputIdx
<< " is different than the corresponding result type: "
<< block.getArgument(numInputs + inputIdx).getType() << " vs "
<< accumulatorSubShapes[inputIdx];
// Check C3.
if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
initValueTypes[inputIdx],
/*ignoreFpPrecision=*/true))
return op.emitError()
<< "The type of reduction-region's result type at index "
<< inputIdx
<< " differs from the reduce-op's corresponding init-value type: "
<< accumulatorSubShapes[inputIdx] << " vs "
<< initValueTypes[inputIdx];
// Check C4.1.
if (!tensorsHaveSameElType(
inputArgTypes[inputIdx],
block.getArgument(numInputs + inputIdx).getType(), true))
return op.emitError()
<< "The element-type of reduce-op's input-parameter at index "
<< inputIdx
<< " differs from that of reduction-region's argument at index "
<< numInputs + inputIdx << ": " << inputArgTypes[inputIdx]
<< " vs " << block.getArgument(numInputs + inputIdx).getType();
// Check C4.2.
Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
auto blockArgTensorTy = blockArgType.cast<TensorType>();
if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();
auto argShape = blockArgTensorTy.getShape();
if (argShape.size() > outputShape.size())
return op.emitError()
<< "The rank of reduction-region's argument at index "
<< numInputs + inputIdx
<< " is not compatible with that of reduce-op's result: "
<< argShape.size() << " vs " << outputShape.size()
<< " (expected)";
int64_t argShapeIdx = 0;
for (int64_t outputShapeIdx = 0;
outputShapeIdx < outputShape.size() && argShapeIdx < argShape.size();
outputShapeIdx++)
if (outputShape[outputShapeIdx] == argShape[argShapeIdx]) argShapeIdx++;
if (argShapeIdx != argShape.size())
return op.emitError()
<< "The shape of reduction-region's argument at index "
<< numInputs + inputIdx
<< " is not compatible with that of reduce-op's input-parameter "
"at index "
<< inputIdx;
}
return success();
}