scripts/models/resnet.py (96 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
import sys
from collections import OrderedDict
from functools import partial
import torch.nn as nn
from inplace_abn import ABN
from modules import GlobalAvgPool2d, ResidualBlock
from .util import try_index
class ResNet(nn.Module):
"""Standard residual network
Parameters
----------
structure : list of int
Number of residual blocks in each of the four modules of the network
bottleneck : bool
If `True` use "bottleneck" residual blocks with 3 convolutions, otherwise use standard blocks
norm_act : callable
Function to create normalization / activation Module
classes : int
If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
of the network
dilation : int or list of int
List of dilation factors for the four modules of the network, or `1` to ignore dilation
keep_outputs : bool
If `True` output a list with the outputs of all modules
"""
def __init__(
self,
structure,
bottleneck,
norm_act=ABN,
classes=0,
dilation=1,
keep_outputs=False,
):
super(ResNet, self).__init__()
self.structure = structure
self.bottleneck = bottleneck
self.dilation = dilation
self.keep_outputs = keep_outputs
if len(structure) != 4:
raise ValueError("Expected a structure with four values")
if dilation != 1 and len(dilation) != 4:
raise ValueError("If dilation is not 1 it must contain four values")
# Initial layers
layers = [
("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)),
("bn1", norm_act(64)),
]
if try_index(dilation, 0) == 1:
layers.append(("pool1", nn.MaxPool2d(3, stride=2, padding=1)))
self.mod1 = nn.Sequential(OrderedDict(layers))
# Groups of residual blocks
in_channels = 64
if self.bottleneck:
channels = (64, 64, 256)
else:
channels = (64, 64)
for mod_id, num in enumerate(structure):
# Create blocks for module
blocks = []
for block_id in range(num):
stride, dil = self._stride_dilation(dilation, mod_id, block_id)
blocks.append(
(
"block%d" % (block_id + 1),
ResidualBlock(
in_channels,
channels,
norm_act=norm_act,
stride=stride,
dilation=dil,
),
)
)
# Update channels and p_keep
in_channels = channels[-1]
# Create module
self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
# Double the number of channels for the next module
channels = [c * 2 for c in channels]
# Pooling and predictor
if classes != 0:
self.classifier = nn.Sequential(
OrderedDict(
[
("avg_pool", GlobalAvgPool2d()),
("fc", nn.Linear(in_channels, classes)),
]
)
)
@staticmethod
def _stride_dilation(dilation, mod_id, block_id):
d = try_index(dilation, mod_id)
s = 2 if d == 1 and block_id == 0 and mod_id > 0 else 1
return s, d
def forward(self, x):
outs = list()
outs.append(self.mod1(x))
outs.append(self.mod2(outs[-1]))
outs.append(self.mod3(outs[-1]))
outs.append(self.mod4(outs[-1]))
outs.append(self.mod5(outs[-1]))
if hasattr(self, "classifier"):
outs.append(self.classifier(outs[-1]))
if self.keep_outputs:
return outs
else:
return outs[-1]
_NETS = {
"18": {"structure": [2, 2, 2, 2], "bottleneck": False},
"34": {"structure": [3, 4, 6, 3], "bottleneck": False},
"50": {"structure": [3, 4, 6, 3], "bottleneck": True},
"101": {"structure": [3, 4, 23, 3], "bottleneck": True},
"152": {"structure": [3, 8, 36, 3], "bottleneck": True},
}
__all__ = []
for name, params in _NETS.items():
net_name = "net_resnet" + name
setattr(sys.modules[__name__], net_name, partial(ResNet, **params))
__all__.append(net_name)