in seamseg/utils/snapshot.py [0:0]
def pre_train_from_snapshots(model, snapshots, modules):
for snapshot in snapshots:
if ":" in snapshot:
module_name, snapshot = snapshot.split(":")
else:
module_name = None
snapshot = torch.load(snapshot, map_location="cpu")
state_dict = snapshot["state_dict"]
if module_name is None:
for module_name in modules:
if module_name in state_dict:
_load_pretraining_dict(getattr(model, module_name), state_dict[module_name])
else:
if module_name in modules:
_load_pretraining_dict(getattr(model, module_name), state_dict[module_name])
else:
raise ValueError("Unrecognized network module {}".format(module_name))