in models/wavenet.py [0:0]
def __init__(self, config: Config) -> None:
"""
Create a new WaveNet.
"""
super().__init__(config)
self.config = config
self.scalar_input: bool = is_scalar_input(config.model.input_type)
self.model = torch.nn.DataParallel(
wavenet.WaveNet(
out_channels=config.model.out_channels,
layers=config.model.layers,
stacks=config.model.stacks,
residual_channels=config.model.residual_channels,
gate_channels=config.model.gate_channels,
skip_out_channels=config.model.skip_out_channels,
cin_channels=config.model.cin_channels,
gin_channels=config.model.gin_channels,
n_speakers=config.model.n_speakers,
dropout=config.model.dropout,
kernel_size=config.model.kernel_size,
cin_pad=config.model.cin_pad,
upsample_conditional_features=config.model.upsample_conditional_features,
upsample_params=config.model.upsample_params,
scalar_input=self.scalar_input,
output_distribution=config.model.output_distribution,
)
)
self.optimizer: torch.optim.Optimizer = torch.optim.Adam(
self.parameters(), lr=self.config.model.learning_rate
)
self.compand: torch.nn.Module = torchaudio.transforms.MuLawEncoding(
config.model.quantize_channels
)
self.expand: torch.nn.Module = torchaudio.transforms.MuLawDecoding(
config.model.quantize_channels
)
if is_mulaw_quantize(config.model.input_type):
self.criterion: torch.nn.Module = torch.nn.CrossEntropyLoss()
else:
if config.model.output_distribution == "Logistic":
self.criterion: torch.nn.Module = DiscretizedMixturelogisticLoss(config)
elif config.model.output_distribution == "Normal":
self.criterion: torch.nn.Module = MixtureGaussianLoss(config)
else:
raise RuntimeError(
"Not supported output distribution type: {}".format(
config.model.output_distribution
)
)