optimum/onnx/graph_transformations.py (198 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
import onnx
from onnx import ModelProto
from ..utils import logging
from .transformations_utils import (
_create_name_sharing_dict,
_deduplicate_gather_matmul,
_deduplicated_cross_model_initializers,
_find_duplicate_initializers,
_find_matching_initializers,
_get_all_inputs,
_get_onnx_opset,
_get_weights_to_tie,
_remove_redundant_initializers,
_replace_input_names,
_unify_onnx_outputs,
cast_int64_tensorproto_to_int32,
)
if TYPE_CHECKING:
import torch.nn as nn
logger = logging.get_logger()
def remove_duplicate_weights(model: ModelProto, inplace: bool = False) -> ModelProto:
"""
Finds and removes duplicate weights in a model by keeping only unique weights, and make the duplicate values point
to them.
This function only removes duplicate weights that are exactly identical (e.g., not transposed).
Args:
model (`onnx.ModelProto`): The model to remove duplicates from.
inplace (`bool`, defaults to False): Whether to perform this transformation inplace.
Returns:
`onnx.ModelProto`: The model without duplicates.
"""
if not inplace:
model = copy.deepcopy(model)
duplicates = _find_duplicate_initializers(models=[model])
name_sharing_dict = _create_name_sharing_dict(duplicates)
_replace_input_names(models=[model], name_sharing_dict=name_sharing_dict)
_remove_redundant_initializers(models=[model], name_sharing_dict=name_sharing_dict)
return model
def remove_duplicate_weights_from_tied_info(
onnx_model: ModelProto, torch_model: "nn.Module", tied_params: List[List[str]], save_path: str
):
"""
Tries to remove potential duplicate ONNX initializers from the tied information in tied_params.
Args:
onnx_model (`onnx.ModelProto`):
The ONNX model for which to tie potentially duplicate initializers.
torch_model (`nn.Module`):
The PyTorch model corresponding to the ONNX one.
tied_params (`List[List[str]]`):
A list of groups of torch parameters that are tied, i.e. shared. For them,
the torch module shares the same pointer.
"""
tied_params_with_op, tied_groups_to_tie, tied_groups_ignored = _get_weights_to_tie(tied_params, torch_model)
if len(tied_groups_ignored) >= 1:
logger.info(
f"The groups of weights {tied_groups_ignored} will not be tied as either already tied or tying is not implemented."
)
initializer_name_to_idx = {}
for idx, initializer in enumerate(onnx_model.graph.initializer):
initializer_name_to_idx[initializer.name] = idx
tied_groups_map = _find_matching_initializers(tied_params_with_op, onnx_model, initializer_name_to_idx)
onnx_model = _deduplicate_gather_matmul(onnx_model, tied_groups_to_tie, tied_groups_map, initializer_name_to_idx)
check_and_save_model(onnx_model, save_path=save_path)
return onnx_model
def replace_atenops_to_gather(model: ModelProto) -> ModelProto:
"""
Replaces broken ATenOp nodes back to Gather nodes.
Args:
model (`onnx.ModelProto`):
The ONNX model to fix.
Returns:
`onnx.ModelProto`: The ONNX model fixed.
"""
nodes = model.graph.node
for node in nodes:
if node.op_type in ["ATenOp", "ATen"]:
op_num = node.name.split("_")[-1]
new_node = onnx.helper.make_node(
"Gather",
name="Gather_" + op_num,
inputs=[node.input[0], node.input[1]],
outputs=node.output,
)
model.graph.node.remove(node)
model.graph.node.insert(int(op_num), new_node)
onnx.checker.check_model(model)
return model
def check_and_save_model(model: onnx.ModelProto, save_path: Optional[Union[str, Path]]):
# We can check ModelProtos that are smaller than 2GB before saving them.
# For larger models, we need to save them first and then check their save path.
# https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#checking-a-large-onnx-model-2gb
if model.ByteSize() < onnx.checker.MAXIMUM_PROTOBUF:
# For the try catch, refer to https://github.com/microsoft/onnxruntime/issues/14768
try:
onnx.checker.check_model(model)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
save_path = Path(save_path).as_posix()
external_file_name = os.path.basename(save_path) + "_data"
external_file_path = os.path.join(os.path.dirname(save_path), external_file_name)
if save_path.endswith(".onnx") and os.path.isfile(save_path):
os.remove(save_path)
model_uses_external_data = False
if os.path.isfile(external_file_path):
model_uses_external_data = True
os.remove(external_file_path)
FORCE_ONNX_EXTERNAL_DATA = os.getenv("FORCE_ONNX_EXTERNAL_DATA", "0") == "1"
onnx.save(
model,
save_path,
save_as_external_data=model_uses_external_data or FORCE_ONNX_EXTERNAL_DATA,
all_tensors_to_one_file=True,
location=external_file_name,
convert_attribute=True,
size_threshold=1024 if not FORCE_ONNX_EXTERNAL_DATA else 100,
)
try:
onnx.checker.check_model(save_path)
except Exception as e:
if "No Op registered for" in str(e):
pass
else:
raise e
def merge_decoders(
decoder: Union[ModelProto, Path, str],
decoder_with_past: Union[ModelProto, Path, str],
graph_name: str = "merged",
producer_name: str = "optimum-onnx",
save_path: Optional[Union[str, Path]] = None,
strict: bool = True,
) -> ModelProto:
"""
Fuses decoder ONNX model and decoder with past ONNX model into one ONNX model with if logic.
Args:
decoder (`Union[ModelProto, Path, str]`):
Decoder ONNX model.
decoder_with_past (`Union[ModelProto, Path, str]`):
Decoder with past ONNX model.
graph_name (`str`, defaults to `"merged"`):
Name of the parent graph (graph of the control flow node).
producer_name (`str`, defaults to `"optimum-onnx"`):
Graph producer name.
save_path (`Optional[Union[str, Path]]`, defaults to `None`):
The path to save merged ONNX model. The model will be saved if the path is given.
strict (`bool`, defaults to `True`):
When set, the decoder and decoder_with_past are expected to have strictly the same number of outputs. When False,
the decoder is allowed to have more outputs that decoder_with_past, in which case constant outputs are added to match
the number of outputs.
Returns:
`~onnx.ModelProto`: The fused decoder ONNX model.
"""
if isinstance(decoder, (str, Path)):
decoder = Path(decoder).as_posix()
decoder = onnx.load(decoder)
if isinstance(decoder_with_past, (str, Path)):
decoder_with_past = Path(decoder_with_past).as_posix()
decoder_with_past = onnx.load(decoder_with_past)
decoder_opset = _get_onnx_opset(decoder)
decoder_with_past_opset = _get_onnx_opset(decoder_with_past)
if decoder_opset != decoder_with_past_opset:
raise ValueError(
f"Decoder's opset is {decoder_opset}, but decoder with past's opset is {decoder_with_past_opset}. Make sure having the same opset before merging."
)
_unify_onnx_outputs(decoder, decoder_with_past, strict=strict)
all_inputs = _get_all_inputs([decoder, decoder_with_past])
# Replace the axis name `sequence_length` of the attention_mask input by `attention_mask_sequence_length`.
# This is because the merged model `input_ids` and `attention_mask` inputs may not always have the same length on the 2nd axis.
# In the first pass, `input_ids` and `attention_mask` are indeed of the same length, but in later pass `input_ids` is of length 1
# while `attention_mask` is of length `past_sequence_length + 1`
for _, inp in enumerate(all_inputs):
if inp.name == "attention_mask":
if inp.type.tensor_type.shape.dim[1].dim_param != "sequence_length":
raise ValueError("Expected attention_mask second axis to be dynamic and named `sequence_length`.")
inp.type.tensor_type.shape.dim[1].dim_param = "attention_mask_sequence_length"
deduplicated_initializers = _deduplicated_cross_model_initializers([decoder, decoder_with_past], suffix=graph_name)
# Keep initializers of dim 0 (or dim 1 + int32/int64) in subgraphs for readability purposes, and also because
# ONNX Runtime breaks after optimization + merge if they are not
decoder_initializers = []
for initializer in decoder.graph.initializer:
if len(initializer.dims) == 0 or (len(initializer.dims) == 1 and initializer.data_type in [6, 7]):
decoder_initializers.append(initializer)
decoder_with_past_initializers = []
for initializer in decoder_with_past.graph.initializer:
if len(initializer.dims) == 0 or (len(initializer.dims) == 1 and initializer.data_type in [6, 7]):
decoder_with_past_initializers.append(initializer)
# Make subgraphs
no_past_branch = onnx.helper.make_graph(
nodes=decoder.graph.node,
name="no_past",
inputs=[],
outputs=decoder.graph.output,
initializer=decoder_initializers,
)
with_past_branch = onnx.helper.make_graph(
nodes=decoder_with_past.graph.node,
name="with_past",
inputs=[],
outputs=decoder_with_past.graph.output,
initializer=decoder_with_past_initializers,
)
# Merge subgraphs with a `If` node
use_cache_branch = onnx.helper.make_tensor_value_info(
name="use_cache_branch",
elem_type=onnx.TensorProto.BOOL,
shape=[1],
)
if_node = onnx.helper.make_node(
"If",
inputs=["use_cache_branch"],
outputs=[output.name for output in no_past_branch.output],
name="optimum::if",
then_branch=with_past_branch,
else_branch=no_past_branch,
)
merged_graph = onnx.helper.make_graph(
nodes=[if_node],
name=graph_name,
inputs=all_inputs + [use_cache_branch],
outputs=no_past_branch.output,
initializer=deduplicated_initializers,
)
# Preserve imports from the decoder without/with past ONNX
opset_imports = []
opset_domains = set()
for opset_import in list(decoder.opset_import) + list(decoder_with_past.opset_import):
if opset_import.domain not in opset_domains:
opset_imports.append(opset_import)
opset_domains.add(opset_import.domain)
# TODO: update IR version in the future.
merged_model = onnx.helper.make_model_gen_version(
merged_graph, producer_name=producer_name, opset_imports=opset_imports, ir_version=9
)
check_and_save_model(merged_model, save_path=save_path)
return merged_model
def cast_slice_nodes_inputs_to_int32(model: ModelProto) -> ModelProto:
"""
Convert node inputs of `Slice` nodes from int64 to int32, casting the out of range values.
The constant node inputs are stored in `model.graph.node`, and the sole way to check which node
they are consumed by is to iterate over nodes and check `node.input` for a match.
Note that constant inputs to nodes as `Squeeze`, `Unsqueeze` can not be converted to int32, as the
these operators explicitely expect int64 inputs according to ONNX specifications:
https://github.com/onnx/onnx/blob/main/docs/Operators.md
"""
map_input_node = {}
map_node_inputs = {}
for node in model.graph.node:
for input_name in node.input:
map_input_node[input_name] = {"op_type": node.op_type, "node_name": node.name}
map_node_inputs[node.name] = node.input
for node in model.graph.node:
if (
node.op_type == "Constant"
and node.attribute[0].t.data_type == 7 # int64
and f"{node.name}_output_0" in map_input_node
and map_input_node[node.name + "_output_0"]["op_type"] == "Slice"
):
logger.debug(f"Converting {node.name} to int32")
# `Slice` node is homogeneous (requires parameters of same type), hence cast to int32 only if all of its inputs are constants
# refer to onnx/defs/schema.h
cast = all(
"Constant" in inp for inp in map_node_inputs[map_input_node[node.name + "_output_0"]["node_name"]][1:]
)
cast_int64_tensorproto_to_int32(node.attribute[0].t, cast=cast)
return model