def pre_train_from_snapshots()

in seamseg/utils/snapshot.py [0:0]


def pre_train_from_snapshots(model, snapshots, modules):
    for snapshot in snapshots:
        if ":" in snapshot:
            module_name, snapshot = snapshot.split(":")
        else:
            module_name = None

        snapshot = torch.load(snapshot, map_location="cpu")
        state_dict = snapshot["state_dict"]

        if module_name is None:
            for module_name in modules:
                if module_name in state_dict:
                    _load_pretraining_dict(getattr(model, module_name), state_dict[module_name])
        else:
            if module_name in modules:
                _load_pretraining_dict(getattr(model, module_name), state_dict[module_name])
            else:
                raise ValueError("Unrecognized network module {}".format(module_name))