def temporary_weight_swap()

in sparse_autoencoder/train.py [0:0]


def temporary_weight_swap(model: torch.nn.Module, new_weights: list[torch.Tensor]):
    for _p, new_p in zip(model.parameters(), new_weights, strict=True):
        assert _p.shape == new_p.shape
        _p.data, new_p.data = new_p.data, _p.data

    yield

    for _p, new_p in zip(model.parameters(), new_weights, strict=True):
        assert _p.shape == new_p.shape
        _p.data, new_p.data = new_p.data, _p.data