in lib/Dialect/mhlo/transforms/rank_specialization.cc [172:276]
LogicalResult matchAndRewrite(chlo::RankSpecializationClusterOp op,
PatternRewriter &rewriter) const override {
auto preceding_op =
llvm::dyn_cast_or_null<chlo::RankSpecializationClusterOp>(
op->getPrevNode());
if (!preceding_op) return failure();
Block *body = op.getBody();
Block *preceding_body = preceding_op.getBody();
auto yield_op = llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
op.getBody()->getTerminator());
auto preceding_yield_op =
llvm::dyn_cast<chlo::RankSpecializationClusterYieldOp>(
preceding_op.getBody()->getTerminator());
// Merge cluster operands. Consider only those operands of the second
// cluster that do not originate in the preceding cluster.
SmallVector<Value, 8> new_operands;
for (Value v : preceding_op.operands()) new_operands.push_back(v);
for (Value v : op.operands()) {
if (v.getDefiningOp() != preceding_op &&
!llvm::is_contained(preceding_op.operands(), v)) {
new_operands.push_back(v);
}
}
// Merge cluster results. Consider only those results of the preceding
// cluster that are not exclusively used as operands to the second cluster.
SmallVector<Value, 8> new_unmapped_results;
for (auto it :
llvm::zip(preceding_op.results(), preceding_yield_op.results())) {
Value result, inner_result;
std::tie(result, inner_result) = it;
if (!llvm::all_of(result.getUsers(),
[&](Operation *user) { return user == op; })) {
new_unmapped_results.push_back(inner_result);
}
}
for (Value v : yield_op.results()) new_unmapped_results.push_back(v);
// Create merged cluster op.
rewriter.setInsertionPoint(preceding_op);
auto loc = op.getLoc();
auto result_types = llvm::to_vector<16>(llvm::map_range(
new_unmapped_results, [](Value v) { return v.getType(); }));
auto new_op = rewriter.create<chlo::RankSpecializationClusterOp>(
loc, result_types, new_operands);
auto operand_types = llvm::to_vector<16>(
llvm::map_range(new_operands, [](Value v) { return v.getType(); }));
Block *new_body =
rewriter.createBlock(&new_op.body(), {}, operand_types,
SmallVector<Location>(operand_types.size(), loc));
rewriter.setInsertionPointToStart(new_body);
// Map operands and copy operations of the preceding cluster into the new
// body.
BlockAndValueMapping bvm;
for (const auto &it : llvm::enumerate(preceding_body->getArguments()))
bvm.map(it.value(), new_body->getArgument(it.index()));
for (Operation &nested_op : preceding_body->without_terminator())
rewriter.clone(nested_op, bvm);
// Map operands and copy operations of the second cluster. If they result
// from the preceeding cluster, we can simply map the corresponding value
// internally.
for (auto it : llvm::zip(body->getArguments(), op.operands())) {
Value block_arg, operand;
std::tie(block_arg, operand) = it;
if (operand.getDefiningOp() == preceding_op) {
auto where = llvm::find(preceding_op.results(), operand);
assert(where.getBase() != nullptr && "expected to find ");
bvm.map(block_arg,
bvm.lookup(preceding_yield_op.getOperand(where.getIndex())));
} else {
auto where = llvm::find(new_op.operands(), operand);
bvm.map(block_arg, new_body->getArgument(where.getIndex()));
}
}
for (Operation &nested_op : body->without_terminator()) {
rewriter.clone(nested_op, bvm);
}
// Yield inner results.
rewriter.create<chlo::RankSpecializationClusterYieldOp>(
loc,
llvm::to_vector<16>(llvm::map_range(new_unmapped_results, [&](Value v) {
return bvm.lookupOrDefault(v);
})));
// Replace the two cluster ops with the new corresponding results.
SmallVector<Value, 8> preceding_op_replacements;
int64_t i = 0;
for (Value result : preceding_op.results()) {
Value replacement = nullptr;
if (!llvm::all_of(result.getUsers(),
[&](Operation *user) { return user == op; })) {
replacement = new_op->getResult(i++);
}
preceding_op_replacements.push_back(replacement);
}
ValueRange op_replacements = new_op.results().take_back(op.getNumResults());
rewriter.replaceOp(op, op_replacements);
rewriter.replaceOp(preceding_op, preceding_op_replacements);
return success();
}