in fairseq/trainer.py [0:0]
def train_step(self, samples, dummy_batch=False, raise_oom=False):
"""Do forward, backward and parameter update."""
if self._dummy_batch is None:
self._dummy_batch = samples[0]
self._set_seed()
self.model.train()
self.criterion.train()
self.zero_grad()
if not dummy_batch:
self.meters['train_wall'].start()
# forward and backward pass
logging_outputs, sample_sizes, ooms = [], [], 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
ignore_grad = True
else:
ignore_grad = False
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.args.distributed_world_size > 1
and hasattr(self.model, 'no_sync')
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try:
with maybe_no_sync():
# forward and backward
loss, sample_size, logging_output = self.task.train_step(
sample, self.model, self.criterion, self.optimizer,
ignore_grad
)
if not ignore_grad:
logging_outputs.append(logging_output)
sample_sizes.append(sample_size)
except RuntimeError as e:
if 'out of memory' in str(e):
msg = (
'| WARNING: ran out of memory with exception: '
+ '{};'.format(e)
+ '\n Skipping batch'
)
# TODO: print should really go to logger, this print goes
# to stdout, which is buffered, which in many case is not
# printed out if another exception happens
# print(msg)
print(msg, file=sys.stderr)
if raise_oom:
raise ValueError(msg)
ooms += 1
self.zero_grad()
else:
raise e
if ooms > 0 and self._oom_batch is not None:
self.handle_ooms(ooms)
if dummy_batch:
return None
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1 and (
(not self.args.use_bmuf)
or (
self.args.use_bmuf
and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0
)
):
logging_outputs, sample_sizes, ooms, prev_norms = \
zip(*distributed_utils.all_gather_list(
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
))
logging_outputs = list(chain.from_iterable(logging_outputs))
sample_sizes = list(chain.from_iterable(sample_sizes))
ooms = sum(ooms)
if not self.args.use_bmuf:
assert (
all(norm == prev_norms[0] for norm in prev_norms)
or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms)
), 'Fatal error: gradients are inconsistent between workers'
self.meters['oom'].update(ooms, len(samples))
if ooms == self.args.distributed_world_size * len(samples):
print('| WARNING: OOM in all workers, skipping update')
self.zero_grad()
return None
# aggregate logging outputs and sample sizes
logging_output = self.task.aggregate_logging_outputs(
logging_outputs, self.criterion
)
sample_size = self.task.grad_denom(sample_sizes, self.criterion)
if not all(k in logging_output for k in ['ntokens', 'nsentences']):
raise Exception((
'Please update the {}.aggregate_logging_outputs() method to '
'return ntokens and nsentences'
).format(self.task.__class__.__name__))
try:
# normalize grads by sample size
if sample_size > 0:
self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size))
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
self._prev_grad_norm = grad_norm
# take an optimization step
self.optimizer.step()
self.set_num_updates(self.get_num_updates() + 1)
# task specific update per step
self.task.update_step(self._num_updates)
# update meters
ntokens = logging_output.get('ntokens', 0)
nsentences = logging_output.get('nsentences', 0)
self.meters['wps'].update(ntokens)
self.meters['ups'].update(1.)
self.meters['wpb'].update(ntokens)
self.meters['bsz'].update(nsentences)
self.meters['gnorm'].update(grad_norm)
self.meters['clip'].update(
1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.
)
self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size)
if 'train_acc' in self.meters:
self.meters['train_acc'].update(
logging_output.get('acc', 0), sample_size)
if 'nll_loss' in logging_output:
self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens)
except OverflowError as e:
print('| WARNING: overflow detected, ' + str(e))
self.zero_grad()
logging_output = None
if self.args.fp16:
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale)
self.meters['train_wall'].stop()
return logging_output