def run_optimizer()

in classy_vision/tasks/classification_task.py [0:0]


    def run_optimizer(self, loss):
        """Runs backwards pass and update the optimizer"""

        self.check_inf_nan(loss)

        # Gradient accumulation logic. We always set optimizer_period, even
        # if gradient accumulation is disabled. Assumes all batches have the
        # same size
        update_idx = self.num_updates // self.get_global_batchsize()
        do_zero_grad = (update_idx % self.optimizer_period) == 0
        do_step = self._should_do_step()

        if do_zero_grad:
            self.optimizer.zero_grad()

        if self.amp_type == AmpType.APEX:
            with apex.amp.scale_loss(loss, self.optimizer.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.amp_type == AmpType.PYTORCH:
            self.amp_grad_scaler.scale(loss).backward()
        else:
            loss.backward()

        if do_step:
            # Handle gradient accumulation related gradient rescaling
            if self.optimizer_period != 1:
                self._rescale_gradients(1 / self.optimizer_period)

            # Clipping must happen after grad accumulation
            if self.clip_grad_norm is not None:
                self._clip_gradients(self.clip_grad_norm)

            if self.amp_type == AmpType.PYTORCH:
                # If using mixed precision, handle underflow-related scaling
                # See https://pytorch.org/docs/stable/amp.html#gradient-scaling
                # for context
                self.amp_grad_scaler.step(self.optimizer, where=self.where)
                self.amp_grad_scaler.update()
            else:
                self.optimizer.step(where=self.where)