in models/parallel_wavegan.py [0:0]
def __init__(self, config: Config) -> None:
"""
Create a new ParallelWaveGAN.
"""
super().__init__(config)
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" # fix-me
self.config: Config = remove_none_values_from_dict(
OmegaConf.to_container(config)
)
# Model
generator_class = getattr(
models,
self.config.model.generator_type,
)
discriminator_class = getattr(
models,
self.config.model.discriminator_type,
)
self.model = torch.nn.ModuleDict(
{
"generator": generator_class(**self.config.model.generator_params),
"discriminator": discriminator_class(
**self.config.model.discriminator_params
),
}
)
self.model["generator"] = torch.nn.DataParallel(self.model["generator"]).to(
self.device
)
self.model["discriminator"] = torch.nn.DataParallel(
self.model["discriminator"]
).to(self.device)
# Optimizer
generator_optimizer_class = getattr(
optimizers,
self.config.model.generator_optimizer_type,
)
discriminator_optimizer_class = getattr(
optimizers,
self.config.model.discriminator_optimizer_type,
)
self.optimizer: Dict[str, torch.optim.Optimizer] = {
"generator": generator_optimizer_class(
self.model["generator"].module.parameters(),
**self.config.model.generator_optimizer,
),
"discriminator": discriminator_optimizer_class(
self.model["discriminator"].module.parameters(),
**self.config.model.discriminator_optimizer,
),
}
# Scheduler
generator_scheduler_class = getattr(
torch.optim.lr_scheduler,
self.config.model.generator_scheduler_type,
)
discriminator_scheduler_class = getattr(
torch.optim.lr_scheduler,
self.config.model.discriminator_scheduler_type,
)
self.scheduler: Dict[str, torch.optim.lr_scheduler._LRScheduler] = {
"generator": generator_scheduler_class(
optimizer=self.optimizer["generator"],
**self.config.model.generator_scheduler_params,
),
"discriminator": discriminator_scheduler_class(
optimizer=self.optimizer["discriminator"],
**self.config.model.discriminator_scheduler_params,
),
}
# Loss
self.criterion: Dict[str, torch.nn.Module] = {
"stft": MultiResolutionSTFTLoss(
**self.config.model.stft_loss_params # pyre-ignore
).to(self.device),
"mse": torch.nn.MSELoss().to(self.device),
}
if self.config.model.use_feat_match_loss:
self.criterion["l1"] = torch.nn.L1Loss().to(self.device)
if self.config.model.generator_params.out_channels > 1:
self.criterion["pqmf"] = PQMF(
subbands=config.model.generator_params.out_channels
).to(self.device)
if self.config.model.use_subband_stft_loss:
assert self.config.model.generator_params.out_channels > 1
self.criterion["sub_stft"] = MultiResolutionSTFTLoss(
**self.config.model.subband_stft_loss_params # pyre-ignore
).to(self.device)
self.compand: torch.nn.Module = torchaudio.transforms.MuLawEncoding()
self.expand: torch.nn.Module = torchaudio.transforms.MuLawDecoding()