in models/src/wavegrad/nn.py [0:0]
def __init__(self, config):
super(WaveGradNN, self).__init__()
# Building upsampling branch (mels -> signal)
self.ublock_preconv = Conv1dWithInitialization(
in_channels=MEL_NUM_BANDS,
out_channels=config.model.upsampling_preconv_out_channels,
kernel_size=3,
stride=1,
padding=1,
)
upsampling_in_sizes = [
config.model.upsampling_preconv_out_channels
] + config.model.upsampling_out_channels[:-1]
self.ublocks = torch.nn.ModuleList(
[
UBlock(
in_channels=in_size,
out_channels=out_size,
factor=factor,
dilations=dilations,
)
for in_size, out_size, factor, dilations in zip(
upsampling_in_sizes,
config.model.upsampling_out_channels,
config.model.factors,
config.model.upsampling_dilations,
)
]
)
self.ublock_postconv = Conv1dWithInitialization(
in_channels=config.model.upsampling_out_channels[-1],
out_channels=1,
kernel_size=3,
stride=1,
padding=1,
)
# Building downsampling branch (starting from signal)
self.dblock_preconv = Conv1dWithInitialization(
in_channels=1,
out_channels=config.model.downsampling_preconv_out_channels,
kernel_size=5,
stride=1,
padding=2,
)
downsampling_in_sizes = [
config.model.downsampling_preconv_out_channels
] + config.model.downsampling_out_channels[:-1]
self.dblocks = torch.nn.ModuleList(
[
DBlock(
in_channels=in_size,
out_channels=out_size,
factor=factor,
dilations=dilations,
)
for in_size, out_size, factor, dilations in zip(
downsampling_in_sizes,
config.model.downsampling_out_channels,
config.model.factors[1:][::-1],
config.model.downsampling_dilations,
)
]
)
# Building FiLM connections (in order of downscaling stream)
film_in_sizes = [32] + list(config.model.downsampling_out_channels)
film_out_sizes = list(config.model.upsampling_out_channels[::-1])
film_factors = [1] + list(config.model.factors[1:][::-1])
self.films = torch.nn.ModuleList(
[
FiLM(
in_channels=in_size,
out_channels=out_size,
input_dscaled_by=np.product(
film_factors[: i + 1]
), # for proper positional encodings initialization
)
for i, (in_size, out_size) in enumerate(
zip(film_in_sizes, film_out_sizes)
)
]
)