in tinynn/converter/operators/optimize.py [0:0]
def group_conv_rewrite_pass(self):
vertices = self.graph.graph.vs.select(functools.partial(is_group_conv_node, graph_converter=self.graph.graph))
remove_ids = []
ops = []
restore_mapping = []
for conv in vertices:
restore_nodes = []
# For each node that is next of a transformable node,
# a. if it is an output node, remove it anyway since it will always be reconstructed
# b. otherwise, record the info of the edge so that we may restore it after reconstruction
for out_edge in conv.out_edges():
next_node = self.graph.graph.vs[out_edge.target]
if next_node['node_type'] == ExtendedOperator.OUTPUT_NODE:
remove_ids.append(next_node.index)
del self.graph.tensor_map[next_node['outputs'][0]]
del self.graph.tensor_node_map[next_node['outputs'][0]]
else:
restore_nodes.append((out_edge['name'], next_node['name']))
# Remove the mapping since they are going to be removed
for output_name in conv['outputs']:
del self.graph.tensor_map[output_name]
del self.graph.tensor_node_map[output_name]
restore_mapping.append(restore_nodes)
remove_ids.append(conv.index)
# Make sure the nodes are topologically sorted
sorted_ops = [node['op'] for node in sorted(vertices, key=lambda x: int(re.search(r'\d+', x['name'])[0]))]
# Delete nodes before transformation in the graph
self.graph.graph.delete_vertices(remove_ids)
for conv, mapping in zip(sorted_ops, restore_mapping):
input_tensor = conv.inputs[0]
weight_tensor = conv.inputs[1]
bias_tensor = conv.inputs[2] if len(conv.inputs) > 2 else None
output_tensor = conv.outputs[0]
num_input_channel = input_tensor.shape[3]
num_weight_channel = weight_tensor.shape[3]
num_chunks = num_input_channel // num_weight_channel
ops = []
input_tensors = [
self.create_transform_tensor(arr, quantization=input_tensor.quantization)
for arr in np.split(input_tensor.tensor, num_chunks, 3)
]
output_tensors = [
self.create_transform_tensor(arr, quantization=output_tensor.quantization)
for arr in np.split(output_tensor.tensor, num_chunks, 3)
]
weights = [
self.create_attr_tensor(arr, quantization=weight_tensor.quantization)
for arr in np.split(weight_tensor.tensor, num_chunks, 0)
]
if bias_tensor is not None:
biases = [
self.create_attr_tensor(arr, quantization=bias_tensor.quantization)
for arr in np.split(bias_tensor.tensor, num_chunks, 0)
]
else:
biases = [None] * num_chunks
dim_tensor = self.create_attr_tensor(np.array(3, dtype='int32'))
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], input_tensors, num_chunks))
for it, ot, w, b in zip(input_tensors, output_tensors, weights, biases):
inputs = [it, w]
if b is not None:
inputs.append(b)
ops.append(
tfl.Conv2dOperator(
inputs,
[ot],
strideH=conv.strideH,
strideW=conv.strideW,
dilationHFactor=conv.dilationHFactor,
dilationWFactor=conv.dilationWFactor,
fusedActivationFunction=conv.fusedActivationFunction,
padding=conv.padding,
)
)
ops.append(tfl.ConcatenationOperator(output_tensors, [output_tensor], 3))
for op in ops:
self.graph.add_operator(op, transform=True)
self.graph.try_restore_edges(mapping)