scripts/models/densenet.py (102 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
import sys
from collections import OrderedDict
from functools import partial
import torch.nn as nn
from inplace_abn import ABN
from modules import GlobalAvgPool2d, DenseModule
from .util import try_index
class DenseNet(nn.Module):
def __init__(
self,
structure,
norm_act=ABN,
input_3x3=False,
growth=32,
theta=0.5,
classes=0,
dilation=1,
):
"""DenseNet
Parameters
----------
structure : list of int
Number of layers in each of the four dense blocks of the network.
norm_act : callable
Function to create normalization / activation Module.
input_3x3 : bool
If `True` use three `3x3` convolutions in the input module instead of a single `7x7` one.
growth : int
Number of channels in each layer, i.e. the "growth" factor of the DenseNet.
theta : float
Reduction factor for the transition blocks.
classes : int
If not `0` also include global average pooling and a fully-connected layer with `classes` outputs at the end
of the network.
dilation : int or list of int
List of dilation factors, or `1` to ignore dilation. If the dilation factor for a module is greater than `1`
skip the pooling in the transition block right before it.
"""
super(DenseNet, self).__init__()
self.structure = structure
if len(structure) != 4:
raise ValueError("Expected a structure with four values")
# Initial layers
if input_3x3:
layers = [
("conv1", nn.Conv2d(3, growth * 2, 3, stride=2, padding=1, bias=False)),
("bn1", norm_act(growth * 2)),
(
"conv2",
nn.Conv2d(
growth * 2, growth * 2, 3, stride=1, padding=1, bias=False
),
),
("bn2", norm_act(growth * 2)),
(
"conv3",
nn.Conv2d(
growth * 2, growth * 2, 3, stride=1, padding=1, bias=False
),
),
("pool", nn.MaxPool2d(3, stride=2, padding=1)),
]
else:
layers = [
("conv1", nn.Conv2d(3, growth * 2, 7, stride=2, padding=3, bias=False)),
("pool", nn.MaxPool2d(3, stride=2, padding=1)),
]
self.mod1 = nn.Sequential(OrderedDict(layers))
in_channels = growth * 2
for mod_id in range(4):
d = try_index(dilation, mod_id)
s = 2 if d == 1 and mod_id > 0 else 1
# Create transition module
if mod_id > 0:
out_channels = int(in_channels * theta)
layers = [
("bn", norm_act(in_channels)),
("conv", nn.Conv2d(in_channels, out_channels, 1, bias=False)),
]
if s == 2:
layers.append(("pool", nn.AvgPool2d(2, 2)))
self.add_module(
"tra%d" % (mod_id + 1), nn.Sequential(OrderedDict(layers))
)
in_channels = out_channels
# Create dense module
mod = DenseModule(
in_channels, growth, structure[mod_id], norm_act=norm_act, dilation=d
)
self.add_module("mod%d" % (mod_id + 2), mod)
in_channels = mod.out_channels
# Pooling and predictor
self.bn_out = norm_act(in_channels)
if classes != 0:
self.classifier = nn.Sequential(
OrderedDict(
[
("avg_pool", GlobalAvgPool2d()),
("fc", nn.Linear(in_channels, classes)),
]
)
)
def forward(self, x):
x = self.mod1(x)
x = self.mod2(x)
x = self.tra2(x)
x = self.mod3(x)
x = self.tra3(x)
x = self.mod4(x)
x = self.tra4(x)
x = self.mod5(x)
x = self.bn_out(x)
if hasattr(self, "classifier"):
x = self.classifier(x)
return x
_NETS = {
"121": {"structure": [6, 12, 24, 16]},
"169": {"structure": [6, 12, 32, 32]},
"201": {"structure": [6, 12, 48, 32]},
"264": {"structure": [6, 12, 64, 48]},
}
__all__ = []
for name, params in _NETS.items():
net_name = "net_densenet" + name
setattr(sys.modules[__name__], net_name, partial(DenseNet, **params))
__all__.append(net_name)