src/tir/ir/stmt.cc (575 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/stmt.cc
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt.h>
#include "buffer_common.h"
#include "utils.h"
namespace tvm {
namespace tir {
// LetStmt
LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
ICHECK(value.defined());
ICHECK(body.defined());
auto vdtype = value.dtype();
// It is still valid to bind a pointer type
// var to a value that is of type handle.
if (var->type_annotation.as<PointerTypeNode>()) {
ICHECK(vdtype.is_handle());
} else {
ICHECK_EQ(value.dtype(), var.dtype());
}
ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
node->var = std::move(var);
node->value = std::move(value);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.LetStmt")
.set_body_typed([](Var var, PrimExpr value, Stmt body, Span span) {
return LetStmt(var, value, body, span);
});
TVM_REGISTER_NODE_TYPE(LetStmtNode);
// AttrStmt
AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) {
auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
n->value = std::move(value);
n->body = std::move(body);
n->span = std::move(span);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed([](Any node, String attr_key, PrimExpr value, Stmt body, Span span) {
// when node is a POD data type like int or bool, first convert to primexpr.
if (node.type_index() < ffi::TypeIndex::kTVMFFIStaticObjectBegin) {
return AttrStmt(node.as<PrimExpr>().value(), attr_key, value, body, span);
}
return AttrStmt(node.as<ObjectRef>().value(), attr_key, value, body, span);
});
TVM_REGISTER_NODE_TYPE(AttrStmtNode);
// AssertStmt
AssertStmt::AssertStmt(PrimExpr condition, PrimExpr message, Stmt body, Span span) {
ICHECK(condition.defined());
CHECK(condition.dtype().is_bool())
<< "AssertStmt should have boolean condition, "
<< "but received " << condition << " with dtype " << condition.dtype();
ICHECK(message.dtype() == DataType::Int(32) || message.as<StringImmNode>())
<< "TypeError: AssertStmt message must be an int or string:" << message << "\n";
ObjectPtr<AssertStmtNode> node = make_object<AssertStmtNode>();
node->condition = std::move(condition);
node->message = std::move(message);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(AssertStmtNode);
TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed([](PrimExpr condition, StringImm message, Stmt body, Span span) {
return AssertStmt(condition, message, body, span);
});
// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding, Map<String, Any> annotations, Span span) {
ICHECK(loop_var.defined());
ICHECK(min.defined());
ICHECK(extent.defined());
ICHECK(body.defined());
auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) {
auto dtype = expr.dtype();
CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint()))
<< "TIR For nodes require a scalar integer as the " << field_name << ", but received "
<< expr << " with dtype " << dtype;
};
require_scalar_int_dtype(loop_var, "loop_var");
require_scalar_int_dtype(min, "min");
require_scalar_int_dtype(extent, "extent");
// When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them
// without raising errors.
auto try_promote_imm_dtype = [&](const PrimExpr& e) {
ICHECK(e.dtype().bits() <= loop_var.dtype().bits())
<< " Loop variable's dtype (" << loop_var.dtype()
<< ") is narrower than that of `min` or `extent` (" << e.dtype() << ")";
const IntImmNode* a = e.as<IntImmNode>();
if (a && e.dtype().bits() < loop_var.dtype().bits()) {
return make_const(loop_var.dtype(), a->value);
} else {
return e;
}
};
min = try_promote_imm_dtype(min);
extent = try_promote_imm_dtype(extent);
ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype();
ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype();
ObjectPtr<ForNode> node = make_object<ForNode>();
node->loop_var = std::move(loop_var);
node->min = std::move(min);
node->extent = std::move(extent);
node->kind = kind;
node->body = std::move(body);
node->thread_binding = std::move(thread_binding);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
[](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body,
Optional<IterVar> thread_binding, Optional<Map<String, Any>> annotations, Span span) {
return For(loop_var, min, extent, static_cast<ForKind>(kind), body, thread_binding,
annotations.value_or(Map<String, Any>()), span);
});
TVM_REGISTER_NODE_TYPE(ForNode);
std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*)
switch (type) {
case ForKind::kSerial:
out << "for";
break;
case ForKind::kParallel:
out << "parallel";
break;
case ForKind::kUnrolled:
out << "unrolled";
break;
case ForKind::kVectorized:
out << "vectorized";
break;
case ForKind::kThreadBinding:
out << "launch_thread";
break;
}
return out;
}
// While
While::While(PrimExpr condition, Stmt body, Span span) {
ICHECK(condition.defined());
ICHECK(condition.dtype().is_scalar());
ICHECK(condition.as<tir::IntImmNode>() == nullptr) << "The condition should not be trivial.";
ICHECK(body.defined());
ObjectPtr<WhileNode> node = make_object<WhileNode>();
node->condition = std::move(condition);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt body, Span span) {
return While(condition, body, span);
});
TVM_REGISTER_NODE_TYPE(WhileNode);
// ProducerStore
ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array<PrimExpr> indices,
Span span) {
ObjectPtr<ProducerStoreNode> node = make_object<ProducerStoreNode>();
node->producer = std::move(producer);
node->value = std::move(value);
node->indices = std::move(indices);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerStore")
.set_body_typed([](DataProducer producer, PrimExpr value, Array<PrimExpr> indices, Span span) {
return ProducerStore(producer, value, indices, span);
});
TVM_REGISTER_NODE_TYPE(ProducerStoreNode);
// Allocate
Allocate::Allocate(Var buffer_var, DataType dtype, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Map<String, Any> annotations, Span span) {
CHECK(IsPointerType(buffer_var->type_annotation, dtype) ||
(dtype.is_bool() && IsPointerType(buffer_var->type_annotation, DataType::Int(8))))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
<< buffer_var->type_annotation
<< "). The data type should be an element of the pointer type.";
for (size_t i = 0; i < extents.size(); ++i) {
ICHECK(extents[i].defined());
ICHECK(extents[i].dtype().is_scalar());
}
ICHECK(body.defined());
ICHECK(condition.defined());
ICHECK(condition.dtype().is_bool());
ObjectPtr<AllocateNode> node = make_object<AllocateNode>();
node->buffer_var = std::move(buffer_var);
node->dtype = dtype;
node->extents = std::move(extents);
node->condition = std::move(condition);
node->body = std::move(body);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
int64_t AllocateNode::ConstantAllocationSize(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
result *= int_size->value;
if (result > std::numeric_limits<int64_t>::max()) {
return 0;
}
} else {
return 0;
}
}
return static_cast<int64_t>(result);
}
TVM_REGISTER_GLOBAL("tir.Allocate")
.set_body_typed([](Var buffer_var, DataType type, Array<PrimExpr> extents, PrimExpr condition,
Stmt body, Map<String, Any> annotations, Span span) {
return Allocate(buffer_var, type, extents, condition, body, annotations, span);
});
TVM_REGISTER_NODE_TYPE(AllocateNode);
// Const
// The constructor to create a IRNode with constant data
// depending on the type of ObjectRef, it will either
// create AllocateConstNode with irmod_storage_idx or data
AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
ObjectRef data_or_idx, Stmt body, Map<String, Any> annotations,
Span span) {
ICHECK(IsPointerType(buffer_var->type_annotation, dtype))
<< "The allocated data type (" << dtype
<< ") does not match the type annotation of the buffer " << buffer_var << " ("
<< buffer_var->type_annotation
<< "). The data type should be an element of the pointer type.";
for (size_t i = 0; i < extents.size(); ++i) {
ICHECK(extents[i].defined());
ICHECK(extents[i].dtype().is_scalar());
}
ICHECK(body.defined());
ICHECK(data_or_idx.defined());
ObjectPtr<AllocateConstNode> node = make_object<AllocateConstNode>();
node->buffer_var = std::move(buffer_var);
node->dtype = dtype;
node->extents = std::move(extents);
node->body = std::move(body);
node->annotations = annotations;
node->span = std::move(span);
if (data_or_idx->IsInstance<runtime::NDArray::ContainerType>()) {
node->data = Optional<tvm::runtime::NDArray>(Downcast<runtime::NDArray>(data_or_idx));
node->irmod_storage_idx = Optional<Integer>();
} else if (data_or_idx->IsInstance<IntImmNode>()) {
node->data = Optional<tvm::runtime::NDArray>();
node->irmod_storage_idx = Optional<Integer>(Downcast<Integer>(data_or_idx));
} else {
LOG(FATAL) << "Data type not supported: " << data_or_idx->GetTypeKey();
}
data_ = std::move(node);
}
int64_t AllocateConstNode::ConstantAllocationSize(const Array<PrimExpr>& extents) {
int64_t result = 1;
for (size_t i = 0; i < extents.size(); ++i) {
if (const IntImmNode* int_size = extents[i].as<IntImmNode>()) {
result *= int_size->value;
if (result > std::numeric_limits<int64_t>::max()) {
return 0;
}
} else {
return 0;
}
}
return static_cast<int64_t>(result);
}
TVM_REGISTER_GLOBAL("tir.AllocateConst")
.set_body_typed([](Var buffer_var, DataType dtype, Array<PrimExpr> extents,
ObjectRef data_or_idx, Stmt body, Optional<Map<String, Any>> annotations,
Span span) {
return AllocateConst(buffer_var, dtype, extents, data_or_idx, body, annotations.value_or({}),
span);
});
TVM_REGISTER_NODE_TYPE(AllocateConstNode);
// DeclBuffer
DeclBuffer::DeclBuffer(Buffer buffer, Stmt body, Span span) {
ObjectPtr<DeclBufferNode> node = make_object<DeclBufferNode>();
node->buffer = std::move(buffer);
node->body = std::move(body);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt body, Span span) {
return DeclBuffer(buffer, body, span);
});
TVM_REGISTER_NODE_TYPE(DeclBufferNode);
// ProducerRealize
ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition,
Stmt body, String storage_scope, Span span) {
for (size_t i = 0; i < bounds.size(); ++i) {
ICHECK(bounds[i]->min.defined());
ICHECK(bounds[i]->extent.defined());
ICHECK(bounds[i]->min.dtype().is_scalar());
ICHECK(bounds[i]->extent.dtype().is_scalar());
}
ICHECK(body.defined());
ICHECK(condition.defined());
ICHECK(condition.dtype().is_bool());
ObjectPtr<ProducerRealizeNode> node = make_object<ProducerRealizeNode>();
node->producer = std::move(producer);
node->bounds = std::move(bounds);
node->condition = std::move(condition);
node->body = std::move(body);
node->span = std::move(span);
node->storage_scope = std::move(storage_scope);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.ProducerRealize")
.set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
String storage_scope, Span span) {
return ProducerRealize(producer, bounds, condition, body, storage_scope, span);
});
TVM_REGISTER_NODE_TYPE(ProducerRealizeNode);
// Prefetch
Prefetch::Prefetch(Buffer buffer, Array<Range> bounds, Span span) {
data_ = make_object<PrefetchNode>(buffer, bounds, span);
}
TVM_REGISTER_GLOBAL("tir.Prefetch")
.set_body_typed([](Buffer buffer, Array<Range> bounds, Span span) {
return Prefetch(buffer, bounds, span);
});
TVM_REGISTER_NODE_TYPE(PrefetchNode);
// SeqStmt
SeqStmt::SeqStmt(Array<Stmt> seq, Span span) {
bool requires_flattening = std::any_of(
seq.begin(), seq.end(), [](const Stmt& stmt) { return stmt->IsInstance<SeqStmtNode>(); });
if (requires_flattening) {
auto flattened = SeqStmt::Flatten(seq);
if (auto* ptr = flattened.as<SeqStmtNode>()) {
seq = ptr->seq;
} else {
seq = {flattened};
}
}
ICHECK_NE(seq.size(), 0) << "An empty SeqStmt is prohibited. "
<< "To write a no-op, use Evaluate(0), "
<< "or the result of SeqStmt::Flatten()";
ICHECK_NE(seq.size(), 1) << "A SeqStmt of length 1 is prohibited. "
<< "Use the node " << seq[0] << "directly, "
<< "or for dynamic usage, normalize using SeqStmt::Flatten()";
auto node = make_object<SeqStmtNode>();
node->seq = std::move(seq);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.SeqStmt").set_body_typed([](Array<Stmt> seq, Span span) {
return SeqStmt(std::move(seq), span);
});
TVM_REGISTER_NODE_TYPE(SeqStmtNode);
// IfThenElse
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) {
ICHECK(condition.defined());
ICHECK(then_case.defined());
// else_case may be null.
ObjectPtr<IfThenElseNode> node = make_object<IfThenElseNode>();
node->condition = std::move(condition);
node->then_case = std::move(then_case);
node->else_case = std::move(else_case);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_NODE_TYPE(IfThenElseNode);
TVM_REGISTER_GLOBAL("tir.IfThenElse")
.set_body_typed([](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) {
return IfThenElse(condition, then_case, else_case, span);
});
// Evaluate
Evaluate::Evaluate(PrimExpr value, Span span) {
ICHECK(value.defined());
ObjectPtr<EvaluateNode> node = make_object<EvaluateNode>();
node->value = std::move(value);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Evaluate").set_body_typed([](PrimExpr value, Span span) {
return Evaluate(value, span);
});
TVM_REGISTER_NODE_TYPE(EvaluateNode);
// BufferStore
BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate, Span span) {
ICHECK_EQ(buffer->shape.size(), indices.size())
<< "Buffer " << buffer->name << " is " << buffer->shape.size()
<< "-dimensional, cannot be indexed with the " << indices.size()
<< "-dimensional indices provided.";
for (int i = 0; i < static_cast<int>(indices.size()) - 1; i++) {
ICHECK(indices[i].dtype().is_scalar())
<< "Only the last index of a buffer access may be a vector type.";
}
bool is_index_scalable = indices.empty() ? false : indices.back().dtype().is_scalable_vector();
bool is_buffer_dtype_scalable = buffer->dtype.is_scalable_vector();
bool is_value_dtype_scalable = value.dtype().is_scalable_vector();
ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
<< "Index dtype and buffer dtype can't both be scalable.";
if (predicate.defined()) {
bool is_predicate_dtype_scalable = predicate.value().dtype().is_scalable_vector();
ICHECK_EQ(is_value_dtype_scalable, is_predicate_dtype_scalable)
<< "Predicate mask dtype and value dtype must both be scalable.";
}
if (is_index_scalable || is_buffer_dtype_scalable) {
ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into scalable buffer";
}
int index_lanes = indices.empty() ? 1 : indices.back().dtype().get_lanes_or_vscale_factor();
int buffer_lanes = buffer->dtype.get_lanes_or_vscale_factor();
int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes)
<< "Cannot store value with " << value_dtype_lanes << ", expected value with "
<< index_lanes * buffer_lanes << " (" << index_lanes << " index lanes * " << buffer_lanes
<< " buffer element lanes)";
if (predicate.defined()) {
DataType predicate_dtype = predicate.value().dtype();
int predicate_dtype_lanes = predicate_dtype.get_lanes_or_vscale_factor();
ICHECK_EQ(value_dtype_lanes, predicate_dtype_lanes)
<< "Got a predicate mask with " << predicate_dtype_lanes
<< " lanes, but trying to store a value with " << value_dtype_lanes
<< " lanes. The number of lanes must match.";
DataType predicate_element_dtype = predicate_dtype.element_of();
ICHECK(predicate_element_dtype.is_bool())
<< "Predicate mask elements must be boolean values, but got " << predicate_element_dtype
<< ".";
}
runtime::DataType buffer_dtype;
if (is_index_scalable || is_buffer_dtype_scalable) {
buffer_dtype = buffer->dtype.with_scalable_vscale_factor(buffer_lanes * index_lanes);
} else {
buffer_dtype = buffer->dtype.with_lanes(buffer_lanes * index_lanes);
}
if (buffer_dtype != value.dtype()) {
LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " //
<< "buffer's dtype is `" << buffer->dtype //
<< "`, the lanes of indexing are: `" << index_lanes //
<< "`, the scalability is: `" << buffer_dtype.is_scalable_vector()
<< "`, but RHS's dtype is `" << value.dtype() << "`";
}
ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
node->buffer = std::move(buffer);
node->value = std::move(value);
node->indices = std::move(indices);
node->predicate = std::move(predicate);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.BufferStore")
.set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices,
Optional<PrimExpr> predicate,
Span span) { return BufferStore(buffer, value, indices, predicate, span); });
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
// BufferRealize
BufferRealize::BufferRealize(Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
Span span) {
data_ = make_object<BufferRealizeNode>(buffer, bounds, condition, body, span);
}
TVM_REGISTER_GLOBAL("tir.BufferRealize")
.set_body_typed([](Buffer buffer, Array<Range> bounds, PrimExpr condition, Stmt body,
Span span) { return BufferRealize(buffer, bounds, condition, body, span); });
TVM_REGISTER_NODE_TYPE(BufferRealizeNode);
// BufferRegion
PrimExpr BufferRegionNode::ToPrimExpr() const {
// Auto convert to PrimExpr if it is a single point load
Array<PrimExpr> indices;
indices.reserve(this->region.size());
for (const Range& r : this->region) {
if (tvm::tir::is_one(r->extent)) {
indices.push_back(r->min);
} else if (r->extent.as<IntImmNode>()) {
indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent));
} else {
LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << GetRef<BufferRegion>(this);
}
}
return tir::BufferLoad(this->buffer, indices);
}
BufferRegion::BufferRegion(Buffer buffer, Array<Range> region) {
CHECK_EQ(buffer->shape.size(), region.size())
<< "The dimension between " << buffer << " and region " << region
<< " mismatched, the buffer is " << buffer;
ObjectPtr<BufferRegionNode> node = make_object<BufferRegionNode>();
node->buffer = std::move(buffer);
node->region = std::move(region);
data_ = std::move(node);
}
BufferRegion BufferRegion::FullRegion(Buffer buffer) {
Array<Range> region;
for (PrimExpr extent : buffer->shape) {
region.push_back(Range::FromMinExtent(0, extent));
}
return BufferRegion(buffer, region);
}
BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
Array<Range> region;
for (const PrimExpr& index : indices) {
if (const RampNode* ramp_index = index.as<RampNode>()) {
region.push_back(
Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes));
} else {
region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
}
}
return BufferRegion(buffer, region);
}
TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array<Range> region) {
return BufferRegion(buffer, region);
});
TVM_REGISTER_NODE_TYPE(BufferRegionNode);
// MatchBufferRegion
MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) {
const Buffer& source_buffer = source->buffer;
arith::Analyzer analyzer;
// Check scope and dtype
CHECK_EQ(buffer.scope(), source_buffer.scope())
<< "MatchBuffer " << buffer << " scope mismatch:" << buffer.scope() << " vs. "
<< source_buffer.scope();
CHECK_EQ(buffer->dtype, source_buffer->dtype)
<< "MatchBuffer " << buffer << " data type mismatch:" << buffer->dtype << " vs. "
<< source_buffer->dtype;
// Check data_alignment
CHECK(source_buffer->data_alignment % buffer->data_alignment == 0)
<< "Trying to match buffer to another one with lower alignment requirement "
<< " required_alignment=" << buffer->data_alignment
<< ", provided_alignment=" << source_buffer->data_alignment;
// Check BufferType. AutoBroadcast is not allowed for now.
CHECK(buffer->buffer_type == BufferType::kDefault &&
source_buffer->buffer_type == BufferType::kDefault)
<< "AutoBroadcast is not allowed in MatchBuffer";
// Validate shape
CHECK(source->region.size() >= buffer->shape.size())
<< "Dimension of source Region expected to be larger or equal than target buffer shape, but "
"got "
<< source->region.size() << " vs. " << buffer->shape.size();
size_t offset = source->region.size() - buffer->shape.size();
for (size_t i = 0; i < offset; ++i) {
CHECK(analyzer.CanProve(source->region[i]->extent == 1))
<< "The higher dimension should be 1, but got " << source->region[i]->extent << ".";
}
for (size_t i = 0; i < buffer->shape.size(); ++i) {
const Range& source_range = source->region[i + offset];
const PrimExpr& buffer_shape = buffer->shape[i];
if (!buffer_shape->IsInstance<VarNode>()) {
CHECK(analyzer.CanProve(source_range->extent == buffer_shape))
<< "The dimension mismatched between source region and target buffer shape, got "
<< source_range->extent << " vs. " << buffer_shape << ".";
}
}
// Note that we do not check elem_offset and strides in this function
// Construction
ObjectPtr<MatchBufferRegionNode> node = make_object<MatchBufferRegionNode>();
node->buffer = std::move(buffer);
node->source = std::move(source);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) {
return MatchBufferRegion(buffer, source);
});
TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode);
// Block
Block::Block(Array<IterVar> iter_vars, Array<BufferRegion> reads, Array<BufferRegion> writes,
String name_hint, Stmt body, Optional<Stmt> init, Array<Buffer> alloc_buffers,
Array<MatchBufferRegion> match_buffers, Map<String, Any> annotations, Span span) {
ObjectPtr<BlockNode> node = make_object<BlockNode>();
node->iter_vars = std::move(iter_vars);
node->reads = std::move(reads);
node->writes = std::move(writes);
node->name_hint = std::move(name_hint);
node->body = std::move(body);
node->init = std::move(init);
node->alloc_buffers = std::move(alloc_buffers);
node->match_buffers = std::move(match_buffers);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.Block")
.set_body_typed([](Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, String name_hint, Stmt body, Optional<Stmt> init,
Array<Buffer> alloc_buffers, Array<MatchBufferRegion> match_buffers,
Map<String, Any> annotations, Span span) {
return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers,
annotations, span);
});
TVM_REGISTER_NODE_TYPE(BlockNode);
// BlockRealize
BlockRealize::BlockRealize(Array<PrimExpr> values, PrimExpr predicate, Block block, Span span) {
CHECK_EQ(block->iter_vars.size(), values.size())
<< "ValueError: BlockRealize needs to have the same number of iter_vars and binding values";
CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression";
ObjectPtr<BlockRealizeNode> node = make_object<BlockRealizeNode>();
node->iter_values = std::move(values);
node->predicate = std::move(predicate);
node->block = std::move(block);
node->span = std::move(span);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.BlockRealize")
.set_body_typed([](Array<PrimExpr> iter_values, PrimExpr predicate, Block block, Span span) {
return BlockRealize(iter_values, predicate, block, span);
});
TVM_REGISTER_NODE_TYPE(BlockRealizeNode);
PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tir.type_annotation");
return tir::Call(dtype, op, {}, span);
}
TVM_TIR_REGISTER_OP("type_annotation")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
} // namespace tir
} // namespace tvm