LogicalResult matchAndRewrite()

in lib/Dialect/mhlo/transforms/rank_specialization.cc [172:276]


  LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
                                PatternRewriter &rewriter) const override {
    auto preceding_op =
        llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
            op->getPrevNode());
    if (!preceding_op) return failure();
    Block *body = op.getBody();
    Block *preceding_body = preceding_op.getBody();
    auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
        op.getBody()->getTerminator());
    auto preceding_yield_op =
        llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
            preceding_op.getBody()->getTerminator());

    // Merge cluster operands. Consider only those operands of the second
    // cluster that do not originate in the preceding cluster.
    SmallVector<Value, 8> new_operands;
    for (Value v : preceding_op.operands()) new_operands.push_back(v);
    for (Value v : op.operands()) {
      if (v.getDefiningOp() != preceding_op &&
          !llvm::is_contained(preceding_op.operands(), v)) {
        new_operands.push_back(v);
      }
    }

    // Merge cluster results. Consider only those results of the preceding
    // cluster that are not exclusively used as operands to the second cluster.
    SmallVector<Value, 8> new_unmapped_results;
    for (auto it :
         llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
      Value result, inner_result;
      std::tie(result, inner_result) = it;
      if (!llvm::all_of(result.getUsers(),
                        [&](Operation *user) { return user == op; })) {
        new_unmapped_results.push_back(inner_result);
      }
    }
    for (Value v : yield_op.results()) new_unmapped_results.push_back(v);

    // Create merged cluster op.
    rewriter.setInsertionPoint(preceding_op);
    auto loc = op.getLoc();
    auto result_types = llvm::to_vector<16>(llvm::map_range(
        new_unmapped_results, [](Value v) { return v.getType(); }));
    auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
        loc, result_types, new_operands);
    auto operand_types = llvm::to_vector<16>(
        llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
    Block *new_body =
        rewriter.createBlock(&new_op.body(), {}, operand_types,
                             SmallVector<Location>(operand_types.size(), loc));
    rewriter.setInsertionPointToStart(new_body);

    // Map operands and copy operations of the preceding cluster into the new
    // body.
    BlockAndValueMapping bvm;
    for (const auto &it : llvm::enumerate(preceding_body->getArguments()))
      bvm.map(it.value(), new_body->getArgument(it.index()));
    for (Operation &nested_op : preceding_body->without_terminator())
      rewriter.clone(nested_op, bvm);

    // Map operands and copy operations of the second cluster. If they result
    // from the preceeding cluster, we can simply map the corresponding value
    // internally.
    for (auto it : llvm::zip(body->getArguments(), op.operands())) {
      Value block_arg, operand;
      std::tie(block_arg, operand) = it;
      if (operand.getDefiningOp() == preceding_op) {
        auto where = llvm::find(preceding_op.results(), operand);
        assert(where.getBase() != nullptr && "expected to find ");
        bvm.map(block_arg,
                bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
      } else {
        auto where = llvm::find(new_op.operands(), operand);
        bvm.map(block_arg, new_body->getArgument(where.getIndex()));
      }
    }
    for (Operation &nested_op : body->without_terminator()) {
      rewriter.clone(nested_op, bvm);
    }

    // Yield inner results.
    rewriter.create<chlo::RankSpecializationClusterYieldOp>(
        loc,
        llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
          return bvm.lookupOrDefault(v);
        })));

    // Replace the two cluster ops with the new corresponding results.
    SmallVector<Value, 8> preceding_op_replacements;
    int64_t i = 0;
    for (Value result : preceding_op.results()) {
      Value replacement = nullptr;
      if (!llvm::all_of(result.getUsers(),
                        [&](Operation *user) { return user == op; })) {
        replacement = new_op->getResult(i++);
      }
      preceding_op_replacements.push_back(replacement);
    }
    ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
    rewriter.replaceOp(op, op_replacements);
    rewriter.replaceOp(preceding_op, preceding_op_replacements);

    return success();
  }