in src/model.py [0:0]
def build_layers(img_sz, img_fm, init_fm, max_fm, n_layers, n_attr, n_skip,
deconv_method, instance_norm, enc_dropout, dec_dropout):
"""
Build auto-encoder layers.
"""
assert init_fm <= max_fm
assert n_skip <= n_layers - 1
assert np.log2(img_sz).is_integer()
assert n_layers <= int(np.log2(img_sz))
assert type(instance_norm) is bool
assert 0 <= enc_dropout < 1
assert 0 <= dec_dropout < 1
norm_fn = nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d
enc_layers = []
dec_layers = []
n_in = img_fm
n_out = init_fm
for i in range(n_layers):
enc_layer = []
dec_layer = []
skip_connection = n_layers - (n_skip + 1) <= i < n_layers - 1
n_dec_in = n_out + n_attr + (n_out if skip_connection else 0)
n_dec_out = n_in
# encoder layer
enc_layer.append(nn.Conv2d(n_in, n_out, 4, 2, 1))
if i > 0:
enc_layer.append(norm_fn(n_out, affine=True))
enc_layer.append(nn.LeakyReLU(0.2, inplace=True))
if enc_dropout > 0:
enc_layer.append(nn.Dropout(enc_dropout))
# decoder layer
if deconv_method == 'upsampling':
dec_layer.append(nn.UpsamplingNearest2d(scale_factor=2))
dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out, 3, 1, 1))
elif deconv_method == 'convtranspose':
dec_layer.append(nn.ConvTranspose2d(n_dec_in, n_dec_out, 4, 2, 1, bias=False))
else:
assert deconv_method == 'pixelshuffle'
dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out * 4, 3, 1, 1))
dec_layer.append(nn.PixelShuffle(2))
if i > 0:
dec_layer.append(norm_fn(n_dec_out, affine=True))
if dec_dropout > 0 and i >= n_layers - 3:
dec_layer.append(nn.Dropout(dec_dropout))
dec_layer.append(nn.ReLU(inplace=True))
else:
dec_layer.append(nn.Tanh())
# update
n_in = n_out
n_out = min(2 * n_out, max_fm)
enc_layers.append(nn.Sequential(*enc_layer))
dec_layers.insert(0, nn.Sequential(*dec_layer))
return enc_layers, dec_layers