in fairseq/trainer.py [0:0]
def train_step(self, samples, raise_oom=False):
"""Do forward, backward and parameter update."""
if self._dummy_batch == "DUMMY":
self._dummy_batch = samples[0]
self._set_seed()
self.model.train()
self.criterion.train()
self.zero_grad()
metrics.log_start_time("train_wall", priority=800, round=0)
# forward and backward pass
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples):
sample = self._prepare_sample(sample)
if sample is None:
# when sample is None, run forward/backward on a dummy batch
# and ignore the resulting gradients
sample = self._prepare_sample(self._dummy_batch)
is_dummy_batch = True
else:
is_dummy_batch = False
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.args.distributed_world_size > 1
and hasattr(self.model, "no_sync")
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try:
with maybe_no_sync():
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
criterion=self.criterion,
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
)
del loss
logging_outputs.append(logging_output)
sample_size += sample_size_i
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
if self.cuda and self.get_num_updates() == 0:
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
if raise_oom:
raise e
logger.warning(
"attempting to recover from OOM in forward/backward pass"
)
ooms += 1
self.zero_grad()
else:
raise e
if is_dummy_batch:
sample_size *= 0. # multiply by 0 to preserve device
if torch.is_tensor(sample_size):
sample_size = sample_size.float()
else:
sample_size = float(sample_size)
# gather logging outputs from all replicas
if self._sync_stats():
logging_outputs, (sample_size, ooms) = self._aggregate_logging_outputs(
logging_outputs, sample_size, ooms, ignore=is_dummy_batch,
)
try:
# multiply gradients by (# GPUs / sample_size) since DDP
# already normalizes by the number of GPUs. Thus we get
# (sum_of_gradients / sample_size).
if not self.args.use_bmuf:
self.optimizer.multiply_grads(
self.args.distributed_world_size / sample_size
)
elif sample_size > 0: # BMUF needs to check sample size
num = self.args.distributed_world_size if self._sync_stats() else 1
self.optimizer.multiply_grads(num / sample_size)
# clip grads
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
# check that grad norms are consistent across workers
if not self.args.use_bmuf:
self._check_grad_norms(grad_norm)
# take an optimization step
self.optimizer.step()
self.set_num_updates(self.get_num_updates() + 1)
# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs, sample_size, grad_norm,
)
# clear CUDA cache to reduce memory fragmentation
if (
self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
) == 0
and torch.cuda.is_available()
and not self.args.cpu
):
torch.cuda.empty_cache()
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print out where it fails
with NanDetector(self.model):
self.task.train_step(
sample, self.model, self.criterion, self.optimizer, self.get_num_updates(),
ignore_grad=False
)
raise
except OverflowError as e:
logger.info("NOTE: overflow detected, " + str(e))
self.zero_grad()
logging_output = None
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
logger.error("OOM during optimization, irrecoverable")
raise e
if self.args.fp16:
metrics.log_scalar("loss_scale", self.optimizer.scaler.loss_scale, priority=700, round=0)
metrics.log_stop_time("train_wall")
return logging_output