in XLM/src/trainer.py [0:0]
def optimize(self, loss):
"""
Optimize.
"""
# check NaN
if (loss != loss).data.any():
logger.warning("NaN detected")
# exit()
params = self.params
# optimizers
names = self.optimizers.keys()
optimizers = [self.optimizers[k] for k in names]
# regular optimization
if params.amp == -1:
for optimizer in optimizers:
optimizer.zero_grad()
loss.backward()
if params.clip_grad_norm > 0:
for name in names:
# norm_check_a = (sum([p.grad.norm(p=2).item() ** 2 for p in self.parameters[name]])) ** 0.5
clip_grad_norm_(
self.parameters[name], params.clip_grad_norm)
# norm_check_b = (sum([p.grad.norm(p=2).item() ** 2 for p in self.parameters[name]])) ** 0.5
# print(name, norm_check_a, norm_check_b)
for optimizer in optimizers:
optimizer.step()
# AMP optimization
else:
if self.n_iter % params.accumulate_gradients == 0:
with apex.amp.scale_loss(loss, optimizers) as scaled_loss:
scaled_loss.backward()
if params.clip_grad_norm > 0:
for name in names:
# norm_check_a = (sum([p.grad.norm(p=2).item() ** 2 for p in apex.amp.master_params(self.optimizers[name])])) ** 0.5
clip_grad_norm_(apex.amp.master_params(
self.optimizers[name]), params.clip_grad_norm)
# norm_check_b = (sum([p.grad.norm(p=2).item() ** 2 for p in apex.amp.master_params(self.optimizers[name])])) ** 0.5
# print(name, norm_check_a, norm_check_b)
for optimizer in optimizers:
optimizer.step()
optimizer.zero_grad()
else:
with apex.amp.scale_loss(loss, optimizers, delay_unscale=True) as scaled_loss:
scaled_loss.backward()