def train_step()

in custom/language_modeling_with_generation.py [0:0]


    def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
        model.train()

        do_mle_step=True
        # -- sequence level training
        if torch.rand(1).item() < self._sequence_level_train_rate:
            # check if current minibatch has at least one legal prefix, if not, do CE loss
            if sample['net_input']['src_tokens'].size(1) >= self.sequence_criterion.sequence_prefix_length:
                do_mle_step = False

                loss, sample_size, logging_output = self.sequence_criterion(model, sample,
                                                                            generator=self.generator)
                if ignore_grad:
                    loss *= 0
                optimizer.backward(loss)

        # -- normal training
        if do_mle_step:
            compute_custom_metrics = self._train_step % self._compute_metrics_interval == 0
            loss, sample_size, logging_output = criterion(model, sample, compute_custom_metrics=compute_custom_metrics)
            if ignore_grad:
                loss *= 0
            optimizer.backward(loss)

            # only track this for normal training steps, since sequence training always computes it own metrics.
            self._train_step += 1
        return loss, sample_size, logging_output