in src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc [174:336]
void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) {
StmtSRef block_sref = sch->GetSRef(block_rv);
if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) {
return;
}
const int n_loops = loop_rvs.size();
if (n_loops == 0) {
parsed->max_parallel_extent = -1;
parsed->max_vectorize_extent = -1;
return;
}
// Extract loop_srefs, and calculate the iterator types
Array<StmtSRef> loop_srefs;
std::vector<int> loop_types;
{
loop_srefs.reserve(n_loops);
loop_types.reserve(n_loops);
for (const LoopRV& loop_rv : loop_rvs) {
loop_srefs.push_back(sch->GetSRef(loop_rv));
loop_types.push_back(GetLoopIterType(loop_srefs.back()));
}
}
// check the maximal number of axes that are vectorizable (contiguous memory access)
BlockRealize realize = GetBlockRealize(sch->state(), block_sref);
Array<BufferRegion> buffer_access(realize->block->reads);
buffer_access.insert(buffer_access.end(), realize->block->writes.begin(),
realize->block->writes.end());
std::unordered_map<const VarNode*, PrimExpr> binding_map;
for (size_t i = 0; i < realize->iter_values.size(); i++) {
binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i];
}
int max_fusible = INT32_MAX;
// for each block read/write, get the strides of the loop vars and find the fusible
// (vectorizable) axes
for (const BufferRegion& access : buffer_access) {
int fusible = 0;
std::vector<int64_t> strides;
// get strides for each loop var
for (const StmtSRef& loop_sref : loop_srefs) {
int64_t stride = 0, buffer_stride = 1;
const auto* var = loop_sref->StmtAs<ForNode>();
arith::Analyzer analyzer;
for (int i = access->region.size() - 1; i >= 0; i--) {
PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map));
int64_t coef = StrideExtractor::Extract(idx, var->loop_var);
if (coef != 0) {
stride = coef * buffer_stride;
break;
}
buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
}
strides.push_back(stride);
}
int prev_used_iter = -1;
// check the number of fusible loops
for (int i = strides.size() - 1; i >= 0; i--) {
if (strides[i] == 0) {
// not used in the buffer access, safe to fuse
fusible++;
continue;
} else if (prev_used_iter == -1) {
// the stride of last axis is not 1 means the memory access is not contiguous
if (strides[i] != 1 && fusible != 0) {
break;
}
fusible++;
prev_used_iter = i;
} else {
// contiguous memory access
const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value;
if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
fusible++;
prev_used_iter = i;
} else {
break;
}
}
}
max_fusible = std::min(max_fusible, fusible);
}
// Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization.
int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types);
// Calculate the parallelize extent
if (parsed->max_parallel_extent != -1) {
int max_extent = parsed->max_parallel_extent;
int& num_fusible = parsed->num_parallel_loops = 0;
int64_t prod_extent = 1;
for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
break;
}
// Check if the loop extent is valid
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (extent == nullptr) {
break;
}
// Then we can fuse it in
++num_fusible;
// Check if we need to break
prod_extent *= *extent;
if (prod_extent > max_extent || !IsSingleStmt(loop->body)) {
break;
}
}
if (prod_extent == 1) {
num_fusible = -1;
}
}
// Calculate the vectorize extent
if (parsed->max_vectorize_extent != -1) {
int max_extent = parsed->max_vectorize_extent;
int& num_fusible = parsed->num_vectorize_loops = 0;
int64_t prod_extent = 1;
for (int i = n_loops - 1;
i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) {
const StmtSRef& loop_sref = loop_srefs[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
if (HasAnnOrBinding(loop)) {
break;
}
// Cannot vectorize reduce axis
if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) {
break;
}
// Cannot fuse with a loop with multiple children
if (!IsSingleStmt(loop->body)) {
break;
}
// Check if the loop extent is valid
const int64_t* extent = GetLoopIntExtent(loop_sref);
if (extent == nullptr) {
break;
}
// Check if the extent is still in a good range
prod_extent *= *extent;
if (prod_extent > max_extent) {
break;
}
++num_fusible;
}
if (prod_extent == 1) {
num_fusible = -1;
}
}
if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) {
if (max_rw_loops == n_loops && max_fusible == n_loops) {
// All loops can be fused, parallelized and vectorized
parsed->num_parallel_loops = n_loops;
parsed->num_vectorize_loops = n_loops;
} else {
// Prefer num_vectorize to num_parallel
parsed->num_parallel_loops =
std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops);
}
}
}