in tinynn/graph/tracer.py [0:0]
def gen_module_constrctor_line(module, mod_cache=None):
"""Generates the constructor line for a loaded module"""
ignored_args = {
'torch.nn.ZeroPad2d': ['value'],
'torch.nn.UpsamplingNearest2d': ['mode'],
'torch.nn.UpsamplingBilinear2d': ['mode'],
}
legacy_modules = {
'torch.nn.FractionalMaxPool2d',
'torch.nn.FractionalMaxPool3d',
'torch.nn.TransformerEncoderLayer',
'torch.nn.TransformerDecoderLayer',
'torch.nn.Upsample',
}
def _skip_ignored_args(name, *args, **kwargs):
iargs = ignored_args[name]
for arg in iargs:
if isinstance(arg, int):
if arg >= 0 and arg < len(args):
del args[arg]
elif isinstance(arg, str):
if arg in kwargs:
del kwargs[arg]
else:
raise AttributeError(f'Unknown type {type(arg)} in ignored args')
return args_as_string(args, kwargs)
name = qualified_name(type(module), short=True)
if mod_cache is None:
mod_cache = {}
module_cls = type(module)
if name in legacy_modules:
if hasattr(module_cls, '__constants__') and hasattr(module_cls, '__init__'):
known_constants = set(module_cls.__constants__)
arg_info = get_constructor_args(module_cls)
args = []
for p in arg_info:
prop_name = p.name
if prop_name == 'self':
continue
if prop_name not in known_constants:
if not hasattr(module_cls, prop_name) or not isinstance(
getattr(module_cls, prop_name), (type(None), str, int, float, bool)
):
log.warning(
f'Argument "{prop_name}" of the constructor of {name} is not a known constant, skipping'
)
continue
if p.default is p.empty:
prop_value = getattr(module, prop_name)
args.append(f'{prop_value}')
else:
# If loading from legacy model, then the property can be missing.
# Thus, we should skip it so as to use the default value.
if not hasattr(module, prop_name):
continue
prop_value = getattr(module, prop_name)
# Appending keyword args
default_value = p.default
# Skip the arg if it has the same value with the default one
if default_value == prop_value:
continue
if type(prop_value) is str:
prop_value_str = f'"{prop_value}"'
else:
prop_value_str = prop_value
args.append(f'{prop_name}={prop_value_str}')
mid = ', '.join(args)
result = f'{name}({mid})'
else:
with patch(module_cls, '__getattribute__', new_module_getattr_gen, name, mod_cache):
if getattr(module_cls, '__repr__') == getattr(torch.nn.Module, '__repr__'):
result = f'{name}({module.extra_repr()})'
else:
ns = '.'.join(name.split('.')[:-1])
result = f'{ns}.{repr(module)}'
if name in ignored_args:
start_pos = result.find('(')
end_pos = result.rfind(')')
head = result[: start_pos + 1]
tail = result[end_pos:]
mid = result[start_pos + 1 : end_pos]
result = head + eval(f'_skip_ignored_args("{name}", {mid})') + tail
return result, mod_cache