scripts/modules/dense.py (65 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. from collections import OrderedDict import torch import torch.nn as nn from inplace_abn import ABN class DenseModule(nn.Module): def __init__( self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1 ): super(DenseModule, self).__init__() self.in_channels = in_channels self.growth = growth self.layers = layers self.convs1 = nn.ModuleList() self.convs3 = nn.ModuleList() for i in range(self.layers): self.convs1.append( nn.Sequential( OrderedDict( [ ("bn", norm_act(in_channels)), ( "conv", nn.Conv2d( in_channels, self.growth * bottleneck_factor, 1, bias=False, ), ), ] ) ) ) self.convs3.append( nn.Sequential( OrderedDict( [ ("bn", norm_act(self.growth * bottleneck_factor)), ( "conv", nn.Conv2d( self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False, dilation=dilation, ), ), ] ) ) ) in_channels += self.growth @property def out_channels(self): return self.in_channels + self.growth * self.layers def forward(self, x): inputs = [x] for i in range(self.layers): x = torch.cat(inputs, dim=1) x = self.convs1[i](x) x = self.convs3[i](x) inputs += [x] return torch.cat(inputs, dim=1)