def parse_args()

in tinynn/graph/tracer.py [0:0]


    def parse_args(self, *args, **kwargs):
        """Sets the string representation of the arguments"""

        def _tensor_name(a, convert_to_parameter=False, original=False):
            """Get the tensor name from the computation graph"""
            ns = ''
            if constant_handler(a, self.unique_name, self.full_name):
                ns = 'self.'
                pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
                if original:
                    node = current_graph().nodes_map[pre_node_name]
                    pre_node_name = node.original_name
            else:
                pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
                node = current_graph().nodes_map[pre_node_name]
                if original:
                    pre_node_name = node.original_name
                if type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional):
                    ns = 'self.'
            if id(a) in current_graph().tensor_pre_index_dict:
                pre_node_index = current_graph().tensor_pre_index_dict[id(a)]
                log.debug(f'pre_index gen func {self.kind}: {pre_node_index}')
                if isinstance(pre_node_index, (list, tuple)):
                    indices_str = ''.join([f'[{i}]' for i in pre_node_index])
                    return f"{ns}{pre_node_name}{indices_str}"
                else:
                    return f"{ns}{pre_node_name}[{pre_node_index}]"
            else:
                return f"{ns}{pre_node_name}"

        def _escape_arg(arg: str):
            """Escapes the special characters in the argument string"""
            for c in ('{', '}'):
                if c in arg:
                    arg = arg.replace(c, f'{c}{c}')
            return arg

        def _parse_args(arg):
            """Converts the argument to a list of strings"""
            new_arg = []

            for a in arg:
                if isinstance(a, (list, tuple, torch.Size)):
                    new_arg.append(_parse_args(a))
                elif type(a) in (torch.Tensor, torch.nn.Parameter) or (
                    type(a) in (torch.dtype, torch.device, torch.Size) and id(a) in current_graph().tensor_pre_node_dict
                ):
                    self.prev_tensors.append(a)
                    self.tensor_names.append(_tensor_name(a))
                    self.original_tensor_names.append(_tensor_name(a, original=True))
                    new_arg.append('{}')
                elif type(a) in (str, torch.device):
                    new_arg.append(_escape_arg(f"\'{a}\'"))
                elif type(a) in (int, bool, torch.dtype):
                    new_arg.append(str(a))
                elif type(a) is float:
                    str_arg = str(a)
                    if str_arg in ('nan', 'inf', '-inf'):
                        new_arg.append(f"float('{str_arg}')")
                    else:
                        new_arg.append(str_arg)
                elif a is None:
                    new_arg.append('None')
                elif a is Ellipsis:
                    new_arg.append('...')
                elif type(a) is slice:
                    t = (a.start, a.stop, a.step)
                    parts = []
                    for x in t:
                        if x is None:
                            parts.append('')
                        else:
                            parts.extend(_parse_args([x]))
                    r = ':'.join(parts)
                    if r.endswith(':'):
                        r = r[:-1]
                    new_arg.append(r)
                elif isinstance(a, torch.nn.quantized.FloatFunctional):
                    float_functional_cls = type(a)
                    module_constructor_lines[id(a)] = f'{qualified_name(float_functional_cls, short=True)}()'

                    new_node = TraceNode(a)
                    current_graph().nodes_map[new_node.unique_name] = new_node
                    current_graph().other_init_nodes.append(new_node)
                    current_graph().tensor_pre_node_dict[id(a)] = new_node.unique_name
                    self.tensor_names.append(_tensor_name(a))
                    self.original_tensor_names.append(_tensor_name(a, original=True))
                    self.prev_tensors.append(a)
                    new_arg.append('{}')
                elif isinstance(a, nn.Module):
                    unique_name = current_graph().module_unique_name_dict[id(a)]
                    current_graph().tensor_pre_node_dict[id(a)] = unique_name
                    self.tensor_names.append(f'self.{unique_name}')
                    self.original_tensor_names.append(_tensor_name(a, original=True))
                    self.prev_tensors.append(a)
                    new_arg.append('{}')
                else:
                    log.error(f"unsupported type {type(a)} while generating arg for func {self.full_name}")
                    assert False

            return new_arg

        self.tensor_names = []
        self.original_tensor_names = []
        self.prev_tensors.clear()
        arg_str = _parse_args(args)

        kw_items = kwargs.items()
        if kw_items:
            kw_keys, kw_vals = zip(*kw_items)
            kw_val_strs = _parse_args(kw_vals)

            for (k, v) in zip(kw_keys, kw_val_strs):
                if type(v) is list:
                    v_str = self._flatten_list(v)
                    arg_str.append(f"{k}={v_str}")
                else:
                    arg_str.append(f"{k}={v}")

        self.args_parsed = copy.deepcopy(arg_str)
        self.kwargs = copy.deepcopy(kwargs)
        self.args_parsed_origin = copy.deepcopy(self.args_parsed)
        self.args_to_string(self.args_parsed)

        return self