def __init__()

in models/wavenet.py [0:0]


    def __init__(self, config: Config) -> None:
        """
        Create a new WaveNet.
        """
        super().__init__(config)

        self.config = config
        self.scalar_input: bool = is_scalar_input(config.model.input_type)

        self.model = torch.nn.DataParallel(
            wavenet.WaveNet(
                out_channels=config.model.out_channels,
                layers=config.model.layers,
                stacks=config.model.stacks,
                residual_channels=config.model.residual_channels,
                gate_channels=config.model.gate_channels,
                skip_out_channels=config.model.skip_out_channels,
                cin_channels=config.model.cin_channels,
                gin_channels=config.model.gin_channels,
                n_speakers=config.model.n_speakers,
                dropout=config.model.dropout,
                kernel_size=config.model.kernel_size,
                cin_pad=config.model.cin_pad,
                upsample_conditional_features=config.model.upsample_conditional_features,
                upsample_params=config.model.upsample_params,
                scalar_input=self.scalar_input,
                output_distribution=config.model.output_distribution,
            )
        )

        self.optimizer: torch.optim.Optimizer = torch.optim.Adam(
            self.parameters(), lr=self.config.model.learning_rate
        )
        self.compand: torch.nn.Module = torchaudio.transforms.MuLawEncoding(
            config.model.quantize_channels
        )
        self.expand: torch.nn.Module = torchaudio.transforms.MuLawDecoding(
            config.model.quantize_channels
        )

        if is_mulaw_quantize(config.model.input_type):
            self.criterion: torch.nn.Module = torch.nn.CrossEntropyLoss()
        else:
            if config.model.output_distribution == "Logistic":
                self.criterion: torch.nn.Module = DiscretizedMixturelogisticLoss(config)
            elif config.model.output_distribution == "Normal":
                self.criterion: torch.nn.Module = MixtureGaussianLoss(config)
            else:
                raise RuntimeError(
                    "Not supported output distribution type: {}".format(
                        config.model.output_distribution
                    )
                )