Map BufferAttrs()

in src/script/printer/tir/buffer.cc [27:183]


Map<String, ExprDoc> BufferAttrs(tir::Buffer buffer, const ObjectPath& buffer_p, const Frame& frame,
                                 const IRDocsifier& d, BufferVarDefinition var_definitions) {
  using tvm::tir::Var;
  using tvm::tir::VarNode;
  Map<String, ExprDoc> kwargs;
  Array<ExprDoc> var_def_lhs;
  Array<ExprDoc> var_def_rhs;

  // Step 0. Set up statistics
  std::unordered_map<const Object*, int> use_count;
  auto update_use_count = [&](const PrimExpr& e) {
    tir::PostOrderVisit(e, [&](const ObjectRef& n) {
      if (const VarNode* var = n.as<VarNode>()) {
        ++use_count[var];
      }
    });
  };
  update_use_count(buffer->elem_offset);
  update_use_count(buffer->data);
  for (const PrimExpr& e : buffer->strides) {
    update_use_count(e);
  }
  for (const PrimExpr& e : buffer->shape) {
    update_use_count(e);
  }
  auto is_new_var = [&](const PrimExpr& e) {
    return e->IsInstance<VarNode>() && !d->IsVarDefined(e);
  };
  auto add_out_of_line_var_def = [&](const Var& var, const ObjectPath& var_p) {
    ICHECK(!d->IsVarDefined(var));
    ExprDoc lhs = DefineVar(var, frame, d);
    lhs->source_paths.push_back(var_p);
    var_def_lhs.push_back(lhs);
    var_def_rhs.push_back(PrintVarCreation(var, var_p, d));
  };
  auto try_inline_def = [&](const PrimExpr& e, const ObjectPath& e_p,
                            std::function<ExprDoc()> inline_f) {
    ICHECK(is_new_var(e));
    Var var = Downcast<Var>(e);
    if (use_count[var.get()] == 1) {
      d->Define(e, frame, inline_f);
      return true;
    } else {
      add_out_of_line_var_def(var, e_p);
      return false;
    }
  };
  // Step 1. Handle `buffer.shape`
  {
    const Array<PrimExpr>& shape = buffer->shape;
    ObjectPath shape_p = buffer_p->Attr("shape");
    int n = shape.size();
    Array<ExprDoc> results;
    results.reserve(n);
    for (int i = 0; i < n; ++i) {
      PrimExpr e = shape[i];
      ObjectPath e_p = shape_p->ArrayIndex(i);
      if (is_new_var(e)) {
        add_out_of_line_var_def(Downcast<Var>(e), e_p);
      }
      results.push_back(d->AsDoc<ExprDoc>(e, e_p));
    }
    kwargs.Set("shape", TupleDoc(results));
  }
  // Step 2. Handle `buffer.dtype`
  if (buffer->dtype != d->cfg->buffer_dtype) {
    kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, buffer_p->Attr("dtype")));
  }
  // Step 3. Handle `buffer.data`
  bool is_inline_data = false;
  if (is_new_var(buffer->data)) {
    if (var_definitions >= BufferVarDefinition::DataPointer) {
      is_inline_data = try_inline_def(buffer->data, buffer_p->Attr("data"), [=]() {
        return d->AsDoc<ExprDoc>(buffer, buffer_p)->Attr("data");
      });
    } else {
      add_out_of_line_var_def(buffer->data, buffer_p->Attr("data"));
    }
  }
  if (!is_inline_data) {
    kwargs.Set("data", d->AsDoc<ExprDoc>(buffer->data, buffer_p->Attr("data")));
  }
  // Step 4. Handle `buffer.strides`
  if (!buffer->strides.empty()) {
    const Array<PrimExpr>& strides = buffer->strides;
    ObjectPath strides_p = buffer_p->Attr("strides");
    int n = strides.size();
    Array<ExprDoc> results;
    results.reserve(n);
    for (int i = 0; i < n; ++i) {
      PrimExpr e = strides[i];
      ObjectPath e_p = strides_p->ArrayIndex(i);
      if (is_new_var(e)) {
        if (try_inline_def(e, e_p, [=]() {
              return d->AsDoc<ExprDoc>(buffer, buffer_p)
                  ->Attr("strides")[{LiteralDoc::Int(i, NullOpt)}];
            })) {
          results.push_back(LiteralDoc::Str(Downcast<Var>(e)->name_hint, e_p));
          continue;
        }
      }
      results.push_back(d->AsDoc<ExprDoc>(e, e_p));
    }
    kwargs.Set("strides", TupleDoc(results));
  }
  // Step 5. Handle `buffer.elem_offset`
  bool needs_print_factor = false;
  if (const auto* int_imm = buffer->elem_offset.as<IntImmNode>()) {
    if (int_imm->value != 0) {
      kwargs.Set("elem_offset",
                 d->AsDoc<ExprDoc>(buffer->elem_offset,  //
                                   buffer_p->Attr("elem_offset")));
    }
  } else if (is_new_var(buffer->elem_offset)) {
    try_inline_def(buffer->elem_offset, buffer_p->Attr("elem_offset"),
                   [=]() { return d->AsDoc<ExprDoc>(buffer, buffer_p)->Attr("elem_offset"); });
    needs_print_factor = true;
  } else {
    kwargs.Set("elem_offset",
               d->AsDoc<ExprDoc>(buffer->elem_offset,  //
                                 buffer_p->Attr("elem_offset")));
  }
  // Step 6. Handle `buffer.scope`
  {
    String scope = buffer.scope();
    if (scope != "global") {
      kwargs.Set(
          "scope",
          LiteralDoc::Str(scope,
                          buffer_p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
    }
  }
  // Step 7. Handle `buffer.data_alignment`
  if (buffer->data_alignment != runtime::kAllocAlignment) {
    kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, buffer_p->Attr("data_alignment")));
  }
  // Step 8. Handle `buffer.offset_factor`
  if (needs_print_factor || buffer->offset_factor != 1) {
    kwargs.Set("offset_factor",
               LiteralDoc::Int(buffer->offset_factor, buffer_p->Attr("offset_factor")));
  }
  // Step 9. Handle `buffer.buffer_type`
  if (buffer->buffer_type != tir::BufferType::kDefault) {
    kwargs.Set("buffer_type", LiteralDoc::Str("auto", buffer_p->Attr("buffer_type")));
  }
  // Step 10. Handle `buffer.axis_separator`
  if (!buffer->axis_separators.empty()) {
    kwargs.Set("axis_separators",
               d->AsDoc<ExprDoc>(buffer->axis_separators, buffer_p->Attr("axis_separators")));
  }
  if (var_def_lhs.size() == 1) {
    frame->stmts.push_back(AssignDoc(var_def_lhs[0], var_def_rhs[0], NullOpt));
  } else if (var_def_lhs.size() > 1) {
    frame->stmts.push_back(AssignDoc(TupleDoc(var_def_lhs), TupleDoc(var_def_rhs), NullOpt));
  }
  return kwargs;
}