LogicalResult TryLowerTo1DOr2DReduction()

in lib/Dialect/mhlo/transforms/group_reduction_dimensions.cc [130:255]


LogicalResult TryLowerTo1DOr2DReduction(
    ReduceOp op, RankedTensorType arg_ty, Value arg,
    SmallVector<int64_t>& ordered_reduction_dims,
    bool prefer_columns_reductions, PatternRewriter& rewriter) {
  // Group the argument dimensions by their kind.
  SmallVector<DimensionGroup> dim_groups;
  GroupDimensions(arg_ty, ordered_reduction_dims, dim_groups);

  // Do not (re-)apply if the dimensions are already fully collapsed.
  if (dim_groups.size() <= 2 &&
      llvm::all_of(dim_groups, [](auto g) { return g.size() == 1; })) {
    return failure();
  }

  // Determine whether or not a dynamic reshape is needed for the final result.
  int64_t num_dyn_parallel_dims = 0;
  for (auto group : dim_groups) {
    if (group.kind != DimensionKind::kParallel) continue;
    for (int64_t i = group.begin; i < group.end; i++) {
      if (arg_ty.isDynamicDim(i)) num_dyn_parallel_dims++;
    }
  }
  bool requires_dynamic_reshape = num_dyn_parallel_dims > 1;

  // Reify the result shape early so that the pattern can fail without altering
  // the IR.
  Optional<Value> result_shape;
  if (requires_dynamic_reshape) {
    llvm::SmallVector<Value, 1> reified_shapes;
    if (failed(llvm::cast<InferShapedTypeOpInterface>(op.getOperation())
                   .reifyReturnTypeShapes(rewriter, op->getOperands(),
                                          reified_shapes))) {
      return failure();
    }
    assert(reified_shapes.size() == 1 && "expect exactly one shape");
    result_shape = reified_shapes.front();
  }

  // Collapse dimension groups so that all adjacent dimensions of the
  // intermediate result are of a different kind.
  Value interm_result = arg;
  auto loc = op.getLoc();
  bool requires_collapse =
      llvm::any_of(dim_groups, [&](auto g) { return g.size() > 1; });
  if (requires_collapse) {
    auto reassociation =
        llvm::to_vector(llvm::map_range(dim_groups, [&](auto g) {
          return llvm::to_vector<2>(llvm::seq<int64_t>(g.begin, g.end));
        }));
    interm_result = rewriter.create<tensor::CollapseShapeOp>(loc, interm_result,
                                                             reassociation);
  }

  // If required, transpose the intermediate result so that dimensions kinds
  // form two partitions, which can be collapsed to a 2D intermediate result.
  bool requires_transpose = dim_groups.size() > 2;
  if (requires_transpose) {
    // Materialize transpose.
    DimensionKind leading_dim_kind = prefer_columns_reductions
                                         ? DimensionKind::kReduction
                                         : DimensionKind::kParallel;
    DimensionKind trailing_dim_kind = prefer_columns_reductions
                                          ? DimensionKind::kParallel
                                          : DimensionKind::kReduction;
    SmallVector<int64_t> perm;
    for (int i = 0; i < dim_groups.size(); i++) {
      if (dim_groups[i].kind == leading_dim_kind) perm.push_back(i);
    }
    int64_t num_leading_dims = perm.size();
    for (int i = 0; i < dim_groups.size(); i++) {
      if (dim_groups[i].kind == trailing_dim_kind) perm.push_back(i);
    }
    auto perm_attr = hlo::GetI64ElementsAttr(perm, &rewriter);
    interm_result = rewriter.create<TransposeOp>(loc, interm_result, perm_attr)
                        ->getResults()
                        .front();

    // Collapse intermediate result rank 2.
    SmallVector<ReassociationIndices, 2> reassociation = {
        llvm::to_vector<2>(llvm::seq<int64_t>(0, num_leading_dims)),
        llvm::to_vector<2>(llvm::seq<int64_t>(num_leading_dims, perm.size()))};
    interm_result = rewriter.create<tensor::CollapseShapeOp>(loc, interm_result,
                                                             reassociation);
  }

  // Materialize inner 1D or 2D reduction.
  bool leading_reduction =
      requires_transpose ? prefer_columns_reductions
                         : dim_groups.front().kind == DimensionKind::kReduction;
  int64_t reduction_dim = leading_reduction ? 0 : 1;
  auto reduction_dim_attr = hlo::GetI64ElementsAttr({reduction_dim}, &rewriter);
  Value init_val = op.init_values().front();
  auto reduction_op = rewriter.create<ReduceOp>(loc, interm_result, init_val,
                                                reduction_dim_attr);
  rewriter.inlineRegionBefore(op.body(), reduction_op.body(),
                              reduction_op.body().begin());
  interm_result = reduction_op->getResults().front();

  // Restore the expected shape by dynamic reshape, if required.
  auto result_ty = op->getResultTypes().front().cast<RankedTensorType>();
  if (requires_dynamic_reshape) {
    assert(result_shape && "expect to have reified the result shape");
    interm_result = rewriter.create<DynamicReshapeOp>(
        loc, result_ty, interm_result, *result_shape);
  }

  // Othwerise, restore the expected shape by shape expansion, if required.
  int64_t result_rank = result_ty.getRank();
  int64_t interm_result_rank =
      interm_result.getType().cast<RankedTensorType>().getRank();
  bool requires_expand =
      !requires_dynamic_reshape && result_rank != interm_result_rank;
  if (requires_expand) {
    assert(interm_result_rank <= 1 &&
           "expect intermediate result to be of rank 0 or 1 before expansion");
    SmallVector<ReassociationIndices, 1> reassociation;
    bool is_scalar_expansion = interm_result_rank == 0;
    if (!is_scalar_expansion)
      reassociation = {llvm::to_vector(llvm::seq<int64_t>(0, result_rank))};
    interm_result = rewriter.create<tensor::ExpandShapeOp>(
        loc, result_ty, interm_result, reassociation);
  }

  rewriter.replaceOp(op, interm_result);
  return success();
}