in custom/language_modeling_with_generation.py [0:0]
def train_step(self, sample, model, criterion, optimizer, ignore_grad=False):
model.train()
do_mle_step=True
# -- sequence level training
if torch.rand(1).item() < self._sequence_level_train_rate:
# check if current minibatch has at least one legal prefix, if not, do CE loss
if sample['net_input']['src_tokens'].size(1) >= self.sequence_criterion.sequence_prefix_length:
do_mle_step = False
loss, sample_size, logging_output = self.sequence_criterion(model, sample,
generator=self.generator)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
# -- normal training
if do_mle_step:
compute_custom_metrics = self._train_step % self._compute_metrics_interval == 0
loss, sample_size, logging_output = criterion(model, sample, compute_custom_metrics=compute_custom_metrics)
if ignore_grad:
loss *= 0
optimizer.backward(loss)
# only track this for normal training steps, since sequence training always computes it own metrics.
self._train_step += 1
return loss, sample_size, logging_output