void AdjustParallelVectorize()

in src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc [174:336]


void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
                             const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) {
  StmtSRef block_sref = sch->GetSRef(block_rv);
  if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) {
    return;
  }
  const int n_loops = loop_rvs.size();
  if (n_loops == 0) {
    parsed->max_parallel_extent = -1;
    parsed->max_vectorize_extent = -1;
    return;
  }
  // Extract loop_srefs, and calculate the iterator types
  Array<StmtSRef> loop_srefs;
  std::vector<int> loop_types;
  {
    loop_srefs.reserve(n_loops);
    loop_types.reserve(n_loops);
    for (const LoopRV& loop_rv : loop_rvs) {
      loop_srefs.push_back(sch->GetSRef(loop_rv));
      loop_types.push_back(GetLoopIterType(loop_srefs.back()));
    }
  }
  // check the maximal number of axes that are vectorizable (contiguous memory access)
  BlockRealize realize = GetBlockRealize(sch->state(), block_sref);
  Array<BufferRegion> buffer_access(realize->block->reads);
  buffer_access.insert(buffer_access.end(), realize->block->writes.begin(),
                       realize->block->writes.end());
  std::unordered_map<const VarNode*, PrimExpr> binding_map;
  for (size_t i = 0; i < realize->iter_values.size(); i++) {
    binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i];
  }
  int max_fusible = INT32_MAX;
  // for each block read/write, get the strides of the loop vars and find the fusible
  // (vectorizable) axes
  for (const BufferRegion& access : buffer_access) {
    int fusible = 0;
    std::vector<int64_t> strides;
    // get strides for each loop var
    for (const StmtSRef& loop_sref : loop_srefs) {
      int64_t stride = 0, buffer_stride = 1;
      const auto* var = loop_sref->StmtAs<ForNode>();
      arith::Analyzer analyzer;
      for (int i = access->region.size() - 1; i >= 0; i--) {
        PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map));
        int64_t coef = StrideExtractor::Extract(idx, var->loop_var);
        if (coef != 0) {
          stride = coef * buffer_stride;
          break;
        }
        buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
      }
      strides.push_back(stride);
    }
    int prev_used_iter = -1;
    // check the number of fusible loops
    for (int i = strides.size() - 1; i >= 0; i--) {
      if (strides[i] == 0) {
        // not used in the buffer access, safe to fuse
        fusible++;
        continue;
      } else if (prev_used_iter == -1) {
        // the stride of last axis is not 1 means the memory access is not contiguous
        if (strides[i] != 1 && fusible != 0) {
          break;
        }
        fusible++;
        prev_used_iter = i;
      } else {
        // contiguous memory access
        const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
        int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value;
        if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
          fusible++;
          prev_used_iter = i;
        } else {
          break;
        }
      }
    }
    max_fusible = std::min(max_fusible, fusible);
  }

  // Calculate how many loops are rewritable, i.e. valid for vectorization and parallelization.
  int max_rw_loops = CalculateNumRewritableLoops(loop_srefs, loop_types);

  // Calculate the parallelize extent
  if (parsed->max_parallel_extent != -1) {
    int max_extent = parsed->max_parallel_extent;
    int& num_fusible = parsed->num_parallel_loops = 0;
    int64_t prod_extent = 1;
    for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) {
      const StmtSRef& loop_sref = loop_srefs[i];
      const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
      if (HasAnnOrBinding(loop)) {
        break;
      }
      // Check if the loop extent is valid
      const int64_t* extent = GetLoopIntExtent(loop_sref);
      if (extent == nullptr) {
        break;
      }
      // Then we can fuse it in
      ++num_fusible;
      // Check if we need to break
      prod_extent *= *extent;
      if (prod_extent > max_extent || !IsSingleStmt(loop->body)) {
        break;
      }
    }
    if (prod_extent == 1) {
      num_fusible = -1;
    }
  }
  // Calculate the vectorize extent
  if (parsed->max_vectorize_extent != -1) {
    int max_extent = parsed->max_vectorize_extent;
    int& num_fusible = parsed->num_vectorize_loops = 0;
    int64_t prod_extent = 1;
    for (int i = n_loops - 1;
         i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) {
      const StmtSRef& loop_sref = loop_srefs[i];
      const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
      if (HasAnnOrBinding(loop)) {
        break;
      }
      // Cannot vectorize reduce axis
      if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) {
        break;
      }
      // Cannot fuse with a loop with multiple children
      if (!IsSingleStmt(loop->body)) {
        break;
      }
      // Check if the loop extent is valid
      const int64_t* extent = GetLoopIntExtent(loop_sref);
      if (extent == nullptr) {
        break;
      }
      // Check if the extent is still in a good range
      prod_extent *= *extent;
      if (prod_extent > max_extent) {
        break;
      }
      ++num_fusible;
    }
    if (prod_extent == 1) {
      num_fusible = -1;
    }
  }

  if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) {
    if (max_rw_loops == n_loops && max_fusible == n_loops) {
      // All loops can be fused, parallelized and vectorized
      parsed->num_parallel_loops = n_loops;
      parsed->num_vectorize_loops = n_loops;
    } else {
      // Prefer num_vectorize to num_parallel
      parsed->num_parallel_loops =
          std::min(parsed->num_parallel_loops, n_loops - parsed->num_vectorize_loops);
    }
  }
}