models/wavenet.py [270:298]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return loss

    def train_step(
        self, spectrograms: Tensor, waveforms: Tensor
    ) -> Tuple[Tensor, Dict[str, Tensor]]:
        """
        Runs a single train step of the model.

        Returns:
          A tuple containing overall model loss and a list of losses to log
          to Tensorboard. The first loss is printed to the console and logged
          to Tensorboard.
        """

        # Forward pass.
        loss = self.loss(spectrograms, waveforms)

        # Learning rate schedule
        if self.config.model.lr_schedule:
            current_lr = self.config.model.learning_rate
            lr_schedule_fn = getattr(lrschedule, self.config.model.lr_schedule)
            lr_schedule_kwargs = remove_none_values_from_dict(
                OmegaConf.to_container(self.config.model.lr_schedule_kwargs)
            )
            current_lr = lr_schedule_fn(
                current_lr, self.global_step, **lr_schedule_kwargs
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = current_lr
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



models/wavernn.py [177:204]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return loss

    def train_step(
        self, spectrograms: Tensor, waveforms: Tensor
    ) -> Tuple[Tensor, Dict[str, Tensor]]:
        """
        Runs a single train step of the model.

        Returns:
          A tuple containing overall model loss and a list of losses to log
          to Tensorboard. The first loss is printed to the console and logged
          to Tensorboard.
        """
        # Forward pass.
        loss = self.loss(spectrograms, waveforms)

        # Learning rate schedule
        if self.config.model.lr_schedule:
            current_lr = self.config.model.learning_rate
            lr_schedule_fn = getattr(lrschedule, self.config.model.lr_schedule)
            lr_schedule_kwargs = remove_none_values_from_dict(
                OmegaConf.to_container(self.config.model.lr_schedule_kwargs)
            )
            current_lr = lr_schedule_fn(
                current_lr, self.global_step, **lr_schedule_kwargs
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = current_lr
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



