in tinynn/graph/tracer.py [0:0]
def parse_args(self, *args, **kwargs):
"""Sets the string representation of the arguments"""
def _tensor_name(a, convert_to_parameter=False, original=False):
"""Get the tensor name from the computation graph"""
ns = ''
if constant_handler(a, self.unique_name, self.full_name):
ns = 'self.'
pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
if original:
node = current_graph().nodes_map[pre_node_name]
pre_node_name = node.original_name
else:
pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
node = current_graph().nodes_map[pre_node_name]
if original:
pre_node_name = node.original_name
if type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional):
ns = 'self.'
if id(a) in current_graph().tensor_pre_index_dict:
pre_node_index = current_graph().tensor_pre_index_dict[id(a)]
log.debug(f'pre_index gen func {self.kind}: {pre_node_index}')
if isinstance(pre_node_index, (list, tuple)):
indices_str = ''.join([f'[{i}]' for i in pre_node_index])
return f"{ns}{pre_node_name}{indices_str}"
else:
return f"{ns}{pre_node_name}[{pre_node_index}]"
else:
return f"{ns}{pre_node_name}"
def _escape_arg(arg: str):
"""Escapes the special characters in the argument string"""
for c in ('{', '}'):
if c in arg:
arg = arg.replace(c, f'{c}{c}')
return arg
def _parse_args(arg):
"""Converts the argument to a list of strings"""
new_arg = []
for a in arg:
if isinstance(a, (list, tuple, torch.Size)):
new_arg.append(_parse_args(a))
elif type(a) in (torch.Tensor, torch.nn.Parameter) or (
type(a) in (torch.dtype, torch.device, torch.Size) and id(a) in current_graph().tensor_pre_node_dict
):
self.prev_tensors.append(a)
self.tensor_names.append(_tensor_name(a))
self.original_tensor_names.append(_tensor_name(a, original=True))
new_arg.append('{}')
elif type(a) in (str, torch.device):
new_arg.append(_escape_arg(f"\'{a}\'"))
elif type(a) in (int, bool, torch.dtype):
new_arg.append(str(a))
elif type(a) is float:
str_arg = str(a)
if str_arg in ('nan', 'inf', '-inf'):
new_arg.append(f"float('{str_arg}')")
else:
new_arg.append(str_arg)
elif a is None:
new_arg.append('None')
elif a is Ellipsis:
new_arg.append('...')
elif type(a) is slice:
t = (a.start, a.stop, a.step)
parts = []
for x in t:
if x is None:
parts.append('')
else:
parts.extend(_parse_args([x]))
r = ':'.join(parts)
if r.endswith(':'):
r = r[:-1]
new_arg.append(r)
elif isinstance(a, torch.nn.quantized.FloatFunctional):
float_functional_cls = type(a)
module_constructor_lines[id(a)] = f'{qualified_name(float_functional_cls, short=True)}()'
new_node = TraceNode(a)
current_graph().nodes_map[new_node.unique_name] = new_node
current_graph().other_init_nodes.append(new_node)
current_graph().tensor_pre_node_dict[id(a)] = new_node.unique_name
self.tensor_names.append(_tensor_name(a))
self.original_tensor_names.append(_tensor_name(a, original=True))
self.prev_tensors.append(a)
new_arg.append('{}')
elif isinstance(a, nn.Module):
unique_name = current_graph().module_unique_name_dict[id(a)]
current_graph().tensor_pre_node_dict[id(a)] = unique_name
self.tensor_names.append(f'self.{unique_name}')
self.original_tensor_names.append(_tensor_name(a, original=True))
self.prev_tensors.append(a)
new_arg.append('{}')
else:
log.error(f"unsupported type {type(a)} while generating arg for func {self.full_name}")
assert False
return new_arg
self.tensor_names = []
self.original_tensor_names = []
self.prev_tensors.clear()
arg_str = _parse_args(args)
kw_items = kwargs.items()
if kw_items:
kw_keys, kw_vals = zip(*kw_items)
kw_val_strs = _parse_args(kw_vals)
for (k, v) in zip(kw_keys, kw_val_strs):
if type(v) is list:
v_str = self._flatten_list(v)
arg_str.append(f"{k}={v_str}")
else:
arg_str.append(f"{k}={v}")
self.args_parsed = copy.deepcopy(arg_str)
self.kwargs = copy.deepcopy(kwargs)
self.args_parsed_origin = copy.deepcopy(self.args_parsed)
self.args_to_string(self.args_parsed)
return self