in main_lincls.py [0:0]
def sanity_check(state_dict, pretrained_weights, linear_keyword):
"""
Linear classifier should not change any weights other than the linear layer.
This sanity check asserts nothing wrong happens (e.g., BN stats updated).
"""
print("=> loading '{}' for sanity check".format(pretrained_weights))
checkpoint = torch.load(pretrained_weights, map_location="cpu")
state_dict_pre = checkpoint['state_dict']
for k in list(state_dict.keys()):
# only ignore linear layer
if '%s.weight' % linear_keyword in k or '%s.bias' % linear_keyword in k:
continue
# name in pretrained model
k_pre = 'module.base_encoder.' + k[len('module.'):] \
if k.startswith('module.') else 'module.base_encoder.' + k
assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
'{} is changed in linear classifier training.'.format(k)
print("=> sanity check passed.")