def train_step()

in utils/trainer.py [0:0]


    def train_step(self, samples):
        self.optimizer.zero_grad()
        logging_outputs = []

        for i, sample in enumerate(samples):
            sample = self.prepare_sample(sample)

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (
                    self.args.world_size > 1
                    and hasattr(self.model, 'no_sync')
                    and i < len(samples) - 1
                ):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            with maybe_no_sync():
                # forward and backward
                sample_size = int(sample['sample_size'])
                total_loss, logging_output = self.model(**sample)

                avg_loss = total_loss / sample_size
                self.optimizer.backward(avg_loss)

                # logging_output = {
                #     'sample_size': sample['sample_size'],
                #     'total_loss': total_loss.item()
                # }
                logging_outputs.append(logging_output)

        # gather logging outputs from all replicas
        # if self.args.world_size > 1:
        #     logging_outputs = distributed_utils.all_gather_list(logging_outputs)
        #     logging_outputs = list(chain.from_iterable(logging_outputs))

        # sample_size = sum(x['sample_size'] for x in logging_outputs)
        # logging_output = {
        #     'sample_size': sum(x['sample_size'] for x in logging_outputs),
        #     'total_loss': sum(x['total_loss'] for x in logging_outputs),
        # }
        # avg_loss = logging_output['total_loss'] / logging_output['sample_size']
        # avg_ppl = math.exp(avg_loss)
        # logging_output.update({
        #     'avg_loss': avg_loss,
        #     'avg_ppl': avg_ppl
        # })

        avg_logging_output = {
            k: np.average([x[k] for x in logging_outputs])
            for k in logging_outputs[0]
        }

        if (
            0 < self.args.empty_cache_freq <= self._num_updates and
            self._num_updates % self.args.empty_cache_freq == 0 and
            not self.args.cpu
        ):
            torch.cuda.empty_cache()

        try:
            # self.optimizer.multiply_grads(self.args.world_size / float(sample_size))

            # clip grads
            # if self.args.clip_norm > 0.:
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
            # logging.info(f'Iter {self._num_updates} Gradient Norm: {grad_norm}')

            # take an optimization step
            self.optimizer.step()
            self.take_one_step()
        except OverflowError as e:
            logging.error('| WARNING: overflow detected, ' + str(e))
            self.optimizer.zero_grad()

        return avg_logging_output