def gen_module_constrctor_line()

in tinynn/graph/tracer.py [0:0]


def gen_module_constrctor_line(module, mod_cache=None):
    """Generates the constructor line for a loaded module"""
    ignored_args = {
        'torch.nn.ZeroPad2d': ['value'],
        'torch.nn.UpsamplingNearest2d': ['mode'],
        'torch.nn.UpsamplingBilinear2d': ['mode'],
    }
    legacy_modules = {
        'torch.nn.FractionalMaxPool2d',
        'torch.nn.FractionalMaxPool3d',
        'torch.nn.TransformerEncoderLayer',
        'torch.nn.TransformerDecoderLayer',
        'torch.nn.Upsample',
    }

    def _skip_ignored_args(name, *args, **kwargs):
        iargs = ignored_args[name]
        for arg in iargs:
            if isinstance(arg, int):
                if arg >= 0 and arg < len(args):
                    del args[arg]
            elif isinstance(arg, str):
                if arg in kwargs:
                    del kwargs[arg]
            else:
                raise AttributeError(f'Unknown type {type(arg)} in ignored args')
        return args_as_string(args, kwargs)

    name = qualified_name(type(module), short=True)

    if mod_cache is None:
        mod_cache = {}
    module_cls = type(module)
    if name in legacy_modules:
        if hasattr(module_cls, '__constants__') and hasattr(module_cls, '__init__'):
            known_constants = set(module_cls.__constants__)
            arg_info = get_constructor_args(module_cls)
            args = []
            for p in arg_info:
                prop_name = p.name
                if prop_name == 'self':
                    continue
                if prop_name not in known_constants:
                    if not hasattr(module_cls, prop_name) or not isinstance(
                        getattr(module_cls, prop_name), (type(None), str, int, float, bool)
                    ):
                        log.warning(
                            f'Argument "{prop_name}" of the constructor of {name} is not a known constant, skipping'
                        )
                        continue
                if p.default is p.empty:
                    prop_value = getattr(module, prop_name)
                    args.append(f'{prop_value}')
                else:
                    # If loading from legacy model, then the property can be missing.
                    # Thus, we should skip it so as to use the default value.
                    if not hasattr(module, prop_name):
                        continue
                    prop_value = getattr(module, prop_name)
                    # Appending keyword args
                    default_value = p.default
                    # Skip the arg if it has the same value with the default one
                    if default_value == prop_value:
                        continue
                    if type(prop_value) is str:
                        prop_value_str = f'"{prop_value}"'
                    else:
                        prop_value_str = prop_value
                    args.append(f'{prop_name}={prop_value_str}')
            mid = ', '.join(args)
            result = f'{name}({mid})'
    else:
        with patch(module_cls, '__getattribute__', new_module_getattr_gen, name, mod_cache):
            if getattr(module_cls, '__repr__') == getattr(torch.nn.Module, '__repr__'):
                result = f'{name}({module.extra_repr()})'
            else:
                ns = '.'.join(name.split('.')[:-1])
                result = f'{ns}.{repr(module)}'

        if name in ignored_args:
            start_pos = result.find('(')
            end_pos = result.rfind(')')
            head = result[: start_pos + 1]
            tail = result[end_pos:]
            mid = result[start_pos + 1 : end_pos]
            result = head + eval(f'_skip_ignored_args("{name}", {mid})') + tail

    return result, mod_cache