models/resnet_imagenet.py (227 lines of code) (raw):
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import Bottleneck, ResNet
from models.base import BaseModel
class ResNet_ImageNet(ResNet, BaseModel):
def __init__(self, block, num_blocks, num_classes=1000, num_outputs=1000,
return_features=False):
super(ResNet_ImageNet, self).__init__(block, num_blocks, num_classes=num_outputs)
self.return_features = return_features
self.penultimate_layer_dim = self.fc.weight.shape[1]
# print('self.penultimate_layer_dim:', self.penultimate_layer_dim)
self.num_classes = num_classes
self.num_outputs = num_outputs
self.build_aux_layers()
self.forward = BaseModel.forward
def forward_features(self, x):
c1 = self.maxpool(self.relu(self.bn1(self.conv1(x))))
h1 = self.layer1(c1) # (64,32,32)
h2 = self.layer2(h1) # (128,16,16)
h3 = self.layer3(h2) # (256,8,8)
h4 = self.layer4(h3) # (512,4,4)
p4 = self.avgpool(h4) # (512,1,1)
p4 = torch.flatten(p4, 1) # (512)
return p4
def forward_classifier(self, p4):
return self.fc(p4) # (10)
def check_fc_dict(self, state_dict):
if state_dict['fc.weight'].shape != self.fc.weight.shape:
new_node_num = self.fc.weight.shape[0] - state_dict['fc.weight'].shape[0]
state_dict['fc.weight'] = torch.cat((state_dict['fc.weight'], self.fc.weight[-new_node_num:, :]), dim=0)
state_dict['fc.bias'] = torch.cat((state_dict['fc.bias'], self.fc.bias[-new_node_num:]), dim=0)
def ResNet50(num_classes=1000, num_outputs=1000, return_features=False):
return ResNet_ImageNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, num_outputs=num_outputs,
return_features=return_features)
def ResNet18(num_classes=2, num_outputs=2, return_features=False):
return ResNet_ImageNet(Bottleneck, [2, 2, 2, 2], num_classes=num_classes, num_outputs=num_outputs,
return_features=return_features)
if __name__ == '__main__':
from thop import profile
net = ResNet50(num_classes=10, num_outputs=10)
x = torch.randn(1,3,224,224)
flops, params = profile(net, inputs=(x, ))
y = net(x)
print(y.size())
print('GFLOPS: %.4f, #params: %.4fM' % (flops/1e9, params/1e6)) # GFLOPS: 4.1095, #params: 23.5285M
bn_parameter_number, fc_parameter_number, all_parameter_number = 0, 0, 0
for name, p in net.named_parameters():
if 'bn' in name:
bn_parameter_number += p.numel()
if 'fc' in name:
fc_parameter_number += p.numel()
if 'projection' not in name:
all_parameter_number += p.numel()
all_size = all_parameter_number * 4 /1e6
bn_size = bn_parameter_number * 4 /1e6
fc_size = fc_parameter_number * 4 /1e6
print('all_size: %s MB' % (all_size), 2*all_size)
print('bn_size: %s MB' % (all_size+bn_size), bn_size)
print('fc_size: %s MB' % (all_size+fc_size), fc_size)
print('both_size: %s MB' % (all_size+bn_size+fc_size))
'''
module.conv1.weight
module.bn1.weight
module.bn1.bias
module.layer1.0.conv1.weight
module.layer1.0.bn1.weight
module.layer1.0.bn1.bias
module.layer1.0.conv2.weight
module.layer1.0.bn2.weight
module.layer1.0.bn2.bias
module.layer1.0.conv3.weight
module.layer1.0.bn3.weight
module.layer1.0.bn3.bias
module.layer1.0.downsample.0.weight
module.layer1.0.downsample.1.weight
module.layer1.0.downsample.1.bias
module.layer1.1.conv1.weight
module.layer1.1.bn1.weight
module.layer1.1.bn1.bias
module.layer1.1.conv2.weight
module.layer1.1.bn2.weight
module.layer1.1.bn2.bias
module.layer1.1.conv3.weight
module.layer1.1.bn3.weight
module.layer1.1.bn3.bias
module.layer1.2.conv1.weight
module.layer1.2.bn1.weight
module.layer1.2.bn1.bias
module.layer1.2.conv2.weight
module.layer1.2.bn2.weight
module.layer1.2.bn2.bias
module.layer1.2.conv3.weight
module.layer1.2.bn3.weight
module.layer1.2.bn3.bias
module.layer2.0.conv1.weight
module.layer2.0.bn1.weight
module.layer2.0.bn1.bias
module.layer2.0.conv2.weight
module.layer2.0.bn2.weight
module.layer2.0.bn2.bias
module.layer2.0.conv3.weight
module.layer2.0.bn3.weight
module.layer2.0.bn3.bias
module.layer2.0.downsample.0.weight
module.layer2.0.downsample.1.weight
module.layer2.0.downsample.1.bias
module.layer2.1.conv1.weight
module.layer2.1.bn1.weight
module.layer2.1.bn1.bias
module.layer2.1.conv2.weight
module.layer2.1.bn2.weight
module.layer2.1.bn2.bias
module.layer2.1.conv3.weight
module.layer2.1.bn3.weight
module.layer2.1.bn3.bias
module.layer2.2.conv1.weight
module.layer2.2.bn1.weight
module.layer2.2.bn1.bias
module.layer2.2.conv2.weight
module.layer2.2.bn2.weight
module.layer2.2.bn2.bias
module.layer2.2.conv3.weight
module.layer2.2.bn3.weight
module.layer2.2.bn3.bias
module.layer2.3.conv1.weight
module.layer2.3.bn1.weight
module.layer2.3.bn1.bias
module.layer2.3.conv2.weight
module.layer2.3.bn2.weight
module.layer2.3.bn2.bias
module.layer2.3.conv3.weight
module.layer2.3.bn3.weight
module.layer2.3.bn3.bias
module.layer3.0.conv1.weight
module.layer3.0.bn1.weight
module.layer3.0.bn1.bias
module.layer3.0.conv2.weight
module.layer3.0.bn2.weight
module.layer3.0.bn2.bias
module.layer3.0.conv3.weight
module.layer3.0.bn3.weight
module.layer3.0.bn3.bias
module.layer3.0.downsample.0.weight
module.layer3.0.downsample.1.weight
module.layer3.0.downsample.1.bias
module.layer3.1.conv1.weight
module.layer3.1.bn1.weight
module.layer3.1.bn1.bias
module.layer3.1.conv2.weight
module.layer3.1.bn2.weight
module.layer3.1.bn2.bias
module.layer3.1.conv3.weight
module.layer3.1.bn3.weight
module.layer3.1.bn3.bias
module.layer3.2.conv1.weight
module.layer3.2.bn1.weight
module.layer3.2.bn1.bias
module.layer3.2.conv2.weight
module.layer3.2.bn2.weight
module.layer3.2.bn2.bias
module.layer3.2.conv3.weight
module.layer3.2.bn3.weight
module.layer3.2.bn3.bias
module.layer3.3.conv1.weight
module.layer3.3.bn1.weight
module.layer3.3.bn1.bias
module.layer3.3.conv2.weight
module.layer3.3.bn2.weight
module.layer3.3.bn2.bias
module.layer3.3.conv3.weight
module.layer3.3.bn3.weight
module.layer3.3.bn3.bias
module.layer3.4.conv1.weight
module.layer3.4.bn1.weight
module.layer3.4.bn1.bias
module.layer3.4.conv2.weight
module.layer3.4.bn2.weight
module.layer3.4.bn2.bias
module.layer3.4.conv3.weight
module.layer3.4.bn3.weight
module.layer3.4.bn3.bias
module.layer3.5.conv1.weight
module.layer3.5.bn1.weight
module.layer3.5.bn1.bias
module.layer3.5.conv2.weight
module.layer3.5.bn2.weight
module.layer3.5.bn2.bias
module.layer3.5.conv3.weight
module.layer3.5.bn3.weight
module.layer3.5.bn3.bias
module.layer4.0.conv1.weight
module.layer4.0.bn1.weight
module.layer4.0.bn1.bias
module.layer4.0.conv2.weight
module.layer4.0.bn2.weight
module.layer4.0.bn2.bias
module.layer4.0.conv3.weight
module.layer4.0.bn3.weight
module.layer4.0.bn3.bias
module.layer4.0.downsample.0.weight
module.layer4.0.downsample.1.weight
module.layer4.0.downsample.1.bias
module.layer4.1.conv1.weight
module.layer4.1.bn1.weight
module.layer4.1.bn1.bias
module.layer4.1.conv2.weight
module.layer4.1.bn2.weight
module.layer4.1.bn2.bias
module.layer4.1.conv3.weight
module.layer4.1.bn3.weight
module.layer4.1.bn3.bias
module.layer4.2.conv1.weight
module.layer4.2.bn1.weight
module.layer4.2.bn1.bias
module.layer4.2.conv2.weight
module.layer4.2.bn2.weight
module.layer4.2.bn2.bias
module.layer4.2.conv3.weight
module.layer4.2.bn3.weight
module.layer4.2.bn3.bias
module.fc.weight
module.fc.bias
module.projection.0.weight
module.projection.0.bias
module.projection.2.weight
module.projection.2.bias
'''