in src/auto_scheduler/compute_dag.cc [893:1141]
ComputeDAG ComputeDAG::RewriteLayout(Array<Step>* transform_steps,
LayoutRewriteOption layout_rewrite) const {
CHECK(layout_rewrite != LayoutRewriteOption::NoRewrite)
<< "Call ComputeDAG::RewriteLayout with NoRewrite.";
ComputeDAG new_dag = *this;
ComputeDAGNode* p_dag = new_dag.CopyOnWrite();
auto node = make_object<StateNode>();
node->transform_steps = *transform_steps;
node->concrete = true;
const State& state = InferBound(State(node));
OperationSet handled_ops;
for (size_t stage_id = 0; stage_id < state->stages.size(); stage_id++) {
const auto& stage = state->stages[stage_id];
const te::Operation& op = stage->op;
if (!op->IsInstance<te::ComputeOpNode>()) {
continue;
}
const Map<String, ObjectRef>& attrs = op->attrs;
if (attrs.count(layout_free_placeholders_key) == 0) {
continue;
}
const ObjectRef& attr_value = attrs[layout_free_placeholders_key];
for (const auto& placeholder : Downcast<Array<te::Tensor>>(attr_value)) {
const auto& placeholder_op = placeholder->op;
// Check whether this placeholder has already been handled
if (handled_ops.count(placeholder_op)) {
continue;
}
// Skip the op that is not direct consumer of this placeholder.
// This is usually caused by cache read/write.
bool direct_consumer = false;
for (auto& t : op->InputTensors()) {
if (t->op == placeholder_op) {
direct_consumer = true;
break;
}
}
if (!direct_consumer) {
continue;
}
handled_ops.insert(placeholder_op);
// Process original layout
std::set<std::string> placeholder_axis_names;
std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder);
Array<PrimExpr> origin_shape;
std::vector<std::string> origin_axes;
ParseKernelLayout(origin_layout, &origin_shape, &origin_axes);
// Process new layout
std::string new_layout =
GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names);
Array<PrimExpr> new_shape;
std::vector<std::string> new_axes;
ParseKernelLayout(new_layout, &new_shape, &new_axes);
// Process op updates
te::Operation new_op_to_update;
if (layout_rewrite == LayoutRewriteOption::RewriteForPreTransformed) {
// Create new placeholder
new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
placeholder_op.as<te::PlaceholderOpNode>()->dtype);
} else if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
// Process index strides
std::unordered_map<std::string, PrimExpr> axes_stride;
for (const auto& i : origin_axes) {
axes_stride[i] = Integer(1);
}
Array<PrimExpr> new_stride(new_shape.size(), PrimExpr());
PrimExpr temp = Integer(1);
for (int i = new_shape.size() - 1; i >= 0; i--) {
new_stride.Set(i, axes_stride[new_axes[i]]);
axes_stride[new_axes[i]] *= new_shape[i];
}
// Add an extra layout transform stage
const auto& layout_transform_tensor = te::compute(
new_shape,
[&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes,
&new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
Array<PrimExpr> access_indices;
for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) {
PrimExpr temp = Integer(0);
for (size_t i = 0; i < new_shape.size(); i++) {
if (origin_axes[indice_index].compare(new_axes[i]) == 0) {
temp += indices[i] * new_stride[i];
}
}
access_indices.push_back(temp);
}
return placeholder_op.output(0)(access_indices);
},
"auto_scheduler_layout_transform");
new_op_to_update = layout_transform_tensor->op;
// Update the transform steps
for (size_t i = 0; i < transform_steps->size(); i++) {
Step step = (*transform_steps)[i];
if (step->stage_id >= static_cast<int>(stage_id)) {
step.CopyOnWrite()->stage_id++;
}
if (step->IsInstance<ComputeAtStepNode>()) {
auto compute_at_step = tvm::Downcast<ComputeAtStep>(step);
if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) {
dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++;
}
transform_steps->Set(i, std::move(compute_at_step));
} else {
transform_steps->Set(i, std::move(step));
}
}
// Add schedule for the new added transform stage
Array<Integer> to_fuse;
if (new_shape.size() >= 5) {
to_fuse.push_back(0);
to_fuse.push_back(1);
to_fuse.push_back(2);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
} else if (new_shape.size() >= 3) {
to_fuse.push_back(0);
to_fuse.push_back(1);
transform_steps->push_back(FuseStep(stage_id, to_fuse));
}
transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
}
te::Operation new_compute_op, original_compute_op;
Array<PrimExpr> new_body;
IndexRewriter index_rewriter(placeholder_op, new_layout);
for (const auto& op : p_dag->ops) {
if (auto* pop = op.as<te::ComputeOpNode>()) {
bool need_update = false;
for (auto& t : op->InputTensors()) {
if (t->op == placeholder_op) {
need_update = true;
break;
}
}
if (need_update) {
for (const auto& body : pop->body) {
new_body.push_back(index_rewriter.Rewrite(body));
}
original_compute_op = op;
CHECK(!new_compute_op.defined());
auto new_attrs = pop->attrs;
new_attrs.Set("ori_placeholder_layout", tvm::String(origin_layout));
new_attrs.Set("new_placeholder_layout", tvm::String(new_layout));
new_compute_op = te::ComputeOp(pop->name, pop->tag, new_attrs, pop->axis, new_body);
}
}
}
// construct the map from original_op to new_op
std::unordered_map<te::Operation, te::Operation> updated_ops;
Array<te::Operation> original_ops = p_dag->ops;
p_dag->ops.clear();
for (size_t i = 0; i < original_ops.size(); ++i) {
const auto& original_op = original_ops[i];
if (original_op == placeholder_op) {
if (layout_rewrite == LayoutRewriteOption::InsertTransformStage) {
p_dag->ops.push_back(placeholder_op);
}
p_dag->ops.push_back(new_op_to_update);
updated_ops[placeholder_op] = new_op_to_update;
} else if (original_op == original_compute_op) {
p_dag->ops.push_back(new_compute_op);
updated_ops[original_compute_op] = new_compute_op;
} else {
p_dag->ops.push_back(original_op);
}
}
ArrayNode* pops = p_dag->ops.CopyOnWrite();
// Because ops is sorted in topo-order, only do one pass linear scan here.
for (size_t i = 0; i < pops->size(); ++i) {
const auto& original_op = Downcast<te::Operation>(pops->at(i));
if (auto* pop = original_op.as<te::ComputeOpNode>()) {
if (original_op == new_op_to_update) {
continue;
}
auto inputs = pop->InputTensors();
std::unordered_map<te::Tensor, te::Tensor> rmap;
for (auto input : inputs) {
auto it = updated_ops.find(input->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
it = updated_ops.find(new_op);
}
if (new_op.defined()) {
int index = input->value_index;
rmap[input] = new_op.output(index);
}
}
if (!rmap.empty()) {
te::Operation new_op = pop->ReplaceInputs(original_op, rmap);
updated_ops[original_op] = new_op;
pops->SetItem(i, new_op);
}
}
}
Array<te::Tensor> old_tensors = p_dag->tensors;
ArrayNode* p_tensors = p_dag->tensors.CopyOnWrite();
for (size_t i = 0; i < old_tensors.size(); ++i) {
const auto& old_tensor = old_tensors[i];
if (layout_rewrite != LayoutRewriteOption::RewriteForPreTransformed &&
old_tensor->op->IsInstance<te::PlaceholderOpNode>()) {
continue;
}
auto it = updated_ops.find(old_tensor->op);
te::Operation new_op;
while (it != updated_ops.end()) {
new_op = it->second;
it = updated_ops.find(new_op);
}
if (new_op.defined()) {
auto index = old_tensor->value_index;
p_tensors->SetItem(i, new_op.output(index));
}
}
} // end for placeholder
} // end for stage
p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
Array<te::Operation> out_ops;
for (const auto& op : p_dag->access_analyzer->ops_topo_order) {
if (p_dag->access_analyzer.IsOutput(op)) {
out_ops.push_back(op);
}
}
p_dag->ops.clear();
te::Schedule sch = te::create_schedule(out_ops);
for (auto stage : sch->stages) {
p_dag->ops.push_back(stage->op);
}
p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
p_dag->init_state = State(p_dag->ops);
return new_dag;
}