def step()

in apex/apex/optimizers/fp16_optimizer.py [0:0]


    def step(self, closure=None):
        """
        Not supporting closure.
        """
        # First compute norm for all group so we know if there is overflow
        grads_groups_flat = []
        norm_groups = []
        skip = False
        for i, group in enumerate(self.fp16_groups):
            grads_groups_flat.append(_flatten_dense_tensors([p.grad for p in group]))
            norm_groups.append(self._compute_grad_norm(grads_groups_flat[i]))
            if norm_groups[i] == -1: #TODO: early break
                skip = True

        if skip:
            self._update_scale(skip)
            return

        # norm is in fact norm*cur_scale
        self.optimizer.step(grads=[[g] for g in grads_groups_flat],
                            output_params=[[p] for p in self.fp16_groups_flat],
                            scale=self.cur_scale,
                            grad_norms=norm_groups)

        # TODO: we probably don't need this? just to be safe
        for i in range(len(norm_groups)):
            updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i])
            for p,q in zip(self.fp16_groups[i], updated_params):
                p.data = q.data

        self._update_scale(False)
        return