in sparse_autoencoder/model.py [0:0]
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param x: input data (shape: [batch, n_inputs])
:return: autoencoder latents pre activation (shape: [batch, n_latents])
autoencoder latents (shape: [batch, n_latents])
reconstructed data (shape: [batch, n_inputs])
"""
x, info = self.preprocess(x)
latents_pre_act = self.encode_pre_act(x)
latents = self.activation(latents_pre_act)
recons = self.decode(latents, info)
# set all indices of self.stats_last_nonzero where (latents != 0) to 0
self.stats_last_nonzero *= (latents == 0).all(dim=0).long()
self.stats_last_nonzero += 1
return latents_pre_act, latents, recons