def loss()

in models/wavenet.py [0:0]


    def loss(self, spectrograms: Tensor, waveforms: Tensor) -> Tensor:
        """
        Compute loss on a batch.

        Returns:
          The negative log likelihood loss.
        """

        # Forward pass.

        target = waveforms[:, 1:]  # [batch_size, n_samples-1]

        if self.config.model.input_type in ["mulaw", "mulaw-quantize"]:
            waveforms = self.compand(waveforms)  # [batch_size, n_samples]
            target = self.compand(target)

            if self.config.model.input_type == "mulaw":
                waveforms = self.label_2_float(
                    waveforms, self.config.model.quantize_channels
                )
                waveforms = waveforms.unsqueeze(1)

                target = self.label_2_float(target, self.config.model.quantize_channels)
            else:
                waveforms = F.one_hot(waveforms, self.config.model.quantize_channels)
                waveforms = waveforms.transpose(1, 2).float()

        elif self.config.model.input_type == "raw":
            waveforms = waveforms.unsqueeze(1)
        else:
            raise RuntimeError(
                "Not supported input type: {}".format(self.config.model.input_type)
            )

        output = self.model(waveforms, c=spectrograms)
        if self.config.model.input_type in ["mulaw", "raw"]:
            target = target.unsqueeze(2)  # [batch_size, n_samples-1, 1]
        loss = self.criterion(output[:, :, :-1], target)

        return loss