sparse_autoencoder/kernels.py [421:466]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def triton_add_mul_(
    x: torch.Tensor,
    a: torch.Tensor,
    b: torch.Tensor,
    c: float,
):
    """
    does
    x += a * b * c

    x : [m, n]
    a : [m, n]
    b : [m, n]
    c : float
    """

    if len(a.shape) == 1:
        a = a[None, :].broadcast_to(x.shape)

    if len(b.shape) == 1:
        b = b[None, :].broadcast_to(x.shape)

    assert x.shape == a.shape == b.shape

    BLOCK_SIZE_M = 64
    BLOCK_SIZE_N = 64
    grid = lambda META: (
        triton.cdiv(x.shape[0], META["BLOCK_SIZE_M"]),
        triton.cdiv(x.shape[1], META["BLOCK_SIZE_N"]),
    )
    triton_add_mul_kernel[grid](
        x,
        a,
        b,
        c,
        x.stride(0),
        x.stride(1),
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        BLOCK_SIZE_M,
        BLOCK_SIZE_N,
        x.shape[0],
        x.shape[1],
    )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



