in part_generator.py [0:0]
def __init__(self, image_size, latent_dim, network_capacity=16):
super().__init__()
self.image_size = image_size
self.latent_dim = latent_dim
self.num_layers = int(log2(image_size) - 1)
init_channels = 4 * network_capacity
self.initial_block = nn.Parameter(torch.randn((init_channels, 4, 4)))
filters = [init_channels] + [network_capacity * (2 ** (i + 1)) for i in range(self.num_layers)][::-1]
in_out_pairs = zip([ch+network_capacity*(2 ** (self.num_layers-1-i)) for i, ch in enumerate(filters[0:-1])], filters[1:])
self.blocks = nn.ModuleList([])
for ind, (in_chan, out_chan) in enumerate(in_out_pairs):
not_first = ind != 0
not_last = ind != (self.num_layers - 1)
block = GeneratorBlock(
latent_dim,
in_chan,
out_chan,
upsample = not_first,
upsample_rgb = not_last
)
self.blocks.append(block)