def forward()

in sparse_autoencoder/kernels.py [0:0]


        def forward(ctx, output, target):
            ctx.save_for_backward(output, target)
            out = torch.zeros(a, dtype=torch.float32, device=output.device)

            triton_mse_loss_fp16_kernel[(a,)](
                output,
                target,
                out,
                stride_a_output=output.stride(0),
                stride_a_target=target.stride(0),
                a=a,
                b=b,
                BLOCK_SIZE_B=BLOCK_SIZE_B,
            )

            return out