in utils/trainer.py [0:0]
def train_step(self, samples):
self.optimizer.zero_grad()
logging_outputs = []
for i, sample in enumerate(samples):
sample = self.prepare_sample(sample)
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.args.world_size > 1
and hasattr(self.model, 'no_sync')
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
with maybe_no_sync():
# forward and backward
sample_size = int(sample['sample_size'])
total_loss, logging_output = self.model(**sample)
avg_loss = total_loss / sample_size
self.optimizer.backward(avg_loss)
# logging_output = {
# 'sample_size': sample['sample_size'],
# 'total_loss': total_loss.item()
# }
logging_outputs.append(logging_output)
# gather logging outputs from all replicas
# if self.args.world_size > 1:
# logging_outputs = distributed_utils.all_gather_list(logging_outputs)
# logging_outputs = list(chain.from_iterable(logging_outputs))
# sample_size = sum(x['sample_size'] for x in logging_outputs)
# logging_output = {
# 'sample_size': sum(x['sample_size'] for x in logging_outputs),
# 'total_loss': sum(x['total_loss'] for x in logging_outputs),
# }
# avg_loss = logging_output['total_loss'] / logging_output['sample_size']
# avg_ppl = math.exp(avg_loss)
# logging_output.update({
# 'avg_loss': avg_loss,
# 'avg_ppl': avg_ppl
# })
avg_logging_output = {
k: np.average([x[k] for x in logging_outputs])
for k in logging_outputs[0]
}
if (
0 < self.args.empty_cache_freq <= self._num_updates and
self._num_updates % self.args.empty_cache_freq == 0 and
not self.args.cpu
):
torch.cuda.empty_cache()
try:
# self.optimizer.multiply_grads(self.args.world_size / float(sample_size))
# clip grads
# if self.args.clip_norm > 0.:
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
# logging.info(f'Iter {self._num_updates} Gradient Norm: {grad_norm}')
# take an optimization step
self.optimizer.step()
self.take_one_step()
except OverflowError as e:
logging.error('| WARNING: overflow detected, ' + str(e))
self.optimizer.zero_grad()
return avg_logging_output