in lib/models/pose_resnet.py [0:0]
def init_weights(self, pretrained=''):
if os.path.isfile(pretrained):
logger.info('=> init deconv weights from normal distribution')
for name, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
logger.info('=> init {}.weight as normal(0, 0.001)'.format(name))
logger.info('=> init {}.bias as 0'.format(name))
nn.init.normal_(m.weight, std=0.001)
if self.deconv_with_bias:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
logger.info('=> init {}.weight as 1'.format(name))
logger.info('=> init {}.bias as 0'.format(name))
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
logger.info('=> init final conv weights from normal distribution')
for m in self.final_layer.modules():
if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
logger.info('=> init {}.weight as normal(0, 0.001)'.format(name))
logger.info('=> init {}.bias as 0'.format(name))
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
# pretrained_state_dict = torch.load(pretrained)
logger.info('=> loading pretrained model {}'.format(pretrained))
# self.load_state_dict(pretrained_state_dict, strict=False)
checkpoint = torch.load(pretrained)
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict_old = checkpoint['state_dict']
state_dict = OrderedDict()
# delete 'module.' because it is saved from DataParallel module
for key in state_dict_old.keys():
if key.startswith('module.'):
# state_dict[key[7:]] = state_dict[key]
# state_dict.pop(key)
state_dict[key[7:]] = state_dict_old[key]
else:
state_dict[key] = state_dict_old[key]
else:
raise RuntimeError(
'No state_dict found in checkpoint file {}'.format(pretrained))
self.load_state_dict(state_dict, strict=False)
else:
logger.error('=> imagenet pretrained model dose not exist')
logger.error('=> please download it first')
raise ValueError('imagenet pretrained model does not exist')