in src/tir/analysis/identify_memcpy.cc [42:273]
std::variant<MemCpyDetails, std::string> IdentifyMemCpyImpl(const For& loop,
arith::Analyzer* analyzer) {
Map<Var, arith::IntSet> loop_intervals;
Map<Var, Range> loop_ranges;
PrimExpr total_loop_iterations = 1;
// Walk through the loop nest, stopping at the first loop whose body
// is not a loop.
Stmt stmt = loop;
while (auto* for_node = stmt.as<ForNode>()) {
loop_ranges.Set(for_node->loop_var, Range::FromMinExtent(for_node->min, for_node->extent));
loop_intervals.Set(for_node->loop_var,
arith::IntSet::FromMinExtent(for_node->min, for_node->extent));
total_loop_iterations = total_loop_iterations * for_node->extent;
stmt = for_node->body;
}
BufferStore store;
if (auto opt = stmt.as<BufferStore>()) {
store = opt.value();
} else {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Expected innermost loop to have BufferStore body, but instead found " << stmt)
.str();
}
BufferLoad load;
if (auto opt = store->value.as<BufferLoad>()) {
load = opt.value();
} else {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Expected BufferStore's value to be BufferLoad, but instead found "
<< store->value)
.str();
}
// Now, we have a BufferStore whose value is a BufferLoad. Because
// non-flat physical indices are target-dependent, only handle cases
// where the buffer will be flattened to a 1-d physical buffer.
Array<PrimExpr> flattened_dst = store->buffer.OffsetOf(store->indices);
Array<PrimExpr> flattened_src = load->buffer.OffsetOf(load->indices);
if (flattened_dst.size() != 1 || flattened_src.size() != 1) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Expected flattened dimension of src/dest to be 1, but found"
<< flattened_src.size() << "-d src and " << flattened_dst.size() << "-d dst")
.str();
}
PrimExpr src_index = flattened_src[0];
PrimExpr dst_index = flattened_dst[0];
// First check, do the input/output form affine subsets of their
// respective buffers?
//
// For example, should exclude the following, indices are not affine
//
// for i in T.serial(16):
// B[i] = A[T.abs(i-8)]
auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, Bool(true),
arith::IterMapLevel::Bijective, analyzer);
if (src_iter_map->errors.size()) {
return static_cast<const std::stringstream&>(std::stringstream()
<< "arith::DetectIterMap(src) returned "
<< src_iter_map->errors.size() << " errors: ["
<< src_iter_map->errors << "]"
<< " for src_index = " << src_index)
.str();
}
auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, Bool(true),
arith::IterMapLevel::Bijective, analyzer);
if (dst_iter_map->errors.size()) {
return static_cast<const std::stringstream&>(std::stringstream()
<< "arith::DetectIterMap(dst) returned "
<< dst_iter_map->errors.size() << " errors: ["
<< dst_iter_map->errors << "]"
<< " for dst_index = " << dst_index)
.str();
}
// Second check, are those affine subsets contiguous? If so, then
// the index expressions will visit every location between the min
// and the max. This checks surjectivity over a linear region,
// which may not be the same as DetectIterMap's check of
// surjectivity over the affine subset.
//
// For example, should exclude the following, doesn't touch all
// output locations within the output region touched.
//
// for i in T.serial(16):
// B[2*i] = A[i]
//
// Similarly, should exclude the following, doesn't touch all
// input locations within the input region touched.
//
// for i in T.serial(16):
// B[i] = A[2*i]
total_loop_iterations = analyzer->Simplify(total_loop_iterations);
auto src_interval = analyzer->int_set(src_index, loop_intervals);
auto dst_interval = analyzer->int_set(dst_index, loop_intervals);
if (!src_interval.HasLowerBound() || !src_interval.HasUpperBound()) {
return static_cast<const std::stringstream&>(std::stringstream()
<< "Expected known bounds for src, but found "
<< src_interval << " for expression " << src_index)
.str();
}
if (!dst_interval.HasLowerBound() || !dst_interval.HasUpperBound()) {
return static_cast<const std::stringstream&>(std::stringstream()
<< "Expected known bounds for dst, but found "
<< dst_interval << " for expression " << dst_index)
.str();
}
{
PrimExpr must_prove = total_loop_iterations == src_interval.max() - src_interval.min() + 1;
PrimExpr simplified = analyzer->Simplify(must_prove);
if (!analyzer->CanProve(simplified)) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Mismatch between loop iterations (" << total_loop_iterations
<< ") and number of src indices touched (" << src_interval
<< ". Equality to prove simplified to " << simplified)
.str();
}
}
{
PrimExpr must_prove = total_loop_iterations == dst_interval.max() - dst_interval.min() + 1;
PrimExpr simplified = analyzer->Simplify(must_prove);
if (!analyzer->CanProve(simplified)) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Mismatch between loop iterations (" << total_loop_iterations
<< ") and number of dst indices touched (" << dst_interval
<< ". Equality to prove simplified to " << simplified)
.str();
}
}
// Third check, is there a transformation applied between the input
// and output iterators?
//
// For example, the following would pass all checks so far, but
// converts between row-major and column-major layouts, and could
// not be specified as a memcpy.
//
// for i,j in T.grid(4,4):
// B[i,j] = A[j,i]
auto src_iter_sum = src_iter_map->indices[0];
auto dst_iter_sum = dst_iter_map->indices[0];
if (src_iter_sum->args.size() != dst_iter_sum->args.size()) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "IterMap for src/dst unpacked to different number of IterSplitExpr: "
<< src_iter_sum->args.size() << " for src, " << dst_iter_sum->args.size()
<< " for dst. "
<< "IterMaps were detected as src = " << src_iter_sum << ", dst = " << dst_iter_sum)
.str();
}
std::vector<arith::IterSplitExpr> src_iter_terms(src_iter_sum->args.begin(),
src_iter_sum->args.end());
std::vector<arith::IterSplitExpr> dst_iter_terms(dst_iter_sum->args.begin(),
dst_iter_sum->args.end());
auto make_comparison_tuple = [](const arith::IterSplitExpr& expr) {
auto as_int_or_zero = [](auto& val) -> int64_t {
if (auto* as_int = val.template as<IntImmNode>()) {
return as_int->value;
} else {
return 0;
}
};
return std::tuple{
static_cast<bool>(expr->scale.as<IntImmNode>()), as_int_or_zero(expr->scale),
static_cast<bool>(expr->extent.as<IntImmNode>()), as_int_or_zero(expr->lower_factor),
static_cast<bool>(expr->lower_factor.as<IntImmNode>()), as_int_or_zero(expr->lower_factor),
};
};
auto sorting_function = [&make_comparison_tuple](const arith::IterSplitExpr& lhs,
const arith::IterSplitExpr& rhs) -> bool {
return make_comparison_tuple(lhs) < make_comparison_tuple(rhs);
};
std::sort(src_iter_terms.begin(), src_iter_terms.end(), sorting_function);
std::sort(dst_iter_terms.begin(), dst_iter_terms.end(), sorting_function);
for (size_t i = 0; i < src_iter_terms.size(); i++) {
const arith::IterSplitExpr& src_term = src_iter_terms[i];
const arith::IterSplitExpr& dst_term = dst_iter_terms[i];
if (!analyzer->CanProve(
arith::NormalizeIterMapToExpr(src_term->source->source == dst_term->source->source))) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Term " << i << " had different source, src_term->source = " << src_term->source
<< ", dst_term->source = " << dst_term->source)
.str();
}
if (!analyzer->CanProve(src_term->lower_factor == dst_term->lower_factor)) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Term " << i << " had different lower_factor, src_term->lower_factor = "
<< src_term->lower_factor
<< ", dst_term->lower_factor = " << dst_term->lower_factor)
.str();
}
if (!analyzer->CanProve(src_term->extent == dst_term->extent)) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Term " << i << " had different extent, src_term->extent = " << src_term->extent
<< ", dst_term->extent = " << dst_term->extent)
.str();
}
if (!analyzer->CanProve(src_term->scale == dst_term->scale)) {
return static_cast<const std::stringstream&>(
std::stringstream()
<< "Term " << i << " had different scale, src_term->scale = " << src_term->scale
<< ", dst_term->scale = " << dst_term->scale)
.str();
}
}
BufferRegion src_region(load->buffer, arith::DomainTouched(loop, load->buffer, true, true));
BufferRegion dst_region(store->buffer, arith::DomainTouched(loop, store->buffer, true, true));
return MemCpyDetails{src_region, dst_region};
}