in sparse_autoencoder/train.py [0:0]
def use_ema_weights(self):
assert self.ema_steps > 0
# apply bias correction
bias_correction = 1 - self.ema_multiplier**self.ema_steps
ema_weights_bias_corrected = torch._foreach_div(self.ema_weights, bias_correction)
with torch.no_grad():
with temporary_weight_swap(self.model, ema_weights_bias_corrected):
yield