def use_ema_weights()

in sparse_autoencoder/train.py [0:0]


    def use_ema_weights(self):
        assert self.ema_steps > 0

        # apply bias correction
        bias_correction = 1 - self.ema_multiplier**self.ema_steps
        ema_weights_bias_corrected = torch._foreach_div(self.ema_weights, bias_correction)

        with torch.no_grad():
            with temporary_weight_swap(self.model, ema_weights_bias_corrected):
                yield