in modules/SwissArmyTransformer/sat/training/model_io.py [0:0]
def zero3_state_dict(module, *args, destination=None, prefix='', keep_vars=False):
r"""Return a dictionary containing references to the whole state of a already paritioned module.
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
warnings.warn(
"Positional args are being deprecated, use kwargs instead. Refer to "
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
" for details.")
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=module._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
for hook in module._state_dict_pre_hooks.values():
hook(module, prefix, keep_vars)
zero3_save_to_state_dict(module, destination, prefix, keep_vars)
for name, module in module._modules.items():
if module is not None:
zero3_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination