in classy_vision/tasks/classification_task.py [0:0]
def run_optimizer(self, loss):
"""Runs backwards pass and update the optimizer"""
self.check_inf_nan(loss)
# Gradient accumulation logic. We always set optimizer_period, even
# if gradient accumulation is disabled. Assumes all batches have the
# same size
update_idx = self.num_updates // self.get_global_batchsize()
do_zero_grad = (update_idx % self.optimizer_period) == 0
do_step = self._should_do_step()
if do_zero_grad:
self.optimizer.zero_grad()
if self.amp_type == AmpType.APEX:
with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss:
scaled_loss.backward()
elif self.amp_type == AmpType.PYTORCH:
self.amp_grad_scaler.scale(loss).backward()
else:
loss.backward()
if do_step:
# Handle gradient accumulation related gradient rescaling
if self.optimizer_period != 1:
self._rescale_gradients(1 / self.optimizer_period)
# Clipping must happen after grad accumulation
if self.clip_grad_norm is not None:
self._clip_gradients(self.clip_grad_norm)
if self.amp_type == AmpType.PYTORCH:
# If using mixed precision, handle underflow-related scaling
# See https://pytorch.org/docs/stable/amp.html#gradient-scaling
# for context
self.amp_grad_scaler.step(self.optimizer, where=self.where)
self.amp_grad_scaler.update()
else:
self.optimizer.step(where=self.where)