in func/train.py [0:0]
def init_model(model, ckpt_path, modules_to_keep, logger):
"""Initialize model with weights from ckpt_path.
Args:
ckpt_path (str): A string with path to file
modules_to_keep (str): A comma sep string with the module name prefix
that should be loaded from the checkpoint
"""
logger.debug('Initing %s with ckpt path: %s, using modules in it %s',
model, ckpt_path, modules_to_keep)
checkpoint = torch.load(ckpt_path, map_location="cpu")
if 'model' in checkpoint.keys():
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint.keys():
state_dict = checkpoint['state_dict']
elif 'classy_state_dict' in checkpoint.keys():
state_dict = checkpoint['classy_state_dict']
# This is likely coming from a VISSL codebase, so the actual trunk
# params will be as follows. Ideally support this more generally TODO
state_dict = state_dict['base_model']['model']['trunk']
else:
state_dict = checkpoint
if modules_to_keep:
# Keep only the elements of state_dict that match modules to keep.
# Also, remove that prefix from the names
filtered_state_dict = {}
for key, val in state_dict.items():
for mod_name in modules_to_keep.split(','):
if key.startswith(mod_name):
filtered_state_dict[key[len(mod_name):]] = val
continue
state_dict = filtered_state_dict
# Ignore any parameters/buffers (bn mean/var) where shape does not match
for name, param in itertools.chain(model.named_parameters(),
model.named_buffers()):
if name in state_dict and state_dict[name].shape != param.shape:
logger.warning('Ckpt shape mismatch for %s (%s vs %s). Ignoring.',
name, state_dict[name].shape, param.shape)
del state_dict[name]
missing_keys, unexp_keys = model.load_state_dict(state_dict, strict=False)
logger.warning('Could not init from %s: %s', ckpt_path, missing_keys)
logger.warning('Unused keys in %s: %s', ckpt_path, unexp_keys)