in src/script/printer/tir/buffer.cc [27:183]
Map<String, ExprDoc> BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, const Frame& frame,
const IRDocsifier& d, BufferVarDefinition var_definitions) {
using tvm::tir::Var;
using tvm::tir::VarNode;
Map<String, ExprDoc> kwargs;
Array<ExprDoc> var_def_lhs;
Array<ExprDoc> var_def_rhs;
// Step 0. Set up statistics
std::unordered_map<const Object*, int> use_count;
auto update_use_count = [&](const PrimExpr& e) {
tir::PostOrderVisit(e, [&](const ObjectRef& n) {
if (const VarNode* var = n.as<VarNode>()) {
++use_count[var];
}
});
};
update_use_count(buffer->elem_offset);
update_use_count(buffer->data);
for (const PrimExpr& e : buffer->strides) {
update_use_count(e);
}
for (const PrimExpr& e : buffer->shape) {
update_use_count(e);
}
auto is_new_var = [&](const PrimExpr& e) {
return e->IsInstance<VarNode>() && !d->IsVarDefined(e);
};
auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) {
ICHECK(!d->IsVarDefined(var));
ExprDoc lhs = DefineVar(var, frame, d);
lhs->source_paths.push_back(var_p);
var_def_lhs.push_back(lhs);
var_def_rhs.push_back(PrintVarCreation(var, var_p, d));
};
auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p,
std::function<ExprDoc()> inline_f) {
ICHECK(is_new_var(e));
Var var = Downcast<Var>(e);
if (use_count[var.get()] == 1) {
d->Define(e, frame, inline_f);
return true;
} else {
add_out_of_line_var_def(var, e_p);
return false;
}
};
// Step 1. Handle `buffer.shape`
{
const Array<PrimExpr>& shape = buffer->shape;
ObjectPath shape_p = buffer_p->Attr("shape");
int n = shape.size();
Array<ExprDoc> results;
results.reserve(n);
for (int i = 0; i < n; ++i) {
PrimExpr e = shape[i];
ObjectPath e_p = shape_p->ArrayIndex(i);
if (is_new_var(e)) {
add_out_of_line_var_def(Downcast<Var>(e), e_p);
}
results.push_back(d->AsDoc<ExprDoc>(e, e_p));
}
kwargs.Set("shape", TupleDoc(results));
}
// Step 2. Handle `buffer.dtype`
if (buffer->dtype != d->cfg->buffer_dtype) {
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype")));
}
// Step 3. Handle `buffer.data`
bool is_inline_data = false;
if (is_new_var(buffer->data)) {
if (var_definitions >= BufferVarDefinition::DataPointer) {
is_inline_data = try_inline_def(buffer->data, buffer_p->Attr("data"), [=]() {
return d->AsDoc<ExprDoc>(buffer, buffer_p)->Attr("data");
});
} else {
add_out_of_line_var_def(buffer->data, buffer_p->Attr("data"));
}
}
if (!is_inline_data) {
kwargs.Set("data", d->AsDoc<ExprDoc>(buffer->data, buffer_p->Attr("data")));
}
// Step 4. Handle `buffer.strides`
if (!buffer->strides.empty()) {
const Array<PrimExpr>& strides = buffer->strides;
ObjectPath strides_p = buffer_p->Attr("strides");
int n = strides.size();
Array<ExprDoc> results;
results.reserve(n);
for (int i = 0; i < n; ++i) {
PrimExpr e = strides[i];
ObjectPath e_p = strides_p->ArrayIndex(i);
if (is_new_var(e)) {
if (try_inline_def(e, e_p, [=]() {
return d->AsDoc<ExprDoc>(buffer, buffer_p)
->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}];
})) {
results.push_back(LiteralDoc::Str(Downcast<Var>(e)->name_hint, e_p));
continue;
}
}
results.push_back(d->AsDoc<ExprDoc>(e, e_p));
}
kwargs.Set("strides", TupleDoc(results));
}
// Step 5. Handle `buffer.elem_offset`
bool needs_print_factor = false;
if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
if (int_imm->value != 0) {
kwargs.Set("elem_offset",
d->AsDoc<ExprDoc>(buffer->elem_offset, //
buffer_p->Attr("elem_offset")));
}
} else if (is_new_var(buffer->elem_offset)) {
try_inline_def(buffer->elem_offset, buffer_p->Attr("elem_offset"),
[=]() { return d->AsDoc<ExprDoc>(buffer, buffer_p)->Attr("elem_offset"); });
needs_print_factor = true;
} else {
kwargs.Set("elem_offset",
d->AsDoc<ExprDoc>(buffer->elem_offset, //
buffer_p->Attr("elem_offset")));
}
// Step 6. Handle `buffer.scope`
{
String scope = buffer.scope();
if (scope != "global") {
kwargs.Set(
"scope",
LiteralDoc::Str(scope,
buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
}
}
// Step 7. Handle `buffer.data_alignment`
if (buffer->data_alignment != runtime::kAllocAlignment) {
kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, buffer_p->Attr("data_alignment")));
}
// Step 8. Handle `buffer.offset_factor`
if (needs_print_factor || buffer->offset_factor != 1) {
kwargs.Set("offset_factor",
LiteralDoc::Int(buffer->offset_factor, buffer_p->Attr("offset_factor")));
}
// Step 9. Handle `buffer.buffer_type`
if (buffer->buffer_type != tir::BufferType::kDefault) {
kwargs.Set("buffer_type", LiteralDoc::Str("auto", buffer_p->Attr("buffer_type")));
}
// Step 10. Handle `buffer.axis_separator`
if (!buffer->axis_separators.empty()) {
kwargs.Set("axis_separators",
d->AsDoc<ExprDoc>(buffer->axis_separators, buffer_p->Attr("axis_separators")));
}
if (var_def_lhs.size() == 1) {
frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt));
} else if (var_def_lhs.size() > 1) {
frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), TupleDoc(var_def_rhs), NullOpt));
}
return kwargs;
}