ComputeDAG ComputeDAG::RewriteLayout()

in src/auto_scheduler/compute_dag.cc [893:1141]


ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
                                     LayoutRewriteOption layout_rewrite) const {
  CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite)
      << "Call ComputeDAG::RewriteLayout with NoRewrite.";
  ComputeDAG new_dag = *this;
  ComputeDAGNode* p_dag = new_dag.CopyOnWrite();

  auto node = make_object<StateNode>();
  node->transform_steps = *transform_steps;
  node->concrete = true;
  const State& state = InferBound(State(node));

  OperationSet handled_ops;
  for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) {
    const auto& stage = state->stages[stage_id];

    const te::Operation& op = stage->op;
    if (!op->IsInstance<te::ComputeOpNode>()) {
      continue;
    }
    const Map<String, ObjectRef>& attrs = op->attrs;
    if (attrs.count(layout_free_placeholders_key) == 0) {
      continue;
    }
    const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
    for (const auto& placeholder : Downcast<Array<te::Tensor>>(attr_value)) {
      const auto& placeholder_op = placeholder->op;

      // Check whether this placeholder has already been handled
      if (handled_ops.count(placeholder_op)) {
        continue;
      }
      // Skip the op that is not direct consumer of this placeholder.
      // This is usually caused by cache read/write.
      bool direct_consumer = false;
      for (auto& t : op->InputTensors()) {
        if (t->op == placeholder_op) {
          direct_consumer = true;
          break;
        }
      }
      if (!direct_consumer) {
        continue;
      }
      handled_ops.insert(placeholder_op);

      // Process original layout
      std::set<std::string> placeholder_axis_names;
      std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder);
      Array<PrimExpr> origin_shape;
      std::vector<std::string> origin_axes;
      ParseKernelLayout(origin_layout, &origin_shape, &origin_axes);

      // Process new layout
      std::string new_layout =
          GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names);
      Array<PrimExpr> new_shape;
      std::vector<std::string> new_axes;
      ParseKernelLayout(new_layout, &new_shape, &new_axes);

      // Process op updates
      te::Operation new_op_to_update;
      if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) {
        // Create new placeholder
        new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
                                             placeholder_op.as<te::PlaceholderOpNode>()->dtype);
      } else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
        // Process index strides
        std::unordered_map<std::string, PrimExpr> axes_stride;
        for (const auto& i : origin_axes) {
          axes_stride[i] = Integer(1);
        }
        Array<PrimExpr> new_stride(new_shape.size(), PrimExpr());
        PrimExpr temp = Integer(1);
        for (int i = new_shape.size() - 1; i >= 0; i--) {
          new_stride.Set(i, axes_stride[new_axes[i]]);
          axes_stride[new_axes[i]] *= new_shape[i];
        }

        // Add an extra layout transform stage
        const auto& layout_transform_tensor = te::compute(
            new_shape,
            [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes,
             &new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
              Array<PrimExpr> access_indices;
              for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) {
                PrimExpr temp = Integer(0);
                for (size_t i = 0; i < new_shape.size(); i++) {
                  if (origin_axes[indice_index].compare(new_axes[i]) == 0) {
                    temp += indices[i] * new_stride[i];
                  }
                }
                access_indices.push_back(temp);
              }
              return placeholder_op.output(0)(access_indices);
            },
            "auto_scheduler_layout_transform");
        new_op_to_update = layout_transform_tensor->op;

        // Update the transform steps
        for (size_t i = 0; i < transform_steps->size(); i++) {
          Step step = (*transform_steps)[i];
          if (step->stage_id >= static_cast<int>(stage_id)) {
            step.CopyOnWrite()->stage_id++;
          }
          if (step->IsInstance<ComputeAtStepNode>()) {
            auto compute_at_step = tvm::Downcast<ComputeAtStep>(step);
            if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) {
              dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++;
            }
            transform_steps->Set(i, std::move(compute_at_step));
          } else {
            transform_steps->Set(i, std::move(step));
          }
        }

        // Add schedule for the new added transform stage
        Array<Integer> to_fuse;

        if (new_shape.size() >= 5) {
          to_fuse.push_back(0);
          to_fuse.push_back(1);
          to_fuse.push_back(2);
          transform_steps->push_back(FuseStep(stage_id, to_fuse));
        } else if (new_shape.size() >= 3) {
          to_fuse.push_back(0);
          to_fuse.push_back(1);
          transform_steps->push_back(FuseStep(stage_id, to_fuse));
        }
        transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
      }

      te::Operation new_compute_op, original_compute_op;
      Array<PrimExpr> new_body;
      IndexRewriter index_rewriter(placeholder_op, new_layout);
      for (const auto& op : p_dag->ops) {
        if (auto* pop = op.as<te::ComputeOpNode>()) {
          bool need_update = false;
          for (auto& t : op->InputTensors()) {
            if (t->op == placeholder_op) {
              need_update = true;
              break;
            }
          }
          if (need_update) {
            for (const auto& body : pop->body) {
              new_body.push_back(index_rewriter.Rewrite(body));
            }
            original_compute_op = op;
            CHECK(!new_compute_op.defined());
            auto new_attrs = pop->attrs;
            new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
            new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
            new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
          }
        }
      }

      // construct the map from original_op to new_op
      std::unordered_map<te::Operation, te::Operation> updated_ops;

      Array<te::Operation> original_ops = p_dag->ops;
      p_dag->ops.clear();
      for (size_t i = 0; i < original_ops.size(); ++i) {
        const auto& original_op = original_ops[i];
        if (original_op == placeholder_op) {
          if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
            p_dag->ops.push_back(placeholder_op);
          }
          p_dag->ops.push_back(new_op_to_update);
          updated_ops[placeholder_op] = new_op_to_update;
        } else if (original_op == original_compute_op) {
          p_dag->ops.push_back(new_compute_op);
          updated_ops[original_compute_op] = new_compute_op;
        } else {
          p_dag->ops.push_back(original_op);
        }
      }

      ArrayNode* pops = p_dag->ops.CopyOnWrite();
      // Because ops is sorted in topo-order, only do one pass linear scan here.
      for (size_t i = 0; i < pops->size(); ++i) {
        const auto& original_op = Downcast<te::Operation>(pops->at(i));
        if (auto* pop = original_op.as<te::ComputeOpNode>()) {
          if (original_op == new_op_to_update) {
            continue;
          }
          auto inputs = pop->InputTensors();
          std::unordered_map<te::Tensor, te::Tensor> rmap;
          for (auto input : inputs) {
            auto it = updated_ops.find(input->op);
            te::Operation new_op;
            while (it != updated_ops.end()) {
              new_op = it->second;
              it = updated_ops.find(new_op);
            }
            if (new_op.defined()) {
              int index = input->value_index;
              rmap[input] = new_op.output(index);
            }
          }
          if (!rmap.empty()) {
            te::Operation new_op = pop->ReplaceInputs(original_op, rmap);
            updated_ops[original_op] = new_op;
            pops->SetItem(i, new_op);
          }
        }
      }

      Array<te::Tensor> old_tensors = p_dag->tensors;
      ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
      for (size_t i = 0; i < old_tensors.size(); ++i) {
        const auto& old_tensor = old_tensors[i];
        if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed &&
            old_tensor->op->IsInstance<te::PlaceholderOpNode>()) {
          continue;
        }
        auto it = updated_ops.find(old_tensor->op);
        te::Operation new_op;
        while (it != updated_ops.end()) {
          new_op = it->second;
          it = updated_ops.find(new_op);
        }
        if (new_op.defined()) {
          auto index = old_tensor->value_index;
          p_tensors->SetItem(i, new_op.output(index));
        }
      }
    }  // end for placeholder
  }    // end for stage
  p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);

  Array<te::Operation> out_ops;
  for (const auto& op : p_dag->access_analyzer->ops_topo_order) {
    if (p_dag->access_analyzer.IsOutput(op)) {
      out_ops.push_back(op);
    }
  }

  p_dag->ops.clear();
  te::Schedule sch = te::create_schedule(out_ops);
  for (auto stage : sch->stages) {
    p_dag->ops.push_back(stage->op);
  }
  p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
  p_dag->init_state = State(p_dag->ops);

  return new_dag;
}