def zero3_state_dict()

in modules/SwissArmyTransformer/sat/training/model_io.py [0:0]


def zero3_state_dict(module, *args, destination=None, prefix='', keep_vars=False):
    r"""Return a dictionary containing references to the whole state of a already paritioned module.
    """
    # TODO: Remove `args` and the parsing logic when BC allows.
    if len(args) > 0:
        if destination is None:
            destination = args[0]
        if len(args) > 1 and prefix == '':
            prefix = args[1]
        if len(args) > 2 and keep_vars is False:
            keep_vars = args[2]
        # DeprecationWarning is ignored by default
        warnings.warn(
            "Positional args are being deprecated, use kwargs instead. Refer to "
            "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
            " for details.")

    if destination is None:
        destination = OrderedDict()
        destination._metadata = OrderedDict()

    local_metadata = dict(version=module._version)
    if hasattr(destination, "_metadata"):
        destination._metadata[prefix[:-1]] = local_metadata

    for hook in module._state_dict_pre_hooks.values():
        hook(module, prefix, keep_vars)
    zero3_save_to_state_dict(module, destination, prefix, keep_vars)
    for name, module in module._modules.items():
        if module is not None:
            zero3_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
    for hook in module._state_dict_hooks.values():
        hook_result = hook(module, destination, prefix, local_metadata)
        if hook_result is not None:
            destination = hook_result
    return destination