in sparse_autoencoder/train.py [0:0]
def forward(self, x):
class EncWrapper(torch.autograd.Function):
@staticmethod
def forward(ctx, x, pre_bias, weight, latent_bias):
x = x - pre_bias
latents_pre_act = F.linear(x, weight, latent_bias)
inds, vals = sharded_topk(
latents_pre_act,
k=self.k,
sh_comm=self.comms.sh_comm,
capacity_factor=4,
)
## set num nonzero stat ##
tmp = torch.zeros_like(self.stats_last_nonzero)
tmp.scatter_add_(
0,
inds.reshape(-1),
(vals > 1e-3).to(tmp.dtype).reshape(-1),
)
self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
self.stats_last_nonzero += 1
## end stats ##
## auxk
if self.auxk is not None: # for auxk
# IMPORTANT: has to go after stats update!
# WARN: auxk_mask_fn can mutate latents_pre_act!
auxk_inds, auxk_vals = sharded_topk(
self.auxk_mask_fn(latents_pre_act),
k=self.auxk,
sh_comm=self.comms.sh_comm,
capacity_factor=2,
)
ctx.save_for_backward(x, weight, inds, auxk_inds)
else:
ctx.save_for_backward(x, weight, inds)
auxk_inds = None
auxk_vals = None
## end auxk
return (
inds,
vals,
auxk_inds,
auxk_vals,
)
@staticmethod
def backward(ctx, _, grad_vals, __, grad_auxk_vals):
# encoder backwards
if self.auxk is not None:
x, weight, inds, auxk_inds = ctx.saved_tensors
all_inds = torch.cat((inds, auxk_inds), dim=-1)
all_grad_vals = torch.cat((grad_vals, grad_auxk_vals), dim=-1)
else:
x, weight, inds = ctx.saved_tensors
all_inds = inds
all_grad_vals = grad_vals
grad_sum = torch.zeros(self.n_dirs_local, dtype=torch.float32, device=grad_vals.device)
grad_sum.scatter_add_(
-1, all_inds.flatten(), all_grad_vals.flatten().to(torch.float32)
)
return (
None,
# pre_bias grad optimization - can reduce before mat-vec multiply
-(grad_sum @ weight),
triton_sparse_transpose_dense_matmul(all_inds, all_grad_vals, x, N=self.n_dirs_local),
grad_sum,
)
pre_bias = self.comms.sh_allreduce_backward(self.pre_bias)
# encoder
inds, vals, auxk_inds, auxk_vals = EncWrapper.apply(
x, pre_bias, self.encoder.weight, self.latent_bias
)
vals = torch.relu(vals)
if auxk_vals is not None:
auxk_vals = torch.relu(auxk_vals)
recons = self.decode_sparse(inds, vals)
return recons, {
"auxk_inds": auxk_inds,
"auxk_vals": auxk_vals,
}