scripts/models/wider_resnet.py (162 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. import sys from collections import OrderedDict from functools import partial import torch.nn as nn from inplace_abn import ABN from modules import IdentityResidualBlock, GlobalAvgPool2d class WiderResNet(nn.Module): def __init__(self, structure, norm_act=ABN, classes=0): """Wider ResNet with pre-activation (identity mapping) blocks Parameters ---------- structure : list of int Number of residual blocks in each of the six modules of the network. norm_act : callable Function to create normalization / activation Module. classes : int If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end of the network. """ super(WiderResNet, self).__init__() self.structure = structure if len(structure) != 6: raise ValueError("Expected a structure with six values") # Initial layers self.mod1 = nn.Sequential( OrderedDict( [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] ) ) # Groups of residual blocks in_channels = 64 channels = [ (128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), (1024, 2048, 4096), ] for mod_id, num in enumerate(structure): # Create blocks for module blocks = [] for block_id in range(num): blocks.append( ( "block%d" % (block_id + 1), IdentityResidualBlock( in_channels, channels[mod_id], norm_act=norm_act ), ) ) # Update channels and p_keep in_channels = channels[mod_id][-1] # Create module if mod_id <= 4: self.add_module( "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) ) self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) # Pooling and predictor self.bn_out = norm_act(in_channels) if classes != 0: self.classifier = nn.Sequential( OrderedDict( [ ("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes)), ] ) ) def forward(self, img): out = self.mod1(img) out = self.mod2(self.pool2(out)) out = self.mod3(self.pool3(out)) out = self.mod4(self.pool4(out)) out = self.mod5(self.pool5(out)) out = self.mod6(self.pool6(out)) out = self.mod7(out) out = self.bn_out(out) if hasattr(self, "classifier"): out = self.classifier(out) return out class WiderResNetA2(nn.Module): def __init__(self, structure, norm_act=ABN, classes=0, dilation=False): """Wider ResNet with pre-activation (identity mapping) blocks This variant uses down-sampling by max-pooling in the first two blocks and by strided convolution in the others. Parameters ---------- structure : list of int Number of residual blocks in each of the six modules of the network. norm_act : callable Function to create normalization / activation Module. classes : int If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end of the network. dilation : bool If `True` apply dilation to the last three modules and change the down-sampling factor from 32 to 8. """ super(WiderResNetA2, self).__init__() self.structure = structure self.dilation = dilation if len(structure) != 6: raise ValueError("Expected a structure with six values") # Initial layers self.mod1 = nn.Sequential( OrderedDict( [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] ) ) # Groups of residual blocks in_channels = 64 channels = [ (128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), (1024, 2048, 4096), ] for mod_id, num in enumerate(structure): # Create blocks for module blocks = [] for block_id in range(num): if not dilation: dil = 1 stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 else: if mod_id == 3: dil = 2 elif mod_id > 3: dil = 4 else: dil = 1 stride = 2 if block_id == 0 and mod_id == 2 else 1 if mod_id == 4: drop = partial(nn.Dropout2d, p=0.3) elif mod_id == 5: drop = partial(nn.Dropout2d, p=0.5) else: drop = None blocks.append( ( "block%d" % (block_id + 1), IdentityResidualBlock( in_channels, channels[mod_id], norm_act=norm_act, stride=stride, dilation=dil, dropout=drop, ), ) ) # Update channels and p_keep in_channels = channels[mod_id][-1] # Create module if mod_id < 2: self.add_module( "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) ) self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) # Pooling and predictor self.bn_out = norm_act(in_channels) if classes != 0: self.classifier = nn.Sequential( OrderedDict( [ ("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes)), ] ) ) def forward(self, img): out = self.mod1(img) out = self.mod2(self.pool2(out)) out = self.mod3(self.pool3(out)) out = self.mod4(out) out = self.mod5(out) out = self.mod6(out) out = self.mod7(out) out = self.bn_out(out) if hasattr(self, "classifier"): return self.classifier(out) else: return out _NETS = { "16": {"structure": [1, 1, 1, 1, 1, 1]}, "20": {"structure": [1, 1, 1, 3, 1, 1]}, "38": {"structure": [3, 3, 6, 3, 1, 1]}, } __all__ = [] for name, params in _NETS.items(): net_name = "net_wider_resnet" + name setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) __all__.append(net_name) for name, params in _NETS.items(): net_name = "net_wider_resnet" + name + "_a2" setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) __all__.append(net_name)