in onnxruntime/core/optimizer/layer_norm_fusion.cc [364:609]
Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::vector<std::reference_wrapper<Node>> nodes_to_remove;
for (auto node_index : node_topology_list) {
nodes_to_remove.clear();
auto* p_pow = graph.GetNode(node_index);
if (p_pow == nullptr)
continue; // we removed the node as part of an earlier fusion
Node& pow_node = *p_pow;
ORT_RETURN_IF_ERROR(Recurse(pow_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13}) ||
!graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) ||
!optimizer_utils::CheckOutputEdges(graph, pow_node, 1) ||
graph.NodeProducesGraphOutput(pow_node) ||
!IsSupportedDataType(pow_node)) {
continue;
}
nodes_to_remove.push_back(pow_node);
const Node* p_reduce_mean = nullptr;
p_reduce_mean = graph_utils::FirstChildByType(pow_node, "ReduceMean");
if (p_reduce_mean == nullptr) {
continue;
}
Node& reduce_mean_node = *graph.GetNode(p_reduce_mean->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
reduce_mean_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) ||
!IsSupportedDataType(reduce_mean_node) ||
reduce_mean_node.GetInputEdgesCount() == 0) {
continue;
}
nodes_to_remove.push_back(reduce_mean_node);
const Node* p_add = graph_utils::FirstChildByType(reduce_mean_node, "Add");
if (p_add == nullptr) {
continue;
}
Node& add_node = *graph.GetNode(p_add->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7, 13, 14}) ||
add_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, add_node, 1) ||
!IsSupportedDataType(add_node)) {
continue;
}
nodes_to_remove.push_back(add_node);
const Node* p_sqrt = graph_utils::FirstChildByType(add_node, "Sqrt");
if (p_sqrt == nullptr) {
continue;
}
Node& sqrt_node = *graph.GetNode(p_sqrt->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sqrt_node, "Sqrt", {6, 13}) ||
sqrt_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, sqrt_node, 1) ||
!IsSupportedDataType(sqrt_node) ||
sqrt_node.GetInputEdgesCount() == 0) {
continue;
}
nodes_to_remove.push_back(sqrt_node);
const Node* p_div = graph_utils::FirstChildByType(sqrt_node, "Div");
if (p_div == nullptr) {
continue;
}
Node& div_node = *graph.GetNode(p_div->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div_node, "Div", {7, 13, 14}) ||
div_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, div_node, 1) ||
!IsSupportedDataType(div_node)) {
continue;
}
nodes_to_remove.push_back(div_node);
const NodeArg* p_div_input = div_node.MutableInputDefs()[0];
const NodeArg* p_pow_input = pow_node.MutableInputDefs()[0];
if (p_pow_input == nullptr || p_div_input == nullptr) {
continue;
}
bool cast_1_present = false;
int64_t cast_1_to_attr;
// check if there are Casts as input to the Pow and Div
if (p_div_input == p_pow_input) {
const Node* p_pow_input_node = graph_utils::GetInputNode(pow_node, 0);
if (allow_precision_change_ && p_pow_input_node != nullptr) {
Node& pow_input_node = *graph.GetNode(p_pow_input_node->Index());
// If input to Pow is a Cast, and the Cast has 2 consumers only (Pow, Div)
if (graph_utils::IsSupportedOptypeVersionAndDomain(pow_input_node, "Cast", {9, 13}) &&
pow_input_node.GetExecutionProviderType() == pow_node.GetExecutionProviderType() &&
optimizer_utils::CheckOutputEdges(graph, pow_input_node, 2)) {
// get the 'to' attribute of Cast
int64_t pcast_to;
const onnxruntime::NodeAttributes& pcast_attributes = pow_input_node.GetAttributes();
NodeAttributes::const_iterator pcast_to_attr = pcast_attributes.find("to");
if (pcast_to_attr != pcast_attributes.end()) {
pcast_to = static_cast<int64_t>(pcast_to_attr->second.i());
} else {
continue;
}
cast_1_present = true;
cast_1_to_attr = pcast_to;
} // end Cast check
} // end allow_precision_change_
} else {
continue;
}
// div --> mul or div --> cast --> mul
Node* next_node = graph.GetNode(div_node.OutputNodesBegin()->Index());
Node* p_cast_2 = nullptr;
if (allow_precision_change_ &&
graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Cast", {9, 13}) &&
optimizer_utils::CheckOutputEdges(graph, *next_node, 1)) {
p_cast_2 = next_node;
next_node = graph.GetNode(p_cast_2->OutputNodesBegin()->Index());
nodes_to_remove.push_back(*p_cast_2);
}
Node& mul_node = *next_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
mul_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
}
nodes_to_remove.push_back(mul_node);
// get axes attributes
const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes();
std::vector<int64_t> axes_values;
if (attributes.find("axes") != attributes.end()) {
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
}
// Get the inputs for the new LayerNormalization node.
// scale and bias could be multi-dims; we only support it for training at the moment
// because SkipLayerNorm kernel, for example, has dependency on single dim size
NodeArg* scale = nullptr;
for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) {
if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) ||
graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) {
#ifdef ENABLE_TRAINING
if (axes_values.empty() ||
mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast<int>(axes_values.size())) {
scale = mul_node.MutableInputDefs()[i];
}
#else
// Scale must be 1d.
if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) {
scale = mul_node.MutableInputDefs()[i];
}
#endif
}
}
if (scale == nullptr) {
continue;
}
std::vector<NodeArg*> layer_norm_input_defs{pow_node.MutableInputDefs()[0]};
// There was a cast at input, so make sure the 'to' type for input casts
// is the same as type for scale input. If not, add a Cast.
if (allow_precision_change_ && cast_1_present) {
// get type of activation input
ONNX_NAMESPACE::TensorProto_DataType cast_1_type = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_1_to_attr);
const ONNX_NAMESPACE::TypeProto* casted_type = DataTypeImpl::TensorTypeFromONNXEnum(cast_1_type)->GetTypeProto();
// get type of scale input and compare to activation input type
if (scale->Type() != nullptr &&
DataTypeImpl::TypeFromProto(*scale->TypeAsProto()) != DataTypeImpl::TypeFromProto(*casted_type)) {
std::string node_name = graph.GenerateNodeName("Cast_Scale");
auto* casted_scale = &graph.GetOrCreateNodeArg(node_name, casted_type);
std::vector<NodeArg*> input_defs = {scale};
std::vector<NodeArg*> output_defs = {casted_scale};
auto& cast_node = graph.AddNode(node_name, "Cast", "cast scale of layer norm", input_defs, output_defs);
cast_node.AddAttribute("to", cast_1_to_attr);
cast_node.SetExecutionProviderType(pow_node.GetExecutionProviderType());
layer_norm_input_defs.push_back(casted_scale);
} else { // scale type is same as casted type
layer_norm_input_defs.push_back(scale);
}
} else { // cast1 is not present or allow_precision_change_ false
layer_norm_input_defs.push_back(scale);
}
Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"),
"SimplifiedLayerNormalization",
"fused LayerNorm subgraphs ",
layer_norm_input_defs,
{}, {}, kOnnxDomain);
// Get constant "epsilon" from "Add" node if available. Else, default value will be used.
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, add_node.MutableInputDefs()[1]->Name());
if (tensor_proto != nullptr &&
tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
Initializer initializer{*tensor_proto, graph.ModelPath()};
layer_norm_node.AddAttribute("epsilon", initializer.data<float>()[0]);
} else {
layer_norm_node.AddAttribute("epsilon", DEFAULT_LAYERNORM_EPSILON);
}
// Assign provider to this new node. Provider should be same as the provider for old node.
layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType());
if (allow_precision_change_ && p_cast_2 != nullptr) {
ONNX_NAMESPACE::TensorProto_DataType cast_1_type = gsl::narrow_cast<ONNX_NAMESPACE::TensorProto_DataType>(cast_1_to_attr);
const ONNX_NAMESPACE::TypeProto* casted_type = DataTypeImpl::TensorTypeFromONNXEnum(cast_1_type)->GetTypeProto();
NodeArg* LN_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("layer_norm_out"), casted_type);
layer_norm_node.MutableOutputDefs().push_back(LN_output);
Node& cast_ln_node = graph.AddNode(graph.GenerateNodeName("Cast"),
"Cast",
"cast output of layer norm",
{LN_output},
{});
auto cast_2_to_attr = p_cast_2->GetAttributes().find("to")->second.i();
cast_ln_node.AddAttribute("to", cast_2_to_attr);
cast_ln_node.SetExecutionProviderType(pow_node.GetExecutionProviderType());
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node, cast_ln_node);
} else {
// move input edges to add (first in list) across to the layer_norm_node.
// move output definitions and output edges from mul_node (last in list) to layer_norm_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node);
}
#ifdef ENABLE_TRAINING
// add one extra output def, so we have 2 output defs that match what gradient builder expected
layer_norm_node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("saved_inv_std_var"), nullptr));
#endif
modified = true;
}
return Status::OK();
}