in lib/Dialect/mhlo/transforms/mhlo_canonicalize_reduction.cc [117:246]
void runOnOperation() override {
getOperation().walk([&](ReduceOp op) {
SmallVector<int64_t, 4> dims_to_reduce;
DenseSet<int64_t> dims_to_reduce_set;
for (auto dim : op.dimensions().getValues<APInt>()) {
dims_to_reduce.push_back(dim.getSExtValue());
dims_to_reduce_set.insert(dims_to_reduce.back());
}
// empty reduction is just a no-op, thus no need to do codegen.
if (dims_to_reduce.empty()) return;
// suppose reduce input is a ranked tensor
auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!ty) return signalPassFailure();
int rank = ty.getRank();
int ndims_to_reduce = dims_to_reduce.size();
auto elem_ty = ty.getElementType();
llvm::sort(dims_to_reduce);
// skip case d) form since we don't support it.
if ((dims_to_reduce.back() - dims_to_reduce[0]) !=
(ndims_to_reduce - 1) ||
(dims_to_reduce[0] != 0 && dims_to_reduce.back() != (rank - 1))) {
return;
}
// rank 2 row/column reduction is already supported.
if (rank == 2 && ndims_to_reduce == 1) {
return;
}
SmallVector<int64_t, 4> dims_to_keep;
for (int i = 0; i < rank; ++i) {
if (!dims_to_reduce_set.count(i)) dims_to_keep.push_back(i);
}
OpBuilder b(op);
auto loc = op.getLoc();
// TODO(disc): uniformed shape_scalar_type with shape_derivation
auto shape_scalar_type = b.getIntegerType(32);
auto one = b.create<arith::ConstantIntOp>(loc, 1ll, shape_scalar_type);
// funtion to get total elements in selected dimensions
auto dim_prod = [&](ArrayRef<int64_t> dims) {
Value nelems = one;
for (int64_t v : dims) {
Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
nelems = b.create<arith::MulIOp>(
loc, nelems,
b.create<arith::IndexCastOp>(loc, dim_index, shape_scalar_type));
}
return nelems;
};
SmallVector<Value, 2> new_operand_dims;
DenseIntElementsAttr attr;
Value nelem_to_reduce = dim_prod(dims_to_reduce);
Value nelem_to_keep = dim_prod(dims_to_keep);
if (rank == ndims_to_reduce) {
// case c) Reduce to scalar.
// Currently we don't support reduce to scalar directly.
// As a workaround, we convert the `reduce to scalar` to a rank 2
// column reduction having following form:
// Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
// shape `[nelems, 1]`.
// TODO(disc): this may have performance issue. Implements a reduce to
// scalar schedule if necessary.
new_operand_dims.push_back(nelem_to_reduce);
new_operand_dims.push_back(nelem_to_keep);
attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
} else if (dims_to_reduce[0] == 0) {
// case a) column reduction
new_operand_dims.push_back(nelem_to_reduce);
new_operand_dims.push_back(nelem_to_keep);
attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
} else {
// case b) row reduction
new_operand_dims.push_back(nelem_to_keep);
new_operand_dims.push_back(nelem_to_reduce);
attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
}
Value new_operand_shape =
b.create<tensor::FromElementsOp>(loc, new_operand_dims);
SmallVector<Value, 4> new_operands;
for (Value operand : op.inputs()) {
new_operands.push_back(b.create<DynamicReshapeOp>(
loc,
RankedTensorType::get(
SmallVector<int64_t, 4>(new_operand_dims.size(),
ShapedType::kDynamicSize),
elem_ty),
operand, new_operand_shape));
}
auto new_op =
b.create<ReduceOp>(loc, new_operands, op.init_values(), attr);
new_op.body().takeBody(op.body());
SmallVector<Value, 4> new_results;
if (dims_to_keep.empty()) {
// case c) reduce to scalar
// reshape rank 1 tensor with size 1 to a rank 0 tensor
for (Value result : new_op.getResults()) {
new_results.push_back(b.create<ReshapeOp>(
loc, RankedTensorType::get({}, elem_ty), result));
}
} else {
SmallVector<Value, 4> result_dims;
for (int64_t i : dims_to_keep) {
Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
result_dims.push_back(
b.create<arith::IndexCastOp>(loc, dim_index, shape_scalar_type));
}
Value result_shape = b.create<tensor::FromElementsOp>(loc, result_dims);
for (auto&& e : llvm::zip(op.getResults(), new_op.getResults())) {
new_results.push_back(b.create<DynamicReshapeOp>(
loc, std::get<0>(e).getType(), std::get<1>(e), result_shape));
}
}
for (auto&& e : llvm::zip(op.getResults(), new_results)) {
std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
}
op.erase();
});
}