in sparse_autoencoder/kernels.py [0:0]
def forward(ctx, output, target):
ctx.save_for_backward(output, target)
out = torch.zeros(a, dtype=torch.float32, device=output.device)
triton_mse_loss_fp16_kernel[(a,)](
output,
target,
out,
stride_a_output=output.stride(0),
stride_a_target=target.stride(0),
a=a,
b=b,
BLOCK_SIZE_B=BLOCK_SIZE_B,
)
return out