in tinynn/converter/base.py [0:0]
def init_operations(self):
log.debug('Initialize operators...')
node_queue = collections.deque(self.graph.nodes())
scope_map = {}
current_scope = None
while node_queue:
node = node_queue.popleft()
k = node.kind()
output_tensors = []
converter_type = OPERATOR_CONVERTER_DICT.get(k, NoTrackOperator)
converter = converter_type(
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
self.enable_mtk_ops,
self.hybrid_asymmetric_inputs,
self.unroll_rnn,
self.separated_rnn_gate_calc,
self.conv_transpose_with_bias,
self.legacy_gelu,
)
# Don't track the operator if all the input nodes are not tracked unless it has custom implementation
# (e.g prim::* ops)
if converter_type.run == NoTrackOperator.run and converter_type != NoTrackOperator:
no_track_flag = True
for n in converter.input_names:
if self.common_graph.has_nested_names(n):
nested_names = self.common_graph.get_list_expanded_names(n)
for x in nested_names:
if x in self.common_graph.tensor_map and self.common_graph.tensor_map[x].buffer is None:
no_track_flag = False
break
elif n in self.common_graph.tensor_map and self.common_graph.tensor_map[n].buffer is None:
no_track_flag = False
break
if no_track_flag:
if converter_type == ATenDequantizeOperator:
converter_type = TrackQParamsOperator
elif converter_type == ATenQuantizePerTensorOperator:
converter_type = TrackRevQParamsOperator
else:
converter_type = NoTrackOperator
converter = converter_type(
node,
self.tensor_map,
current_scope,
not self.strict_symmetric_check,
self.q_type,
self.hybrid_q_type,
self.map_bilstm_to_lstm,
self.enable_mtk_ops,
self.hybrid_asymmetric_inputs,
self.unroll_rnn,
self.separated_rnn_gate_calc,
self.conv_transpose_with_bias,
self.legacy_gelu,
)
if k != 'prim::Constant':
log.debug(f'{k} {converter.input_names} -> {converter.output_names} {converter_type.__name__}')
# Don't fetch attrs and schemas for non-tracking nodes
if converter_type not in (NoTrackOperator, TrackRevQParamsOperator, TrackQParamsOperator):
try:
attrs = converter.fetch_all_attrs(node)
except StopIteration:
attrs = None
args = converter.fetch_annotated_args(node)
else:
attrs = None
args = None
converter.parse(node, attrs, args, self.common_graph)
outputs = converter.output_names
new_nodes = converter.output_nodes
if output_tensors is not None:
output_tensors.extend(converter.get_output_tensors())
if len(new_nodes) > 0:
node_queue.extendleft(reversed(new_nodes))
if k == 'prim::PythonOp':
s = node.scopeName()
scope_map.setdefault(s, 0)
scope_map[s] += 1
current_scope = f'{s}_{scope_map[s]}'
converter.prepare_scope_tensors(node, attrs, args, self.common_graph, current_scope)
elif k == 'prim::Return':
current_scope = None
assert len(output_tensors) == len(outputs)
for t, name in zip(output_tensors, outputs):
self.tensor_map[name] = t
if self.preserve_tensors and isinstance(t, torch.Tensor):
self.tensor_map_copies[name] = t.detach().clone()