in sparse_autoencoder/train.py [0:0]
def temporary_weight_swap(model: torch.nn.Module, new_weights: list[torch.Tensor]):
for _p, new_p in zip(model.parameters(), new_weights, strict=True):
assert _p.shape == new_p.shape
_p.data, new_p.data = new_p.data, _p.data
yield
for _p, new_p in zip(model.parameters(), new_weights, strict=True):
assert _p.shape == new_p.shape
_p.data, new_p.data = new_p.data, _p.data