models/resnet.py (159 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. ''' PyTorch implementation of ResNet taken from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py which is originally licensed under MIT. ''' import torch import torch.nn as nn import torch.nn.functional as F from models.base import BaseModel class BasicBlock(nn.Module): def __init__(self, in_planes, mid_planes, out_planes, norm, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = norm(mid_planes) self.conv2 = nn.Conv2d(mid_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = norm(out_planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != out_planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False), norm(out_planes) ) def forward(self, x): out = self.bn1(self.conv1(x)) out = F.relu(out) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) # print(out.size()) return out class ResNet(BaseModel): def __init__(self, block, num_blocks, num_classes=10, num_outputs=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): super(ResNet, self).__init__() if pooling == 'avgpool': self.pooling = nn.AvgPool2d(4) elif pooling == 'maxpool': self.pooling = nn.MaxPool2d(4) else: raise Exception('Unsupported pooling: %s' % pooling) self.in_planes = 64 self.return_features = return_features self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = norm(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, norm=norm) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, norm=norm) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, norm=norm) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, norm=norm) self.linear = nn.Linear(512, num_outputs) self.num_classes = num_classes self.num_outputs = num_outputs self.penultimate_layer_dim = 512 self.build_aux_layers() def _make_layer(self, block, planes, num_blocks, norm, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, planes, norm, stride)) self.in_planes = planes return nn.Sequential(*layers) def forward_features(self, x): c1 = F.relu(self.bn1(self.conv1(x))) # (3,32,32) h1 = self.layer1(c1) # (64,32,32) h2 = self.layer2(h1) # (128,16,16) h3 = self.layer3(h2) # (256,8,8) h4 = self.layer4(h3) # (512,4,4) p4 = self.pooling(h4) # (512,1,1) p4 = p4.view(p4.size(0), -1) # (512) return p4 def forward_classifier(self, p4): return self.linear(p4) # (10) def ResNet18(num_classes=10, num_outputs=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): ''' GFLOPS: 0.5579, model size: 11.1740MB ''' return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, num_outputs=num_outputs, pooling=pooling, norm=norm, return_features=return_features) def ResNet34(num_classes=10, num_outputs=10, pooling='avgpool', norm=nn.BatchNorm2d, return_features=False): ''' GFLOPS: 1.1635, model size: 21.2859MB ''' return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, num_outputs=num_outputs, pooling=pooling, norm=norm, return_features=return_features) if __name__ == '__main__': from thop import profile net = ResNet18(num_classes=10, num_outputs=10, return_features=True) x = torch.randn(1,3,32,32) flops, params = profile(net, inputs=(x, )) y, features = net(x) print(y.size()) print('GFLOPS: %.4f, model size: %.4fMB' % (flops/1e9, params/1e6)) ''' conv1.weight bn1.weight bn1.bias layer1.0.conv1.weight layer1.0.bn1.weight layer1.0.bn1.bias layer1.0.conv2.weight layer1.0.bn2.weight layer1.0.bn2.bias layer1.1.conv1.weight layer1.1.bn1.weight layer1.1.bn1.bias layer1.1.conv2.weight layer1.1.bn2.weight layer1.1.bn2.bias layer2.0.conv1.weight layer2.0.bn1.weight layer2.0.bn1.bias layer2.0.conv2.weight layer2.0.bn2.weight layer2.0.bn2.bias layer2.0.shortcut.0.weight layer2.0.shortcut.1.weight layer2.0.shortcut.1.bias layer2.1.conv1.weight layer2.1.bn1.weight layer2.1.bn1.bias layer2.1.conv2.weight layer2.1.bn2.weight layer2.1.bn2.bias layer3.0.conv1.weight layer3.0.bn1.weight layer3.0.bn1.bias layer3.0.conv2.weight layer3.0.bn2.weight layer3.0.bn2.bias layer3.0.shortcut.0.weight layer3.0.shortcut.1.weight layer3.0.shortcut.1.bias layer3.1.conv1.weight layer3.1.bn1.weight layer3.1.bn1.bias layer3.1.conv2.weight layer3.1.bn2.weight layer3.1.bn2.bias layer4.0.conv1.weight layer4.0.bn1.weight layer4.0.bn1.bias layer4.0.conv2.weight layer4.0.bn2.weight layer4.0.bn2.bias layer4.0.shortcut.0.weight layer4.0.shortcut.1.weight layer4.0.shortcut.1.bias layer4.1.conv1.weight layer4.1.bn1.weight layer4.1.bn1.bias layer4.1.conv2.weight layer4.1.bn2.weight layer4.1.bn2.bias linear.weight linear.bias projection.0.weight projection.0.bias projection.2.weight projection.2.bias '''