static LogicalResult verifyReducerShape()

in lib/Dialect/mhlo/IR/hlo_ops.cc [3339:3502]


static LogicalResult verifyReducerShape(
    ReduceOp op, Block& block, ArrayRef<TensorType> inputArgTypes,
    ArrayRef<TensorType> initValueTypes, int64_t numInputs,
    ArrayRef<int64_t> outputShape, bool allInputsUnranked,
    SmallVectorImpl<TensorType>& accumulatorSubShapes) {
  // Check that the number of reduction-region arguments matches with that of
  // reduce-op's arguments.
  if (block.getArguments().size() != numInputs * 2)
    return op.emitError() << "Reduction-region must take " << numInputs * 2
                          << " parameters, but takes "
                          << block.getArguments().size() << " parameter(s)";

  // Check if the reduction-region produces non-zero outputs.
  if (block.getTerminator()->getOperands().empty())
    return op.emitError()
           << "The reduction-region expected to return some value(s)";

  // Check that the reduction-region returns a tuple- OR list- of tensors.
  // The number of result-tensors must match the `numInputs`.
  // TODO(b/171261845): Remove tuples from MHLO dialect.
  auto tupleT =
      block.getTerminator()->getOperand(0).getType().dyn_cast<TupleType>();
  if (tupleT && block.getTerminator()->getOperands().size() == 1) {
    if (tupleT.size() != numInputs)
      return op.emitError()
             << "Reduction-region here must produce a tuple with " << numInputs
             << " tensors, but produces " << tupleT.size() << " instead";

    for (Type elementType : tupleT.getTypes()) {
      auto tensorTy = elementType.dyn_cast<TensorType>();
      if (!tensorTy)
        return op.emitError() << "Reduction-region here must produce tuple "
                                 "of tensor-typed results, but "
                                 "produces "
                              << elementType << " instead";

      accumulatorSubShapes.push_back(tensorTy);
    }
  } else {
    if (block.getTerminator()->getOperands().size() != numInputs)
      return op.emitError()
             << "Reduction-region here must produce " << numInputs
             << " tensors, but produces "
             << block.getTerminator()->getOperands().size() << " instead";

    for (Value retOperand : block.getTerminator()->getOperands()) {
      auto tensorTy = retOperand.getType().dyn_cast<TensorType>();
      if (!tensorTy)
        return op.emitError() << "Reduction-region here must produce "
                                 "tensor-typed result(s), but "
                                 "produces "
                              << retOperand.getType() << " instead";

      accumulatorSubShapes.push_back(tensorTy);
    }
  }

  // Consider typical reduce-op syntax:
  //
  //      reduce(I(i), V(j)):
  //       block(BI(i), BV(j)):
  //         ... some computation ...
  //         return(R(i))
  //
  // where
  //  I(i)  : i-th input of reduce-op
  //  V(j)  : j-th init-value of reduce-op
  //  BI(i) : i-th input of reducer-function
  //  BV(j) : j-th init-value of reducer-function
  //  R(i)  : i-th return-type
  //
  //  Note that: |I(i)| == V(j)| == |BI(i)| == |BV(j)| == |R(i)|
  //
  //  Here are the type-constraints among V(j), BI(i), BV(j), and R(i).
  //    C1 : Check that BI(i) and R(i) have same shape and element-type.
  //    C2 : Check that BV(j) and R(i) have same shape and element-type.
  //    C3 : Check that V(j) and R(i) have same shape and element-type.
  //
  //  From C1, C2, and C3, we can infer that V(j), BI(i), BV(j), and R(i) all
  //  have compatible shapes and element-types.
  //  The next check, C4, adds constraints on how the type if I(i) is related
  //  to any_of(V(j), BI(i), BV(j), and R(i)), say BV(j);
  //
  //  C4.1 : Check that I(i) and BV(j) have same element-type.
  //  C4.2 : Check that shape of BV(j) is a 'sub-sequence' of the shape of
  //         output-array. The shape of output-array is determined from that
  //         of I(i) after removing the "dimensions-to-reduce" (as specified by
  //         the dimensions attribute of reduce-op).
  for (int64_t inputIdx = 0; inputIdx < numInputs; ++inputIdx) {
    // Check C1.
    if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
                                       block.getArgument(inputIdx).getType()))
      return op.emitError()
             << "The type of reduction-region's parameter at index " << inputIdx
             << " is different than the corresponding result type: "
             << block.getArgument(inputIdx).getType() << " vs "
             << accumulatorSubShapes[inputIdx];

    // Check C2.
    if (!compatibleShapeAndElementType(
            accumulatorSubShapes[inputIdx],
            block.getArgument(numInputs + inputIdx).getType(),
            /*ignoreFpPrecision=*/true))
      return op.emitError()
             << "The type of reduction-region's parameter at index "
             << numInputs + inputIdx
             << " is different than the corresponding result type: "
             << block.getArgument(numInputs + inputIdx).getType() << " vs "
             << accumulatorSubShapes[inputIdx];

    // Check C3.
    if (!compatibleShapeAndElementType(accumulatorSubShapes[inputIdx],
                                       initValueTypes[inputIdx],
                                       /*ignoreFpPrecision=*/true))
      return op.emitError()
             << "The type of reduction-region's result type at index "
             << inputIdx
             << " differs from the reduce-op's corresponding init-value type: "
             << accumulatorSubShapes[inputIdx] << " vs "
             << initValueTypes[inputIdx];

    // Check C4.1.
    if (!tensorsHaveSameElType(
            inputArgTypes[inputIdx],
            block.getArgument(numInputs + inputIdx).getType(), true))
      return op.emitError()
             << "The element-type of reduce-op's input-parameter at index "
             << inputIdx
             << " differs from that of reduction-region's argument at index "
             << numInputs + inputIdx << ": " << inputArgTypes[inputIdx]
             << " vs " << block.getArgument(numInputs + inputIdx).getType();

    // Check C4.2.
    Type blockArgType = block.getArgument(numInputs + inputIdx).getType();
    auto blockArgTensorTy = blockArgType.cast<TensorType>();

    if (allInputsUnranked || !blockArgTensorTy.hasRank()) return success();

    auto argShape = blockArgTensorTy.getShape();
    if (argShape.size() > outputShape.size())
      return op.emitError()
             << "The rank of reduction-region's argument at index "
             << numInputs + inputIdx
             << " is not compatible with that of reduce-op's result: "
             << argShape.size() << " vs " << outputShape.size()
             << " (expected)";

    int64_t argShapeIdx = 0;
    for (int64_t outputShapeIdx = 0;
         outputShapeIdx < outputShape.size() && argShapeIdx < argShape.size();
         outputShapeIdx++)
      if (outputShape[outputShapeIdx] == argShape[argShapeIdx]) argShapeIdx++;

    if (argShapeIdx != argShape.size())
      return op.emitError()
             << "The shape of reduction-region's argument at index "
             << numInputs + inputIdx
             << " is not compatible with that of reduce-op's input-parameter "
                "at index "
             << inputIdx;
  }

  return success();
}