bool HasReshapePattern()

in src/relax/analysis/tir_op_pattern_kind.cc [353:538]


bool HasReshapePattern(const PrimFunc& func) {
  class ReshapeDetector : public StmtVisitor {
   public:
    static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, Stmt stmt) {
      ReshapeDetector detector(src_buffer, dst_buffer);
      detector(stmt);
      return detector.is_reshape_;
    }

   private:
    explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& dst_buffer)
        : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {}

    void VisitStmt_(const ForNode* loop) final {
      ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
      // To detect the reshape pattern, we require each For to have
      // either another For or a BlockRealize as body.
      if (!(loop->body->IsInstance<ForNode>() || loop->body->IsInstance<BlockRealizeNode>())) {
        return;
      }
      this->VisitStmt(loop->body);
    }

    void VisitStmt_(const BlockRealizeNode* block_realize) final {
      // Constructing the mapping from block iterators to iterator
      // binding values. The mapping will be used in the substitution of
      // the flattened buffer access index.
      const Block& block = block_realize->block;
      const Array<IterVar>& block_iter = block->iter_vars;
      const Array<PrimExpr>& iter_values = block_realize->iter_values;
      ICHECK_EQ(block_iter.size(), iter_values.size());
      int n_iter = block_iter.size();
      for (int i = 0; i < n_iter; ++i) {
        // To detect the reshape pattern, we require each block iter to be data-parallel.
        if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) {
          return;
        }
      }

      // Recurse into the block.
      this->VisitStmt(block);
    }

    void VisitStmt_(const BlockNode* block) final {
      // Step 0. If the block body is a ForNode, recurse into it.
      if (block->body->IsInstance<ForNode>()) {
        this->VisitStmt(block->body);
        return;
      }

      Map<tir::Var, Range> var_range;
      for (const IterVar& v : block->iter_vars) {
        ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
        var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent));
      }

      // Step 1. Get the load/store pattern of the block body.
      // To detect the reshape pattern, we require the block body to be a
      // BufferStore, which has a BufferLoad as value.
      const auto* buffer_store = block->body.as<BufferStoreNode>();
      if (buffer_store == nullptr) {
        return;
      }
      const auto* buffer_load = buffer_store->value.as<BufferLoadNode>();
      if (buffer_load == nullptr) {
        return;
      }
      // Further, we require the buffer being stored and being loaded to
      // match the parameter of the PrimFunc, namely `dst_buffer_` and `src_buffer_`.
      if (!(buffer_store->buffer.same_as(dst_buffer_) &&
            buffer_load->buffer.same_as(src_buffer_))) {
        return;
      }

      // Apply check 1: use iter_map_simplify
      // This check requires at least one of the src/dst side is a trivial buffer
      // access (e.g., buf[ax0, ax1, ax2]).

      auto f_calc_flattened_idx = [&](const Buffer& buffer, const Array<PrimExpr>& indices) {
        ICHECK_EQ(indices.size(), buffer->shape.size());
        int ndim = indices.size();
        PrimExpr idx = 0;
        for (int i = 0; i < ndim; ++i) {
          idx = idx * buffer->shape[i] + indices[i];
        }
        idx = ana_.Simplify(idx);
        return arith::IterMapSimplify(
            /*indices=*/{idx},
            /*input_iters=*/var_range,
            /*input_pred=*/Bool(true),
            /*check_level=*/arith::IterMapLevel::Surjective,
            /*analyzer=*/&ana_,
            /*simplify_trivial_iterators=*/true)[0];
      };

      auto f_is_trivial_indices = [block, this](const Buffer& buffer,
                                                const Array<PrimExpr>& indices) {
        if (indices.size() != block->iter_vars.size()) {
          return false;
        }
        for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
          if (!(indices[i].same_as(block->iter_vars[i]->var) &&
                this->ana_.CanProveEqual(block->iter_vars[i]->dom->min,
                                         IntImm(DataType::Int(64), /*value=*/0)) &&
                this->ana_.CanProveEqual(buffer->shape[i], block->iter_vars[i]->dom->extent))) {
            return false;
          }
        }
        return true;
      };

      Array<PrimExpr> nontrivial_indices{nullptr};
      Buffer nontrivial_buffer{nullptr};
      if (f_is_trivial_indices(dst_buffer_, buffer_store->indices)) {
        nontrivial_indices = buffer_load->indices;
        nontrivial_buffer = src_buffer_;
      } else if (f_is_trivial_indices(src_buffer_, buffer_load->indices)) {
        nontrivial_indices = buffer_store->indices;
        nontrivial_buffer = dst_buffer_;
      }

      if (nontrivial_indices.defined()) {
        DataType dtype =
            !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64);
        tir::Var fused_var("fused", dtype);
        Map<tir::Var, PrimExpr> inverse_indices_map;
        PrimExpr stride = IntImm(dtype, /*value=*/1);
        for (int i = static_cast<int>(block->iter_vars.size()) - 1; i >= 0; --i) {
          inverse_indices_map.Set(
              block->iter_vars[i]->var,
              floormod(floordiv(fused_var, stride), block->iter_vars[i]->dom->extent));
          stride *= block->iter_vars[i]->dom->extent;
        }
        PrimExpr flattened_idx = f_calc_flattened_idx(nontrivial_buffer, nontrivial_indices);
        flattened_idx = Substitute(std::move(flattened_idx), inverse_indices_map);

        Array<PrimExpr> simplify_res = arith::IterMapSimplify(
            /*indices=*/{flattened_idx},
            /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}},
            /*input_pred=*/Bool(true),
            /*check_level=*/arith::IterMapLevel::Surjective,
            /*analyzer=*/&this->ana_,
            /*simplify_trivial_iterators=*/true);
        ICHECK_EQ(simplify_res.size(), 1);

        if (simplify_res[0].same_as(fused_var)) {
          this->is_reshape_ = true;
          return;
        }
      }

      // Apply check 2 as followup when check 1 is not satisfied.
      // Calculate the flattened access index according to the load/store pattern.
      PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices);
      PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices);
      // Check if we can prove the equality of flattened indices.
      if (ana_.CanProveEqual(src_idx, dst_idx)) {
        this->is_reshape_ = true;
        return;
      }
    }

    bool is_reshape_;
    const Buffer& src_buffer_;
    const Buffer& dst_buffer_;
    arith::Analyzer ana_;
  };

  Array<Buffer> buffer_args;
  for (const auto& param : func->params) {
    if (auto buffer = func->buffer_map.Get(param)) {
      buffer_args.push_back(buffer.value());
    }
  }

  if (buffer_args.size() < 2) {
    return false;
  }
  Buffer src_buffer = buffer_args.front();
  Buffer dst_buffer = buffer_args.back();

  // To detect the reshape pattern, we require each For to have
  // either another For or a BlockRealize as body.
  ICHECK(func->body->IsInstance<BlockRealizeNode>());
  return ReshapeDetector::Detect(src_buffer, dst_buffer, func->body);
}