src/relax/ir/expr_functor.cc (628 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file src/relax/expr_functor.cc * \brief A wrapper around ExprFunctor which functionally updates the AST. * * ExprMutator uses memoization and self return in order to amortize * the cost of using functional updates. */ #include <tvm/ir/type_functor.h> #include <tvm/relax/analysis.h> #include <tvm/relax/expr_functor.h> #include <tvm/relax/type.h> // functions to be overriden. #define RELAX_VISIT_BINDING_DISPATCH(OP) \ vtable.template set_dispatch<OP>( \ [](const ObjectRef& n, TSelf* self, const VarBindingNode* binding) { \ self->VisitBinding_(binding, static_cast<const OP*>(n.get())); \ }); #define RELAX_VAR_BINDING_DISPATCH_IMPL(Type) \ Type::VisitBindingVTable Type::InitVisitBindingVTable() { \ VisitBindingVTable vtable; \ RELAX_VISIT_BINDING_DISPATCH(ConstantNode); \ RELAX_VISIT_BINDING_DISPATCH(TupleNode); \ RELAX_VISIT_BINDING_DISPATCH(VarNode); \ RELAX_VISIT_BINDING_DISPATCH(DataflowVarNode); \ RELAX_VISIT_BINDING_DISPATCH(ShapeExprNode); \ RELAX_VISIT_BINDING_DISPATCH(ExternFuncNode); \ RELAX_VISIT_BINDING_DISPATCH(GlobalVarNode); \ RELAX_VISIT_BINDING_DISPATCH(FunctionNode); \ RELAX_VISIT_BINDING_DISPATCH(CallNode); \ RELAX_VISIT_BINDING_DISPATCH(SeqExprNode); \ RELAX_VISIT_BINDING_DISPATCH(IfNode); \ RELAX_VISIT_BINDING_DISPATCH(OpNode); \ RELAX_VISIT_BINDING_DISPATCH(TupleGetItemNode); \ RELAX_VISIT_BINDING_DISPATCH(PrimValueNode); \ RELAX_VISIT_BINDING_DISPATCH(StringImmNode); \ RELAX_VISIT_BINDING_DISPATCH(DataTypeImmNode); \ return vtable; \ } \ void Type::VisitBinding_(const VarBindingNode* binding) { \ static VisitBindingVTable vtable = InitVisitBindingVTable(); \ const Expr& value = binding->value; \ ICHECK(value.defined()) << "Found null pointer node while traversing AST."; \ ICHECK(vtable.can_dispatch(value)) \ << "VisitVarBinding do not allow binding value type" << value->GetTypeKey(); \ vtable(value, this, binding); \ } // functions to be overriden. #define RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OP) \ void ExprVisitor::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ this->VisitExpr(binding->value); \ this->VisitVarDef(binding->var); \ } // functions to be overriden. #define RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OP) \ void ExprMutator::VisitBinding_(const VarBindingNode* binding, const OP* value) { \ Expr new_value = this->VisitExpr(binding->value); \ this->ReEmitBinding(binding, new_value); \ } namespace tvm { namespace relax { // ================== // ExprVisitor void ExprVisitor::VisitExprDepStructInfoField(const StructInfo& struct_info) { // recurse into struct info in case they depend on value // under the current scope. default_struct_info_field_visitor_.VisitStructInfo(struct_info); } ExprVisitor::DefaultStructInfoFieldVisitor::DefaultStructInfoFieldVisitor(ExprVisitor* parent) : parent_(parent) {} void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const Expr& expr) { parent_->VisitExpr(expr); } void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfoExprField(const PrimExpr& expr) { parent_->VisitPrimExpr(expr); } void ExprVisitor::DefaultStructInfoFieldVisitor::VisitStructInfo_(const FuncStructInfoNode* op) { // Do not recurse into function struct info // as they won't contain ref to values in current scope. } void ExprVisitor::VisitExpr(const Expr& expr) { ExprFunctor::VisitExpr(expr); } void ExprVisitor::VisitExpr_(const ConstantNode* op) { this->VisitSpan(op->span); // Constant's StructInfo does not depend on Expr. } void ExprVisitor::VisitExpr_(const GlobalVarNode* op) { this->VisitSpan(op->span); // FuncStructInfo is not value-dep } void ExprVisitor::VisitExpr_(const TupleNode* op) { this->VisitSpan(op->span); for (Expr field : op->fields) { this->VisitExpr(field); } if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } // Visit the use-site of a defined Var void ExprVisitor::VisitExpr_(const VarNode* op) { this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } // Visit the use-site of a defined DataflowVar void ExprVisitor::VisitExpr_(const DataflowVarNode* op) { VisitExpr_(static_cast<const VarNode*>(op)); } void ExprVisitor::VisitExpr_(const FunctionNode* op) { this->VisitSpan(op->span); for (Var param : op->params) { this->VisitVarDef(param); } this->VisitExpr(op->body); // FuncStructInfo does not depend on Expr. } void ExprVisitor::VisitExpr_(const CallNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->op); for (StructInfo sinfo_arg : op->sinfo_args) { this->VisitExprDepStructInfoField(sinfo_arg); } for (Expr arg : op->args) { this->VisitExpr(arg); } if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } void ExprVisitor::VisitExpr_(const IfNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->cond); this->VisitExpr(op->true_branch); this->VisitExpr(op->false_branch); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } void ExprVisitor::VisitExpr_(const OpNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitSpan(op->span); this->VisitExpr(op->tuple); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } void ExprVisitor::VisitExpr_(const ShapeExprNode* op) { for (PrimExpr val : op->values) { this->VisitPrimExpr(val); } this->VisitSpan(op->span); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } void ExprVisitor::VisitExpr_(const ExternFuncNode* op) { this->VisitSpan(op->span); // FuncStructInfo does not depend on Expr. } void ExprVisitor::VisitExpr_(const SeqExprNode* op) { this->VisitSpan(op->span); for (BindingBlock block : op->blocks) { this->VisitBindingBlock(block); } this->VisitExpr(op->body); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } } void ExprVisitor::VisitExpr_(const PrimValueNode* op) { this->VisitPrimExpr(op->value); if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) { this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); } this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const StringImmNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitExpr_(const DataTypeImmNode* op) { this->VisitSpan(op->span); } void ExprVisitor::VisitSpan(const Span& span) {} void ExprVisitor::VisitPrimExpr(const PrimExpr& expr) {} // implementations of binding visitor dispatch RELAX_VAR_BINDING_DISPATCH_IMPL(ExprVisitor); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ConstantNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(VarNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataflowVarNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ShapeExprNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(ExternFuncNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(GlobalVarNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(FunctionNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(CallNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(SeqExprNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(IfNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(OpNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(TupleGetItemNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(PrimValueNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(StringImmNode); RELAX_EXPR_VISITOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprVisitor::VisitBinding_(const MatchCastNode* binding) { this->VisitExpr(binding->value); this->VisitExprDepStructInfoField(binding->struct_info); this->VisitVarDef(binding->var); } void ExprVisitor::VisitBindingBlock_(const BindingBlockNode* block) { for (Binding binding : block->bindings) { this->VisitBinding(binding); } } void ExprVisitor::VisitBindingBlock_(const DataflowBlockNode* block) { for (Binding binding : block->bindings) { this->VisitBinding(binding); } } void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { VisitVarDef_(static_cast<const VarNode*>(var)); } void ExprVisitor::VisitVarDef_(const VarNode* var) { this->VisitSpan(var->span); } void ExprVisitor::VisitBinding(const Binding& binding) { if (const auto* node = binding.as<VarBindingNode>()) { VisitBinding_(node); } else if (const auto* node = binding.as<MatchCastNode>()) { VisitBinding_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } } void ExprVisitor::VisitBindingBlock(const BindingBlock& block) { if (const auto* node = block.as<DataflowBlockNode>()) { VisitBindingBlock_(node); } else if (const auto* node = block.as<BindingBlockNode>()) { VisitBindingBlock_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); } } void ExprVisitor::VisitVarDef(const Var& var) { if (const auto* node = var.as<DataflowVarNode>()) { VisitVarDef_(node); } else if (const auto* node = var.as<VarNode>()) { VisitVarDef_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); } } class ExprApplyVisit : public ExprVisitor { public: explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {} void VisitExpr(const Expr& e) final { ExprVisitor::VisitExpr(e); f_(e); } private: std::function<void(const Expr&)> f_; }; void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { ExprApplyVisit(fvisit).VisitExpr(e); } TVM_REGISTER_GLOBAL("relax.analysis.post_order_visit").set_body_typed([](Expr expr, PackedFunc f) { PostOrderVisit(expr, [f](const Expr& n) { f(n); }); }); // ================== // ExprMutatorBase StructInfo ExprMutatorBase::VisitExprDepStructInfoField(const StructInfo& struct_info) { // recurse into struct info in case they depend on value // under the current scope. return default_struct_info_field_mutator_.VisitStructInfo(struct_info); } ExprMutatorBase::DefaultStructInfoFieldMutator::DefaultStructInfoFieldMutator( ExprMutatorBase* parent) : parent_(parent) {} Expr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField(const Expr& expr) { return parent_->VisitExpr(expr); } PrimExpr ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfoExprField( const PrimExpr& expr) { return parent_->VisitPrimExpr(expr); } StructInfo ExprMutatorBase::DefaultStructInfoFieldMutator::VisitStructInfo_( const FuncStructInfoNode* op) { // Do not recurse into function struct info // as they won't contain ref to values in current scope. return GetRef<StructInfo>(op); } Expr ExprMutatorBase::VisitExpr(const Expr& expr) { return ExprFunctor::VisitExpr(expr); } Expr ExprMutatorBase::VisitExpr_(const ConstantNode* op) { // Constant' struct info won't be affected by Expr/PrimExpr change. return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const GlobalVarNode* op) { // FuncStructInfo won't be affected by Expr/PrimExpr change. return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const TupleNode* op) { bool unchanged = true; tvm::Array<Expr> fields; for (Expr field : op->fields) { Expr new_field = this->VisitExpr(field); fields.push_back(new_field); unchanged &= new_field.same_as(field); } if (unchanged) { // If tuple's struct info change it means that // one of its fields' struct info will change // so un-changed already implies that struct info won't change return GetRef<Expr>(op); } else { // when there is a change return a new tuple node return Tuple(fields, op->span); } } // Visit the use-site of a defined Var Expr ExprMutatorBase::VisitExpr_(const VarNode* op) { // struct info of var-use should remain stable // or the var itself will get replaced return GetRef<Expr>(op); } // Visit the use-site of a defined DataflowVar Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) { return VisitExpr_(static_cast<const VarNode*>(op)); } Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) { // struct info of function is not value dependent // so no need to check struct_info field Expr body = this->VisitExpr(op->body); if (body.same_as(op->body)) { return GetRef<Expr>(op); } else { return Function(op->params, body, op->ret_struct_info, op->is_pure, op->attrs); } } Expr ExprMutatorBase::VisitExpr_(const CallNode* call_node) { Expr new_op = this->VisitExpr(call_node->op); bool unchanged = call_node->op.same_as(new_op); Array<StructInfo> sinfo_args; for (StructInfo sinfo_arg : call_node->sinfo_args) { StructInfo new_sinfo_arg = this->VisitExprDepStructInfoField(sinfo_arg); sinfo_args.push_back(new_sinfo_arg); unchanged &= new_sinfo_arg.same_as(sinfo_arg); } tvm::Array<Expr> call_args; for (Expr arg : call_node->args) { Expr new_arg = this->VisitExpr(arg); call_args.push_back(new_arg); unchanged &= new_arg.same_as(arg); } if (unchanged && VisitAndCheckStructInfoFieldUnchanged(call_node->struct_info_)) { return GetRef<Expr>(call_node); } else { return Call(new_op, call_args, call_node->attrs, sinfo_args, call_node->span); } } Expr ExprMutatorBase::VisitExpr_(const IfNode* op) { Expr guard = this->VisitExpr(op->cond); Expr true_b = this->VisitExpr(op->true_branch); Expr false_b = this->VisitExpr(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef<Expr>(op); } else { return If(guard, true_b, false_b, op->span); } } Expr ExprMutatorBase::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const TupleGetItemNode* op) { auto t = this->VisitExpr(op->tuple); if (op->tuple.same_as(t)) { // struct info can be deterministically derived by tuple and index // if t does not change, then struct info won't change. return GetRef<Expr>(op); } else { return TupleGetItem(t, op->index, op->span); } } Expr ExprMutatorBase::VisitExpr_(const PrimValueNode* op) { auto value = this->VisitPrimExpr(op->value); if (op->value.same_as(value)) { // struct info can be deterministically derived by value // if value does not change, then struct info won't change. return GetRef<Expr>(op); } return PrimValue(value, op->span); } Expr ExprMutatorBase::VisitExpr_(const StringImmNode* op) { return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const DataTypeImmNode* op) { return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const ShapeExprNode* op) { auto values = op->values.Map([this](const PrimExpr& e) { return this->VisitPrimExpr(e); }); if (values.same_as(op->values)) { // If values does not change, struct info won't change. return GetRef<Expr>(op); } else { return ShapeExpr(values, op->span); } } Expr ExprMutatorBase::VisitExpr_(const ExternFuncNode* op) { // StructInfo of function remains value independent. return GetRef<Expr>(op); } Expr ExprMutatorBase::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; Array<BindingBlock> blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { blocks.push_back(new_block); } all_blocks_unchanged &= block.same_as(new_block); } Expr body = this->VisitExpr(op->body); if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef<Expr>(op); } return SeqExpr(blocks, body); } BindingBlock ExprMutatorBase::VisitBindingBlock(const BindingBlock& block) { Array<Binding> bindings; if (const auto* node = block.as<BindingBlockNode>()) { for (auto binding : node->bindings) { if (auto var_binding = binding.as<VarBindingNode>()) { Expr new_value = this->VisitExpr(var_binding->value); bindings.push_back(VarBinding(var_binding->var, new_value)); } else if (auto match_cast = binding.as<MatchCastNode>()) { Expr new_value = this->VisitExpr(match_cast->value); bindings.push_back(MatchCast(match_cast->var, new_value, match_cast->struct_info)); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } } } else { LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); } if (block.as<DataflowBlockNode>()) { return DataflowBlock(bindings); } else { return BindingBlock(bindings); } } PrimExpr ExprMutatorBase::VisitPrimExpr(const PrimExpr& expr) { return expr; } // ================== // ExprMutator Expr ExprMutator::VisitExpr(const Expr& expr) { return builder_->Normalize(ExprFunctor::VisitExpr(expr)); } // Visit the use-site of a defined Var Expr ExprMutator::VisitExpr_(const VarNode* op) { auto it = var_remap_.find(op->vid); if (it != var_remap_.end()) { return it->second; } // default case return self. return GetRef<Expr>(op); } // Visit the use-site of a defined DataflowVar Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) { return VisitExpr_(static_cast<const VarNode*>(op)); } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { tvm::Array<Var> params; bool all_params_unchanged = true; for (Var param : op->params) { Var new_param = this->VisitVarDef(param); params.push_back(new_param); if (!param.same_as(new_param)) { var_remap_[param->vid] = new_param; all_params_unchanged = false; } } Expr body = this->VisitWithNewScope(op->body, params); if (all_params_unchanged && body.same_as(op->body)) { // No changes to the function, return the original object return GetRef<Expr>(op); } else if (IsBaseOf(GetStructInfo(body), op->ret_struct_info)) { // If the function was mutated into a form that can no longer // propagate shape information all the way to the return value, we // may keep the return struct info. This is only allowed when the // body produces a return value that is the same as, or more // specific than, the pre-mutation struct info. For example, if // the previous return value was `TensorStructInfo(shape=[16,16])` // but the body only produced `TensorStructInfo(ndim=2)`, we can // keep the more specific information. return Function(params, body, op->ret_struct_info, op->is_pure, op->attrs); } else { // If the function was mutated such that the body produces an // output that is incompatible with the original return struct // info, the original return struct info should not be used. For // example, if the previous return value was // `TensorStructInfo(shape=[16,16])`, but the new return value is // `TensorStructInfo(shape=[8,8])`. return Function(params, body, NullOpt, op->is_pure, op->attrs); } } Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr guard = this->VisitExpr(op->cond); Expr true_b = this->VisitWithInnerScope(op->true_branch); Expr false_b = this->VisitWithInnerScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef<Expr>(op); } else { return If(guard, true_b, false_b, op->span); } } Expr ExprMutator::VisitExpr_(const SeqExprNode* op) { bool all_blocks_unchanged = true; Array<BindingBlock> blocks; for (auto block : op->blocks) { BindingBlock new_block = this->VisitBindingBlock(block); if (!new_block->bindings.empty()) { blocks.push_back(new_block); } all_blocks_unchanged &= block.same_as(new_block); } builder_->BeginBindingBlock(); Expr body = this->VisitExpr(op->body); BindingBlock prologue = builder_->EndBlock(); if (!prologue->bindings.empty()) { blocks.push_back(prologue); all_blocks_unchanged = false; } if (all_blocks_unchanged && body.same_as(op->body) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { return GetRef<Expr>(op); } else { return SeqExpr(blocks, body); } } RELAX_VAR_BINDING_DISPATCH_IMPL(ExprMutator); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(VarNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataflowVarNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ShapeExprNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ExternFuncNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(GlobalVarNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(FunctionNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(CallNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(SeqExprNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(IfNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(OpNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(TupleGetItemNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(PrimValueNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(StringImmNode); RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(DataTypeImmNode); void ExprMutator::ReEmitBinding(const VarBindingNode* binding, Expr new_value) { Var new_var = this->VisitVarDef(binding->var); // fast path: re-emit binding if nothing changes if (new_var.same_as(binding->var) && new_value.same_as(binding->value)) { builder_->EmitNormalized(GetRef<VarBinding>(binding)); return; } auto new_sinfo = new_value->struct_info_.as<StructInfo>(); ICHECK(new_sinfo) << "InternalError: " << "In binding of variable " << binding->var << ", the value " << new_value << " does not have StructInfo. " << "This typically occurs when ReEmitBinding is called without first calling Normalize."; Var temp = WithStructInfo(new_var, new_sinfo.value()); if (!temp.same_as(new_var)) { new_var = temp; } this->var_remap_[binding->var->vid] = new_var; this->var_remap_[new_var->vid] = new_var; builder_->EmitNormalized(VarBinding(new_var, new_value)); } void ExprMutator::VisitBinding_(const MatchCastNode* binding) { Expr new_value = this->VisitExpr(binding->value); StructInfo new_struct_info = this->VisitExprDepStructInfoField(binding->struct_info); Var new_var = this->VisitVarDef(binding->var); MatchCast new_binding = [&]() -> MatchCast { if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && new_struct_info.same_as(binding->struct_info)) { // re-emit old binding if nothing changes return GetRef<MatchCast>(binding); } else { new_value = builder_->NormalizeArgument(new_value); new_var = WithStructInfo(new_var, new_struct_info); var_remap_[binding->var->vid] = new_var; var_remap_[new_var->vid] = new_var; return MatchCast(new_var, new_value, new_struct_info, binding->span); } }(); builder_->EmitNormalized(new_binding); builder_->AddDefinitionToScope(new_binding->var); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { builder_->BeginBindingBlock(); for (Binding binding : block->bindings) { this->VisitBinding(binding); } return builder_->EndBlock(); } BindingBlock ExprMutator::VisitBindingBlock_(const DataflowBlockNode* block) { builder_->BeginDataflowBlock(); for (auto binding : block->bindings) { this->VisitBinding(binding); } return builder_->EndBlock(); } Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) { Var output = VisitVarDef_(static_cast<const VarNode*>(var)); // Because we delegate from DataflowVar visitor to Var visitor to // provide default behavior in subclasses, we may produce a Var // where we should produce a DataflowVar. if (!output->IsInstance<DataflowVarNode>()) { output = DataflowVar(output->vid, GetStructInfo(output), output->span); } return output; } Var ExprMutator::VisitVarDef_(const VarNode* var) { if (auto* sinfo = var->struct_info_.as<StructInfoNode>()) { StructInfo struct_info = this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo)); if (struct_info.same_as(var->struct_info_)) { return GetRef<Var>(var); } else { return Var(var->vid, struct_info, var->span); } } else { return GetRef<Var>(var); } } void ExprMutator::VisitBinding(const Binding& binding) { if (const auto* node = binding.as<VarBindingNode>()) { VisitBinding_(node); } else if (const auto* node = binding.as<MatchCastNode>()) { VisitBinding_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << binding->GetTypeKey(); } } BindingBlock ExprMutator::VisitBindingBlock(const BindingBlock& block) { BindingBlock ret; if (const auto* node = block.as<DataflowBlockNode>()) { ret = VisitBindingBlock_(node); } else if (const auto* node = block.as<BindingBlockNode>()) { ret = VisitBindingBlock_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << block->GetTypeKey(); } return ret; } Var ExprMutator::VisitVarDef(const Var& var) { Var ret; if (const auto* node = var.as<DataflowVarNode>()) { ret = VisitVarDef_(node); } else if (const auto* node = var.as<VarNode>()) { ret = VisitVarDef_(node); } else { LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey(); } return ret; } Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params) { ICHECK(expr->IsInstance<SeqExprNode>()) << "Normal form requires all new scope is stored as SeqExpr"; PrimExpr constraint = Bool(true); if (params.defined()) { auto non_negative_expressions = CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo))); for (const auto& expr : non_negative_expressions) { constraint = constraint && (expr >= 0); } } builder_->BeginScope(params); // Outer scope only includes TIR variables that can be inferred from // the function parameters. With<arith::ConstraintContext> context(builder_->GetAnalyzer(), constraint); builder_->BeginInnerScope(); // Inner scope also includes any TIR variables that are defined by // MatchCast nodes, and are internal to the scope. Expr ret = this->VisitExpr(expr); builder_->EndScope(); // Normalization (and the resulting StructInfo inference) of the // expr occurs outside of the body's parameters, but inside the // function signature's scope. This keeps variables that are // inferable based on the function signature, to allow callers to // propagate StructInfo across the function. ret = builder_->Normalize(ret); builder_->EndScope(); return ret; } Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { ICHECK(expr->IsInstance<SeqExprNode>()) << "Normal form requires all new scope is stored as SeqExpr"; builder_->BeginInnerScope(); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; } Optional<Expr> ExprMutator::LookupBinding(const Var& var) { return builder_->LookupBinding(var); } Var ExprMutator::WithStructInfo(Var var, StructInfo struct_info) { ICHECK(struct_info.defined()); // TODO(relax-team) add StructInfoEqual check if (var->struct_info_.defined()) { // use same-as as a quick path if (var->struct_info_.same_as(struct_info) || StructuralEqual()(var->struct_info_, struct_info)) { return var; } else { Var new_var = var.as<DataflowVarNode>() ? DataflowVar(var->vid, struct_info, var->span) : Var(var->vid, struct_info, var->span); return new_var; } } else { UpdateStructInfo(var, struct_info); return var; } } } // namespace relax } // namespace tvm