scripts/models/resnext.py (104 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
import sys
from collections import OrderedDict
from functools import partial
import torch.nn as nn
from inplace_abn import ABN
from modules import IdentityResidualBlock, GlobalAvgPool2d
from .util import try_index
class ResNeXt(nn.Module):
def __init__(
self,
structure,
groups=64,
norm_act=ABN,
input_3x3=False,
classes=0,
dilation=1,
base_channels=(128, 128, 256),
):
"""Pre-activation (identity mapping) ResNeXt model
Parameters
----------
structure : list of int
Number of residual blocks in each of the four modules of the network.
groups : int
Number of groups in each ResNeXt block
norm_act : callable
Function to create normalization / activation Module.
input_3x3 : bool
If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one.
classes : int
If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
of the network.
dilation : list of list of int or list of int or int
List of dilation factors, or `1` to ignore dilation. For each module, if a single value is given it is
used for all its blocks, otherwise this expects a value for each block.
base_channels : list of int
Channels in the blocks of the first residual module. Each following module will multiply these values by 2.
"""
super(ResNeXt, self).__init__()
self.structure = structure
if len(structure) != 4:
raise ValueError("Expected a structure with four values")
if dilation != 1 and len(dilation) != 4:
raise ValueError("If dilation is not 1 it must contain four values")
# Initial layers
if input_3x3:
layers = [
("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)),
("bn1", norm_act(64)),
("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
("bn2", norm_act(64)),
("conv3", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
("pool", nn.MaxPool2d(3, stride=2, padding=1)),
]
else:
layers = [
("conv1", nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)),
("pool", nn.MaxPool2d(3, stride=2, padding=1)),
]
self.mod1 = nn.Sequential(OrderedDict(layers))
# Groups of residual blocks
in_channels = 64
channels = base_channels
for mod_id, num in enumerate(structure):
# Create blocks for module
blocks = []
for block_id in range(num):
s, d = self._stride_dilation(mod_id, block_id, dilation)
blocks.append(
(
"block%d" % (block_id + 1),
IdentityResidualBlock(
in_channels,
channels,
stride=s,
norm_act=norm_act,
groups=groups,
dilation=d,
),
)
)
# Update channels
in_channels = channels[-1]
# Create and add module
self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
channels = [c * 2 for c in channels]
# Pooling and predictor
self.bn_out = norm_act(in_channels)
if classes != 0:
self.classifier = nn.Sequential(
OrderedDict(
[
("avg_pool", GlobalAvgPool2d()),
("fc", nn.Linear(in_channels, classes)),
]
)
)
def forward(self, img):
out = self.mod1(img)
out = self.mod2(out)
out = self.mod3(out)
out = self.mod4(out)
out = self.mod5(out)
out = self.bn_out(out)
if hasattr(self, "classifier"):
out = self.classifier(out)
return out
@staticmethod
def _stride_dilation(mod_id, block_id, dilation):
if dilation == 1:
s = 2 if mod_id > 0 and block_id == 0 else 1
d = 1
else:
if dilation[mod_id] == 1:
s = 2 if mod_id > 0 and block_id == 0 else 1
d = 1
else:
s = 1
d = try_index(dilation[mod_id], block_id)
return s, d
_NETS = {
"50": {"structure": [3, 4, 6, 3]},
"101": {"structure": [3, 4, 23, 3]},
"152": {"structure": [3, 8, 36, 3]},
}
__all__ = []
for name, params in _NETS.items():
net_name = "net_resnext" + name
setattr(sys.modules[__name__], net_name, partial(ResNeXt, **params))
__all__.append(net_name)