def forward()

in sparse_autoencoder/train.py [0:0]


    def forward(self, x):
        class EncWrapper(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x, pre_bias, weight, latent_bias):
                x = x - pre_bias
                latents_pre_act = F.linear(x, weight, latent_bias)

                inds, vals = sharded_topk(
                    latents_pre_act,
                    k=self.k,
                    sh_comm=self.comms.sh_comm,
                    capacity_factor=4,
                )

                ## set num nonzero stat ##
                tmp = torch.zeros_like(self.stats_last_nonzero)
                tmp.scatter_add_(
                    0,
                    inds.reshape(-1),
                    (vals > 1e-3).to(tmp.dtype).reshape(-1),
                )
                self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
                self.stats_last_nonzero += 1
                ## end stats ##

                ## auxk
                if self.auxk is not None:  # for auxk
                    # IMPORTANT: has to go after stats update!
                    # WARN: auxk_mask_fn can mutate latents_pre_act!
                    auxk_inds, auxk_vals = sharded_topk(
                        self.auxk_mask_fn(latents_pre_act),
                        k=self.auxk,
                        sh_comm=self.comms.sh_comm,
                        capacity_factor=2,
                    )
                    ctx.save_for_backward(x, weight, inds, auxk_inds)
                else:
                    ctx.save_for_backward(x, weight, inds)
                    auxk_inds = None
                    auxk_vals = None

                ## end auxk

                return (
                    inds,
                    vals,
                    auxk_inds,
                    auxk_vals,
                )

            @staticmethod
            def backward(ctx, _, grad_vals, __, grad_auxk_vals):
                # encoder backwards
                if self.auxk is not None:
                    x, weight, inds, auxk_inds = ctx.saved_tensors

                    all_inds = torch.cat((inds, auxk_inds), dim=-1)
                    all_grad_vals = torch.cat((grad_vals, grad_auxk_vals), dim=-1)
                else:
                    x, weight, inds = ctx.saved_tensors

                    all_inds = inds
                    all_grad_vals = grad_vals

                grad_sum = torch.zeros(self.n_dirs_local, dtype=torch.float32, device=grad_vals.device)
                grad_sum.scatter_add_(
                    -1, all_inds.flatten(), all_grad_vals.flatten().to(torch.float32)
                )

                return (
                    None,
                    # pre_bias grad optimization - can reduce before mat-vec multiply
                    -(grad_sum @ weight),
                    triton_sparse_transpose_dense_matmul(all_inds, all_grad_vals, x, N=self.n_dirs_local),
                    grad_sum,
                )

        pre_bias = self.comms.sh_allreduce_backward(self.pre_bias)

        # encoder
        inds, vals, auxk_inds, auxk_vals = EncWrapper.apply(
            x, pre_bias, self.encoder.weight, self.latent_bias
        )

        vals = torch.relu(vals)
        if auxk_vals is not None:
            auxk_vals = torch.relu(auxk_vals)

        recons = self.decode_sparse(inds, vals)

        return recons, {
            "auxk_inds": auxk_inds,
            "auxk_vals": auxk_vals,
        }