def load_objects()

in ignite/handlers/checkpoint.py [0:0]


    def load_objects(to_load: Mapping, checkpoint: Union[str, Mapping], **kwargs: Any) -> None:
        """Helper method to apply ``load_state_dict`` on the objects from ``to_load`` using states from ``checkpoint``.

        Args:
            to_load: a dictionary with objects, e.g. `{"model": model, "optimizer": optimizer, ...}`
            checkpoint: a string filepath or a dictionary with state_dicts to load, e.g. `{"model": model_state_dict,
                "optimizer": opt_state_dict}`. If `to_load` contains a single key, then checkpoint can contain
                directly corresponding state_dict.
            kwargs: Keyword arguments accepted for `nn.Module.load_state_dict()`. Passing `strict=False` enables
                the user to load part of the pretrained model (useful for example, in Transfer Learning)

        Examples:
            .. code-block:: python

                import torch
                from ignite.engine import Engine, Events
                from ignite.handlers import ModelCheckpoint, Checkpoint
                trainer = Engine(lambda engine, batch: None)
                handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True)
                model = torch.nn.Linear(3, 3)
                optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
                to_save = {"weights": model, "optimizer": optimizer}
                trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save)
                trainer.run(torch.randn(10, 1), 5)

                to_load = to_save
                checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
                checkpoint = torch.load(checkpoint_fp)
                Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

                # or using a string for checkpoint filepath

                to_load = to_save
                checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth"
                Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint_fp)

        Note:
            If ``to_load`` contains objects of type torch `DistributedDataParallel`_ or
            `DataParallel`_, method ``load_state_dict`` will applied to their internal wrapped model (``obj.module``).

        .. _DistributedDataParallel: https://pytorch.org/docs/stable/generated/
            torch.nn.parallel.DistributedDataParallel.html
        .. _DataParallel: https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html
        """

        if isinstance(checkpoint, str):
            checkpoint_obj = torch.load(checkpoint)
        else:
            checkpoint_obj = checkpoint

        Checkpoint._check_objects(to_load, "load_state_dict")
        if not isinstance(checkpoint, (collections.Mapping, str)):
            raise TypeError(f"Argument checkpoint should be a string or a dictionary, but given {type(checkpoint)}")

        if len(kwargs) > 1 or any(k for k in kwargs if k not in ["strict"]):
            warnings.warn("kwargs contains keys other than strict and these will be ignored")

        is_state_dict_strict = kwargs.get("strict", True)
        if len(to_load) == 1:
            # single object and checkpoint is directly a state_dict
            key, obj = list(to_load.items())[0]
            if key not in checkpoint_obj:
                if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
                    obj = obj.module
                obj.load_state_dict(checkpoint_obj, strict=is_state_dict_strict)
                return

        # multiple objects to load
        for k, obj in to_load.items():
            if k not in checkpoint_obj:
                raise ValueError(f"Object labeled by '{k}' from `to_load` is not found in the checkpoint")
            if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
                obj = obj.module
            if isinstance(obj, torch.nn.Module):
                obj.load_state_dict(checkpoint_obj[k], strict=is_state_dict_strict)
            else:
                obj.load_state_dict(checkpoint_obj[k])