in src/relax/analysis/tir_op_pattern_kind.cc [353:538]
bool HasReshapePattern(const PrimFunc& func) {
class ReshapeDetector : public StmtVisitor {
public:
static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, Stmt stmt) {
ReshapeDetector detector(src_buffer, dst_buffer);
detector(stmt);
return detector.is_reshape_;
}
private:
explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& dst_buffer)
: is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {}
void VisitStmt_(const ForNode* loop) final {
ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
// To detect the reshape pattern, we require each For to have
// either another For or a BlockRealize as body.
if (!(loop->body->IsInstance<ForNode>() || loop->body->IsInstance<BlockRealizeNode>())) {
return;
}
this->VisitStmt(loop->body);
}
void VisitStmt_(const BlockRealizeNode* block_realize) final {
// Constructing the mapping from block iterators to iterator
// binding values. The mapping will be used in the substitution of
// the flattened buffer access index.
const Block& block = block_realize->block;
const Array<IterVar>& block_iter = block->iter_vars;
const Array<PrimExpr>& iter_values = block_realize->iter_values;
ICHECK_EQ(block_iter.size(), iter_values.size());
int n_iter = block_iter.size();
for (int i = 0; i < n_iter; ++i) {
// To detect the reshape pattern, we require each block iter to be data-parallel.
if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) {
return;
}
}
// Recurse into the block.
this->VisitStmt(block);
}
void VisitStmt_(const BlockNode* block) final {
// Step 0. If the block body is a ForNode, recurse into it.
if (block->body->IsInstance<ForNode>()) {
this->VisitStmt(block->body);
return;
}
Map<tir::Var, Range> var_range;
for (const IterVar& v : block->iter_vars) {
ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
}
// Step 1. Get the load/store pattern of the block body.
// To detect the reshape pattern, we require the block body to be a
// BufferStore, which has a BufferLoad as value.
const auto* buffer_store = block->body.as<BufferStoreNode>();
if (buffer_store == nullptr) {
return;
}
const auto* buffer_load = buffer_store->value.as<BufferLoadNode>();
if (buffer_load == nullptr) {
return;
}
// Further, we require the buffer being stored and being loaded to
// match the parameter of the PrimFunc, namely `dst_buffer_` and `src_buffer_`.
if (!(buffer_store->buffer.same_as(dst_buffer_) &&
buffer_load->buffer.same_as(src_buffer_))) {
return;
}
// Apply check 1: use iter_map_simplify
// This check requires at least one of the src/dst side is a trivial buffer
// access (e.g., buf[ax0, ax1, ax2]).
auto f_calc_flattened_idx = [&](const Buffer& buffer, const Array<PrimExpr>& indices) {
ICHECK_EQ(indices.size(), buffer->shape.size());
int ndim = indices.size();
PrimExpr idx = 0;
for (int i = 0; i < ndim; ++i) {
idx = idx * buffer->shape[i] + indices[i];
}
idx = ana_.Simplify(idx);
return arith::IterMapSimplify(
/*indices=*/{idx},
/*input_iters=*/var_range,
/*input_pred=*/Bool(true),
/*check_level=*/arith::IterMapLevel::Surjective,
/*analyzer=*/&ana_,
/*simplify_trivial_iterators=*/true)[0];
};
auto f_is_trivial_indices = [block, this](const Buffer& buffer,
const Array<PrimExpr>& indices) {
if (indices.size() != block->iter_vars.size()) {
return false;
}
for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
if (!(indices[i].same_as(block->iter_vars[i]->var) &&
this->ana_.CanProveEqual(block->iter_vars[i]->dom->min,
IntImm(DataType::Int(64), /*value=*/0)) &&
this->ana_.CanProveEqual(buffer->shape[i], block->iter_vars[i]->dom->extent))) {
return false;
}
}
return true;
};
Array<PrimExpr> nontrivial_indices{nullptr};
Buffer nontrivial_buffer{nullptr};
if (f_is_trivial_indices(dst_buffer_, buffer_store->indices)) {
nontrivial_indices = buffer_load->indices;
nontrivial_buffer = src_buffer_;
} else if (f_is_trivial_indices(src_buffer_, buffer_load->indices)) {
nontrivial_indices = buffer_store->indices;
nontrivial_buffer = dst_buffer_;
}
if (nontrivial_indices.defined()) {
DataType dtype =
!block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64);
tir::Var fused_var("fused", dtype);
Map<tir::Var, PrimExpr> inverse_indices_map;
PrimExpr stride = IntImm(dtype, /*value=*/1);
for (int i = static_cast<int>(block->iter_vars.size()) - 1; i >= 0; --i) {
inverse_indices_map.Set(
block->iter_vars[i]->var,
floormod(floordiv(fused_var, stride), block->iter_vars[i]->dom->extent));
stride *= block->iter_vars[i]->dom->extent;
}
PrimExpr flattened_idx = f_calc_flattened_idx(nontrivial_buffer, nontrivial_indices);
flattened_idx = Substitute(std::move(flattened_idx), inverse_indices_map);
Array<PrimExpr> simplify_res = arith::IterMapSimplify(
/*indices=*/{flattened_idx},
/*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}},
/*input_pred=*/Bool(true),
/*check_level=*/arith::IterMapLevel::Surjective,
/*analyzer=*/&this->ana_,
/*simplify_trivial_iterators=*/true);
ICHECK_EQ(simplify_res.size(), 1);
if (simplify_res[0].same_as(fused_var)) {
this->is_reshape_ = true;
return;
}
}
// Apply check 2 as followup when check 1 is not satisfied.
// Calculate the flattened access index according to the load/store pattern.
PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices);
PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices);
// Check if we can prove the equality of flattened indices.
if (ana_.CanProveEqual(src_idx, dst_idx)) {
this->is_reshape_ = true;
return;
}
}
bool is_reshape_;
const Buffer& src_buffer_;
const Buffer& dst_buffer_;
arith::Analyzer ana_;
};
Array<Buffer> buffer_args;
for (const auto& param : func->params) {
if (auto buffer = func->buffer_map.Get(param)) {
buffer_args.push_back(buffer.value());
}
}
if (buffer_args.size() < 2) {
return false;
}
Buffer src_buffer = buffer_args.front();
Buffer dst_buffer = buffer_args.back();
// To detect the reshape pattern, we require each For to have
// either another For or a BlockRealize as body.
ICHECK(func->body->IsInstance<BlockRealizeNode>());
return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body);
}