in tinynn/converter/operators/optimize.py [0:0]
def elementwise_op_quantize_passthrough_pass(self):
edges = self.graph.graph.es.select(
functools.partial(
is_quantize_elementwise_op_edge, graph_converter=self.graph.graph, with_lstm=self.hybrid_int16_lstm
)
)
pairs = ((self.graph.graph.vs[edge.source], self.graph.graph.vs[edge.target]) for edge in edges)
filtered_nodes = (k[0] if k[0]['node_type'] != ExtendedOperator.DEQUANTIZE else k[1] for k in pairs)
unique_nodes = list(set(filtered_nodes))
actions = []
remove_edges = []
remove_vertices = []
for node in unique_nodes:
op = node['op']
input_indices = op_input_indices(op)
prev_nodes = []
q_tensors = dict()
prev_output_indices = []
skip_names = []
for i in input_indices:
prev_node_name = op.inputs[i].name
prev_node = self.graph.graph.vs.find(name=self.graph.tensor_node_map[prev_node_name])
prev_nodes.append(prev_node)
prev_output_indices.append(prev_node['outputs'].index(prev_node_name))
if prev_node['node_type'] == ExtendedOperator.DEQUANTIZE:
q_tensors[prev_node_name] = prev_node['op'].inputs[0]
if prev_node['node_type'] == ExtendedOperator.CONSTANT_NODE:
if (
node['node_type'] in (ExtendedOperator.MINIMUM, ExtendedOperator.MAXIMUM)
and i != 0
and prev_node_name not in self.graph.q_mapping
):
f_tensor = self.graph.tensor_map[prev_node_name]
r_tensor = q_tensors[op.inputs[0].name]
q_arr = np.rint(
f_tensor.tensor / r_tensor.quantization.scale + r_tensor.quantization.zero_point
)
i_type = np.iinfo(r_tensor.tensor.dtype)
if np.any(q_arr > i_type.max):
warnings.warn('Overflow while quantizing the tensor')
q_arr = np.minimum(q_arr, i_type.max)
if np.any(q_arr < i_type.min):
warnings.warn('Underflow while quantizing the tensor')
q_arr = np.maximum(q_arr, i_type.min)
q_arr = q_arr.astype(r_tensor.dtype)
q_tensor = self.create_attr_tensor(q_arr, quantization=r_tensor.quantization)
self.graph.q_mapping[prev_node_name] = q_tensor
if prev_node_name in self.graph.q_mapping:
skip_names.append(prev_node_name)
next_nodes = []
next_edges = []
out_nodes = []
for edge in node.out_edges():
if edge.index in remove_edges:
continue
next_node = self.graph.graph.vs[edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
out_nodes.append(next_node)
else:
next_nodes.append(next_node)
next_edges.append(edge)
if next_node['node_type'] == ExtendedOperator.QUANTIZE:
skip = False
name = next_node['op'].inputs[0].name
q_tensor = next_node['op'].outputs[0]
assert q_tensor.quantization is not None
if node['node_type'] in (
ExtendedOperator.BATCH_MATMUL,
ExtendedOperator.ABS,
ExtendedOperator.RSQRT,
):
if q_tensor.dtype not in (np.dtype('int8'), np.dtype('int16')):
skip = True
elif node['node_type'] == ExtendedOperator.DIV:
if q_tensor.dtype != np.dtype('uint8'):
skip = True
elif node['node_type'] == ExtendedOperator.SOFTMAX:
if q_tensor.dtype == np.dtype('int8'):
if (
abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
or q_tensor.quantization.zero_point != -128
):
skip = True
elif q_tensor.dtype == np.dtype('int16'):
if (
abs(q_tensor.quantization.scale - 1.0 / 32768) > 0.001 * 1.0 / 32768
or q_tensor.quantization.zero_point != 0
):
skip = True
elif q_tensor.dtype == np.dtype('uint8'):
if (
abs(q_tensor.quantization.scale - 1.0 / 256) > 0.001 * 1.0 / 256
or q_tensor.quantization.zero_point != 0
):
log.warning(
'On some chips, only softmax with scale=1.0/256 and zero_point=0 is supported'
)
else:
skip = True
elif node['node_type'] == ExtendedOperator.LOG_SOFTMAX:
if q_tensor.dtype == np.dtype('int8'):
if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 127:
skip = True
elif q_tensor.dtype == np.dtype('uint8'):
if q_tensor.quantization.scale != 16.0 / 256 or q_tensor.quantization.zero_point != 255:
skip = True
else:
skip = True
if not skip:
q_tensors[name] = q_tensor
cur_transpose_size = len(q_tensors)
new_transpose_size = len(prev_nodes) + len(next_nodes) - len(skip_names)
# Skip if the number of [de]quantize nodes is not decreasing
if len(next_nodes) == 0 or new_transpose_size > cur_transpose_size:
continue
remove_edges.extend([x.index for x in next_edges])
remove_vertices.extend([x.index for x in out_nodes])
for n in out_nodes:
del self.graph.tensor_map[n['outputs'][0]]
del self.graph.tensor_node_map[n['outputs'][0]]
tensor_node_dict = {}
for prev_node, prev_idx, next_idx in zip(prev_nodes, input_indices, prev_output_indices):
if prev_node['op'] is None:
prev_out = self.graph.tensor_map[prev_node['outputs'][0]]
else:
prev_out = prev_node['op'].outputs[next_idx]
if prev_out.name in tensor_node_dict:
prev_new_out, skip = tensor_node_dict[prev_out.name]
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True, skip)))
skip += 1
tensor_node_dict[prev_out.name] = (prev_new_out, skip)
else:
if prev_out.name in skip_names:
prev_new_out = self.graph.q_mapping[prev_out.name]
self.graph.add_nodes([prev_new_out])
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
else:
prev_new_out = self.create_transform_tensor(
q_tensors[prev_out.name].tensor, quantization=q_tensors[prev_out.name].quantization
)
tensor_node_dict[prev_out.name] = (prev_new_out, 1)
self.graph.add_operator(tfl.QuantizeOperator([prev_out], [prev_new_out]))
actions.append((self.graph.replace_operator_input, (node, prev_idx, prev_new_out, True)))
tensor_node_dict = {}
for i, op_out in enumerate(op.outputs):
new_out = self.create_transform_tensor(
q_tensors[op_out.name].tensor, quantization=q_tensors[op_out.name].quantization
)
# Update relations
if op_out.name in self.graph.tensor_node_map:
del self.graph.tensor_node_map[op_out.name]
self.graph.tensor_node_map[new_out.name] = node['name']
self.graph.tensor_map[new_out.name] = new_out
node['outputs'][i] = new_out.name
op.outputs[i] = new_out
self.graph.add_operator(tfl.DequantizeOperator([new_out], [op_out]))
tensor_node_dict[op_out.name] = self.graph.graph.vs.find(name=self.graph.tensor_node_map[op_out.name])
for edge in next_edges:
source = tensor_node_dict[edge['name']]
self.graph.graph.add_edge(source, edge.target_vertex, name=edge['name'], label=edge['name'])
# Process actions
ids = []
for func, args in actions:
node = args[0]
res = func(*args)
if res is not None:
ids.extend(res)
remove_edges = list(set(remove_edges + ids))
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)