def __init__()

in models/parallel_wavegan.py [0:0]


    def __init__(self, config: Config) -> None:
        """
        Create a new ParallelWaveGAN.
        """
        super().__init__(config)

        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  # fix-me
        self.config: Config = remove_none_values_from_dict(
            OmegaConf.to_container(config)
        )

        # Model
        generator_class = getattr(
            models,
            self.config.model.generator_type,
        )
        discriminator_class = getattr(
            models,
            self.config.model.discriminator_type,
        )

        self.model = torch.nn.ModuleDict(
            {
                "generator": generator_class(**self.config.model.generator_params),
                "discriminator": discriminator_class(
                    **self.config.model.discriminator_params
                ),
            }
        )
        self.model["generator"] = torch.nn.DataParallel(self.model["generator"]).to(
            self.device
        )
        self.model["discriminator"] = torch.nn.DataParallel(
            self.model["discriminator"]
        ).to(self.device)

        # Optimizer
        generator_optimizer_class = getattr(
            optimizers,
            self.config.model.generator_optimizer_type,
        )
        discriminator_optimizer_class = getattr(
            optimizers,
            self.config.model.discriminator_optimizer_type,
        )

        self.optimizer: Dict[str, torch.optim.Optimizer] = {
            "generator": generator_optimizer_class(
                self.model["generator"].module.parameters(),
                **self.config.model.generator_optimizer,
            ),
            "discriminator": discriminator_optimizer_class(
                self.model["discriminator"].module.parameters(),
                **self.config.model.discriminator_optimizer,
            ),
        }

        # Scheduler
        generator_scheduler_class = getattr(
            torch.optim.lr_scheduler,
            self.config.model.generator_scheduler_type,
        )
        discriminator_scheduler_class = getattr(
            torch.optim.lr_scheduler,
            self.config.model.discriminator_scheduler_type,
        )

        self.scheduler: Dict[str, torch.optim.lr_scheduler._LRScheduler] = {
            "generator": generator_scheduler_class(
                optimizer=self.optimizer["generator"],
                **self.config.model.generator_scheduler_params,
            ),
            "discriminator": discriminator_scheduler_class(
                optimizer=self.optimizer["discriminator"],
                **self.config.model.discriminator_scheduler_params,
            ),
        }

        # Loss
        self.criterion: Dict[str, torch.nn.Module] = {
            "stft": MultiResolutionSTFTLoss(
                **self.config.model.stft_loss_params  # pyre-ignore
            ).to(self.device),
            "mse": torch.nn.MSELoss().to(self.device),
        }
        if self.config.model.use_feat_match_loss:
            self.criterion["l1"] = torch.nn.L1Loss().to(self.device)

        if self.config.model.generator_params.out_channels > 1:
            self.criterion["pqmf"] = PQMF(
                subbands=config.model.generator_params.out_channels
            ).to(self.device)
        if self.config.model.use_subband_stft_loss:
            assert self.config.model.generator_params.out_channels > 1
            self.criterion["sub_stft"] = MultiResolutionSTFTLoss(
                **self.config.model.subband_stft_loss_params  # pyre-ignore
            ).to(self.device)
        self.compand: torch.nn.Module = torchaudio.transforms.MuLawEncoding()
        self.expand: torch.nn.Module = torchaudio.transforms.MuLawDecoding()