in utils/util.py [0:0]
def load_pretrained(args, model, logger=print):
ckpt = torch.load(args.pretrained_model, map_location='cpu')
if len(ckpt) == 3: # moco initialization
ckpt = {k[17:]: v for k, v in ckpt['state_dict'].items() if k.startswith('module.encoder_q')}
for fc in ('fc_inter', 'fc_intra', 'fc_order', 'fc_tsn'):
ckpt[fc + '.0.weight'] = ckpt['fc.0.weight']
ckpt[fc + '.0.bias'] = ckpt['fc.0.bias']
ckpt[fc + '.2.weight'] = ckpt['fc.2.weight']
ckpt[fc + '.2.bias'] = ckpt['fc.2.bias']
else:
ckpt = ckpt['model']
[misskeys, unexpkeys] = model.load_state_dict(ckpt, strict=False)
logger('Missing keys: {}'.format(misskeys))
logger('Unexpect keys: {}'.format(unexpkeys))
logger("==> loaded checkpoint '{}'".format(args.pretrained_model))