in lib/models/wgangp.py [0:0]
def __init__(self, num_filters, resample=None, batchnorm=True, inplace=False):
super(ResBlock, self).__init__()
if resample == 'up':
conv_list = [nn.ConvTranspose2d(num_filters, num_filters, 4, stride=2, padding=1),
nn.Conv2d(num_filters, num_filters, 3, padding=1)]
self.conv_shortcut = nn.ConvTranspose2d(num_filters, num_filters, 1, stride=2, output_padding=1)
elif resample == 'down':
conv_list = [nn.Conv2d(num_filters, num_filters, 3, padding=1),
nn.Conv2d(num_filters, num_filters, 3, stride=2, padding=1)]
self.conv_shortcut = nn.Conv2d(num_filters, num_filters, 1, stride=2)
elif resample == None:
conv_list = [nn.Conv2d(num_filters, num_filters, 3, padding=1),
nn.Conv2d(num_filters, num_filters, 3, padding=1)]
self.conv_shortcut = None
else:
raise ValueError('Invalid resample value.')
self.block = []
for conv in conv_list:
if batchnorm:
self.block.append(nn.BatchNorm2d(num_filters))
self.block.append(nn.ReLU(inplace))
self.block.append(conv)
self.block = nn.Sequential(*self.block)