in models/wavenet.py [0:0]
def loss(self, spectrograms: Tensor, waveforms: Tensor) -> Tensor:
"""
Compute loss on a batch.
Returns:
The negative log likelihood loss.
"""
# Forward pass.
target = waveforms[:, 1:] # [batch_size, n_samples-1]
if self.config.model.input_type in ["mulaw", "mulaw-quantize"]:
waveforms = self.compand(waveforms) # [batch_size, n_samples]
target = self.compand(target)
if self.config.model.input_type == "mulaw":
waveforms = self.label_2_float(
waveforms, self.config.model.quantize_channels
)
waveforms = waveforms.unsqueeze(1)
target = self.label_2_float(target, self.config.model.quantize_channels)
else:
waveforms = F.one_hot(waveforms, self.config.model.quantize_channels)
waveforms = waveforms.transpose(1, 2).float()
elif self.config.model.input_type == "raw":
waveforms = waveforms.unsqueeze(1)
else:
raise RuntimeError(
"Not supported input type: {}".format(self.config.model.input_type)
)
output = self.model(waveforms, c=spectrograms)
if self.config.model.input_type in ["mulaw", "raw"]:
target = target.unsqueeze(2) # [batch_size, n_samples-1, 1]
loss = self.criterion(output[:, :, :-1], target)
return loss