Doc PrintBlock()

in src/script/printer/tir/block.cc [25:214]


Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p,  //
               Optional<tir::BlockRealize> opt_realize, Optional<ObjectPath> opt_realize_p) {
  With<TIRFrame> frame(d, block);
  ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined());
  const tir::BlockRealizeNode* realize =
      opt_realize.defined() ? opt_realize.value().get() : nullptr;
  const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr;
  // Step 1. Handle block var and block bindings
  // Step 1.1. Obtain all loop var defined along path
  std::unordered_map<const tir::VarNode*, tir::For> loop_vars;
  for (Frame f : d->frames) {
    if (const auto* tir_f = f.as<TIRFrameNode>()) {
      if (auto for_loop = tir_f->tir.as<tir::For>()) {
        for (Optional<tir::For> loop = for_loop; loop; loop = loop.value()->body.as<tir::For>()) {
          loop_vars.insert(std::make_pair(loop.value()->loop_var.get(), loop.value()));
        }
      }
    }
  }

  std::vector<int> remap_vars_indices;
  auto add_remapped_iter_var = [&](int i) -> bool {
    if (realize && d->cfg->syntax_sugar) {
      tir::ExprDeepEqual expr_equal;
      tir::IterVar iter_var = block->iter_vars[i];
      PrimExpr value = realize->iter_values[i];
      if (iter_var->iter_type == tir::IterVarType::kDataPar ||
          iter_var->iter_type == tir::IterVarType::kCommReduce) {
        if (const auto* var = value.as<tir::VarNode>()) {
          if (loop_vars.count(var)) {
            tir::For for_loop = loop_vars.at(var);
            if (expr_equal(for_loop->min, iter_var->dom->min) &&
                expr_equal(for_loop->extent, iter_var->dom->extent)) {
              remap_vars_indices.push_back(i);
              return true;
            }
          }
        }
      }
    }
    return false;
  };

  auto print_single_iter_var = [&](int i) {
    tir::IterVar iter_var = block->iter_vars[i];
    ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
    ExprDoc rhs = TIR(d, "axis");
    if (iter_var->iter_type == tir::IterVarType::kDataPar) {
      rhs = rhs->Attr("spatial");
    } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) {
      rhs = rhs->Attr("reduce");
    } else if (iter_var->iter_type == tir::IterVarType::kOrdered) {
      rhs = rhs->Attr("scan");
    } else if (iter_var->iter_type == tir::IterVarType::kOpaque) {
      rhs = rhs->Attr("opaque");
    } else {
      LOG(FATAL) << "ValueError: Unknown IterVarType in block signature: "
                 << tir::IterVarType2String(iter_var->iter_type);
    }
    ExprDoc dom{nullptr};
    if (tir::is_zero(iter_var->dom->min)) {
      ExprDoc extent = d->AsDoc<ExprDoc>(iter_var->dom->extent,  //
                                         iter_var_p->Attr("dom")->Attr("extent"));
      dom = extent;
    } else {
      ExprDoc min = d->AsDoc<ExprDoc>(iter_var->dom->min, iter_var_p->Attr("dom")->Attr("min"));
      ExprDoc max = d->AsDoc<ExprDoc>(iter_var->dom->min + iter_var->dom->extent,
                                      iter_var_p->Attr("dom")->Attr("extent"));
      dom = TupleDoc({min, max});
    }
    if (realize) {
      ExprDoc binding = d->AsDoc<ExprDoc>(realize->iter_values[i],  //
                                          realize_p->Attr("iter_values")->ArrayIndex(i));
      rhs = rhs->Call({dom, binding});
    } else {
      rhs = rhs->Call({dom});
    }
    (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt));
  };

  auto print_remapped_iter_var = [&]() {
    if (remap_vars_indices.size()) {
      int m = remap_vars_indices.size();
      if (!m) {
        return;
      }
      if (m == 1) {
        print_single_iter_var(remap_vars_indices[0]);
        remap_vars_indices.clear();
        return;
      }
      Array<ExprDoc> lhs;
      Array<ExprDoc> loop_var_doc;
      lhs.reserve(m);
      loop_var_doc.reserve(m);
      std::string binding_type = "";
      Array<ObjectPath> binding_paths;
      for (int i : remap_vars_indices) {
        tir::IterVar iter_var = block->iter_vars[i];
        ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i);
        lhs.push_back(DefineVar(iter_var->var, *frame, d));
        loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
                                                 realize_p->Attr("iter_values")->ArrayIndex(i)));
        binding_paths.push_back(iter_var_p->Attr("iter_type"));
        binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R";
      }
      ExprDoc rhs = TIR(d, "axis")->Attr("remap");
      ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt);
      binding_str->source_paths = std::move(binding_paths);
      rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)});
      (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
      remap_vars_indices.clear();
    }
  };

  // Step 1.2. Construct all block var bindings
  int n_vars = block->iter_vars.size();
  for (int i = 0; i < n_vars; ++i) {
    if (!add_remapped_iter_var(i)) {
      print_remapped_iter_var();
      print_single_iter_var(i);
    }
  }
  print_remapped_iter_var();

  // Step 2. Handle block predicate
  if (realize) {
    ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool());
    if (!tir::is_one(realize->predicate)) {
      (*frame)->stmts.push_back(ExprStmtDoc(
          TIR(d, "where")
              ->Call({d->AsDoc<ExprDoc>(realize->predicate, realize_p->Attr("predicate"))})));
    }
  }
  // Step 3. Handle block read/write regions
  {
    Array<ExprDoc> reads;
    for (int i = 0, n = block->reads.size(); i < n; ++i) {
      reads.push_back(d->AsDoc<ExprDoc>(block->reads[i], block_p->Attr("reads")->ArrayIndex(i)));
    }
    (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads)));
    Array<ExprDoc> writes;
    for (int i = 0, n = block->writes.size(); i < n; ++i) {
      writes.push_back(d->AsDoc<ExprDoc>(block->writes[i], block_p->Attr("writes")->ArrayIndex(i)));
    }
    (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes")->Call(writes)));
  }
  // Step 4. Handle block attributes
  if (!block->annotations.empty()) {
    (*frame)->stmts.push_back(ExprStmtDoc(
        TIR(d, "block_attr")
            ->Call({d->AsDoc<ExprDoc>(block->annotations, block_p->Attr("annotations"))})));
  }
  // Step 5. Handle `alloc_buffer`
  for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) {
    tir::Buffer buffer = block->alloc_buffers[i];
    ObjectPath buffer_p = block_p->Attr("alloc_buffers")->ArrayIndex(i);
    IdDoc lhs = DefineBuffer(buffer, *frame, d);
    ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *frame, d,
                             BufferVarDefinition::DataPointer);
    (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
  }
  // Step 6. Handle `match_buffer`
  for (int i = 0, n = block->match_buffers.size(); i < n; ++i) {
    tir::MatchBufferRegion buffer_region = block->match_buffers[i];
    ObjectPath buffer_region_p = block_p->Attr("match_buffers")->ArrayIndex(i);
    StmtDoc doc = d->AsDoc<StmtDoc>(buffer_region, buffer_region_p);
    (*frame)->stmts.push_back(doc);
  }
  // Step 7. Handle init block
  if (block->init.defined()) {
    tir::Stmt init = block->init.value();
    With<TIRFrame> init_frame(d, init);
    AsDocBody(init, block_p->Attr("init"), init_frame->get(), d);
    (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init")->Call({}), (*init_frame)->stmts));
  }
  // Step 8. Handle block body
  AsDocBody(block->body, block_p->Attr("body"), frame->get(), d);
  Array<String> kwargs_keys;
  Array<ExprDoc> kwargs_values;
  if (!realize) {
    kwargs_keys.push_back("no_realize");
    kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt));
  }
  return ScopeDoc(NullOpt,
                  TIR(d, "block")  //
                      ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))},
                             kwargs_keys, kwargs_values),
                  (*frame)->stmts);
}