in Dassl.pytorch/dassl/utils/torchtools.py [0:0]
def load_pretrained_weights(model, weight_path):
r"""Load pretrianed weights to model.
Features::
- Incompatible layers (unmatched in name or size) will be ignored.
- Can automatically deal with keys containing "module.".
Args:
model (nn.Module): network model.
weight_path (str): path to pretrained weights.
Examples::
>>> weight_path = 'log/my_model/model-best.pth.tar'
>>> load_pretrained_weights(model, weight_path)
"""
checkpoint = load_checkpoint(weight_path)
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
if k.startswith("module."):
k = k[7:] # discard module.
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict)
if len(matched_layers) == 0:
warnings.warn(
f"Cannot load {weight_path} (check the key names manually)"
)
else:
print(f"Successfully loaded pretrained weights from {weight_path}")
if len(discarded_layers) > 0:
print(
f"Layers discarded due to unmatched keys or size: {discarded_layers}"
)