in pretraining/fairseq/trainer.py [0:0]
def train_step(self, samples, dummy_batch=False):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
self.model.train()
self.zero_grad()
if not dummy_batch:
self.meters['train_wall'].start()
# forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
ignore_grad = True
else:
ignore_grad = False
try:
# forward
loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
if ignore_grad:
loss *= 0
if self.args.distributed_world_size > 1:
# only all-reduce gradients in the last backwards pass
if i < len(samples) - 1:
self.model.need_reduction = False
else:
self.model.need_reduction = True
# backward
self.optimizer.backward(loss)
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
ooms += 1
self.zero_grad()
else:
raise e
if dummy_batch:
return None
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms],
))
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
if ooms == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping update')
self.zero_grad()
return None
# aggregate logging outputs and sample sizes
logging_output = self.criterion._aggregate_logging_outputs(logging_outputs)
sample_size = self.criterion.__class__.grad_denom(sample_sizes)
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences'
).format(self.criterion.__class__.__name__))
try:
# normalize grads by sample size
if self.no_sample_size_normalization:
self.optimizer.multiply_grads(self.args.distributed_world_size)
else:
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
# take an optimization step
self.optimizer.step()
self._num_updates += 1
# update learning rate
self.lr_scheduler.step_update(self._num_updates)
# update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
self.meters['oom'].update(ooms)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e))
self.zero_grad()
logging_output = None
if self.args.fp16:
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
self.meters['train_wall'].stop()
return logging_output