in src/te/schedule/schedule_dataflow_rewrite.cc [731:948]
Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
(*this)->InvalidateCache();
using tir::ReduceNode;
ICHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
Stage reduce_stage = operator[](tensor->op);
const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
ICHECK(compute_op) << "Can only factor ComputeOp";
ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
{
size_t axis_pos = FindNodeRef(leaf_vars, axis);
ICHECK_NE(axis_pos, leaf_vars->size())
<< "Cannot find IterVar " << axis << " in leaf iter vars";
}
// Find touched reduction axis.
std::unordered_map<IterVar, int> touch_map;
touch_map[axis] = 1;
te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
// skip reduction iteration.
std::unordered_set<IterVar> skip_bound_check;
// Verify normal axis are not touched.
for (IterVar iv : compute_op->axis) {
ICHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
skip_bound_check.insert(iv);
}
// get analyzer.
arith::Analyzer analyzer;
// Get the replace index
std::unordered_map<IterVar, Range> dom_map;
std::unordered_map<IterVar, PrimExpr> value_map;
for (IterVar iv : compute_op->reduce_axis) {
if (touch_map.count(iv)) {
dom_map[iv] = iv->dom;
} else {
skip_bound_check.insert(iv);
}
analyzer.Bind(iv->var, iv->dom);
}
te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv)) {
Range dom = dom_map.at(iv);
if (is_one(dom->extent)) {
value_map[iv] = dom->min;
} else {
value_map[iv] = iv->var;
}
}
}
te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
std::vector<PrimExpr> predicates =
MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);
// Get the factored op node.
const int factor_axis_pos =
factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
ICHECK_LE(factor_axis_pos, compute_op->axis.size());
auto n = make_object<ComputeOpNode>();
n->name = compute_op->name + ".rf";
{
// axis relacement.
auto iv_node = make_object<IterVarNode>();
iv_node->dom = dom_map.at(axis);
ICHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
iv_node->var = axis->var;
iv_node->iter_type = kDataPar;
const int size = compute_op->axis.size();
for (int idx = 0; idx < size; ++idx) {
if (factor_axis_pos == idx) {
n->axis.push_back(IterVar(iv_node));
}
n->axis.push_back(compute_op->axis[idx]);
}
if (factor_axis_pos == size) {
n->axis.push_back(IterVar(iv_node));
}
}
// predicate generation, copy not touched axis.
int idx = tensor->value_index;
const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
ICHECK(reduce) << "Can only rfactor non-inline reductions";
predicates.push_back(reduce->condition);
PrimExpr predicate =
likely(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
const_true(1), predicates));
std::unordered_map<const VarNode*, PrimExpr> vsub;
for (IterVar iv : compute_op->reduce_axis) {
if (!touch_map.count(iv)) {
n->reduce_axis.push_back(iv);
} else {
ICHECK(value_map.count(iv));
PrimExpr index = value_map.at(iv);
vsub[iv->var.get()] = index;
}
}
// Copy touched axis.
for (IterVar iv : reduce_stage->leaf_iter_vars) {
if (touch_map.count(iv) && !iv.same_as(axis)) {
ICHECK_EQ(iv->iter_type, kCommReduce);
auto ncpy = make_object<IterVarNode>(*iv.operator->());
ncpy->dom = dom_map.at(iv);
n->reduce_axis.push_back(IterVar(ncpy));
}
}
VarReplacer replacer(vsub);
Array<PrimExpr> new_source =
tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });
PrimExpr new_pred = replacer(predicate);
std::vector<PrimExpr> body;
for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
}
n->body = Array<PrimExpr>(body);
// refresh relations, keep the un-touched relations.
Array<IterVarRelation> rels;
for (IterVarRelation rel : reduce_stage->relations) {
bool touched = false;
if (const SplitNode* r = rel.as<SplitNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else if (const FuseNode* r = rel.as<FuseNode>()) {
if (touch_map.count(r->fused)) touched = true;
} else if (const RebaseNode* r = rel.as<RebaseNode>()) {
if (touch_map.count(r->parent)) touched = true;
} else {
LOG(FATAL) << "unknown relation type";
}
if (!touched) {
rels.push_back(rel);
}
}
// initialize the factored stage.
Operation factor_op(n);
Array<Stage>& stages = (*this)->stages;
size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
Stage factor_stage = Stage(factor_op);
factor_stage->relations = rels;
ICHECK_LT(stage_pos, stages.size());
stages.insert(stages.begin() + stage_pos, factor_stage);
(*this)->stage_map.Set(factor_op, factor_stage);
factor_stage->group = reduce_stage->group;
if (factor_stage->group.defined()) {
++factor_stage->group->num_child_stages;
}
// Replace the old reduction.
IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
Array<Tensor> factor_tensors;
Array<Tensor> old_tensors;
int size = factor_op->num_outputs();
for (int idx = 0; idx < size; ++idx) {
factor_tensors.push_back(factor_op.output(idx));
old_tensors.push_back(reduce_stage->op.output(idx));
}
Array<Tensor> repl_tensors = compute(
old_tensors[0]->shape,
[&](const Array<Var>& i) {
Array<PrimExpr> indices;
const int idx_size = static_cast<int>(i.size());
for (int idx = 0; idx < idx_size; ++idx) {
if (factor_axis_pos == idx) {
indices.push_back(repl_red_axis->var);
}
indices.push_back(i[idx]);
}
Array<PrimExpr> new_init = reduce->init;
if (!reduce->init.empty()) {
std::unordered_map<const VarNode*, PrimExpr> init_vsub;
for (const auto& init : reduce->init) {
if (init->IsInstance<ProducerLoadNode>()) {
ICHECK_EQ(compute_op->axis.size(), idx_size)
<< "'init' should have the number of dimensions as output when using with "
"rfactor";
for (int idx = 0; idx < idx_size; idx++) {
init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
}
}
}
VarReplacer init_replacer(init_vsub);
new_init = tir::UpdateArray(
reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
}
if (factor_axis_pos == idx_size) {
indices.push_back(repl_red_axis->var);
}
Array<PrimExpr> factor_exprs;
for (int idx = 0; idx < size; ++idx) {
factor_exprs.push_back(factor_tensors[idx](indices));
}
Array<PrimExpr> reductions;
Array<IterVar> axis = {repl_red_axis};
PrimExpr cond = const_true();
for (int idx = 0; idx < size; ++idx) {
reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
}
return reductions;
},
reduce_stage->op->name + ".repl");
std::unordered_map<Tensor, Tensor> vmap;
std::unordered_map<Tensor, Tensor> rvmap;
for (int idx = 0; idx < size; ++idx) {
vmap[old_tensors[idx]] = repl_tensors[idx];
rvmap[repl_tensors[idx]] = old_tensors[idx];
}
ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
// revamp the reduction stage.
reduce_stage->op = repl_tensors[0]->op;
reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
reduce_stage->relations = Array<IterVarRelation>();
return factor_tensors;
}