Array Schedule::rfactor()

in src/te/schedule/schedule_dataflow_rewrite.cc [731:948]


Array<Tensor> Schedule::rfactor(const Tensor& tensor, const IterVar& axis, int factor_axis) {
  (*this)->InvalidateCache();
  using tir::ReduceNode;
  ICHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis";
  Stage reduce_stage = operator[](tensor->op);
  const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>();
  ICHECK(compute_op) << "Can only factor ComputeOp";
  ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite();
  {
    size_t axis_pos = FindNodeRef(leaf_vars, axis);
    ICHECK_NE(axis_pos, leaf_vars->size())
        << "Cannot find IterVar " << axis << " in leaf iter vars";
  }
  // Find touched reduction axis.
  std::unordered_map<IterVar, int> touch_map;
  touch_map[axis] = 1;
  te::PassUpBitMaskOr(reduce_stage, &touch_map, true);
  te::PassDownBitMaskOr(reduce_stage, &touch_map, true);
  // skip reduction iteration.
  std::unordered_set<IterVar> skip_bound_check;
  // Verify normal axis are not touched.
  for (IterVar iv : compute_op->axis) {
    ICHECK(!touch_map.count(iv)) << "Factor axis touches normal axis.";
    skip_bound_check.insert(iv);
  }
  // get analyzer.
  arith::Analyzer analyzer;
  // Get the replace index
  std::unordered_map<IterVar, Range> dom_map;
  std::unordered_map<IterVar, PrimExpr> value_map;
  for (IterVar iv : compute_op->reduce_axis) {
    if (touch_map.count(iv)) {
      dom_map[iv] = iv->dom;
    } else {
      skip_bound_check.insert(iv);
    }
    analyzer.Bind(iv->var, iv->dom);
  }
  te::PassDownDomain(reduce_stage, &dom_map, &analyzer, true);
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv)) {
      Range dom = dom_map.at(iv);
      if (is_one(dom->extent)) {
        value_map[iv] = dom->min;
      } else {
        value_map[iv] = iv->var;
      }
    }
  }
  te::PassUpIndex(reduce_stage, dom_map, &value_map, true);
  std::vector<PrimExpr> predicates =
      MakeBoundCheck(reduce_stage, dom_map, value_map, true, skip_bound_check);

  // Get the factored op node.
  const int factor_axis_pos =
      factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis;
  ICHECK_LE(factor_axis_pos, compute_op->axis.size());
  auto n = make_object<ComputeOpNode>();
  n->name = compute_op->name + ".rf";
  {
    // axis relacement.
    auto iv_node = make_object<IterVarNode>();
    iv_node->dom = dom_map.at(axis);
    ICHECK(is_zero(iv_node->dom->min)) << "Can only factor reduction domain starting from 0";
    iv_node->var = axis->var;
    iv_node->iter_type = kDataPar;

    const int size = compute_op->axis.size();
    for (int idx = 0; idx < size; ++idx) {
      if (factor_axis_pos == idx) {
        n->axis.push_back(IterVar(iv_node));
      }
      n->axis.push_back(compute_op->axis[idx]);
    }
    if (factor_axis_pos == size) {
      n->axis.push_back(IterVar(iv_node));
    }
  }
  // predicate generation, copy not touched axis.
  int idx = tensor->value_index;
  const ReduceNode* reduce = compute_op->body[idx].as<ReduceNode>();
  ICHECK(reduce) << "Can only rfactor non-inline reductions";
  predicates.push_back(reduce->condition);

  PrimExpr predicate =
      likely(foldl([](PrimExpr a, PrimExpr b, Span span) { return logical_and(a, b, span); },
                   const_true(1), predicates));

  std::unordered_map<const VarNode*, PrimExpr> vsub;

  for (IterVar iv : compute_op->reduce_axis) {
    if (!touch_map.count(iv)) {
      n->reduce_axis.push_back(iv);
    } else {
      ICHECK(value_map.count(iv));
      PrimExpr index = value_map.at(iv);
      vsub[iv->var.get()] = index;
    }
  }

  // Copy touched axis.
  for (IterVar iv : reduce_stage->leaf_iter_vars) {
    if (touch_map.count(iv) && !iv.same_as(axis)) {
      ICHECK_EQ(iv->iter_type, kCommReduce);
      auto ncpy = make_object<IterVarNode>(*iv.operator->());
      ncpy->dom = dom_map.at(iv);
      n->reduce_axis.push_back(IterVar(ncpy));
    }
  }
  VarReplacer replacer(vsub);
  Array<PrimExpr> new_source =
      tir::UpdateArray(reduce->source, [&replacer](const PrimExpr& e) { return replacer(e); });

  PrimExpr new_pred = replacer(predicate);

  std::vector<PrimExpr> body;
  for (size_t idx = 0; idx < reduce->source.size(); ++idx) {
    body.emplace_back(Reduce(reduce->combiner, new_source, n->reduce_axis, new_pred, idx, {}));
  }
  n->body = Array<PrimExpr>(body);
  // refresh relations, keep the un-touched relations.
  Array<IterVarRelation> rels;
  for (IterVarRelation rel : reduce_stage->relations) {
    bool touched = false;
    if (const SplitNode* r = rel.as<SplitNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else if (const FuseNode* r = rel.as<FuseNode>()) {
      if (touch_map.count(r->fused)) touched = true;
    } else if (const RebaseNode* r = rel.as<RebaseNode>()) {
      if (touch_map.count(r->parent)) touched = true;
    } else {
      LOG(FATAL) << "unknown relation type";
    }
    if (!touched) {
      rels.push_back(rel);
    }
  }
  // initialize the factored stage.
  Operation factor_op(n);
  Array<Stage>& stages = (*this)->stages;
  size_t stage_pos = FindNodeRef(stages.GetArrayNode(), reduce_stage);
  Stage factor_stage = Stage(factor_op);
  factor_stage->relations = rels;
  ICHECK_LT(stage_pos, stages.size());
  stages.insert(stages.begin() + stage_pos, factor_stage);
  (*this)->stage_map.Set(factor_op, factor_stage);
  factor_stage->group = reduce_stage->group;
  if (factor_stage->group.defined()) {
    ++factor_stage->group->num_child_stages;
  }
  // Replace the old reduction.
  IterVar repl_red_axis = reduce_axis(dom_map.at(axis), axis->var->name_hint + ".v");
  Array<Tensor> factor_tensors;
  Array<Tensor> old_tensors;
  int size = factor_op->num_outputs();
  for (int idx = 0; idx < size; ++idx) {
    factor_tensors.push_back(factor_op.output(idx));
    old_tensors.push_back(reduce_stage->op.output(idx));
  }
  Array<Tensor> repl_tensors = compute(
      old_tensors[0]->shape,
      [&](const Array<Var>& i) {
        Array<PrimExpr> indices;
        const int idx_size = static_cast<int>(i.size());
        for (int idx = 0; idx < idx_size; ++idx) {
          if (factor_axis_pos == idx) {
            indices.push_back(repl_red_axis->var);
          }
          indices.push_back(i[idx]);
        }
        Array<PrimExpr> new_init = reduce->init;
        if (!reduce->init.empty()) {
          std::unordered_map<const VarNode*, PrimExpr> init_vsub;
          for (const auto& init : reduce->init) {
            if (init->IsInstance<ProducerLoadNode>()) {
              ICHECK_EQ(compute_op->axis.size(), idx_size)
                  << "'init' should have the number of dimensions as output when using with "
                     "rfactor";
              for (int idx = 0; idx < idx_size; idx++) {
                init_vsub[compute_op->axis[idx]->var.get()] = i[idx];
              }
            }
          }
          VarReplacer init_replacer(init_vsub);
          new_init = tir::UpdateArray(
              reduce->init, [&init_replacer](const PrimExpr& e) { return init_replacer(e); });
        }
        if (factor_axis_pos == idx_size) {
          indices.push_back(repl_red_axis->var);
        }
        Array<PrimExpr> factor_exprs;
        for (int idx = 0; idx < size; ++idx) {
          factor_exprs.push_back(factor_tensors[idx](indices));
        }
        Array<PrimExpr> reductions;
        Array<IterVar> axis = {repl_red_axis};
        PrimExpr cond = const_true();
        for (int idx = 0; idx < size; ++idx) {
          reductions.push_back(Reduce(reduce->combiner, factor_exprs, axis, cond, idx, new_init));
        }
        return reductions;
      },
      reduce_stage->op->name + ".repl");

  std::unordered_map<Tensor, Tensor> vmap;
  std::unordered_map<Tensor, Tensor> rvmap;
  for (int idx = 0; idx < size; ++idx) {
    vmap[old_tensors[idx]] = repl_tensors[idx];
    rvmap[repl_tensors[idx]] = old_tensors[idx];
  }
  ReplaceDataFlow((*this)->stages, &vmap, &rvmap);
  // revamp the reduction stage.
  reduce_stage->op = repl_tensors[0]->op;
  reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars();
  reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars;
  reduce_stage->relations = Array<IterVarRelation>();
  return factor_tensors;
}