tinynn/converter/operators/optimize.py (3,868 lines of code) (raw):
import copy
import functools
import itertools
import re
import typing
import warnings
import igraph as ig
import numpy as np
from tinynn.util.util import class_conditional, get_logger
from ..schemas.tflite.schema_generated import ActivationFunctionType, Padding
from . import tflite as tfl
from .base import FUSE_ACTIVATION_MAP, ExtendedOperator
from .graph import CommonGraph
log = get_logger(__name__, 'INFO')
class GraphOptimizer(object):
graph: CommonGraph
fuse_tensor_count: int
fuse_attr_count: int
fuse_quant: bool
group_conv_rewrite: bool
tflite_micro_rewrite: bool
quantize_input_output_type: typing.Optional[str]
# Optimization levels
NO_OPTIMIZE: int = 0
FOLD_BUFFER: int = 1
FUSE_BN: int = 2
COMMON_OPTIMIZE: int = 3
BRANCH_OPTIMIZE: int = 4
BRANCH_OPTIMIZE_EXTENDED: int = 5
ALL_OPTIMIZE: int = 5
def __init__(
self,
graph: CommonGraph,
level: int,
fuse_quant: bool,
group_conv_rewrite: bool,
rewrite_quantizable: bool,
tflite_micro_rewrite: bool,
quantize_input_output_type: typing.Optional[str],
fuse_input_indices: typing.Optional[typing.List[int]] = None,
fuse_output_indices: typing.Optional[typing.List[int]] = None,
max_transpose_dims: int = -1,
bypass_elementwise_passthrough_constraint: bool = False,
group_tensors: bool = False,
conv_transpose_with_bias: bool = True,
hybrid_int16_lstm: bool = False,
) -> None:
self.graph = graph
self.fuse_tensor_count = 0
self.fuse_attr_count = 0
self.level = level
self.fuse_quant = fuse_quant
self.group_conv_rewrite = group_conv_rewrite
self.rewrite_quantizable = rewrite_quantizable
self.tflite_micro_rewrite = tflite_micro_rewrite
self.quantize_input_output_type = quantize_input_output_type
self.fuse_input_indices = fuse_input_indices
self.fuse_output_indices = fuse_output_indices
self.max_transpose_dims = max_transpose_dims
self.bypass_elementwise_passthrough_constraint = bypass_elementwise_passthrough_constraint
self.group_tensors = group_tensors
self.conv_transpose_with_bias = conv_transpose_with_bias
self.hybrid_int16_lstm = hybrid_int16_lstm
def create_attr_tensor(
self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
):
if name is None:
if self.fuse_attr_count == 0:
name = 'fuse_attr'
else:
name = f'fuse_attr_{self.fuse_attr_count}'
self.fuse_attr_count += 1
return tfl.Tensor(tensor, name, has_buffer=True, quantization=quantization)
def create_transform_tensor(
self, tensor: tfl.Tensor, name: str = None, quantization: typing.Optional[tfl.QuantizationParameters] = None
):
if name is None:
if self.fuse_tensor_count == 0:
name = 'fuse_transform'
else:
name = f'fuse_transform_{self.fuse_tensor_count}'
self.fuse_tensor_count += 1
return tfl.Tensor(tensor, name, has_buffer=False, quantization=quantization)
@class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
def fuse_conv_fc_bn(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_bn_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
actions = []
for conv, bn, tensor in filtered_pairs:
bn_activ = bn['op'].fusedActivationFunction
conv_activ = getattr(conv['op'], 'fusedActivationFunction', None)
if conv_activ is None and bn_activ != ActivationFunctionType.NONE:
continue
# Find out the output of the batch-norm nodes
new_output = bn['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of a batch-norm node, we connect it with the conv node
self.graph.connect_next_tensors(bn, conv, new_output)
# Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
# batch-norm instead
conv['outputs'][0] = new_output
conv['op'].outputs[0] = self.graph.tensor_map[new_output]
self.graph.tensor_node_map[new_output] = conv['name']
tensor['name'] = bn['outputs'][0]
tensor['label'] = bn['outputs'][0]
if bn_activ != ActivationFunctionType.NONE and conv_activ == ActivationFunctionType.NONE:
conv['op'].fusedActivationFunction = bn_activ
# Collect the arguments of the conv and batch-norm nodes
weight = conv['op'].inputs[1]
bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
bn_w, bn_b, bn_mean, bn_var = (
bn_w.tensor.copy(),
bn_b.tensor.copy(),
bn_mean.tensor.copy(),
bn_var.tensor.copy(),
)
activ_w = weight.tensor.copy()
activ_b = bias.tensor.copy() if bias is not None else None
eps = bn['op'].eps
# Fuse conv/fc and batch-norm
new_weight = fuse_bn_weight(
eps, bn_w, bn_var, activ_w, conv['node_type'] == ExtendedOperator.GENERIC_DECONV
)
new_bias = fuse_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b)
# New attribute tensors
new_w = self.create_attr_tensor(new_weight)
new_b = self.create_attr_tensor(new_bias)
# Collect the actions we should take here
# The reason that we don't do the actions here is because we are currently in the loop of vertices,
# the iterator will be invalidated once `replace_operator_input` is called
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
if bias is not None:
actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
else:
actions.append((self.graph.append_operator_input, (conv, new_b)))
remove_ids.append(bn.index)
# Process actions
for func, args in actions:
func(*args)
# Delete batch-norm nodes
for id in remove_ids:
vertex = self.graph.graph.vs[id]
assert vertex['node_type'] == ExtendedOperator.BATCH_NORM
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.FUSE_BN)
def fuse_bn_conv(self):
edges = self.graph.graph.es.select(functools.partial(is_rev_bn_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)
def _remove_last_pred(seq):
bn = seq[0]
conv = seq[1]
# Collect the arguments of the conv and batch-norm nodes
weight = conv['op'].inputs[1]
bias = conv['op'].inputs[2] if len(conv['op'].inputs) > 2 else None
bn_w, bn_b, bn_mean, bn_var = bn['op'].inputs[1:]
bn_w, bn_b, bn_mean, bn_var = (
bn_w.tensor.copy(),
bn_b.tensor.copy(),
bn_mean.tensor.copy(),
bn_var.tensor.copy(),
)
activ_w = weight.tensor.copy()
activ_b = bias.tensor.copy() if bias is not None else None
eps = bn['op'].eps
new_weight = fuse_rev_bn_weight(eps, bn_w, bn_var, activ_w)
new_bias = fuse_rev_bn_bias(eps, bn_w, bn_var, bn_mean, bn_b, activ_b, activ_w)
return False, (conv, bias, new_weight, new_bias)
def _remove_last_action(first_node, last_node, custom_data):
conv, bias, new_weight, new_bias = custom_data
new_w = self.create_attr_tensor(new_weight)
new_b = self.create_attr_tensor(new_bias)
actions = []
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
if bias is not None:
actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
else:
actions.append((self.graph.append_operator_input, (conv, new_b)))
return actions
def _skip_pred(seq):
bn = seq[0]['op']
conv = seq[1]['op']
skip = bn.inputs[0].quantization is not None or (
conv.inputs[1].shape[1] == 1 and conv.inputs[1].shape[0] == conv.groups and conv.groups > 1
)
return skip
elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
_skip_pred,
force_forward_input=True,
)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_activation(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_activ_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
for pre_activ, activ, tensor in filtered_pairs:
if not self.conv_transpose_with_bias and pre_activ['node_type'] == ExtendedOperator.GENERIC_DECONV:
continue
# Find out the output of the batch-norm nodes
new_output = activ['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of the activation node, we connect it with the previous node
self.graph.connect_next_tensors(activ, pre_activ, new_output)
# Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
# batch-norm instead
pre_activ['outputs'][0] = new_output
pre_activ['op'].outputs[0] = self.graph.tensor_map[new_output]
self.graph.tensor_node_map[new_output] = pre_activ['name']
tensor['name'] = activ['outputs'][0]
tensor['label'] = activ['outputs'][0]
# Fuse activation
pre_activ['op'].fusedActivationFunction = FUSE_ACTIVATION_MAP[activ['node_type']]
remove_ids.append(activ.index)
# Delete activation nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_same_padding(self):
edges = self.graph.graph.es.select(functools.partial(is_padding_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)
def _remove_last_pred(seq):
op = seq[1]['op']
return False, op
def _remove_last_action(first_node, last_node, custom_data):
op = custom_data
op.padding = Padding.SAME
return []
def _skip_pred(seq):
pad_op = seq[0]['op']
next_op = seq[1]['op']
input_shape = pad_op.inputs[0].shape[1:-1]
if seq[1]['node_type'] == ExtendedOperator.MAX_POOL_2D:
kernel_shape = (next_op.filterHeight, next_op.filterWidth)
strides = (next_op.strideH, next_op.strideW)
dilation = (1, 1)
elif seq[1]['node_type'] in (
ExtendedOperator.CONV_2D,
ExtendedOperator.DEPTHWISE_CONV_2D,
):
kernel_shape = next_op.inputs[1].shape[1:-1]
strides = (next_op.strideH, next_op.strideW)
dilation = (next_op.dilationHFactor, next_op.dilationWFactor)
elif seq[1]['node_type'] == ExtendedOperator.CONV_3D:
kernel_shape = next_op.inputs[1].shape[:3]
strides = (next_op.strideD, next_op.strideH, next_op.strideW)
dilation = (next_op.dilationDFactor, next_op.dilationHFactor, next_op.dilationWFactor)
pad_args = get_same_padding_args(input_shape, kernel_shape, strides, dilation)
pad_arr = np.array(pad_args, dtype='int32')
old_pad_arr = pad_op.inputs[1].tensor
skip = not np.array_equal(pad_arr, old_pad_arr)
return skip
elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
_skip_pred,
force_forward_input=True,
)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_same_padding_slicing(self):
edges = self.graph.graph.es.select(functools.partial(is_slicing_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
actions = []
for prev_node, slice_node, tensor in filtered_pairs:
prev_op = prev_node['op']
slice_op = slice_node['op']
input_shape = slice_op.outputs[0].shape[1:-1]
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE_CONV:
kernel_shape = prev_op.inputs[1].shape[1:-1]
strides = (prev_op.strideH, prev_op.strideW)
dilation = (1, 1)
elif prev_node['node_type'] == ExtendedOperator.CONV_3D_TRANSPOSE:
kernel_shape = prev_op.inputs[1].shape[:3]
strides = (prev_op.strideD, prev_op.strideH, prev_op.strideW)
dilation = (prev_op.dilationDFactor, prev_op.dilationHFactor, prev_op.dilationWFactor)
pad_args = get_same_padding_args(input_shape, kernel_shape, strides, dilation)
pad_arr = np.array(pad_args, dtype='int32')
start_arr = [x for x in slice_op.inputs[1].tensor]
end_arr = [slice_op.inputs[0].shape[i] - x - slice_op.outputs[0].shape[i] for i, x in enumerate(start_arr)]
old_pad_args = [[x, y] for x, y in zip(start_arr, end_arr)]
skip = not np.array_equal(pad_arr, old_pad_args)
if skip:
continue
# Find out the output of the slice nodes
new_output = slice_node['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of the slice_nodeation node, we connect it with the previous node
self.graph.connect_next_tensors(slice_node, prev_node, new_output)
# Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
# slice op instead
prev_node['outputs'][0] = new_output
prev_node['op'].outputs[0] = self.graph.tensor_map[new_output]
self.graph.tensor_node_map[new_output] = prev_node['name']
tensor['name'] = slice_node['outputs'][0]
tensor['label'] = slice_node['outputs'][0]
# Fuse padding
prev_node['op'].padding = Padding.SAME
new_shape = np.array(prev_node['op'].outputs[0].shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape)
actions.append((self.graph.replace_operator_input, (prev_node, 0, new_shape_tensor)))
remove_ids.append(slice_node.index)
for func, args in actions:
func(*args)
# Delete activation nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_requantize(self):
# Find fusable ops
edges = self.graph.graph.es.select(
functools.partial(is_requantize_fusable_edge, graph_converter=self.graph.graph)
)
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
for pre_activ, activ, tensor in filtered_pairs:
if pre_activ.outdegree() > 1:
skip = False
pre_quantize = None
for out_edge in pre_activ.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
while True:
if next_node['node_type'] == ExtendedOperator.QUANTIZE:
if pre_quantize is None:
pre_quantize = next_node['op'].outputs[0].quantization
else:
cur_quantize = next_node['op'].outputs[0].quantization
if (
pre_quantize.scale != cur_quantize.scale
or pre_quantize.zero_point != cur_quantize.zero_point
or pre_quantize.dim != cur_quantize.dim
):
skip = True
break
elif next_node['node_type'] == ExtendedOperator.DEQUANTIZE:
break
elif next_node['node_type'] in (ExtendedOperator.RESHAPE, ExtendedOperator.TRANSPOSE):
if next_node.outdegree() > 1:
skip = True
break
else:
next_node = self.graph.graph.vs[next_node.out_edges()[0].target]
else:
skip = True
break
if skip:
break
if skip:
continue
# Find out the output of the first node in the sequence
output_name = activ['op'].inputs[0].name
output_idx = pre_activ['outputs'].index(output_name)
new_output = pre_activ['outputs'][output_idx]
assert new_output in self.graph.tensor_map
# For each node that is next of the last node, we connect it with the first node
# Also, the replace the tensors when needed
self.graph.replace_next_tensors(activ, pre_activ, new_output)
new_tensor = pre_activ['op'].outputs[0]
old_tensor = activ['op'].outputs[0]
new_tensor.quantization = old_tensor.quantization
else:
# Find out the output of the batch-norm nodes
new_output = activ['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of the activation node, we connect it with the previous node
self.graph.connect_next_tensors(activ, pre_activ, new_output)
# Update graph, prepare to drop the output tensor of the conv node and use the output tensor of the
# batch-norm instead
pre_activ['outputs'][0] = new_output
pre_activ['op'].outputs[0] = self.graph.tensor_map[new_output]
self.graph.tensor_node_map[new_output] = pre_activ['name']
tensor['name'] = activ['outputs'][0]
tensor['label'] = activ['outputs'][0]
remove_ids.append(activ.index)
# Delete activation nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_reciprocal_sqrt(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_reciprocal_sqrt_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
for sqrt, div, tensor in filtered_pairs:
sqrt['node_type'] = ExtendedOperator.RSQRT
sqrt['op'] = tfl.RsqrtOperator(sqrt['op'].inputs, sqrt['op'].outputs)
div_op = div['op']
if (
div_op.inputs[0].buffer is not None
and np.all(div_op.inputs[0].tensor == 1.0)
and div['op'].fusedActivationFunction == ActivationFunctionType.NONE
):
new_output = div['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of the div node, we connect it with the previous node
self.graph.connect_next_tensors(div, sqrt, new_output)
# Update graph, prepare to drop the output tensor of the div node and use the output tensor of the
# sqrt instead
sqrt['outputs'][0] = new_output
sqrt['op'].outputs[0] = self.graph.tensor_map[new_output]
self.graph.tensor_node_map[new_output] = sqrt['name']
tensor['name'] = div['outputs'][0]
tensor['label'] = div['outputs'][0]
# remove div op
remove_ids.append(div.index)
else:
div['node_type'] = ExtendedOperator.MUL
div['op'] = tfl.MulOperator(
div['op'].inputs, div['op'].outputs, fusedActivationFunction=div['op'].fusedActivationFunction
)
# Delete div nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def remove_tile_before_binary_elementwise_ops(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_tile_binary_op_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
actions = []
binary_op_ids = set()
for tile, op_node, tensor in filtered_pairs:
tile_op = tile['op']
binary_op = op_node['op']
input_idx = None
for i in range(2):
try:
_ = tile['outputs'].index(binary_op.inputs[i].name)
input_idx = i
break
except ValueError:
pass
if input_idx is None:
continue
alter_input_idx = 1 - input_idx
try:
out_shape = np.broadcast_shapes(binary_op.inputs[alter_input_idx].shape, tile_op.inputs[0].shape)
if out_shape != binary_op.outputs[0].shape:
continue
except ValueError:
continue
if op_node.index not in binary_op_ids:
binary_op_ids.add(op_node.index)
else:
continue
new_tensor = tile_op.inputs[0]
# Replace input tensors
actions.append((self.graph.replace_operator_input, (op_node, input_idx, new_tensor)))
# remove tile op
remove_ids.append(tile.index)
# Process actions
for func, args in actions:
func(*args)
# Delete tile nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_conv2d_gather(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_conv2d_gather_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
actions = []
for conv, gather, tensor in filtered_pairs:
# Find out the output of the batch-norm nodes
new_output = gather['outputs'][0]
assert new_output in self.graph.tensor_map
# For each node that is next of the activation node, we connect it with the previous node
self.graph.connect_next_tensors(gather, conv, new_output)
# Update graph, prepare to drop the output tensor of the gather node and use the output tensor of the
# conv instead
conv['outputs'][0] = new_output
conv_out_quant_param = conv['op'].outputs[0].quantization
conv['op'].outputs[0] = self.graph.tensor_map[new_output]
conv['op'].outputs[0].quantization = conv_out_quant_param
self.graph.tensor_node_map[new_output] = conv['name']
tensor['name'] = gather['outputs'][0]
tensor['label'] = gather['outputs'][0]
# permute weight of conv-op
indx = gather['op'].inputs[1].tensor.copy()
w = conv['op'].inputs[1].tensor.copy()
w_quant_param = conv['op'].inputs[1].quantization
new_w = np.take(w, indx, axis=0)
# permute bias of conv-op
b = conv['op'].inputs[2].tensor.copy() if len(conv['op'].inputs) > 2 else None
b_quant_param = conv['op'].inputs[2].quantization
new_b = np.take(b, indx, axis=0) if b is not None else None
if w_quant_param is not None and isinstance(w_quant_param.scale, list) and w_quant_param.dim == 0:
new_w_scale = np.take(w_quant_param.scale, indx, axis=0)
new_w_zeros = np.take(w_quant_param.zero_point, indx, axis=0)
w_quant_param.scale = new_w_scale
w_quant_param.zero_point = new_w_zeros
if new_b is not None:
new_b_scale = np.take(b_quant_param.scale, indx, axis=0)
new_b_zeros = np.take(b_quant_param.zero_point, indx, axis=0)
b_quant_param.scale = new_b_scale
b_quant_param.zero_point = new_b_zeros
new_w = self.create_attr_tensor(new_w, quantization=w_quant_param)
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
new_b = self.create_attr_tensor(new_b, quantization=b_quant_param)
actions.append((self.graph.replace_operator_input, (conv, 2, new_b)))
# remove gather op
remove_ids.append(gather.index)
# Process actions
for func, args in actions:
func(*args)
# Delete activation nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_gather_conv2d(self):
# Find fusable ops
edges = self.graph.graph.es.select(functools.partial(is_gather_conv2d_edge, graph_converter=self.graph.graph))
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]) for x in edges)
def _remove_last_pred(seq):
gather = seq[0]
conv = seq[1]
return False, (gather, conv)
def _remove_last_action(first_node, last_node, custom_data):
gather, conv = custom_data
actions = []
indx = np.argsort(gather['op'].inputs[1].tensor)
w = conv['op'].inputs[1].tensor.copy()
w_quant_param = conv['op'].inputs[1].quantization
new_w = np.take(w, indx, axis=3)
if w_quant_param is not None and isinstance(w_quant_param.scale, list) and w_quant_param.dim == 3:
new_w_scale = np.take(w_quant_param.scale, indx, axis=0)
new_w_zeros = np.take(w_quant_param.zero_point, indx, axis=0)
w_quant_param.scale = new_w_scale
w_quant_param.zero_point = new_w_zeros
new_w = self.create_attr_tensor(new_w, quantization=w_quant_param)
actions.append((self.graph.replace_operator_input, (conv, 1, new_w)))
return actions
elinimate_sequences(
self.graph,
filtered_pairs,
True,
None,
_remove_last_pred,
_remove_last_action,
False,
force_forward_input=True,
)
@class_conditional(lambda self: self.tflite_micro_rewrite)
def split_requantize(self):
vertices = self.graph.graph.vs.select(functools.partial(is_requantize_node, graph_converter=self.graph.graph))
remove_ids = []
ops = []
restore_mapping = []
for quantize in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in quantize.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in quantize['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(quantize.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for quantize, mapping in zip(sorted_ops, restore_mapping):
input_tensor = quantize.inputs[0]
output_tensor = quantize.outputs[0]
intermediate = self.create_transform_tensor(input_tensor.tensor.astype('float32'))
ops.append(tfl.DequantizeOperator([input_tensor], [intermediate]))
ops.append(tfl.QuantizeOperator([intermediate], [output_tensor]))
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
def transform_graph(self):
# Find transformable ops
filtered_nodes = self.graph.graph.vs.select(
functools.partial(is_transformable_node, graph_converter=self.graph.graph)
)
remove_ids = []
ops = []
restore_mapping = []
for node in filtered_nodes:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in node.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in node['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
ops.append(node)
remove_ids.append(node.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(ops, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
# Do transformation
for op, mapping in zip(sorted_ops, restore_mapping):
op.transform(self.graph, mapping)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_simple_transpose_pass(self):
edges = self.graph.graph.es.select(
functools.partial(is_transpose_fusable_edge, graph_converter=self.graph.graph)
)
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edges
filtered_pairs = fuse_connected_edges(filtered_pairs)
def _remove_first_pred(seq):
new_perm = fuse_transpose_perms(seq)
hints = set()
for node in seq:
if 'direction' in node['op'].extra_hints:
hints.add(node['op'].extra_hints['direction'])
if len(hints) == 1:
hint = next(iter(hints))
else:
hint = None
remove_first = np.array_equal(new_perm, np.sort(new_perm))
return remove_first, (new_perm, hint)
def _remove_first_action(first_node, last_node, custom_data):
# Set fused perm to the first transpose node
new_perm, hint = custom_data
if hint is None:
if 'direction' in first_node['op'].extra_hints:
del first_node['op'].extra_hints['direction']
else:
first_node['op'].extra_hints['direction'] = hint
new_perm_tensor = self.create_attr_tensor(new_perm)
action = (self.graph.replace_operator_input, (first_node, 1, new_perm_tensor))
return [action]
elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_simple_gather_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_gather_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edges
filtered_pairs = fuse_connected_edges(filtered_pairs)
def _remove_first_pred(seq):
new_perm = fuse_transpose_perms(seq)
hints = set()
for node in seq:
if 'direction' in node['op'].extra_hints:
hints.add(node['op'].extra_hints['direction'])
if len(hints) == 1:
hint = next(iter(hints))
else:
hint = None
remove_first = np.array_equal(new_perm, np.sort(new_perm))
return remove_first, (new_perm, hint)
def _remove_first_action(first_node, last_node, custom_data):
# Set fused perm to the first transpose node
new_perm, hint = custom_data
if hint is None:
if 'direction' in first_node['op'].extra_hints:
del first_node['op'].extra_hints['direction']
else:
first_node['op'].extra_hints['direction'] = hint
new_perm_tensor = self.create_attr_tensor(new_perm)
action = (self.graph.replace_operator_input, (first_node, 1, new_perm_tensor))
return [action]
def _skip_pred(seq):
for node in seq:
op = node['op']
idx_tensor = op.inputs[1]
if idx_tensor.buffer is None:
return True
if len(idx_tensor.shape) > 1:
return True
return False
elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action, skip_pred=_skip_pred)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_dequant_quant_pass(self, q_first):
edges = self.graph.graph.es.select(
functools.partial(is_dequant_quant_fusable_edge, graph_converter=self.graph.graph, q_first=q_first)
)
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
r_edges = self.graph.graph.es.select(
functools.partial(is_dequant_quant_fusable_edge, graph_converter=self.graph.graph, q_first=not q_first)
)
r_filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in r_edges]
filtered_pairs = fuse_connected_edges(filtered_pairs + r_filtered_pairs)
new_pairs = []
for seq in filtered_pairs:
start_idx = 0
end_idx = len(seq)
if q_first:
if seq[0]['node_type'] != ExtendedOperator.QUANTIZE:
start_idx += 1
if seq[-1]['node_type'] != ExtendedOperator.DEQUANTIZE:
end_idx -= 1
else:
if seq[0]['node_type'] != ExtendedOperator.DEQUANTIZE:
start_idx += 1
if seq[-1]['node_type'] != ExtendedOperator.QUANTIZE:
end_idx -= 1
new_seq = seq[start_idx:end_idx]
if len(new_seq) >= 2:
new_pairs.append(new_seq)
filtered_pairs = new_pairs
def _remove_first_pred(seq):
first_node, last_node = seq[0], seq[-1]
new_qparams = last_node['op'].outputs[0].quantization
orig_qparams = first_node['op'].inputs[0].quantization
if (
first_node['node_type'] == ExtendedOperator.DEQUANTIZE
and last_node['node_type'] == ExtendedOperator.QUANTIZE
):
assert new_qparams is not None
assert orig_qparams is not None
remove_first = (
new_qparams.scale == orig_qparams.scale
and new_qparams.zero_point == orig_qparams.zero_point
and new_qparams.dim == orig_qparams.dim
)
else:
assert new_qparams is None
assert orig_qparams is None
remove_first = True
return remove_first, None
def _remove_first_action(first_node, last_node, custom_data):
# Set new node type to first node
first_node['node_type'] = ExtendedOperator.QUANTIZE
old_op = first_node['op']
first_node['op'] = tfl.QuantizeOperator(old_op.inputs, old_op.outputs)
return []
elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_simple_reshape_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_reshape_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edge
filtered_pairs = fuse_connected_edges(filtered_pairs)
def _remove_first_pred(seq):
first_node, last_node = seq[0], seq[-1]
new_shape = last_node['op'].inputs[1].tensor
orig_shape = np.array(first_node['op'].inputs[0].shape, dtype='int32')
hints = set()
for node in seq:
if 'direction' in node['op'].extra_hints:
hints.add(node['op'].extra_hints['direction'])
if len(hints) == 1:
hint = next(iter(hints))
else:
hint = None
remove_first = np.array_equal(new_shape, orig_shape)
return remove_first, (new_shape, hint)
def _remove_first_action(first_node, last_node, custom_data):
# Set final shape to the first reshape node
new_shape, hint = custom_data
if hint is None:
if 'direction' in first_node['op'].extra_hints:
del first_node['op'].extra_hints['direction']
else:
first_node['op'].extra_hints['direction'] = hint
new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
first_node['op'].newShape = new_shape_tensor.tensor
action = (self.graph.replace_operator_input, (first_node, 1, new_shape_tensor))
return [action]
elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_simple_slice_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_slice_fusable_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edge
filtered_pairs = fuse_connected_edges(filtered_pairs)
def _remove_first_pred(seq):
fused_info = fuse_slices(seq)
return False, fused_info
def _remove_first_action(first_node, last_node, custom_data):
# Set final shape to the first reshape node
start, end, stride = custom_data
if all((x == 1 for x in stride)):
target_class = tfl.SliceOperator
target_type = ExtendedOperator.SLICE
else:
target_class = tfl.StridedSliceOperator
target_type = ExtendedOperator.STRIDED_SLICE
if target_type == ExtendedOperator.SLICE:
size = end - start
start_tensor = self.create_attr_tensor(np.array(start, dtype='int32'))
size_tensor = self.create_attr_tensor(np.array(size, dtype='int32'))
actions = [
(self.graph.replace_operator_input, (first_node, 1, start_tensor)),
(self.graph.replace_operator_input, (first_node, 2, size_tensor)),
]
if first_node['node_type'] != ExtendedOperator.SLICE:
old_slice_op = first_node['op']
first_node['node_type'] = ExtendedOperator.SLICE
first_node['op'] = target_class(old_slice_op.inputs, old_slice_op.outputs)
actions.append((self.graph.remove_operator_input, (first_node, 3)))
else:
size = end - start
start_tensor = self.create_attr_tensor(np.array(start, dtype='int32'))
end_tensor = self.create_attr_tensor(np.array(end, dtype='int32'))
stride_tensor = self.create_attr_tensor(np.array(stride, dtype='int32'))
if first_node['node_type'] == ExtendedOperator.STRIDED_SLICE:
actions = [
(self.graph.replace_operator_input, (first_node, 1, start_tensor)),
(self.graph.replace_operator_input, (first_node, 2, end_tensor)),
(self.graph.replace_operator_input, (first_node, 3, stride_tensor)),
]
else:
old_slice_op = first_node['op']
first_node['node_type'] = ExtendedOperator.STRIDED_SLICE
first_node['op'] = target_class(old_slice_op.inputs, old_slice_op.outputs)
actions = [
(self.graph.replace_operator_input, (first_node, 1, start_tensor)),
(self.graph.replace_operator_input, (first_node, 2, end_tensor)),
(self.graph.append_operator_input, (first_node, stride_tensor)),
]
return actions
elinimate_sequences(self.graph, filtered_pairs, _remove_first_pred, _remove_first_action)
@class_conditional(lambda self: self.group_tensors)
def group_tensors_pass(self):
tensor_map = {}
actions = []
bytes_saved = 0
tensors_saved = 0
for v in self.graph.graph.vs:
if v['node_type'] == ExtendedOperator.CONSTANT_NODE:
tensor = self.graph.tensor_map[v['outputs'][0]]
if tensor.quantization is None:
t_idx = (tensor.buffer.data, tensor.dtype, tensor.shape)
else:
scale = tensor.quantization.scale
zero_point = tensor.quantization.zero_point
if isinstance(scale, list):
scale = tuple(scale)
if isinstance(zero_point, list):
zero_point = tuple(zero_point)
t_idx = (
tensor.buffer.data,
tensor.dtype,
tensor.shape,
scale,
zero_point,
tensor.quantization.dim,
)
if t_idx in tensor_map:
new_tensor = tensor_map[t_idx]
for e in v.out_edges():
target = e.target_vertex
if target['op'] is not None:
for i, inp in enumerate(target['op'].inputs):
if inp.name == tensor.name:
log.debug(f'{inp.name} used in {target["outputs"][0]}:{i} -> {new_tensor.name}')
tensors_saved += 1
bytes_saved += len(inp.buffer.data)
actions.append((self.graph.replace_operator_input, (target, i, new_tensor)))
else:
tensor_map[t_idx] = tensor
# Process actions
for func, args in actions:
func(*args)
log.info(f'{tensors_saved} duplicated tensors found, {bytes_saved / 1024 / 1024:.2f} MB saved')
def cleanup_dead_nodes(self):
cleanup_nodes = []
if not self.graph.graph.is_connected('weak'):
while True:
for vertex in self.graph.graph.vs:
if (
vertex['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE)
and vertex.outdegree() == 0
):
if vertex['node_type'] == ExtendedOperator.INPUT_NODE:
continue
if vertex['node_type'] != ExtendedOperator.CONSTANT_NODE:
if vertex['op'] is None or vertex['op'].extra_hints.get('warn_on_unused', True):
warnings.warn('Non constant node removed, something must be wrong there')
log.warning('-' * 30)
log.warning('Info of the deleted node:')
log.warning(f'vertex: {vertex}')
# edge = self.graph.graph.es.select(name=vertex['outputs'][0])
# assert edge is None, (
# f'The edge {vertex["outputs"][0]} exists but the connection to the vertex'
# f' {vertex["name"]} is broken, probably there have some conflicts in the names'
# ' of the nodes'
# )
cleanup_nodes.append(vertex.index)
if len(cleanup_nodes) == 0:
break
self.graph.graph.delete_vertices(cleanup_nodes)
cleanup_nodes.clear()
@class_conditional(lambda self: self.level >= GraphOptimizer.FOLD_BUFFER)
def fold_transpose_buffer(self):
edges = self.graph.graph.es.select(
functools.partial(is_constant_transpose_fusable_edge, graph_converter=self.graph.graph)
)
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
for constant, transpose, tensor in filtered_pairs:
# Calculate the output of the transposed constant nodes
constant_tensor = transpose['op'].inputs[0].tensor
perm_tensor = transpose['op'].inputs[1].tensor
new_constant = np.transpose(constant_tensor, perm_tensor)
new_tensor = self.create_attr_tensor(new_constant, quantization=transpose['op'].outputs[0].quantization)
new_node = self.graph.add_nodes([new_tensor])[0]
# For each node that is next of a constant transpose node, we connect it with the new constant node
for out_edge in transpose.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
self.graph.graph.add_edge(new_node, next_node, name=new_tensor.name, label=new_tensor.name)
log.debug(
f'NEW EDGE: {new_node["label"]} -> {next_node["label"]} {self.graph.tensor_map[out_edge["name"]]}'
)
op = next_node['op']
for idx in range(len(op.inputs)):
if op.inputs[idx].name == transpose['op'].outputs[0].name:
op.inputs[idx] = new_tensor
remove_ids.append(transpose.index)
# Delete constant transpose nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def transpose_to_reshape_pass(self):
filtered_nodes = self.graph.graph.vs.select(
functools.partial(is_transformable_transpose_node, graph_converter=self.graph.graph)
)
# Collect actions for the transformable transpose nodes
actions = []
for node in filtered_nodes:
original_op = node['op']
output_shape = np.array(original_op.outputs[0].shape, dtype='int32')
shape_tensor = self.create_attr_tensor(output_shape)
new_op = tfl.ReshapeOperator(original_op.inputs, original_op.outputs, output_shape)
node['op'] = new_op
node['node_type'] = ExtendedOperator.RESHAPE
node['label'] = new_op.type_name()
actions.append((self.graph.replace_operator_input, (node, 1, shape_tensor)))
# Process actions
for func, args in actions:
node = args[0]
func(*args)
@class_conditional(lambda self: self.level >= GraphOptimizer.FOLD_BUFFER)
def fold_reshape_buffer(self):
edges = self.graph.graph.es.select(
functools.partial(is_constant_reshape_fusable_edge, graph_converter=self.graph.graph)
)
filtered_pairs = ((self.graph.graph.vs[x.source], self.graph.graph.vs[x.target], x) for x in edges)
remove_ids = []
for constant, reshape, tensor in filtered_pairs:
# Calculate the output of the transposed constant nodes
constant_tensor = reshape['op'].inputs[0].tensor
shape_tensor = reshape['op'].inputs[1].tensor
new_constant = np.reshape(constant_tensor, shape_tensor)
new_tensor = self.create_attr_tensor(new_constant, quantization=reshape['op'].inputs[0].quantization)
new_node = self.graph.add_nodes([new_tensor])[0]
# For each node that is next of a constant transpose node, we connect it with the new constant node
for out_edge in reshape.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
self.graph.graph.add_edge(new_node, next_node, name=new_tensor.name, label=new_tensor.name)
log.debug(
f'NEW EDGE: {new_node["label"]} -> {next_node["label"]} {self.graph.tensor_map[out_edge["name"]]}'
)
op = next_node['op']
for idx in range(len(op.inputs)):
if op.inputs[idx].name == reshape['op'].outputs[0].name:
op.inputs[idx] = new_tensor
remove_ids.append(reshape.index)
# Delete constant transpose nodes
self.graph.graph.delete_vertices(remove_ids)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def remove_noop_pass(self, branch: bool = False):
edges = self.graph.graph.es.select(
functools.partial(is_ending_with_noop_edge, graph_converter=self.graph.graph, branch=branch)
)
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edges
if not branch:
filtered_pairs = fuse_connected_edges(filtered_pairs)
elinimate_sequences(self.graph, filtered_pairs)
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_wrapped_reshape_within_transpose_pass(self):
edges = self.graph.graph.es.select(
functools.partial(is_wrapped_reshape_within_transpose_edge, graph_converter=self.graph.graph)
)
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
# Try to fuse the edges
fused_pairs = fuse_connected_edges(filtered_pairs)
# Only TRANSPOSE->RESHAPE->TRANSPOSE is supported here
filtered_pairs = []
for seq in fused_pairs:
seq_len = len(seq)
transpose_first = seq[0]['node_type'] == ExtendedOperator.TRANSPOSE
if seq_len >= 3 and transpose_first:
filtered_pairs.append(seq[:3])
elif seq_len >= 4:
filtered_pairs.append(seq[1:4])
def _skip_pred(seq):
mid_node = seq[1]
orig_shape = mid_node['op'].inputs[0].shape
new_shape = mid_node['op'].outputs[0].shape
if not is_simple_reshape(orig_shape, new_shape):
return True
new_perm = fuse_transpose_perms_extended(seq)
return (new_perm != np.sort(new_perm)).any()
def _remove_last_pred(seq):
orig_tensor = seq[0]['op'].inputs[0].tensor
return False, (seq[2], orig_tensor)
def _remove_last_action(first_node, last_node, custom_data):
# Set final shape to the first reshape node
last_trans, orig_tensor = custom_data
actions = []
original_op = last_trans['op']
output_shape = np.array(original_op.outputs[0].shape, dtype='int32')
shape_tensor = self.create_attr_tensor(output_shape)
new_op = tfl.ReshapeOperator(original_op.inputs, original_op.outputs, output_shape)
last_trans['op'] = new_op
last_trans['node_type'] = ExtendedOperator.RESHAPE
last_trans['label'] = new_op.type_name()
new_op.inputs[0].tensor = orig_tensor
new_op.inputs[0].shape = new_op.inputs[0].tensor.shape
actions.append((self.graph.replace_operator_input, (last_trans, 1, shape_tensor)))
return actions
elinimate_sequences(self.graph, filtered_pairs, True, None, _remove_last_pred, _remove_last_action, _skip_pred)
@class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE)
def branch_reshape_expand_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_reshape_branch_edge, graph_converter=self.graph.graph))
branch_reshape_nodes = list(set(self.graph.graph.vs[edge.source] for edge in edges))
def _new_reshape(node: ig.Vertex, prev_node: ig.Vertex, next_node: ig.Vertex):
actions = []
op = node['op']
op_out = op.outputs[0]
op_shape = op.inputs[1]
prev_idx = prev_node['outputs'].index(op.inputs[0].name)
if prev_node['node_type'] == ExtendedOperator.INPUT_NODE:
prev_out = self.graph.tensor_map[op.inputs[0].name]
else:
prev_op = prev_node['op']
prev_out = prev_op.outputs[prev_idx]
new_tensor = self.create_transform_tensor(op_out.tensor.copy(), quantization=op_out.quantization)
new_shape = self.create_attr_tensor(op_shape.tensor.copy())
new_op = tfl.ReshapeOperator([prev_out, new_shape], [new_tensor], new_shape.tensor)
new_op.extra_hints.update(op.extra_hints)
self.graph.add_operator(new_op)
next_indices = []
for i, t in enumerate(next_node['op'].inputs):
if t.name == op_out.name:
actions.append((self.graph.replace_operator_input, (next_node, i, new_tensor)))
next_indices.append(i)
assert len(next_indices) > 0, f'{op_out.name} not in {[t.name for t in next_node["op"].inputs]}'
return actions
expand_op_outputs_in_branches(branch_reshape_nodes, _new_reshape, self.graph)
@class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE)
def branch_transpose_expand_pass(self):
edges = self.graph.graph.es.select(
functools.partial(is_transpose_branch_edge, graph_converter=self.graph.graph)
)
branch_transpose_nodes = list(set(self.graph.graph.vs[edge.source] for edge in edges))
def _new_transpose(node: ig.Vertex, prev_node: ig.Vertex, next_node: ig.Vertex):
actions = []
op = node['op']
op_out = op.outputs[0]
op_perm = op.inputs[1]
prev_idx = prev_node['outputs'].index(op.inputs[0].name)
if prev_node['node_type'] in (ExtendedOperator.INPUT_NODE, ExtendedOperator.CONSTANT_NODE):
prev_out = self.graph.tensor_map[op.inputs[0].name]
else:
prev_op = prev_node['op']
prev_out = prev_op.outputs[prev_idx]
new_tensor = self.create_transform_tensor(op_out.tensor.copy(), quantization=op_out.quantization)
new_perm = self.create_attr_tensor(op_perm.tensor.copy())
new_op = tfl.TransposeOperator([prev_out, new_perm], [new_tensor])
new_op.extra_hints.update(op.extra_hints)
self.graph.add_operator(new_op)
next_indices = []
for i, t in enumerate(next_node['op'].inputs):
if t.name == op_out.name:
actions.append((self.graph.replace_operator_input, (next_node, i, new_tensor)))
next_indices.append(i)
assert len(next_indices) > 0, f'{op_out.name} not in {[t.name for t in next_node["op"].inputs]}'
return actions
expand_op_outputs_in_branches(branch_transpose_nodes, _new_transpose, self.graph)
@class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
def elementwise_reshape_transpose_passthrough_pass(self) -> int:
edges = self.graph.graph.es.select(
functools.partial(is_transpose_reshape_op_edge, graph_converter=self.graph.graph)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
processed_nodes = set()
num_actions = 0
for node in unique_nodes:
pending_processed_nodes = set()
op = node['op']
input_indices = op_input_indices(op)
l_shape = op.inputs[0].shape
r_shape = op.outputs[0].shape
if len(l_shape) == 0 or len(r_shape) == 0:
continue
l_map, r_map, _, _ = reshape_mapping(l_shape, r_shape)
mode = None
need_chain = False
for l_val, r_val in zip(l_map, r_map):
if len(l_val) > 1 and len(r_val) == 1:
if mode in (None, 'up'):
mode = 'up'
else:
mode = '?'
break
elif len(r_val) > 1 and len(l_val) == 1:
if mode in (None, 'down'):
mode = 'down'
else:
mode = '?'
break
elif len(r_val) > 1 and len(l_val) > 1:
if len(r_val) != len(l_val) or r_val != l_val:
# TODO: Support this case
mode = '?'
break
else:
need_chain = True
if mode is None:
mode = 'down'
# TODO: Support multi-multi mappings
if mode == '?':
# reset hints if passthrough is not possible
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
if 'direction' in prev_node['op'].extra_hints:
prev_node['op'].extra_hints.pop('direction')
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if 'direction' in next_node['op'].extra_hints:
next_node['op'].extra_hints.pop('direction')
continue
check_consecutive_indices = []
if need_chain:
new_l_map = []
new_r_map = []
for l_val, r_val in zip(l_map, r_map):
if len(l_val) > 1 and len(r_val) > 1:
if mode == 'down':
check_consecutive_indices.append(l_val)
else:
check_consecutive_indices.append(r_val)
for l_item in l_val:
new_l_map.append([l_item])
for r_item in r_val:
new_r_map.append([r_item])
else:
new_l_map.append(l_val)
new_r_map.append(r_val)
l_map = new_l_map
r_map = new_r_map
prev_nodes = []
cand_perms = dict()
cand_rev_perms = dict()
prev_output_indices = []
num_constant_nodes = 0
prev_hints = set()
skip = False
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
if prev_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(prev_node['name'])
if mode == 'down':
perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
elif mode == 'up':
perm = tuple(np.argsort(prev_node['op'].inputs[1].tensor).tolist())
cand_rev_perms.setdefault(perm, 0)
cand_rev_perms[perm] += 1
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints):
continue
next_nodes = []
next_edges = []
out_nodes = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
else:
if next_node['name'] in processed_nodes:
skip = True
break
pending_processed_nodes.add(next_node['name'])
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
if mode == 'down':
perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
cand_rev_perms.setdefault(perm, 0)
cand_rev_perms[perm] += 1
elif mode == 'up':
perm = tuple(next_node['op'].inputs[1].tensor.tolist())
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if skip or (self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints):
continue
cur_transpose_size = sum(cand_perms.values()) + sum(cand_rev_perms.values())
new_transpose_size = len(prev_nodes) + len(next_nodes) - sum(cand_perms.values()) - num_constant_nodes
# Skip if the number of transpose nodes is not decreasing
if len(cand_perms) == 0 or len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
continue
elif new_transpose_size == cur_transpose_size:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
perm = max(cand_perms.items(), key=lambda x: x[1])[0]
perm_arr = np.array(perm, dtype='int32')
skip = False
for check_idx in check_consecutive_indices:
if mode == 'down':
target_idx = perm_arr[check_idx]
elif mode == 'up':
perm_sorter = perm_arr.argsort()
target_idx = perm_sorter[np.searchsorted(perm_arr, check_idx, sorter=perm_sorter)]
normalized_src = [x - check_idx[0] for x in check_idx]
normalized_tgt = [x - target_idx[0] for x in target_idx]
if normalized_src != normalized_tgt:
skip = True
break
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for pending_processed_node in pending_processed_nodes:
processed_nodes.add(pending_processed_node)
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
if mode == 'down':
inv_perm_arr = np.argsort(perm_arr).astype('int32')
l_dict = dict(zip([x[0] for x in l_map], r_map))
indices = map(lambda x: l_dict[x], inv_perm_arr.tolist())
inv_post_perm = list(itertools.chain.from_iterable(indices))
inv_post_perm_arr = np.array(inv_post_perm, dtype='int32')
post_perm_arr = np.argsort(inv_post_perm_arr).astype('int32')
elif mode == 'up':
r_dict = dict(zip([x[0] for x in r_map], l_map))
indices = map(lambda x: r_dict[x], perm)
inv_perm = list(itertools.chain.from_iterable(indices))
inv_perm_arr = np.array(inv_perm, dtype='int32')
post_perm_arr = np.argsort(perm_arr).astype('int32')
inv_post_perm_arr = np.argsort(post_perm_arr).astype('int32')
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
perm_tensor = self.create_attr_tensor(inv_perm_arr)
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
)
transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
perm_tensor = self.create_attr_tensor(post_perm_arr)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, inv_post_perm_arr), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(transpose_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
# OP specific dim handling logic
old_shape = op.inputs[1].tensor
new_shape = self.create_attr_tensor(old_shape[inv_post_perm_arr])
actions.append((self.graph.replace_operator_input, (node, 1, new_shape, True)))
op.newShape = new_shape.tensor
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions
@class_conditional(lambda self: self.rewrite_quantizable)
def elementwise_op_quantize_passthrough_pass(self):
edges = self.graph.graph.es.select(
functools.partial(
is_quantize_elementwise_op_edge, graph_converter=self.graph.graph, with_lstm=self.hybrid_int16_lstm
)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.DEQUANTIZE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
for node in unique_nodes:
op = node['op']
input_indices = op_input_indices(op)
prev_nodes = []
q_tensors = dict()
prev_output_indices = []
skip_names = []
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.DEQUANTIZE:
q_tensors[prev_node_name] = prev_node['op'].inputs[0]
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
if (
node['node_type'] in (ExtendedOperator.MINIMUM, ExtendedOperator.MAXIMUM)
and i != 0
and prev_node_name not in self.graph.q_mapping
):
f_tensor = self.graph.tensor_map[prev_node_name]
r_tensor = q_tensors[op.inputs[0].name]
q_arr = np.rint(
f_tensor.tensor / r_tensor.quantization.scale + r_tensor.quantization.zero_point
)
i_type = np.iinfo(r_tensor.tensor.dtype)
if np.any(q_arr > i_type.max):
warnings.warn('Overflow while quantizing the tensor')
q_arr = np.minimum(q_arr, i_type.max)
if np.any(q_arr < i_type.min):
warnings.warn('Underflow while quantizing the tensor')
q_arr = np.maximum(q_arr, i_type.min)
q_arr = q_arr.astype(r_tensor.dtype)
q_tensor = self.create_attr_tensor(q_arr, quantization=r_tensor.quantization)
self.graph.q_mapping[prev_node_name] = q_tensor
if prev_node_name in self.graph.q_mapping:
skip_names.append(prev_node_name)
next_nodes = []
next_edges = []
out_nodes = []
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.QUANTIZE:
skip = False
name = next_node['op'].inputs[0].name
q_tensor = next_node['op'].outputs[0]
assert q_tensor.quantization is not None
if node['node_type'] in (
ExtendedOperator.BATCH_MATMUL,
ExtendedOperator.ABS,
ExtendedOperator.RSQRT,
):
if q_tensor.dtype not in (np.dtype('int8'), np.dtype('int16')):
skip = True
elif node['node_type'] == ExtendedOperator.DIV:
if q_tensor.dtype != np.dtype('uint8'):
skip = True
elif node['node_type'] == ExtendedOperator.SOFTMAX:
if q_tensor.dtype == np.dtype('int8'):
if (
abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
or q_tensor.quantization.zero_point != -128
):
skip = True
elif q_tensor.dtype == np.dtype('int16'):
if (
abs(q_tensor.quantization.scale - 1.0 / 32768) > 0.001 * 1.0 / 32768
or q_tensor.quantization.zero_point != 0
):
skip = True
elif q_tensor.dtype == np.dtype('uint8'):
if (
abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
or q_tensor.quantization.zero_point != 0
):
log.warning(
'On some chips, only softmax with scale=1.0/256 and zero_point=0 is supported'
)
else:
skip = True
elif node['node_type'] == ExtendedOperator.LOG_SOFTMAX:
if q_tensor.dtype == np.dtype('int8'):
if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 127:
skip = True
elif q_tensor.dtype == np.dtype('uint8'):
if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 255:
skip = True
else:
skip = True
if not skip:
q_tensors[name] = q_tensor
cur_transpose_size = len(q_tensors)
new_transpose_size = len(prev_nodes) + len(next_nodes) - len(skip_names)
# Skip if the number of [de]quantize nodes is not decreasing
if len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
continue
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
if prev_out.name in skip_names:
prev_new_out = self.graph.q_mapping[prev_out.name]
self.graph.add_nodes([prev_new_out])
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
else:
prev_new_out = self.create_transform_tensor(
q_tensors[prev_out.name].tensor, quantization=q_tensors[prev_out.name].quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
self.graph.add_operator(tfl.QuantizeOperator([prev_out], [prev_new_out]))
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
new_out = self.create_transform_tensor(
q_tensors[op_out.name].tensor, quantization=q_tensors[op_out.name].quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
self.graph.add_operator(tfl.DequantizeOperator([new_out], [op_out]))
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
@class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
def elementwise_op_transpose_passthrough_pass(self, quantizable_ops_only: bool = False) -> int:
edges = self.graph.graph.es.select(
functools.partial(
is_transpose_elementwise_op_edge,
graph_converter=self.graph.graph,
quantizable_ops_only=quantizable_ops_only,
)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
if quantizable_ops_only:
all_edges = self.graph.graph.es.select(
functools.partial(
is_transpose_elementwise_op_edge,
graph_converter=self.graph.graph,
quantizable_ops_only=False,
)
)
all_pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in all_edges)
forward_d = dict(all_pairs)
backward_d = {v: k for k, v in forward_d.items()}
filtered_nodes = []
for s, e in pairs:
if s['node_type'] == ExtendedOperator.TRANSPOSE:
pn = backward_d.get(s, None)
if pn is not None:
filtered_nodes.append(pn)
else:
log.warning('Cannot passthrough transpose upward around requantizable ops')
else:
pn = forward_d.get(e, None)
if pn is not None:
filtered_nodes.append(pn)
else:
log.warning('Cannot passthrough transpose downward around requantizable ops')
else:
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.TRANSPOSE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
num_actions = 0
for node in unique_nodes:
op = node['op']
input_indices = op_input_indices(op)
prev_nodes = []
cand_perms = dict()
prev_output_indices = []
num_constant_nodes = 0
num_reshape_transpose = 0
prev_hints = set()
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
perm = tuple(prev_node['op'].inputs[1].tensor.tolist())
if node['node_type'] == ExtendedOperator.PACK:
perm = [i if i < op.axis else i + 1 for i in perm]
perm.insert(op.axis, op.axis)
perm = tuple(perm)
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if prev_node['node_type'] == ExtendedOperator.RESHAPE:
prev_prev_node_name = self.graph.tensor_node_map[prev_node['op'].inputs[0].name]
prev_prev_node = self.graph.graph.vs.find(name=prev_prev_node_name)
if prev_prev_node['node_type'] == ExtendedOperator.TRANSPOSE:
num_reshape_transpose += 1
if 'direction' in prev_prev_node['op'].extra_hints:
prev_hints.add(prev_prev_node['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
continue
next_nodes = []
next_edges = []
out_nodes = []
skip_names = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
skip_names.append(edge['label'])
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.TRANSPOSE:
perm = tuple(np.argsort(next_node['op'].inputs[1].tensor).tolist())
if node['node_type'] == ExtendedOperator.UNPACK:
perm = [i if i < op.axis else i + 1 for i in perm]
perm.insert(op.axis, op.axis)
perm = tuple(perm)
cand_perms.setdefault(perm, 0)
cand_perms[perm] += 1
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if next_node['node_type'] == ExtendedOperator.RESHAPE:
o_nodes = [e.target_vertex for e in next_node.out_edges()]
if len(o_nodes) == 1 and o_nodes[0]['node_type'] == ExtendedOperator.TRANSPOSE:
num_reshape_transpose += 1
if 'direction' in o_nodes[0]['op'].extra_hints:
next_hints.add(o_nodes[0]['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
continue
cur_transpose_size = sum(cand_perms.values()) + num_reshape_transpose
new_transpose_size = (
len(prev_nodes) + len(next_nodes) - num_constant_nodes - cur_transpose_size + num_reshape_transpose
)
# Skip if the following conditions are met
# a. the number of transpose nodes is not decreasing (skip if `bypass_elementwise_passthrough_constraint`)
# b. no hint can be found (skip if optimize level is below BRANCH_OPTIMIZE_EXTENDED)
is_increasing = new_transpose_size > cur_transpose_size
is_not_decreasing = new_transpose_size >= cur_transpose_size
is_same = new_transpose_size == cur_transpose_size
if len(next_nodes) == 0:
continue
else:
if self.bypass_elementwise_passthrough_constraint:
condition = is_not_decreasing
else:
if is_increasing:
continue
condition = is_same
if condition:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
perm = max(cand_perms.items(), key=lambda x: x[1])[0]
perm_arr = np.array(perm, dtype='int32')
inv_perm_arr = np.argsort(perm_arr).astype('int32')
if node['node_type'] == ExtendedOperator.UNPACK:
inv_perm_arr_post = inv_perm_arr[inv_perm_arr != op.axis]
inv_perm_arr_post[inv_perm_arr_post > op.axis] -= 1
perm_arr_post = np.argsort(inv_perm_arr_post).astype('int32')
elif node['node_type'] == ExtendedOperator.PACK:
perm_arr_post = perm_arr
inv_perm_arr_post = inv_perm_arr
perm_arr = perm_arr_post[perm_arr_post != op.axis]
perm_arr[perm_arr > op.axis] -= 1
inv_perm_arr = np.argsort(perm_arr).astype('int32')
else:
perm_arr_post = perm_arr
inv_perm_arr_post = inv_perm_arr
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
perm_tensor = self.create_attr_tensor(inv_perm_arr)
if len(prev_out.shape) != perm_tensor.tensor.size:
new_shape = [1] * (perm_tensor.tensor.size - len(prev_out.shape)) + list(prev_out.shape)
prev_out_reshaped = self.create_transform_tensor(
np.reshape(prev_out.tensor, new_shape), quantization=prev_out.quantization
)
new_shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
self.graph.add_operator(
tfl.ReshapeOperator([prev_out, new_shape_tensor], [prev_out_reshaped], new_shape)
)
prev_out = prev_out_reshaped
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, inv_perm_arr), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
# For unused tensors, we perform inplace shape updates
if op_out.name in skip_names:
orig_shape = np.array(op_out.shape, dtype='int32')
new_shape = orig_shape[inv_perm_arr]
op_out.shape = tuple(new_shape.tolist())
continue
perm_tensor = self.create_attr_tensor(perm_arr_post)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, inv_perm_arr_post), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(transpose_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
# OP specific dim handling logic
if node['node_type'] in (ExtendedOperator.CONCATENATION, ExtendedOperator.GATHER, ExtendedOperator.UNPACK):
old_axis = op.axis
new_axis = np.where(inv_perm_arr == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.PACK:
old_axis = op.axis
new_axis = np.where(inv_perm_arr_post == old_axis)[0][0]
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
old_dim = op.inputs[2].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
old_dim = op.inputs[0].tensor
new_dim = np.where(inv_perm_arr == old_dim)[0][0]
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (
ExtendedOperator.PAD,
ExtendedOperator.PADV2,
ExtendedOperator.MIRROR_PAD,
ExtendedOperator.TILE,
):
old_pad = op.inputs[1].tensor
new_pad = self.create_attr_tensor(old_pad[inv_perm_arr])
actions.append((self.graph.replace_operator_input, (node, 1, new_pad, True)))
elif node['node_type'] == ExtendedOperator.PRELU:
old_weight = op.inputs[1].tensor
if old_weight.ndim != 1:
assert old_weight.ndim + 1 == len(inv_perm_arr)
new_perm = np.argsort(np.argsort(inv_perm_arr[1:]))
new_perm_t = self.create_attr_tensor(np.array(new_perm, dtype='int32'))
new_weight = self.create_transform_tensor(np.transpose(old_weight, new_perm))
self.graph.add_operator(tfl.TransposeOperator([op.inputs[1], new_perm_t], [new_weight]))
actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
elif node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE):
for i, t in enumerate(op.inputs[1:]):
if t.buffer is None:
new_perm_t = self.create_attr_tensor(np.array(inv_perm_arr, dtype='int32'))
new_t = self.create_transform_tensor(t.tensor[inv_perm_arr])
self.graph.add_operator(tfl.TransposeOperator([t, new_perm_t], [new_t]))
else:
new_t = self.create_attr_tensor(t.tensor[inv_perm_arr])
actions.append((self.graph.replace_operator_input, (node, i + 1, new_t, True)))
elif node['node_type'] in (
ExtendedOperator.SUM,
ExtendedOperator.ARG_MIN,
ExtendedOperator.ARG_MAX,
ExtendedOperator.REDUCE_MIN,
ExtendedOperator.REDUCE_MAX,
ExtendedOperator.REDUCE_PROD,
ExtendedOperator.MEAN,
):
old_axis = op.inputs[1].tensor.tolist()
new_axis = []
for t in old_axis:
new_t = np.where(inv_perm_arr == t)[0][0]
new_axis.append(new_t)
axis_arr = np.array(new_axis, dtype='int32')
axis_tensor = self.create_attr_tensor(axis_arr)
actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions
@class_conditional(lambda self: self.level >= GraphOptimizer.BRANCH_OPTIMIZE, 0)
def elementwise_op_reshape_passthrough_pass(self) -> int:
edges = self.graph.graph.es.select(
functools.partial(is_reshape_elementwise_op_edge, graph_converter=self.graph.graph)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.RESHAPE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
num_actions = 0
for node in unique_nodes:
op = node['op']
dim_indice = op_input_dims(op)
input_indices = op_input_indices(op)
prev_nodes = []
cand_shapes = dict()
cand_next_shapes = dict()
prev_output_indices = []
num_constant_nodes = 0
prev_hints = set()
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
num_constant_nodes += 1
if prev_node['node_type'] == ExtendedOperator.RESHAPE:
mapping = dict()
if not is_simple_reshape(
prev_node['op'].inputs[0].shape, prev_node['op'].outputs[0].shape, mapping
):
continue
new_dim = None
if dim_indice is not None:
rev_mapping = {v: k for k, v in mapping.items()}
if node['node_type'] == ExtendedOperator.PACK:
if dim_indice in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice]
else:
if dim_indice - 1 in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice - 1] + 1
elif dim_indice + 1 in rev_mapping:
tmp_new_dim = rev_mapping[dim_indice + 1] - 1
else:
# TODO: Figure out the rev index
tmp_new_dim = -1
tmp_dim_indice = dim_indice
new_dim = -1
dim_indice = -1
else:
if dim_indice not in rev_mapping:
continue
new_dim = rev_mapping[dim_indice]
shape = tuple(prev_node['op'].inputs[0].shape)
shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
if node['node_type'] == ExtendedOperator.PACK and tmp_new_dim >= 0:
shape = list(shape)
shape.insert(tmp_new_dim, -1)
shape = tuple(shape)
cand_shapes.setdefault(shape, 0)
cand_shapes[shape] += 1
next_shape = tuple(prev_node['op'].outputs[0].shape)
next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
if node['node_type'] == ExtendedOperator.PACK:
next_shape = list(next_shape)
next_shape.insert(tmp_dim_indice, -1)
next_shape = tuple(next_shape)
cand_next_shapes.setdefault(next_shape, 0)
cand_next_shapes[next_shape] += 1
if node['node_type'] == ExtendedOperator.PACK:
dim_indice = tmp_dim_indice
if 'direction' in prev_node['op'].extra_hints:
prev_hints.add(prev_node['op'].extra_hints['direction'])
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'up' in prev_hints:
continue
next_nodes = []
next_edges = []
out_nodes = []
skip_names = []
next_hints = set()
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] == ExtendedOperator.UNUSED_NODE:
skip_names.append(edge['label'])
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.RESHAPE:
mapping = dict()
if not is_simple_reshape(
next_node['op'].inputs[0].shape, next_node['op'].outputs[0].shape, mapping
):
continue
new_dim = None
if dim_indice is not None:
if node['node_type'] == ExtendedOperator.UNPACK:
if dim_indice in mapping:
tmp_new_dim = mapping[dim_indice]
else:
if dim_indice - 1 in mapping:
tmp_new_dim = mapping[dim_indice - 1] + 1
elif dim_indice + 1 in mapping:
tmp_new_dim = mapping[dim_indice + 1] - 1
else:
# TODO: Figure out the rev index
tmp_new_dim = -1
tmp_dim_indice = dim_indice
new_dim = -1
dim_indice = -1
else:
if dim_indice not in mapping:
continue
new_dim = mapping[dim_indice]
shape = tuple(next_node['op'].outputs[0].shape)
shape = tuple(x if i != new_dim else -1 for i, x in enumerate(shape))
if node['node_type'] == ExtendedOperator.UNPACK and tmp_new_dim >= 0:
shape = list(shape)
shape.insert(tmp_new_dim, -1)
shape = tuple(shape)
cand_shapes.setdefault(shape, 0)
cand_shapes[shape] += 1
next_shape = tuple(next_node['op'].inputs[0].shape)
next_shape = tuple(x if i != dim_indice else -1 for i, x in enumerate(next_shape))
if node['node_type'] == ExtendedOperator.UNPACK:
next_shape = list(next_shape)
next_shape.insert(tmp_dim_indice, -1)
next_shape = tuple(next_shape)
cand_next_shapes.setdefault(next_shape, 0)
cand_next_shapes[next_shape] += 1
if node['node_type'] == ExtendedOperator.UNPACK:
dim_indice = tmp_dim_indice
if 'direction' in next_node['op'].extra_hints:
next_hints.add(next_node['op'].extra_hints['direction'])
if len(cand_shapes) == 0:
continue
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED and 'down' in next_hints:
continue
cur_reshape_size = max(cand_shapes.values())
cur_next_reshape_size = max(cand_next_shapes.values())
full_size = len(prev_nodes) + len(next_nodes)
if cur_reshape_size != cur_next_reshape_size:
continue
new_reshape_size = full_size - cur_reshape_size - num_constant_nodes
# Skip if not wrapped by reshapes
if (
len(next_nodes) == 0 or new_reshape_size > cur_reshape_size
): # cur_reshape_size < full_size or cur_next_reshape_size < full_size:
continue
elif new_reshape_size == cur_reshape_size:
skip = True
if self.level >= GraphOptimizer.BRANCH_OPTIMIZE_EXTENDED:
if 'down' in prev_hints or 'up' in next_hints:
skip = False
if skip:
continue
num_actions += 1
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
prev_shape = max(cand_shapes.items(), key=lambda x: x[1])[0]
next_shape = max(cand_next_shapes.items(), key=lambda x: x[1])[0]
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
if node['node_type'] == ExtendedOperator.PACK:
tmp_prev_shape = prev_shape
prev_shape = [i for i in prev_shape if i != -1]
prev_shape_aligned = prev_shape
if np.prod(prev_out.shape) != np.prod(prev_shape):
new_prev_shape = prev_out.shape
if len(prev_out.shape) < len(next_shape):
new_prev_shape = [1] * (len(next_shape) - len(prev_out.shape)) + list(prev_out.shape)
mapping = {}
is_simple_reshape(prev_shape, next_shape, mapping)
prev_shape_aligned = np.ones(len(prev_shape), dtype='int32')
for pi, ni in mapping.items():
prev_shape_aligned[pi] = new_prev_shape[ni]
prev_new_out = self.create_transform_tensor(
np.reshape(prev_out.tensor, prev_shape_aligned), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
shape_tensor = self.create_attr_tensor(np.array(prev_new_out.shape, dtype='int32'))
reshape_op = tfl.ReshapeOperator(
[prev_out, shape_tensor], [prev_new_out], newShape=shape_tensor.tensor
)
reshape_op.extra_hints['direction'] = 'up'
self.graph.add_operator(reshape_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
if node['node_type'] == ExtendedOperator.PACK:
prev_shape = tmp_prev_shape
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
if node['node_type'] == ExtendedOperator.UNPACK:
tmp_prev_shape = prev_shape
prev_shape = [i for i in prev_shape if i != -1]
# For unused tensors, we perform inplace shape updates
if op_out.name in skip_names:
new_shape = np.reshape(op_out.tensor, prev_shape).shape
op_out.shape = tuple(new_shape)
if node['node_type'] == ExtendedOperator.UNPACK:
prev_shape = tmp_prev_shape
continue
new_out = self.create_transform_tensor(
np.reshape(op_out.tensor, prev_shape), quantization=op_out.quantization
)
shape_tensor = self.create_attr_tensor(np.array(op_out.shape, dtype='int32'))
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
reshape_op = tfl.ReshapeOperator([new_out, shape_tensor], [op_out], shape_tensor.tensor)
reshape_op.extra_hints['direction'] = 'down'
self.graph.add_operator(reshape_op)
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
if node['node_type'] == ExtendedOperator.UNPACK:
prev_shape = tmp_prev_shape
# OP specific dim handling logic
if node['node_type'] in (
ExtendedOperator.CONCATENATION,
ExtendedOperator.GATHER,
ExtendedOperator.UNPACK,
ExtendedOperator.PACK,
):
new_axis = prev_shape.index(-1)
op.axis = new_axis
elif node['node_type'] == ExtendedOperator.SPLIT_V:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 2, new_dim_tensor, True)))
elif node['node_type'] == ExtendedOperator.SPLIT:
new_dim = prev_shape.index(-1)
new_dim_tensor = self.create_attr_tensor(np.array(new_dim, dtype='int32'))
actions.append((self.graph.replace_operator_input, (node, 0, new_dim_tensor, True)))
elif node['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD):
old_pad = op.inputs[1].tensor
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_pad = np.zeros((len(prev_shape), 2), dtype='int32')
new_pad[new_dim, :] = old_pad[old_dim, :]
new_pad_tensor = self.create_attr_tensor(new_pad)
actions.append((self.graph.replace_operator_input, (node, 1, new_pad_tensor, True)))
elif node['node_type'] == ExtendedOperator.PRELU:
old_weight = op.inputs[1].tensor
if old_weight.ndim != 1:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_shape = np.ones(len(prev_shape) - 1, dtype='int32')
new_shape[new_dim - 1] = old_weight.shape[old_dim - 1]
new_shape_t = self.create_attr_tensor(new_shape)
new_weight = self.create_transform_tensor(np.reshape(old_weight, new_shape))
self.graph.add_operator(tfl.ReshapeOperator([op.inputs[1], new_shape_t], [new_weight], new_shape))
actions.append((self.graph.replace_operator_input, (node, 1, new_weight, True)))
elif node['node_type'] == ExtendedOperator.SLICE:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_start = np.zeros(len(prev_shape), dtype='int32')
new_start[new_dim] = op.inputs[1].tensor[old_dim]
new_start_t = self.create_attr_tensor(new_start)
new_size = np.array(prev_shape, dtype='int32')
new_size[new_dim] = op.inputs[2].tensor[old_dim]
new_size_t = self.create_attr_tensor(new_size)
actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
actions.append((self.graph.replace_operator_input, (node, 2, new_size_t, True)))
elif node['node_type'] == ExtendedOperator.STRIDED_SLICE:
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_start = np.zeros(len(prev_shape), dtype='int32')
new_start[new_dim] = op.inputs[1].tensor[old_dim]
if op.inputs[1].buffer is None:
new_start_t = self.create_transform_tensor(new_start)
starts_mid = new_start[new_dim : new_dim + 1]
starts_mid_tensor = self.create_transform_tensor(starts_mid)
slice_inputs = [
op.inputs[1],
self.create_attr_tensor(np.array([old_dim], dtype='int32')),
self.create_attr_tensor(np.array([1], dtype='int32')),
]
self.graph.add_operator(tfl.SliceOperator(slice_inputs, [starts_mid_tensor]))
starts_left = new_start[:new_dim]
starts_right = new_start[new_dim + 1 :]
starts_tensors = []
if len(starts_left) > 0:
starts_tensors.append(self.create_attr_tensor(starts_left))
starts_tensors.append(starts_mid_tensor)
if len(starts_right) > 0:
starts_tensors.append(self.create_attr_tensor(starts_right))
if len(starts_tensors) > 1:
self.graph.add_operator(tfl.ConcatenationOperator(starts_tensors, [new_start_t], 0))
else:
new_start_t = starts_tensors[0]
else:
new_start_t = self.create_attr_tensor(new_start)
new_end = np.array(prev_shape, dtype='int32')
new_end[new_dim] = op.inputs[2].tensor[old_dim]
if op.inputs[2].buffer is None:
new_end_t = self.create_transform_tensor(new_end)
ends_mid = new_end[new_dim : new_dim + 1]
ends_mid_tensor = self.create_transform_tensor(ends_mid)
slice_inputs = [
op.inputs[2],
self.create_attr_tensor(np.array([old_dim], dtype='int32')),
self.create_attr_tensor(np.array([1], dtype='int32')),
]
self.graph.add_operator(tfl.SliceOperator(slice_inputs, [ends_mid_tensor]))
ends_left = new_end[:new_dim]
ends_right = new_end[new_dim + 1 :]
ends_tensors = []
if len(ends_left) > 0:
ends_tensors.append(self.create_attr_tensor(ends_left))
ends_tensors.append(ends_mid_tensor)
if len(ends_right) > 0:
ends_tensors.append(self.create_attr_tensor(ends_right))
if len(ends_tensors) > 1:
self.graph.add_operator(tfl.ConcatenationOperator(ends_tensors, [new_end_t], 0))
else:
new_end_t = ends_tensors[0]
else:
new_end_t = self.create_attr_tensor(new_end)
new_stride = np.ones(len(prev_shape), dtype='int32')
new_stride[new_dim] = op.inputs[3].tensor[old_dim]
new_stride_t = self.create_attr_tensor(new_stride)
actions.append((self.graph.replace_operator_input, (node, 1, new_start_t, True)))
actions.append((self.graph.replace_operator_input, (node, 2, new_end_t, True)))
actions.append((self.graph.replace_operator_input, (node, 3, new_stride_t, True)))
elif node['node_type'] == ExtendedOperator.TILE:
old_shape = op.inputs[1].tensor
new_dim = prev_shape.index(-1)
old_dim = next_shape.index(-1)
new_shape = np.ones(len(prev_shape), dtype='int32')
new_shape[new_dim] = old_shape[old_dim]
new_shape_tensor = self.create_attr_tensor(new_shape)
actions.append((self.graph.replace_operator_input, (node, 1, new_shape_tensor, True)))
elif node['node_type'] in (
ExtendedOperator.SUM,
ExtendedOperator.ARG_MIN,
ExtendedOperator.ARG_MAX,
ExtendedOperator.REDUCE_MIN,
ExtendedOperator.REDUCE_MAX,
ExtendedOperator.REDUCE_PROD,
ExtendedOperator.MEAN,
):
new_axis = prev_shape.index(-1)
axis_arr = np.array([new_axis], dtype='int32')
axis_tensor = self.create_attr_tensor(axis_arr)
actions.append((self.graph.replace_operator_input, (node, 1, axis_tensor, True)))
elif dim_indice is not None:
raise NotImplementedError(f'{node["node_type"]} has the property `dims` but is not handled')
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
return num_actions
@class_conditional(lambda self: self.level >= GraphOptimizer.COMMON_OPTIMIZE)
def fuse_bmm_add_pass(self):
edges = self.graph.graph.es.select(functools.partial(is_bmm_add_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
filtered_pairs = [
p
for p in filtered_pairs
if p[0]['node_type'] != ExtendedOperator.FULLY_CONNECTED
or len(p[0]['op'].inputs) == 2
or not np.any(p[0]['op'].inputs[2].tensor)
]
remove_ids = []
ops = []
restore_mapping = []
for bmm, add in filtered_pairs:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in add.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in add['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
ops.append((bmm, add))
remove_ids.append(bmm.index)
remove_ids.append(add.index)
# Make sure the nodes are topologically sorted
sorted_ops = [
(nodes[0]['op'], nodes[1]['op'])
for nodes in sorted(ops, key=lambda x: int(re.search(r'\d+', x[1]['name'])[0]))
]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for (bmm, add), mapping in zip(sorted_ops, restore_mapping):
input_tensor = bmm.inputs[0]
weight_tensor = bmm.inputs[1]
bias_tensor = add.inputs[1]
output_tensor = add.outputs[0]
ops = []
if isinstance(bmm, tfl.BatchMatmulOperator):
weight_t = self.create_transform_tensor(np.transpose(weight_tensor.tensor))
weight_perm = self.create_attr_tensor(np.array([1, 0], dtype='int32'))
ops.append(tfl.TransposeOperator([weight_tensor, weight_perm], [weight_t]))
else:
weight_t = weight_tensor
keep_dims = output_tensor.tensor.ndim > 2
ops.append(
tfl.FullyConnectedOperator(
[input_tensor, weight_t, bias_tensor],
[output_tensor],
fusedActivationFunction=add.fusedActivationFunction,
keepNumDims=keep_dims,
)
)
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
@class_conditional(lambda self: self.max_transpose_dims > 0)
def lower_transpose_dim_pass(self):
vertices = self.graph.graph.vs.select(
functools.partial(
is_high_dim_transpose_node, graph_converter=self.graph.graph, max_transpose_dims=self.max_transpose_dims
)
)
remove_ids = []
ops = []
restore_mapping = []
for trans in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in trans.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in trans['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(trans.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for trans, mapping in zip(sorted_ops, restore_mapping):
input_tensor = trans.inputs[0]
perm_tensor = trans.inputs[1]
output_tensor = trans.outputs[0]
input_shape = input_tensor.shape
perm = perm_tensor.tensor
output_shape = output_tensor.shape
last_perm = None
last_dim = None
cum_dim = None
new_shape = []
new_perm = []
for d, p in zip(input_shape, perm):
if last_dim is None and last_perm is None:
cum_dim = d
else:
if p - last_perm == 1 or d == 1 or cum_dim == 1:
cum_dim *= d
else:
new_shape.append(cum_dim)
new_perm.append(last_perm)
cum_dim = d
last_dim = d
last_perm = p
new_shape.append(cum_dim)
new_perm.append(last_perm)
new_perm_arr = np.argsort(new_perm).astype('int32')
assert (
len(new_shape) <= self.max_transpose_dims
), f"Don't know how to reduce the number of dims of transpose with input shape {input_shape}, perm {perm}"
ops = []
input_reduced = self.create_transform_tensor(
np.reshape(input_tensor.tensor, new_shape), quantization=input_tensor.quantization
)
reduced_shape = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
ops.append(tfl.ReshapeOperator([input_tensor, reduced_shape], [input_reduced], new_shape))
transposed = self.create_transform_tensor(
np.transpose(input_reduced.tensor, new_perm_arr), quantization=input_tensor.quantization
)
new_perm_tensor = self.create_attr_tensor(np.array(new_perm_arr, dtype='int32'))
ops.append(tfl.TransposeOperator([input_reduced, new_perm_tensor], [transposed]))
output_shape_tensor = self.create_attr_tensor(np.array(output_shape, dtype='int32'))
ops.append(tfl.ReshapeOperator([transposed, output_shape_tensor], [output_tensor], output_shape))
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
@class_conditional(lambda self: self.group_conv_rewrite)
def group_conv_rewrite_pass(self):
vertices = self.graph.graph.vs.select(functools.partial(is_group_conv_node, graph_converter=self.graph.graph))
remove_ids = []
ops = []
restore_mapping = []
for conv in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in conv.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in conv['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(conv.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for conv, mapping in zip(sorted_ops, restore_mapping):
input_tensor = conv.inputs[0]
weight_tensor = conv.inputs[1]
bias_tensor = conv.inputs[2] if len(conv.inputs) > 2 else None
output_tensor = conv.outputs[0]
num_input_channel = input_tensor.shape[3]
num_weight_channel = weight_tensor.shape[3]
num_chunks = num_input_channel // num_weight_channel
ops = []
input_tensors = [
self.create_transform_tensor(arr, quantization=input_tensor.quantization)
for arr in np.split(input_tensor.tensor, num_chunks, 3)
]
output_tensors = [
self.create_transform_tensor(arr, quantization=output_tensor.quantization)
for arr in np.split(output_tensor.tensor, num_chunks, 3)
]
weights = [
self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
for arr in np.split(weight_tensor.tensor, num_chunks, 0)
]
if bias_tensor is not None:
biases = [
self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
for arr in np.split(bias_tensor.tensor, num_chunks, 0)
]
else:
biases = [None] * num_chunks
dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))
for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
inputs = [it, w]
if b is not None:
inputs.append(b)
ops.append(
tfl.Conv2dOperator(
inputs,
[ot],
strideH=conv.strideH,
strideW=conv.strideW,
dilationHFactor=conv.dilationHFactor,
dilationWFactor=conv.dilationWFactor,
fusedActivationFunction=conv.fusedActivationFunction,
padding=conv.padding,
)
)
ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
@class_conditional(lambda self: self.group_conv_rewrite)
def group_deconv_rewrite_pass(self):
vertices = self.graph.graph.vs.select(functools.partial(is_group_deconv_node, graph_converter=self.graph.graph))
remove_ids = []
ops = []
restore_mapping = []
for conv in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in conv.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in conv['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(conv.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for conv, mapping in zip(sorted_ops, restore_mapping):
input_tensor = conv.inputs[2]
weight_tensor = conv.inputs[1]
output_shape_tensor = conv.inputs[0]
bias_tensor = conv.inputs[3] if len(conv.inputs) > 3 else None
output_tensor = conv.outputs[0]
num_output_channel = output_tensor.shape[3]
num_weight_channel = weight_tensor.shape[0]
num_chunks = num_output_channel // num_weight_channel
ops = []
input_tensors = [
self.create_transform_tensor(arr, quantization=input_tensor.quantization)
for arr in np.split(input_tensor.tensor, num_chunks, 3)
]
output_tensors = [
self.create_transform_tensor(arr, quantization=output_tensor.quantization)
for arr in np.split(output_tensor.tensor, num_chunks, 3)
]
weights = [
self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
for arr in np.split(weight_tensor.tensor, num_chunks, 3)
]
if bias_tensor is not None:
biases = [
self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
for arr in np.split(bias_tensor.tensor, num_chunks, 0)
]
else:
biases = [None] * num_chunks
new_os = output_shape_tensor.tensor.copy()
new_os[3] = num_weight_channel
new_ost = self.create_attr_tensor(new_os)
dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))
for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
inputs = [new_ost, w, it]
if b is not None:
inputs.append(b)
ops.append(
tfl.TransposeConvOperator(
inputs,
[ot],
padding=conv.padding,
strideH=conv.strideH,
strideW=conv.strideW,
)
)
ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
@class_conditional(lambda self: self.tflite_micro_rewrite)
def cat_split_pass(self):
vertices = self.graph.graph.vs.select(functools.partial(is_large_cat_node, graph_converter=self.graph.graph))
remove_ids = []
ops = []
restore_mapping = []
for cat in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in cat.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in cat['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(cat.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for cat, mapping in zip(sorted_ops, restore_mapping):
input_tensors = cat.inputs
layer_inputs = input_tensors
output_tensor = cat.outputs[0]
axis = cat.axis
last_layer = False
ops = []
while True:
layer_outputs = []
while len(layer_inputs) > 0:
curr_inputs = layer_inputs[:10]
input_arrs = [t.tensor for t in curr_inputs]
output_arr = np.concatenate(input_arrs, axis)
if last_layer:
curr_output = output_tensor
else:
curr_output = self.create_transform_tensor(output_arr, quantization=output_tensor.quantization)
layer_outputs.append(curr_output)
ops.append(tfl.ConcatenationOperator(curr_inputs, [curr_output], axis))
layer_inputs = layer_inputs[10:]
if len(layer_outputs) == 0:
break
elif len(layer_outputs) <= 10:
last_layer = True
layer_inputs = layer_outputs
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)
def input_transpose_pass(self):
nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')
remove_edges = []
for name, transpose in zip(self.graph.inputs, self.graph.input_transpose):
if transpose is True:
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
assert node['node_type'] == ExtendedOperator.INPUT_NODE
# For quantized graphs, we insert the transpose op after the quantize op
next_node = None
if node.outdegree() == 1:
next_node = node.out_edges()[0].target_vertex
if next_node['node_type'] != ExtendedOperator.QUANTIZE:
next_node = None
# Transpose input tensor shapes
input_tensor = self.graph.tensor_map[node['name']]
input_tensor.tensor = np.transpose(input_tensor.tensor, nchw2nhwc_perm)
input_tensor.shape = input_tensor.tensor.shape
# Transpose quantize output tensor shapes
last_tensor = input_tensor
last_node = node
if next_node is not None:
last_node = next_node
last_tensor = next_node['op'].outputs[0]
last_tensor.tensor = np.transpose(last_tensor.tensor, nchw2nhwc_perm)
last_tensor.shape = last_tensor.tensor.shape
# Create new transpose op
nhwc2nchw_perm_tensor = self.create_attr_tensor(nhwc2nchw_perm)
transposed = self.create_transform_tensor(
np.transpose(last_tensor.tensor, nhwc2nchw_perm), quantization=last_tensor.quantization
)
transpose_op = tfl.TransposeOperator([last_tensor, nhwc2nchw_perm_tensor], [transposed])
transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(transpose_op)
# Get the newly-generated node
new_node_name = self.graph.tensor_node_map[transposed.name]
new_node = self.graph.graph.vs.find(name=new_node_name)
# Connect the transpose op to the graph
self.graph.replace_next_tensors(last_node, new_node, transposed.name, [new_node_name])
# Collect the unused connections
for edge in last_node.out_edges():
target_vertex = edge.target_vertex
if target_vertex['name'] != new_node_name:
remove_edges.append(edge.index)
# Remove the collected edges
self.graph.graph.delete_edges(remove_edges)
@class_conditional(lambda self: self.quantize_input_output_type is not None)
def quantize_input_output_type_pass(self):
remove_edges = []
remove_vertices = []
for i, name in enumerate(self.graph.inputs):
if self.fuse_input_indices is not None:
if i not in self.fuse_input_indices:
continue
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
assert node['node_type'] == ExtendedOperator.INPUT_NODE
# Update input tensor
input_tensor = self.graph.tensor_map[node['outputs'][0]]
input_type = str(input_tensor.dtype)
if input_type == self.quantize_input_output_type:
continue
input_arr = input_tensor.tensor.copy()
input_quantization = copy.deepcopy(input_tensor.quantization)
if input_type == 'int8' and self.quantize_input_output_type == 'uint8':
input_tensor.tensor = (input_tensor.tensor.astype('int32') + 128).astype('uint8')
input_tensor.quantization.zero_point += 128
input_tensor.dtype = input_tensor.tensor.dtype
elif input_type == 'uint8' and self.quantize_input_output_type == 'int8':
input_tensor.tensor = (input_tensor.tensor.astype('int32') - 128).astype('int8')
input_tensor.quantization.zero_point -= 128
input_tensor.dtype = input_tensor.tensor.dtype
else:
raise AssertionError(
f'Unsupported types: input_type: {input_type}, quantize_input_type:'
f' {self.quantize_input_output_type}'
)
# Create new quantize op
requantized = self.create_transform_tensor(input_arr, quantization=input_quantization)
quantize_op = tfl.QuantizeOperator([input_tensor], [requantized])
self.graph.add_operator(quantize_op)
# Get the newly-generated node
new_node_name = self.graph.tensor_node_map[requantized.name]
new_node = self.graph.graph.vs.find(name=new_node_name)
# Connect the quantize op to the graph
self.graph.replace_next_tensors(node, new_node, requantized.name, [new_node_name])
# Collect the unused connections
for edge in node.out_edges():
target_vertex = edge.target_vertex
if target_vertex['name'] != new_node_name:
remove_edges.append(edge.index)
output_mapping = {}
for i, name in enumerate(self.graph.outputs):
if self.fuse_output_indices is not None:
if i not in self.fuse_output_indices:
continue
output_tensor = self.graph.tensor_map[name]
output_type = str(output_tensor.dtype)
if output_type == self.quantize_input_output_type:
continue
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
for edge in node.out_edges():
next_node = edge.target_vertex
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_vertices.append(next_node.index)
# Update output tensor
output_arr = output_tensor.tensor.copy()
output_quantization = copy.deepcopy(output_tensor.quantization)
if output_type == 'int8' and self.quantize_input_output_type == 'uint8':
output_arr = (output_arr.astype('int32') + 128).astype('uint8')
output_quantization.zero_point += 128
elif output_type == 'uint8' and self.quantize_input_output_type == 'int8':
output_arr = (output_arr.astype('int32') - 128).astype('int8')
output_quantization.zero_point -= 128
else:
raise AssertionError(
f'Unsupported types: output_type: {output_type}, quantize_input_type:'
f' {self.quantize_input_output_type}'
)
requantized = self.create_transform_tensor(output_arr, quantization=output_quantization)
quantize_op = tfl.QuantizeOperator([output_tensor], [requantized])
self.graph.add_operator(quantize_op)
output_mapping[name] = requantized.name
if len(output_mapping) > 0:
new_outputs = []
output_names = []
for name in self.graph.outputs:
if name in output_mapping:
new_outputs.append(output_mapping[name])
output_names.append(output_mapping[name])
else:
new_outputs.append(name)
self.graph.outputs.clear()
self.graph.outputs.extend(new_outputs)
self.graph.add_outputs(output_names)
# Remove the collected edges & vertices
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
def output_transpose_pass(self):
nhwc2nchw_perm = np.array([0, 3, 1, 2], dtype='int32')
nchw2nhwc_perm = np.array([0, 2, 3, 1], dtype='int32')
if isinstance(self.graph.output_transpose, (list, tuple)):
assert len(self.graph.output_transpose) == len(self.graph.outputs)
else:
self.graph.output_transpose = [self.graph.output_transpose] * len(self.graph.outputs)
filtered_dict = {}
for i, (name, transpose) in enumerate(zip(self.graph.outputs, self.graph.output_transpose)):
if name in filtered_dict:
old_transpose = filtered_dict[name]
assert (
transpose == old_transpose
), f"outputs {i} points to an exising tensor {name}, but their property `output_transpose` is different"
else:
filtered_dict[name] = transpose
prev_modify_node_indices = {}
prev_modify_next_indices = {}
next_modify_node_indices = {}
for name, transpose in filtered_dict.items():
if name in self.graph.tensor_map:
tensor = self.graph.tensor_map[name]
if transpose is None:
transpose = len(tensor.shape) == 4
else:
transpose = False
for i, n in enumerate(self.graph.outputs):
if name == n:
self.graph.output_transpose[i] = transpose
if transpose:
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
tensor_idx = node['outputs'].index(name)
prev_node = None
if node['node_type'] == ExtendedOperator.DEQUANTIZE:
prev_node_name = self.graph.tensor_node_map[node['op'].inputs[0].name]
prev_node = self.graph.graph.vs.find(name=prev_node_name)
if prev_node is None:
next_modify_node_indices.setdefault(node, set())
next_modify_node_indices[node].add(tensor_idx)
else:
prev_modify_node_indices.setdefault(node, set())
prev_modify_node_indices[node].add(0)
prev_modify_next_indices.setdefault(node, set())
prev_modify_next_indices[node].add(tensor_idx)
remove_edges = []
remove_vertices = []
actions = []
for node, index in prev_modify_node_indices.items():
next_indices = prev_modify_next_indices[node]
op = node['op']
tensor_names = [node['outputs'][i] for i in index]
next_nodes = {}
for edge in node.out_edges():
if edge['label'] not in tensor_names:
continue
if edge.index in remove_edges:
continue
tensor_idx = tensor_names.index(edge['label'])
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] not in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
next_nodes.setdefault(tensor_idx, [])
next_nodes[tensor_idx].append(next_node)
prev_nodes = []
prev_output_indices = []
for i in index:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, index, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
prev_new_out = self.create_transform_tensor(
np.transpose(prev_out.tensor, nchw2nhwc_perm), quantization=prev_out.quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
prev_transpose_op = tfl.TransposeOperator([prev_out, perm_tensor], [prev_new_out])
prev_transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(prev_transpose_op)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_mapping = {}
for i in next_indices:
t = op.outputs[i]
t.tensor = np.transpose(t.tensor, nchw2nhwc_perm)
t.shape = t.tensor.shape
if i in next_nodes:
new_t = self.create_transform_tensor(np.transpose(t.tensor, nhwc2nchw_perm))
perm_t = self.create_attr_tensor(nhwc2nchw_perm)
next_transpose_op = tfl.TransposeOperator([t, perm_t], [new_t])
next_transpose_op.extra_hints['direction'] = 'down'
self.graph.add_operator(next_transpose_op)
tensor_mapping[t.name] = new_t
for nodes in next_nodes.values():
for n in nodes:
next_op = n['op']
for i, t in enumerate(next_op.inputs):
if t.name in tensor_mapping:
actions.append((self.graph.replace_operator_input, (n, i, tensor_mapping[t.name])))
for node, index in next_modify_node_indices.items():
op = node['op']
tensor_names = [node['outputs'][i] for i in index]
out_nodes = []
next_nodes = []
next_edges = []
for edge in node.out_edges():
if edge['label'] not in tensor_names:
continue
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
tensor_idx = tensor_names.index(edge['label'])
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
elif next_node['node_type'] != ExtendedOperator.UNUSED_NODE:
next_nodes.append(next_node)
next_edges.append(edge)
remove_vertices.extend([x.index for x in out_nodes])
remove_edges.extend([x.index for x in next_edges])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
if i not in index:
continue
op_out.tensor = np.transpose(op_out.tensor, nchw2nhwc_perm)
op_out.shape = op_out.tensor.shape
perm_tensor = self.create_attr_tensor(nchw2nhwc_perm)
new_out = self.create_transform_tensor(
np.transpose(op_out.tensor, nhwc2nchw_perm), quantization=op_out.quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
next_transpose_op = tfl.TransposeOperator([new_out, perm_tensor], [op_out])
next_transpose_op.extra_hints['direction'] = 'up'
self.graph.add_operator(next_transpose_op)
tensor_node_dict[op_out.name] = (
self.graph.graph.vs.find(name=self.graph.tensor_node_map[new_out.name]),
new_out.name,
)
# Connect next edges and replace next tensors
for edge in next_edges:
old_name = edge['name']
source, new_name = tensor_node_dict[old_name]
target = edge.target_vertex
self.graph.graph.add_edge(source, target, name=new_name, label=new_name)
op = target['op']
for i, op_input in enumerate(op.inputs):
if op_input.name == old_name:
op.inputs[i] = self.graph.tensor_map[new_name]
break
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)
def connect_unused_tensors_pass(self):
filtered_nodes = self.graph.graph.vs.select(
functools.partial(is_multi_output_op_node, graph_converter=self.graph.graph)
)
list_unpack_names = set([i for s in self.graph.iterable_map.values() for i in s])
all_tensors = set(self.graph.graph.es['label'])
names = []
for node in filtered_nodes:
output_names = node['outputs']
# Recognizes the pattern SPLIT -> (RESHAPE, ..., RESHAPE)
if not list_unpack_names.isdisjoint(set(output_names)):
output_names = []
outdegree = 0
for edge in node.out_edges():
target_vertex = edge.target_vertex
if target_vertex['node_type'] == ExtendedOperator.RESHAPE:
outdegree += target_vertex.outdegree()
output_names.append(target_vertex['outputs'][0])
# Only nodes with partially unused tensors are supported
if outdegree == 0:
continue
for out in output_names:
if out not in all_tensors:
names.append(out)
self.graph.add_outputs(names, ExtendedOperator.UNUSED_NODE)
def output_list_unpack_pass(self):
output_names = []
unpacked_outputs = []
for name in self.graph.outputs:
if name in self.graph.iterable_map:
names = self.graph.get_list_expanded_names(name)
unpacked_outputs.extend(names)
output_names.extend(names)
else:
unpacked_outputs.append(name)
self.graph.outputs.clear()
self.graph.outputs.extend(unpacked_outputs)
self.graph.add_outputs(output_names)
@class_conditional(lambda self: self.fuse_quant)
def fuse_quant_dequant_nodes(self):
edges = self.graph.graph.es.select(functools.partial(is_quant_dequant_edge, graph_converter=self.graph.graph))
filtered_pairs = [[self.graph.graph.vs[x.source], self.graph.graph.vs[x.target]] for x in edges]
remove_vertices = []
input_mapping = {}
output_mapping = {}
for prev, next in filtered_pairs:
if prev['node_type'] == ExtendedOperator.INPUT_NODE:
input_name = prev['outputs'][0]
if self.fuse_input_indices is not None:
input_idx = self.graph.inputs.index(input_name)
if input_idx not in self.fuse_input_indices:
continue
remove_vertices.append(prev)
next['node_type'] = prev['node_type']
next['op'] = None
input_mapping.setdefault(input_name, [])
input_mapping[input_name].extend(next['outputs'])
else:
if prev['op'] is not None:
prev_name = prev['op'].outputs[0].name
if self.fuse_output_indices is not None:
output_idx = self.graph.outputs.index(prev_name)
if output_idx not in self.fuse_output_indices:
continue
prev['node_type'] = next['node_type']
new_name = prev['op'].inputs[0].name
prev['op'] = None
output_mapping.setdefault(prev_name, [])
output_mapping[prev_name].append(new_name)
remove_vertices.append(next)
self.graph.graph.delete_vertices(remove_vertices)
if len(input_mapping) > 0:
new_inputs = []
for name in self.graph.inputs:
if name in input_mapping:
new_inputs.extend(input_mapping[name])
else:
new_inputs.append(name)
self.graph.inputs.clear()
self.graph.inputs.extend(new_inputs)
if len(output_mapping) > 0:
new_outputs = []
for name in self.graph.outputs:
if name in output_mapping:
new_outputs.extend(output_mapping[name])
else:
new_outputs.append(name)
self.graph.outputs.clear()
self.graph.outputs.extend(new_outputs)
def optimize(self):
# Input/output passes
self.output_list_unpack_pass()
self.input_transpose_pass()
self.output_transpose_pass()
# Connect unused tensors with special nodes
self.connect_unused_tensors_pass()
# Transpose, Reshape and NO-OP cleanup
self.branch_reshape_expand_pass()
self.fuse_simple_reshape_pass()
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
self.fuse_simple_gather_pass()
for branch in (False, True):
self.remove_noop_pass(branch)
self.fuse_wrapped_reshape_within_transpose_pass()
# Buffer folding, which is needed by the fusion passes below
for _ in range(2):
self.fold_reshape_buffer()
self.fold_transpose_buffer()
# Move `transpose` ops for the rewrite quantizable pass
self.elementwise_op_transpose_passthrough_pass(quantizable_ops_only=True)
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
# Fuse reciprocal and sqrt
self.fuse_reciprocal_sqrt()
# Map quantizable ops to quantized kernels
self.elementwise_op_quantize_passthrough_pass()
# Remove consecutive dequantize and quantize nodes
self.fuse_dequant_quant_pass(q_first=False)
# OP fusion passes before transformation
self.fuse_conv_fc_bn()
self.fuse_activation()
self.fuse_requantize()
self.fuse_bn_conv()
# Convert TinyNeuralNetwork ops to TFLite ops
self.transform_graph()
# OP fusion passes after transformation
self.fuse_bmm_add_pass()
self.fuse_activation()
# Transpose and reshape cleanup
self.branch_reshape_expand_pass()
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
self.fuse_simple_reshape_pass()
# Branch transpose & reshape cleanup
for i in range(11):
t_count = self.elementwise_op_transpose_passthrough_pass()
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
r_count = self.elementwise_op_reshape_passthrough_pass()
self.branch_reshape_expand_pass()
self.fuse_simple_reshape_pass()
c_count = self.elementwise_reshape_transpose_passthrough_pass()
self.branch_transpose_expand_pass()
self.fuse_simple_transpose_pass()
if t_count + r_count + c_count == 0:
log.debug(f'elem p/t pass finished in {i + 1} steps')
break
# Other cleanups
self.fuse_simple_slice_pass()
for branch in (False, True):
self.remove_noop_pass(branch)
self.fuse_wrapped_reshape_within_transpose_pass()
# Buffer folding
for _ in range(2):
self.fold_reshape_buffer()
self.fold_transpose_buffer()
# Transpose and reshape cleanup
for _ in range(2):
self.transpose_to_reshape_pass()
self.branch_reshape_expand_pass()
self.fuse_simple_reshape_pass()
self.fuse_simple_transpose_pass()
self.lower_transpose_dim_pass()
# Some advanced fusion logic
self.fuse_conv2d_gather()
# Remove consecutive dequantize and quantize nodes
self.fuse_dequant_quant_pass(q_first=True)
# Fuse reciprocal and sqrt
self.fuse_reciprocal_sqrt()
# Remove additional tile nodes before elementwise ops
self.remove_tile_before_binary_elementwise_ops()
# Fuse activation
self.fuse_activation()
# Fuse quant/dequant nodes
self.fuse_quant_dequant_nodes()
# Input output quantize type
self.quantize_input_output_type_pass()
# Fuse same padding
self.fuse_same_padding()
self.fuse_same_padding_slicing()
self.fuse_gather_conv2d()
# Group conv & deconv
self.group_conv_rewrite_pass()
self.group_deconv_rewrite_pass()
# TFLite micro specific
self.cat_split_pass()
self.split_requantize()
# Group the same tensors into one
self.group_tensors_pass()
# Final cleanup
self.cleanup_dead_nodes()
def is_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type']
in (ExtendedOperator.GENERIC_CONV, ExtendedOperator.GENERIC_DECONV, ExtendedOperator.FULLY_CONNECTED)
and target_vertex['node_type'] == ExtendedOperator.BATCH_NORM
and source_vertex.outdegree() == 1
and target_vertex['op'].inputs[1].buffer is not None
and target_vertex['op'].inputs[2].buffer is not None
and source_vertex['op'].inputs[1].buffer is not None
and (
target_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
or source_vertex['op'].fusedActivationFunction
in (ActivationFunctionType.NONE, target_vertex['op'].fusedActivationFunction)
)
)
def is_rev_bn_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
target_vertex['node_type'] == ExtendedOperator.GENERIC_CONV
and source_vertex['node_type'] == ExtendedOperator.BATCH_NORM
and source_vertex.outdegree() == 1
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].inputs[2].buffer is not None
and target_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
)
def is_padding_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] in (ExtendedOperator.PAD, ExtendedOperator.PADV2)
and (
len(source_vertex['op'].inputs) == 2
or (
len(source_vertex['op'].inputs) == 3
and source_vertex['op'].inputs[2].dtype == np.dtype('float32')
and (
(
source_vertex['op'].inputs[2].tensor[0] == 0.0
and target_vertex['node_type'] != ExtendedOperator.MAX_POOL_2D
)
or (
source_vertex['op'].inputs[2].tensor[0] == np.finfo(np.float32).min
and target_vertex['node_type'] == ExtendedOperator.MAX_POOL_2D
)
)
)
)
and target_vertex['node_type']
in (
ExtendedOperator.CONV_2D,
ExtendedOperator.CONV_3D,
ExtendedOperator.DEPTHWISE_CONV_2D,
ExtendedOperator.MAX_POOL_2D,
)
and source_vertex.outdegree() == 1
and target_vertex['op'].padding == Padding.VALID
)
def is_slicing_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
and (
len(target_vertex['op'].inputs) == 3
or (len(target_vertex['op'].inputs) == 4 and np.all(target_vertex['op'].inputs[3].tensor == 1))
)
and source_vertex['node_type']
in (
ExtendedOperator.TRANSPOSE_CONV,
ExtendedOperator.CONV_3D_TRANSPOSE,
)
and source_vertex.outdegree() == 1
and source_vertex['op'].padding == Padding.VALID
)
def is_requantize_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type']
in (
ExtendedOperator.FULLY_CONNECTED,
ExtendedOperator.GENERIC_CONV,
ExtendedOperator.ADD,
ExtendedOperator.SUB,
ExtendedOperator.MUL,
ExtendedOperator.DIV,
ExtendedOperator.MAX_POOL_2D,
ExtendedOperator.AVERAGE_POOL_2D,
ExtendedOperator.GENERIC_DECONV,
)
and source_vertex['op'].outputs[0].quantization is not None
and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
)
def is_activ_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type']
in (
ExtendedOperator.FULLY_CONNECTED,
ExtendedOperator.GENERIC_CONV,
ExtendedOperator.ADD,
ExtendedOperator.SUB,
ExtendedOperator.MUL,
ExtendedOperator.DIV,
ExtendedOperator.MAX_POOL_2D,
ExtendedOperator.AVERAGE_POOL_2D,
ExtendedOperator.GENERIC_DECONV,
)
and target_vertex['node_type'] in (ExtendedOperator.RELU, ExtendedOperator.RELU6)
and source_vertex['op'].fusedActivationFunction == ActivationFunctionType.NONE
and source_vertex.outdegree() == 1
)
def is_requantize_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return (
vertex['node_type'] == ExtendedOperator.QUANTIZE
and vertex['op'].inputs[0].quantization is not None
and vertex['op'].outputs[0].quantization is not None
)
def is_large_cat_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] == ExtendedOperator.CONCATENATION and len(vertex['op'].inputs) > 10
def is_high_dim_transpose_node(vertex: ig.Vertex, graph_converter: ig.Graph, max_transpose_dims: int):
return vertex['node_type'] == ExtendedOperator.TRANSPOSE and vertex['op'].inputs[1].tensor.size > max_transpose_dims
def is_group_conv_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return (
vertex['node_type'] == ExtendedOperator.CONV_2D
and vertex['op'].inputs[0].shape[3] != vertex['op'].inputs[1].shape[3]
)
def is_group_deconv_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return (
vertex['node_type'] == ExtendedOperator.TRANSPOSE_CONV
and vertex['op'].outputs[0].shape[3] != vertex['op'].inputs[1].shape[0]
)
def is_transformable_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] <= ExtendedOperator.BATCH_NORM and vertex.outdegree() >= 1
def is_transformable_transpose_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return (
vertex['node_type'] == ExtendedOperator.TRANSPOSE
and vertex.outdegree() >= 1
and is_transpose_same_to_reshape_op(vertex['op'])
)
def is_multi_output_op_node(vertex: ig.Vertex, graph_converter: ig.Graph):
return vertex['node_type'] >= 0 and len(vertex['outputs']) > 1 and vertex.outdegree() > 0
def is_quantize_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph, with_lstm: bool):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
(
source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
and is_quantizable_rewrite_op(target_vertex['node_type'], target_vertex['op'], with_lstm)
)
or (
target_vertex['node_type'] == ExtendedOperator.QUANTIZE
and is_quantizable_rewrite_op(source_vertex['node_type'], source_vertex['op'], with_lstm)
)
) and target_vertex['op'].inputs[0].name in source_vertex['outputs']
def is_transpose_reshape_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
(
source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and target_vertex['node_type'] == ExtendedOperator.RESHAPE
)
or (
target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex['node_type'] == ExtendedOperator.RESHAPE
)
) and target_vertex['op'].inputs[0].name in source_vertex['outputs']
def is_transpose_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph, quantizable_ops_only: bool):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
if quantizable_ops_only:
is_unary = is_elementwise_unary_quantizable_op
is_binary = is_elementwise_binary_quantizable_op
else:
is_unary = is_elementwise_unary_op
is_binary = is_elementwise_binary_op
return (
(
source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and (
is_unary(target_vertex['node_type'], target_vertex['op'])
or is_binary(target_vertex['node_type'], target_vertex['op'])
)
)
or (
target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and (
is_unary(source_vertex['node_type'], source_vertex['op'])
or is_binary(source_vertex['node_type'], source_vertex['op'])
)
)
) and (
(
target_vertex['node_type'] != ExtendedOperator.SPLIT
and target_vertex['op'].inputs[0].name in source_vertex['outputs']
)
or (
target_vertex['node_type'] == ExtendedOperator.SPLIT
and target_vertex['op'].inputs[1].name in source_vertex['outputs']
)
)
def is_reshape_elementwise_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
(
source_vertex['node_type'] == ExtendedOperator.RESHAPE
and (
is_elementwise_unary_op(target_vertex['node_type'], target_vertex['op'])
or is_elementwise_binary_op(target_vertex['node_type'], target_vertex['op'])
)
)
or (
target_vertex['node_type'] == ExtendedOperator.RESHAPE
and (
is_elementwise_unary_op(source_vertex['node_type'], source_vertex['op'])
or is_elementwise_binary_op(source_vertex['node_type'], source_vertex['op'])
)
)
) and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
def is_elementwise_reduce_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return (
op_code
in (
ExtendedOperator.SUM,
ExtendedOperator.ARG_MIN,
ExtendedOperator.ARG_MAX,
ExtendedOperator.REDUCE_MIN,
ExtendedOperator.REDUCE_MAX,
ExtendedOperator.REDUCE_PROD,
)
and len(op.inputs[0].shape) == len(op.outputs[0].shape)
) or (
op_code == ExtendedOperator.MEAN
and len(op.inputs[0].shape) == len(op.outputs[0].shape)
and (
len(op.inputs[0].shape) != 4
or (
not np.array_equal(op.inputs[1].tensor, np.array([1, 2], dtype='int32'))
and not np.array_equal(op.inputs[1].tensor, np.array([2, 1], dtype='int32'))
)
)
)
def is_elementwise_unary_quantizable_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return op_code in (
ExtendedOperator.SOFTMAX,
ExtendedOperator.LOG_SOFTMAX,
)
def is_elementwise_binary_quantizable_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return False
def is_elementwise_unary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return op_code in (
ExtendedOperator.RELU,
ExtendedOperator.SIN,
ExtendedOperator.COS,
ExtendedOperator.TANH,
ExtendedOperator.ELU,
ExtendedOperator.PRELU,
ExtendedOperator.EXP,
ExtendedOperator.LOG,
ExtendedOperator.NEG,
ExtendedOperator.FLOOR,
ExtendedOperator.RELU6,
ExtendedOperator.QUANTIZE,
ExtendedOperator.DEQUANTIZE,
ExtendedOperator.SQRT,
ExtendedOperator.RSQRT,
ExtendedOperator.CAST,
ExtendedOperator.LOGISTIC,
ExtendedOperator.HARD_SWISH,
ExtendedOperator.LEAKY_RELU,
ExtendedOperator.SPLIT,
ExtendedOperator.SPLIT_V,
ExtendedOperator.UNPACK,
ExtendedOperator.PAD,
ExtendedOperator.PADV2,
ExtendedOperator.MIRROR_PAD,
ExtendedOperator.SLICE,
ExtendedOperator.STRIDED_SLICE,
ExtendedOperator.TILE,
ExtendedOperator.GATHER,
ExtendedOperator.ABS,
) or is_elementwise_reduce_op(op_code, op)
def is_quantizable_rewrite_op(op_code: ExtendedOperator, op: tfl.BaseOperator, with_lstm: bool):
return op_code in (
ExtendedOperator.BATCH_MATMUL,
ExtendedOperator.SOFTMAX,
ExtendedOperator.LOG_SOFTMAX,
ExtendedOperator.ABS,
ExtendedOperator.SUM,
ExtendedOperator.DIV,
ExtendedOperator.RSQRT,
ExtendedOperator.MAXIMUM,
ExtendedOperator.MINIMUM,
) or (with_lstm and op_code == ExtendedOperator.UNIDIRECTIONAL_SEQUENCE_LSTM)
def is_elementwise_binary_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return (
op_code
in (
ExtendedOperator.CONCATENATION,
ExtendedOperator.PACK,
ExtendedOperator.ADD,
ExtendedOperator.SUB,
ExtendedOperator.MUL,
ExtendedOperator.DIV,
ExtendedOperator.MAXIMUM,
ExtendedOperator.MINIMUM,
ExtendedOperator.SQUARED_DIFFERENCE,
)
and len(op.inputs) >= 2
)
def is_non_passthrough_op(op_code: ExtendedOperator, op: tfl.BaseOperator):
return op_code in (
ExtendedOperator.CONV_2D,
ExtendedOperator.AVERAGE_POOL_2D,
ExtendedOperator.DEPTHWISE_CONV_2D,
ExtendedOperator.MAX_POOL_2D,
)
def is_ending_with_noop_edge(edge: ig.Edge, graph_converter: ig.Graph, branch: bool = False):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
if branch:
source_cond_var = source_vertex.outdegree() >= 1
else:
source_cond_var = source_vertex.outdegree() == 1
return (
source_cond_var
and target_vertex.outdegree() >= 1
and target_vertex['op'] is not None
and target_vertex['op'].inputs[0].name in source_vertex['outputs']
and (
(
target_vertex['node_type'] == ExtendedOperator.RESHAPE
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
)
or (
target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
)
or (
target_vertex['node_type']
in (ExtendedOperator.PAD, ExtendedOperator.PADV2, ExtendedOperator.MIRROR_PAD)
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
)
or (
target_vertex['node_type'] == ExtendedOperator.TILE
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
)
or (
target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
)
or (
target_vertex['node_type'] == ExtendedOperator.CONCATENATION
and len(target_vertex['op'].inputs) == 1
and len(target_vertex['op'].outputs) == 1
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
)
or (
target_vertex['node_type'] == ExtendedOperator.GATHER
and target_vertex['op'].inputs[0].shape == target_vertex['op'].outputs[0].shape
and (np.diff(target_vertex['op'].inputs[1].tensor) == 1).all()
)
or (
target_vertex['node_type'] == ExtendedOperator.CAST
and target_vertex['op'].inDataType == target_vertex['op'].outDataType
)
)
)
def is_bmm_add_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
out_dim_idx = None
if source_vertex['node_type'] == ExtendedOperator.BATCH_MATMUL:
out_dim_idx = -1
elif source_vertex['node_type'] == ExtendedOperator.FULLY_CONNECTED:
out_dim_idx = 0
return (
out_dim_idx is not None
and target_vertex['node_type'] == ExtendedOperator.ADD
and source_vertex['op'].inputs[0].tensor.ndim >= 2
and source_vertex['op'].inputs[1].tensor.ndim == 2
and target_vertex['op'].inputs[1].tensor.ndim == 1
and target_vertex['op'].inputs[1].shape[0] == source_vertex['op'].inputs[1].shape[out_dim_idx]
and source_vertex.outdegree() == 1
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_wrapped_reshape_within_transpose_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
(
(
target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex['node_type'] == ExtendedOperator.RESHAPE
)
or (
source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and target_vertex['node_type'] == ExtendedOperator.RESHAPE
)
)
and source_vertex.outdegree() == 1
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_slice_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
and source_vertex.outdegree() == 1
and target_vertex['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].inputs[2].buffer is not None
)
def is_transpose_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex.outdegree() == 1
and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_gather_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.GATHER
and source_vertex.outdegree() == 1
and target_vertex['node_type'] == ExtendedOperator.GATHER
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
and source_vertex['op'].axis == target_vertex['op'].axis
and source_vertex['op'].batchDims == target_vertex['op'].batchDims
)
def is_reshape_branch_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.RESHAPE
and source_vertex.outdegree() > 1
and target_vertex['node_type'] == ExtendedOperator.RESHAPE
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_transpose_branch_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex.outdegree() > 1
and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_dequant_quant_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph, q_first: bool):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
if q_first:
cond = (
source_vertex['node_type'] == ExtendedOperator.QUANTIZE
and target_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
)
else:
cond = (
source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
)
return (
cond
and source_vertex.outdegree() == 1
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_reshape_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.RESHAPE
and source_vertex.outdegree() == 1
and target_vertex['node_type'] == ExtendedOperator.RESHAPE
and target_vertex.outdegree() >= 1
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
)
def is_constant_transpose_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.CONSTANT_NODE
and target_vertex['node_type'] == ExtendedOperator.TRANSPOSE
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
and target_vertex.outdegree() >= 1
)
def is_constant_reshape_fusable_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.CONSTANT_NODE
and target_vertex['node_type'] == ExtendedOperator.RESHAPE
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
and target_vertex.outdegree() >= 1
)
def is_quant_dequant_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.INPUT_NODE
and target_vertex['node_type'] == ExtendedOperator.QUANTIZE
and source_vertex['outputs'][0] == target_vertex['op'].inputs[0].name
) or (
source_vertex['node_type'] == ExtendedOperator.DEQUANTIZE
and target_vertex['node_type'] == ExtendedOperator.OUTPUT_NODE
)
def is_transpose_same_to_reshape_op(op: tfl.BaseOperator):
num_elements = np.prod(op.inputs[0].shape)
input_shape = np.array(op.inputs[0].shape, dtype='int32')
output_shape = np.array(op.outputs[0].shape, dtype='int32')
if np.array_equal(input_shape[input_shape != 1], output_shape[output_shape != 1]):
input_tensor = np.arange(num_elements).reshape(input_shape)
perm = op.inputs[1].tensor
new_tensor = np.transpose(input_tensor, perm)
return np.array_equal(new_tensor.flatten(), input_tensor.flatten())
else:
return False
def is_conv2d_gather_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.CONV_2D
and target_vertex['node_type'] == ExtendedOperator.GATHER
and source_vertex.outdegree() == 1
and target_vertex['op'].inputs[1].buffer is not None
and target_vertex['op'].axis == 3
and source_vertex['op'].inputs[1].tensor.shape[0] == target_vertex['op'].inputs[1].tensor.shape[0]
)
def is_gather_conv2d_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.GATHER
and target_vertex['node_type'] == ExtendedOperator.CONV_2D
and source_vertex.outdegree() == 1
and source_vertex['op'].inputs[1].buffer is not None
and source_vertex['op'].axis == 3
and source_vertex['op'].inputs[1].tensor.shape[0] == target_vertex['op'].inputs[1].tensor.shape[3]
)
def is_reciprocal_sqrt_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.SQRT
and target_vertex['node_type'] == ExtendedOperator.DIV
and source_vertex.outdegree() == 1
)
def is_tile_binary_op_edge(edge: ig.Edge, graph_converter: ig.Graph):
source_vertex = graph_converter.vs[edge.source]
target_vertex = graph_converter.vs[edge.target]
return (
source_vertex['node_type'] == ExtendedOperator.TILE
and target_vertex['node_type']
in (
ExtendedOperator.ADD,
ExtendedOperator.SUB,
ExtendedOperator.MUL,
ExtendedOperator.DIV,
)
and source_vertex.outdegree() == 1
)
def op_input_dims(op: tfl.BaseOperator):
dim_indices = None
if isinstance(op, (tfl.ConcatenationOperator, tfl.GatherOperator, tfl.PackOperator, tfl.UnpackOperator)):
dim_indices = op.axis
elif isinstance(op, tfl.SplitOperator):
dim_indices = op.inputs[0].tensor.item()
elif isinstance(op, tfl.SplitVOperator):
dim_indices = op.inputs[2].tensor.item()
elif isinstance(op, (tfl.PadOperator, tfl.Padv2Operator, tfl.MirrorPadOperator)):
pads = np.sum(op.inputs[1].tensor, axis=-1)
nonzero_idx = np.nonzero(pads)[0]
# TODO: support multi indices
if nonzero_idx.size == 1:
dim_indices = nonzero_idx[0]
elif isinstance(op, tfl.PreluOperator):
w_shape = np.array(op.inputs[1].shape, dtype='int32')
nonzero_idx = np.nonzero(w_shape != 1)[0]
if nonzero_idx.size == 1:
dim_indices = nonzero_idx[0] + 1
elif isinstance(op, (tfl.SliceOperator, tfl.StridedSliceOperator, tfl.TileOperator)):
old_shape = np.array(op.inputs[0].shape)
new_shape = np.array(op.outputs[0].shape)
diff = new_shape - old_shape
nonzero_idx = np.nonzero(diff)[0]
# TODO: support multi indices
if nonzero_idx.size == 1:
dim_indices = nonzero_idx[0]
elif isinstance(
op,
(
tfl.SumOperator,
tfl.MeanOperator,
tfl.ArgMinOperator,
tfl.ArgMaxOperator,
tfl.ReduceMinOperator,
tfl.ReduceMaxOperator,
tfl.ReduceProdOperator,
),
):
# TODO: support multi indices
if op.inputs[1].tensor.size == 1:
dim_indices = op.inputs[1].tensor[0]
return dim_indices
def op_input_indices(op: tfl.BaseOperator):
if isinstance(op, (tfl.ConcatenationOperator, tfl.PackOperator)):
input_indices = range(len(op.inputs))
elif isinstance(op, tfl.SplitOperator):
input_indices = (1,)
elif isinstance(op, (tfl.BatchMatmulOperator, tfl.MinimumOperator, tfl.MaximumOperator)):
input_indices = range(2)
elif isinstance(
op, (tfl.AddOperator, tfl.SubOperator, tfl.MulOperator, tfl.DivOperator, tfl.SquaredDifferenceOperator)
):
if len(op.inputs[1].shape) == 1 and op.inputs[1].shape[0] == 1:
input_indices = range(1)
elif len(op.inputs[0].shape) == 1 and op.inputs[0].shape[0] == 1:
input_indices = (1,)
else:
input_indices = range(2)
else:
input_indices = range(1)
return input_indices
def fuse_bn_weight(eps, scale, var, weight, transpose):
if transpose:
shape = [1, -1] + [1] * (len(weight.shape) - 2)
else:
shape = [-1, 1] + [1] * (len(weight.shape) - 2)
inv = 1 / np.sqrt(var + eps)
return weight * (scale * inv).reshape(shape)
def fuse_bn_bias(eps, scale, var, mean, bn_b, activ_b):
inv = 1 / np.sqrt(var + eps)
if activ_b is not None:
if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
activ_b = activ_b.repeat(mean.size)
return (activ_b - mean) * inv * scale + bn_b
else:
return (-mean) * inv * scale + bn_b
def fuse_rev_bn_weight(eps, scale, var, weight):
shape = [1, -1] + [1] * (len(weight.shape) - 2)
inv = 1 / np.sqrt(var + eps)
return weight * (scale * inv).reshape(shape)
def fuse_rev_bn_bias(eps, scale, var, mean, bn_b, activ_b, weight):
reduced_dims = tuple([i for i in range(len(weight.shape)) if i > 1])
inv = 1 / np.sqrt(var + eps)
fused_b = bn_b - mean * inv * scale
if weight.shape[1] == 1 and mean.shape[0] > 1:
offset_b = (weight.sum(reduced_dims) * fused_b.reshape(-1, 1)).reshape(-1)
else:
offset_b = np.matmul(weight.sum(reduced_dims), fused_b.reshape(-1, 1)).reshape(-1)
if activ_b is not None:
if activ_b.shape != mean.shape and activ_b.ndim == 1 and activ_b.size == 1:
activ_b = activ_b.repeat(mean.size)
return activ_b + offset_b
else:
return offset_b
def fuse_slices(seq: typing.Iterable[ig.Vertex]):
cur_start = None
cur_end = None
cur_strides = None
for node in seq:
assert node['node_type'] in (ExtendedOperator.SLICE, ExtendedOperator.STRIDED_SLICE)
next_start = node['op'].inputs[1].tensor
if cur_strides is None:
cur_strides = np.ones_like(next_start, dtype='int32')
if cur_start is None:
cur_start = np.zeros_like(next_start, dtype='int32')
if node['node_type'] == ExtendedOperator.SLICE:
next_size = node['op'].inputs[2].tensor
next_end = cur_start + (next_start + next_size) * cur_strides
next_strides = np.ones_like(next_start, dtype='int32')
else:
next_end = node['op'].inputs[2].tensor
next_end = cur_start + next_end * cur_strides
next_strides = node['op'].inputs[3].tensor
if cur_end is None:
cur_start = next_start
cur_end = next_end
cur_strides = next_strides
else:
cur_start += next_start * cur_strides
cur_end = np.min((cur_end, next_end), axis=0)
cur_strides = cur_strides * next_strides
return cur_start, cur_end, cur_strides
def fuse_transpose_perms(seq: typing.Iterable[ig.Vertex]):
cur_perm = None
for node in seq:
assert node['node_type'] in (ExtendedOperator.TRANSPOSE, ExtendedOperator.GATHER)
next_perm = node['op'].inputs[1].tensor
if cur_perm is None:
cur_perm = next_perm
else:
cur_perm = cur_perm[next_perm]
return cur_perm
def fuse_transpose_perms_extended(seq: typing.Iterable[ig.Vertex]):
cur_perm = None
# Reverse the sequence if dim is expanding
if seq[1]['node_type'] == ExtendedOperator.RESHAPE:
if len(seq[1]['op'].inputs[0].shape) < len(seq[1]['op'].outputs[0].shape):
seq = list(reversed(list(seq)))
for node in seq:
if node['node_type'] == ExtendedOperator.TRANSPOSE:
next_perm = node['op'].inputs[1].tensor
if cur_perm is None:
cur_perm = next_perm
else:
cur_perm = cur_perm[next_perm]
elif node['node_type'] == ExtendedOperator.RESHAPE:
if len(seq[1]['op'].inputs[0].shape) > len(seq[1]['op'].outputs[0].shape):
old_shape = node['op'].inputs[0].shape
new_shape = node['op'].outputs[0].shape
else:
new_shape = node['op'].inputs[0].shape
old_shape = node['op'].outputs[0].shape
if old_shape != new_shape:
if len(old_shape) != len(new_shape):
new_shape_padded = list(new_shape) + [None] * (len(old_shape) - len(new_shape))
next_perm = []
new_idx = 0
while new_idx < len(new_shape):
for old, item in zip(old_shape, cur_perm):
if old == new_shape_padded[new_idx] and item not in next_perm:
next_perm.append(item)
new_idx += 1
cur_perm = np.argsort(next_perm)
else:
mapping = {}
for i in range(len(new_shape)):
mapping.setdefault(new_shape[i], [])
mapping[new_shape[i]].append(i)
next_perm = [0] * len(old_shape)
for i in range(len(old_shape)):
next_perm[i] = mapping[old_shape[i]].pop(0)
cur_perm = cur_perm[next_perm]
return cur_perm
def fuse_connected_edges(
filtered_pairs: typing.List[typing.Iterable[ig.Vertex]],
) -> typing.List[typing.Iterable[ig.Vertex]]:
while True:
heads = {n[0]: i for i, n in enumerate(filtered_pairs)}
tails = {n[-1]: i for i, n in enumerate(filtered_pairs)}
connectables = heads.keys() & tails.keys()
if len(connectables) > 0:
curr_filtered = []
for seq in filtered_pairs:
head_connectable = seq[0] in connectables
preserve = head_connectable and filtered_pairs[tails[seq[0]]][0] in connectables
if preserve:
curr_filtered.append(seq)
elif not head_connectable:
if seq[-1] in connectables:
curr_filtered.append(seq + filtered_pairs[heads[seq[-1]]][1:])
else:
curr_filtered.append(seq)
filtered_pairs = curr_filtered
else:
break
return filtered_pairs
def is_simple_reshape(orig_shape, new_shape, mapping: typing.Optional[typing.Dict[int, int]] = None):
if orig_shape == new_shape:
if mapping is not None:
for i in range(len(orig_shape)):
mapping[i] = i
return True
i = 0
j = 0
while True:
if i == len(orig_shape) and j == len(new_shape):
break
elif i == len(orig_shape):
if new_shape[j] == 1:
j += 1
else:
break
elif j == len(new_shape):
if orig_shape[i] == 1:
i += 1
else:
break
elif orig_shape[i] == new_shape[j]:
if mapping is not None:
mapping[i] = j
i += 1
j += 1
elif orig_shape[i] == 1:
i += 1
elif new_shape[j] == 1:
j += 1
else:
break
if i != len(orig_shape) or j != len(new_shape):
return False
else:
return True
def reshape_mapping(shape_1, shape_2):
i = 0
j = 0
acc_l = 1
start_l = 0
acc_r = 1
start_r = 0
mapping_l = []
mapping_r = []
sign = None
while i < len(shape_1) or j < len(shape_2):
if i < len(shape_1) and j < len(shape_2):
if start_l == i and start_r == j and shape_1[i] == shape_2[j]:
mapping_l.append([i])
mapping_r.append([j])
acc_l = 1
acc_r = 1
i += 1
j += 1
start_l = i
start_r = j
sign = None
else:
if sign in ('l', None):
acc_l = shape_1[i] * acc_l
if sign in ('r', None):
acc_r = shape_2[j] * acc_r
if acc_l == acc_r:
mapping_l.append(list(range(start_l, i + 1)))
mapping_r.append(list(range(start_r, j + 1)))
acc_l = 1
acc_r = 1
i += 1
j += 1
start_l = i
start_r = j
sign = None
elif acc_l < acc_r:
sign = 'l'
i += 1
else:
sign = 'r'
j += 1
elif i < len(shape_1):
assert shape_1[i] == 1
mapping_l[-1].append(i)
i += 1
else:
assert shape_2[j] == 1
mapping_r[-1].append(j)
j += 1
non_one_mapping_l = []
non_one_mapping_r = []
for ml, mr in zip(mapping_l, mapping_r):
new_ml = [i for i in ml if shape_1[i] != 1]
new_mr = [j for j in mr if shape_2[j] != 1]
if len(new_ml) > 0 and len(new_mr) > 0:
non_one_mapping_l.append(new_ml)
non_one_mapping_r.append(new_mr)
return mapping_l, mapping_r, non_one_mapping_l, non_one_mapping_r
def elinimate_sequences(
graph_converter: CommonGraph,
filtered_pairs: typing.List[typing.Iterable[ig.Vertex]],
remove_first_pred: typing.Union[bool, typing.Callable] = False,
remove_first_node_action: typing.Optional[typing.Callable] = None,
remove_last_pred: typing.Union[bool, typing.Callable] = True,
remove_last_node_action: typing.Optional[typing.Callable] = None,
skip_pred: typing.Union[bool, typing.Callable] = False,
input_idx: int = 0,
force_forward_input: bool = False,
):
remove_ids = []
actions = []
for seq in filtered_pairs:
first_node = seq[0]
last_node = seq[-1]
if type(skip_pred) is bool:
skip = skip_pred
elif skip_pred is not None:
skip = skip_pred(seq)
if skip:
continue
if type(remove_first_pred) is bool:
remove_first = remove_first_pred
custom_data = None
elif remove_first_pred is not None:
remove_first, custom_data = remove_first_pred(seq)
if type(remove_last_pred) is bool:
remove_last = remove_last_pred
custom_data_last = None
elif remove_last_pred is not None:
remove_last, custom_data_last = remove_last_pred(seq)
# If the first node can also be eliminated, then set the previous node as the first node
if remove_first:
first_node = graph_converter.graph.vs.find(
name=graph_converter.tensor_node_map[first_node['op'].inputs[input_idx].name]
)
if not remove_last:
last_node = seq[-2]
output_idx = 0
if first_node == seq[0]:
next_idx = 1
else:
next_idx = 0
output_name = seq[next_idx]['op'].inputs[input_idx].name
output_idx = first_node['outputs'].index(output_name)
# We use the forward input tensor under the following circumstances.
# 1. If the previous node before the sequence is an input node
# 2. If the first node has multiple outputs and the last node doesn't connect to output nodes
use_forward_input = False
if first_node['node_type'] == ExtendedOperator.INPUT_NODE:
use_forward_input = True
branch = first_node.outdegree() > 1
has_output_nodes = False
for edge in last_node.out_edges():
target_vertex = edge.target_vertex
if target_vertex['node_type'] in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
if use_forward_input:
# Cannot optimize away ops between i/o nodes
skip = True
else:
has_output_nodes = True
break
if branch:
output_outdegree = 0
for edge in first_node.out_edges():
target_vertex = edge.target_vertex
if target_vertex == seq[next_idx]:
continue
if target_vertex['node_type'] in (ExtendedOperator.OUTPUT_NODE, ExtendedOperator.UNUSED_NODE):
if has_output_nodes and edge['label'] == output_name:
output_outdegree += 1
break
else:
names = [t.name for t in target_vertex['op'].inputs]
if output_name in names:
output_outdegree += 1
break
if not has_output_nodes:
use_forward_input = True
elif output_outdegree > 0:
skip = True
if force_forward_input and not use_forward_input:
if not has_output_nodes:
use_forward_input = True
else:
skip = True
if skip:
continue
if use_forward_input:
# Find out the output of the first node in the sequence
new_output = first_node['outputs'][output_idx]
assert new_output in graph_converter.tensor_map
# For each node that is next of the last node, we connect it with the first node
# Also, the replace the tensors when needed
graph_converter.replace_next_tensors(last_node, first_node, new_output)
else:
# Find out the output of the last node in the sequence
new_output = last_node['outputs'][0]
assert new_output in graph_converter.tensor_map
# For each node that is next of the last node, we connect it with the first node
graph_converter.connect_next_tensors(last_node, first_node, new_output)
# Update graph, prepare to drop the output tensor of the intermediate nodes and use the output tensor of
# the last node instead
first_node['outputs'][output_idx] = new_output
if first_node['op'] is not None:
first_node['op'].outputs[output_idx] = graph_converter.tensor_map[new_output]
graph_converter.tensor_node_map[new_output] = first_node['name']
# When the first node is a constant node, we need to set the buffer back
if first_node['node_type'] == ExtendedOperator.CONSTANT_NODE and not use_forward_input:
if seq[0]['node_type'] == ExtendedOperator.CONSTANT_NODE:
old_tensor = graph_converter.tensor_map[seq[0]['name']]
else:
old_tensor = seq[0]['op'].inputs[input_idx]
new_tensor = seq[-1]['op'].outputs[0]
new_tensor.buffer = old_tensor.buffer
if remove_first and remove_last:
# Push the sequence to the removing list
remove_ids.extend([x.index for x in seq])
else:
# Collect actions when removing the first node
start_index = 0
end_index = len(seq)
if not remove_first:
start_index = 1
if remove_first_node_action is not None:
action = remove_first_node_action(first_node, last_node, custom_data)
if action is not None:
actions.extend(action)
if not remove_last:
end_index = len(seq) - 1
if remove_last_node_action is not None:
action = remove_last_node_action(first_node, last_node, custom_data_last)
if action is not None:
actions.extend(action)
# Push the sequence (except the first node) to the removing list
remove_ids.extend([x.index for x in seq[start_index:end_index]])
for func, args in actions:
func(*args)
graph_converter.graph.delete_vertices(remove_ids)
def expand_op_outputs_in_branches(
nodes: typing.List[ig.Vertex],
new_op_func: typing.Callable[[ig.Vertex, ig.Vertex, ig.Vertex], None],
graph_converter: CommonGraph,
):
actions = []
for node in nodes:
preserve_node = None
prev_node_name = node['op'].inputs[0].name
prev_node = graph_converter.graph.vs.find(name=graph_converter.tensor_node_map[prev_node_name])
# Collect next nodes and choose one to preserve
next_nodes = []
for edge in node.out_edges():
next_node = graph_converter.graph.vs[edge.target]
if preserve_node is None or next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
preserve_node = next_node
next_nodes.append(next_node)
# For the filtered nodes, use the cloned op as the previous op
filtered_nodes = list(set(next_nodes) - set([preserve_node]))
for next_node in filtered_nodes:
actions.extend(new_op_func(node, prev_node, next_node))
# Process actions
for func, args in actions:
node = args[0]
func(*args)
def get_same_padding_args(input_shape, filter_shape, strides, dilation):
dim = len(input_shape)
padding = [0] * dim
for i in range(dim):
if input_shape[i] % strides[i] == 0:
padding[i] = max(1 - strides[i] + (filter_shape[i] - 1) * dilation[i], 0)
else:
padding[i] = max(1 + (filter_shape[i] - 1) * dilation[i] - (input_shape[i] % strides[i]), 0)
pad_args = [[0, 0]] + [[x // 2, x - x // 2] for x in padding] + [[0, 0]]
return pad_args