void runOnOperation()

in lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc [117:246]


  void runOnOperation() override {
    getOperation().walk([&](ReduceOp op) {
      SmallVector<int64_t, 4> dims_to_reduce;
      DenseSet<int64_t> dims_to_reduce_set;
      for (auto dim : op.dimensions().getValues<APInt>()) {
        dims_to_reduce.push_back(dim.getSExtValue());
        dims_to_reduce_set.insert(dims_to_reduce.back());
      }

      // empty reduction is just a no-op, thus no need to do codegen.
      if (dims_to_reduce.empty()) return;

      // suppose reduce input is a ranked tensor
      auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
      if (!ty) return signalPassFailure();
      int rank = ty.getRank();
      int ndims_to_reduce = dims_to_reduce.size();
      auto elem_ty = ty.getElementType();
      llvm::sort(dims_to_reduce);

      // skip case d) form since we don't support it.
      if ((dims_to_reduce.back() - dims_to_reduce[0]) !=
              (ndims_to_reduce - 1) ||
          (dims_to_reduce[0] != 0 && dims_to_reduce.back() != (rank - 1))) {
        return;
      }

      // rank 2 row/column reduction is already supported.
      if (rank == 2 && ndims_to_reduce == 1) {
        return;
      }

      SmallVector<int64_t, 4> dims_to_keep;
      for (int i = 0; i < rank; ++i) {
        if (!dims_to_reduce_set.count(i)) dims_to_keep.push_back(i);
      }

      OpBuilder b(op);
      auto loc = op.getLoc();
      // TODO(disc): uniformed shape_scalar_type with shape_derivation
      auto shape_scalar_type = b.getIntegerType(32);
      auto one = b.create<arith::ConstantIntOp>(loc, 1ll, shape_scalar_type);

      // funtion to get total elements in selected dimensions
      auto dim_prod = [&](ArrayRef<int64_t> dims) {
        Value nelems = one;
        for (int64_t v : dims) {
          Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
          nelems = b.create<arith::MulIOp>(
              loc, nelems,
              b.create<arith::IndexCastOp>(loc, dim_index, shape_scalar_type));
        }
        return nelems;
      };

      SmallVector<Value, 2> new_operand_dims;
      DenseIntElementsAttr attr;
      Value nelem_to_reduce = dim_prod(dims_to_reduce);
      Value nelem_to_keep = dim_prod(dims_to_keep);
      if (rank == ndims_to_reduce) {
        // case c) Reduce to scalar.
        // Currently we don't support reduce to scalar directly.
        // As a workaround, we convert the `reduce to scalar` to a rank 2
        // column reduction having following form:
        // Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
        // shape `[nelems, 1]`.
        // TODO(disc): this may have performance issue. Implements a reduce to
        // scalar schedule if necessary.
        new_operand_dims.push_back(nelem_to_reduce);
        new_operand_dims.push_back(nelem_to_keep);
        attr = DenseIntElementsAttr::get(
            RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
      } else if (dims_to_reduce[0] == 0) {
        // case a) column reduction
        new_operand_dims.push_back(nelem_to_reduce);
        new_operand_dims.push_back(nelem_to_keep);
        attr = DenseIntElementsAttr::get(
            RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
      } else {
        // case b) row reduction
        new_operand_dims.push_back(nelem_to_keep);
        new_operand_dims.push_back(nelem_to_reduce);
        attr = DenseIntElementsAttr::get(
            RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
      }

      Value new_operand_shape =
          b.create<tensor::FromElementsOp>(loc, new_operand_dims);

      SmallVector<Value, 4> new_operands;
      for (Value operand : op.inputs()) {
        new_operands.push_back(b.create<DynamicReshapeOp>(
            loc,
            RankedTensorType::get(
                SmallVector<int64_t, 4>(new_operand_dims.size(),
                                        ShapedType::kDynamicSize),
                elem_ty),
            operand, new_operand_shape));
      }
      auto new_op =
          b.create<ReduceOp>(loc, new_operands, op.init_values(), attr);
      new_op.body().takeBody(op.body());

      SmallVector<Value, 4> new_results;
      if (dims_to_keep.empty()) {
        // case c) reduce to scalar
        // reshape rank 1 tensor with size 1 to a rank 0 tensor
        for (Value result : new_op.getResults()) {
          new_results.push_back(b.create<ReshapeOp>(
              loc, RankedTensorType::get({}, elem_ty), result));
        }
      } else {
        SmallVector<Value, 4> result_dims;
        for (int64_t i : dims_to_keep) {
          Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
          result_dims.push_back(
              b.create<arith::IndexCastOp>(loc, dim_index, shape_scalar_type));
        }
        Value result_shape = b.create<tensor::FromElementsOp>(loc, result_dims);
        for (auto&& e : llvm::zip(op.getResults(), new_op.getResults())) {
          new_results.push_back(b.create<DynamicReshapeOp>(
              loc, std::get<0>(e).getType(), std::get<1>(e), result_shape));
        }
      }
      for (auto&& e : llvm::zip(op.getResults(), new_results)) {
        std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
      }
      op.erase();
    });
  }