def train_step()

in pretraining/fairseq/trainer.py [0:0]


    def train_step(self, samples, dummy_batch=False):
        """Do forward, backward and parameter update."""
        # Set seed based on args.seed and the update number so that we get
        # reproducible results when resuming from checkpoints
        seed = self.args.seed + self.get_num_updates()
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        self.model.train()
        self.zero_grad()

        if not dummy_batch:
            self.meters['train_wall'].start()

        # forward and backward pass
        logging_outputs, sample_sizes, ooms = [], [], 0
        for i, sample in enumerate(samples):
            sample = self._prepare_sample(sample)
            if sample is None:
                # when sample is None, run forward/backward on a dummy batch
                # and ignore the resulting gradients
                sample = self._prepare_sample(self._dummy_batch)
                ignore_grad = True
            else:
                ignore_grad = False

            try:
                # forward
                loss, sample_size, logging_output = self.task.get_loss(
                    self.model, self.criterion, sample,
                )
                if ignore_grad:
                    loss *= 0

                if self.args.distributed_world_size > 1:
                    # only all-reduce gradients in the last backwards pass
                    if i < len(samples) - 1:
                        self.model.need_reduction = False
                    else:
                        self.model.need_reduction = True

                # backward
                self.optimizer.backward(loss)

                if not ignore_grad:
                    logging_outputs.append(logging_output)
                    sample_sizes.append(sample_size)
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory, skipping batch')
                    ooms += 1
                    self.zero_grad()
                else:
                    raise e

        if dummy_batch:
            return None

        # gather logging outputs from all replicas
        if self.args.distributed_world_size > 1:
            logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
                [logging_outputs, sample_sizes, ooms],
            ))
            logging_outputs = list(chain.from_iterable(logging_outputs))
            sample_sizes = list(chain.from_iterable(sample_sizes))
            ooms = sum(ooms)

        if ooms == self.args.distributed_world_size:
            print('| WARNING: OOM in all workers, skipping update')
            self.zero_grad()
            return None

        # aggregate logging outputs and sample sizes
        logging_output = self.criterion._aggregate_logging_outputs(logging_outputs)
        sample_size = self.criterion.__class__.grad_denom(sample_sizes)

        if not all(k in logging_output for k in ['ntokens', 'nsentences']):
            raise Exception((
                'Please update the {}.aggregate_logging_outputs() method to '
                'return ntokens and nsentences'
            ).format(self.criterion.__class__.__name__))

        try:
            # normalize grads by sample size
            if self.no_sample_size_normalization:
                self.optimizer.multiply_grads(self.args.distributed_world_size)
            else:
                self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))

            # clip grads
            grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)

            # take an optimization step
            self.optimizer.step()
            self._num_updates += 1

            # update learning rate
            self.lr_scheduler.step_update(self._num_updates)

            # update meters
            ntokens = logging_output.get('ntokens', 0)
            nsentences = logging_output.get('nsentences', 0)
            self.meters['wps'].update(ntokens)
            self.meters['ups'].update(1.)
            self.meters['wpb'].update(ntokens)
            self.meters['bsz'].update(nsentences)
            self.meters['gnorm'].update(grad_norm)
            self.meters['clip'].update(
                1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
            )
            self.meters['oom'].update(ooms)
            self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
            if 'nll_loss' in logging_output:
                self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
        except OverflowError as e:
            print('| WARNING: overflow detected, ' + str(e))
            self.zero_grad()
            logging_output = None

        if self.args.fp16:
            self.meters['loss_scale'].reset()
            self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)

        self.meters['train_wall'].stop()

        return logging_output