inline InlineType AutoInlineNode::CheckInline()

in src/meta_schedule/schedule_rule/auto_inline.cc [99:166]


inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
                                              const tir::BlockRV& block_rv) {
  using namespace tvm::tir;
  StmtSRef block_sref = sch->GetSRef(block_rv);
  bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
  ScheduleState state = sch->state();
  const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
  BlockRealize realize = GetBlockRealize(state, block_sref);
  // Cond 1. The block has only one write buffer
  if (block->writes.size() != 1) {
    return InlineType::kNoInline;
  }
  // Cond 2. For a block that generates a constant tensor, ignore all other conditions
  if (inline_const_tensor && block->reads.empty()) {
    Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
    if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
      return InlineType::kInlineIntoConsumer;
    }
  }
  // Cond 3. The block doesn't contain any disallowed operators
  if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
    return InlineType::kNoInline;
  }
  // Cond 4. The block doesn't have any if-then-else-like constructs
  if (!is_pure_sptial && disallow_if_then_else && HasIfThenElse(realize)) {
    return InlineType::kNoInline;
  }
  // Cond 5. The mapping from read indices to write indices are injective and ordered
  if (!is_pure_sptial && (require_injective || require_ordered)) {
    const BufferRegion& write_region = block->writes[0];
    for (const BufferRegion& read_region : block->reads) {
      bool injective, ordered;
      auto _ = std::ignore;
      std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
               /*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
      if (require_injective && injective == false) {
        return InlineType::kNoInline;
      }
      if (require_ordered && ordered == false) {
        return InlineType::kNoInline;
      }
    }
  }
  // Cond 6. The block is disallowed for auto inline
  if (Optional<String> ann =
          tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_inline_rule)) {
    if (ann.value() == "disable") return InlineType::kNoInline;
  }
  // Last cond: Check inline into the consumers or the spatial producer
  tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
                                                /*require_stage_pipeline=*/false);
  if (into_consumer) {
    Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
    if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
      return InlineType::kInlineIntoConsumer;
    }
  }
  if (into_producer) {
    Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
    if (producer_srefs.size() == 1 &&
        tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
        CanReverseComputeInline(state, block_sref) &&
        !GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
      return InlineType::kInlineIntoProducer;
    }
  }
  return InlineType::kNoInline;
}