in tinynn/converter/operators/optimize.py [0:0]
def quantize_input_output_type_pass(self):
remove_edges = []
remove_vertices = []
for i, name in enumerate(self.graph.inputs):
if self.fuse_input_indices is not None:
if i not in self.fuse_input_indices:
continue
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
assert node['node_type'] == ExtendedOperator.INPUT_NODE
# Update input tensor
input_tensor = self.graph.tensor_map[node['outputs'][0]]
input_type = str(input_tensor.dtype)
if input_type == self.quantize_input_output_type:
continue
input_arr = input_tensor.tensor.copy()
input_quantization = copy.deepcopy(input_tensor.quantization)
if input_type == 'int8' and self.quantize_input_output_type == 'uint8':
input_tensor.tensor = (input_tensor.tensor.astype('int32') + 128).astype('uint8')
input_tensor.quantization.zero_point += 128
input_tensor.dtype = input_tensor.tensor.dtype
elif input_type == 'uint8' and self.quantize_input_output_type == 'int8':
input_tensor.tensor = (input_tensor.tensor.astype('int32') - 128).astype('int8')
input_tensor.quantization.zero_point -= 128
input_tensor.dtype = input_tensor.tensor.dtype
else:
raise AssertionError(
f'Unsupported types: input_type: {input_type}, quantize_input_type:'
f' {self.quantize_input_output_type}'
)
# Create new quantize op
requantized = self.create_transform_tensor(input_arr, quantization=input_quantization)
quantize_op = tfl.QuantizeOperator([input_tensor], [requantized])
self.graph.add_operator(quantize_op)
# Get the newly-generated node
new_node_name = self.graph.tensor_node_map[requantized.name]
new_node = self.graph.graph.vs.find(name=new_node_name)
# Connect the quantize op to the graph
self.graph.replace_next_tensors(node, new_node, requantized.name, [new_node_name])
# Collect the unused connections
for edge in node.out_edges():
target_vertex = edge.target_vertex
if target_vertex['name'] != new_node_name:
remove_edges.append(edge.index)
output_mapping = {}
for i, name in enumerate(self.graph.outputs):
if self.fuse_output_indices is not None:
if i not in self.fuse_output_indices:
continue
output_tensor = self.graph.tensor_map[name]
output_type = str(output_tensor.dtype)
if output_type == self.quantize_input_output_type:
continue
node_name = self.graph.tensor_node_map[name]
node = self.graph.graph.vs.find(name=node_name)
for edge in node.out_edges():
next_node = edge.target_vertex
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_vertices.append(next_node.index)
# Update output tensor
output_arr = output_tensor.tensor.copy()
output_quantization = copy.deepcopy(output_tensor.quantization)
if output_type == 'int8' and self.quantize_input_output_type == 'uint8':
output_arr = (output_arr.astype('int32') + 128).astype('uint8')
output_quantization.zero_point += 128
elif output_type == 'uint8' and self.quantize_input_output_type == 'int8':
output_arr = (output_arr.astype('int32') - 128).astype('int8')
output_quantization.zero_point -= 128
else:
raise AssertionError(
f'Unsupported types: output_type: {output_type}, quantize_input_type:'
f' {self.quantize_input_output_type}'
)
requantized = self.create_transform_tensor(output_arr, quantization=output_quantization)
quantize_op = tfl.QuantizeOperator([output_tensor], [requantized])
self.graph.add_operator(quantize_op)
output_mapping[name] = requantized.name
if len(output_mapping) > 0:
new_outputs = []
output_names = []
for name in self.graph.outputs:
if name in output_mapping:
new_outputs.append(output_mapping[name])
output_names.append(output_mapping[name])
else:
new_outputs.append(name)
self.graph.outputs.clear()
self.graph.outputs.extend(new_outputs)
self.graph.add_outputs(output_names)
# Remove the collected edges & vertices
self.graph.graph.delete_edges(remove_edges)
self.graph.graph.delete_vertices(remove_vertices)