scripts/modules/dense.py (65 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
from collections import OrderedDict
import torch
import torch.nn as nn
from inplace_abn import ABN
class DenseModule(nn.Module):
def __init__(
self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1
):
super(DenseModule, self).__init__()
self.in_channels = in_channels
self.growth = growth
self.layers = layers
self.convs1 = nn.ModuleList()
self.convs3 = nn.ModuleList()
for i in range(self.layers):
self.convs1.append(
nn.Sequential(
OrderedDict(
[
("bn", norm_act(in_channels)),
(
"conv",
nn.Conv2d(
in_channels,
self.growth * bottleneck_factor,
1,
bias=False,
),
),
]
)
)
)
self.convs3.append(
nn.Sequential(
OrderedDict(
[
("bn", norm_act(self.growth * bottleneck_factor)),
(
"conv",
nn.Conv2d(
self.growth * bottleneck_factor,
self.growth,
3,
padding=dilation,
bias=False,
dilation=dilation,
),
),
]
)
)
)
in_channels += self.growth
@property
def out_channels(self):
return self.in_channels + self.growth * self.layers
def forward(self, x):
inputs = [x]
for i in range(self.layers):
x = torch.cat(inputs, dim=1)
x = self.convs1[i](x)
x = self.convs3[i](x)
inputs += [x]
return torch.cat(inputs, dim=1)