in src/meta_schedule/schedule_rule/auto_inline.cc [99:166]
inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
const tir::BlockRV& block_rv) {
using namespace tvm::tir;
StmtSRef block_sref = sch->GetSRef(block_rv);
bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref);
ScheduleState state = sch->state();
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
BlockRealize realize = GetBlockRealize(state, block_sref);
// Cond 1. The block has only one write buffer
if (block->writes.size() != 1) {
return InlineType::kNoInline;
}
// Cond 2. For a block that generates a constant tensor, ignore all other conditions
if (inline_const_tensor && block->reads.empty()) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
return InlineType::kInlineIntoConsumer;
}
}
// Cond 3. The block doesn't contain any disallowed operators
if (!is_pure_sptial && !disallow_op.empty() && HasOp(realize, disallow_op)) {
return InlineType::kNoInline;
}
// Cond 4. The block doesn't have any if-then-else-like constructs
if (!is_pure_sptial && disallow_if_then_else && HasIfThenElse(realize)) {
return InlineType::kNoInline;
}
// Cond 5. The mapping from read indices to write indices are injective and ordered
if (!is_pure_sptial && (require_injective || require_ordered)) {
const BufferRegion& write_region = block->writes[0];
for (const BufferRegion& read_region : block->reads) {
bool injective, ordered;
auto _ = std::ignore;
std::tie(/*exists=*/_, /*surjective=*/_, injective, ordered, /*no_const_read=*/_,
/*no_shift_read=*/_) = AnalyzeReadWritePattern(read_region, write_region);
if (require_injective && injective == false) {
return InlineType::kNoInline;
}
if (require_ordered && ordered == false) {
return InlineType::kNoInline;
}
}
}
// Cond 6. The block is disallowed for auto inline
if (Optional<String> ann =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_inline_rule)) {
if (ann.value() == "disable") return InlineType::kNoInline;
}
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false);
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) {
return InlineType::kInlineIntoConsumer;
}
}
if (into_producer) {
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref) &&
!GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
return InlineType::kInlineIntoProducer;
}
}
return InlineType::kNoInline;
}