def __init__()

in models/src/wavegrad/nn.py [0:0]


    def __init__(self, config):
        super(WaveGradNN, self).__init__()
        # Building upsampling branch (mels -> signal)
        self.ublock_preconv = Conv1dWithInitialization(
            in_channels=MEL_NUM_BANDS,
            out_channels=config.model.upsampling_preconv_out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        )
        upsampling_in_sizes = [
            config.model.upsampling_preconv_out_channels
        ] + config.model.upsampling_out_channels[:-1]
        self.ublocks = torch.nn.ModuleList(
            [
                UBlock(
                    in_channels=in_size,
                    out_channels=out_size,
                    factor=factor,
                    dilations=dilations,
                )
                for in_size, out_size, factor, dilations in zip(
                    upsampling_in_sizes,
                    config.model.upsampling_out_channels,
                    config.model.factors,
                    config.model.upsampling_dilations,
                )
            ]
        )
        self.ublock_postconv = Conv1dWithInitialization(
            in_channels=config.model.upsampling_out_channels[-1],
            out_channels=1,
            kernel_size=3,
            stride=1,
            padding=1,
        )

        # Building downsampling branch (starting from signal)
        self.dblock_preconv = Conv1dWithInitialization(
            in_channels=1,
            out_channels=config.model.downsampling_preconv_out_channels,
            kernel_size=5,
            stride=1,
            padding=2,
        )
        downsampling_in_sizes = [
            config.model.downsampling_preconv_out_channels
        ] + config.model.downsampling_out_channels[:-1]
        self.dblocks = torch.nn.ModuleList(
            [
                DBlock(
                    in_channels=in_size,
                    out_channels=out_size,
                    factor=factor,
                    dilations=dilations,
                )
                for in_size, out_size, factor, dilations in zip(
                    downsampling_in_sizes,
                    config.model.downsampling_out_channels,
                    config.model.factors[1:][::-1],
                    config.model.downsampling_dilations,
                )
            ]
        )
        # Building FiLM connections (in order of downscaling stream)
        film_in_sizes = [32] + list(config.model.downsampling_out_channels)
        film_out_sizes = list(config.model.upsampling_out_channels[::-1])
        film_factors = [1] + list(config.model.factors[1:][::-1])
        self.films = torch.nn.ModuleList(
            [
                FiLM(
                    in_channels=in_size,
                    out_channels=out_size,
                    input_dscaled_by=np.product(
                        film_factors[: i + 1]
                    ),  # for proper positional encodings initialization
                )
                for i, (in_size, out_size) in enumerate(
                    zip(film_in_sizes, film_out_sizes)
                )
            ]
        )