def forward()

in xformers/triton/k_layer_norm.py [0:0]


    def forward(ctx, x, weight, bias, eps):
        # allocate output
        y = torch.empty_like(x)

        # reshape input data into 2D tensor
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape

        # allocate mean and std, they'll be used in the backward pass
        mean = torch.empty((M,), dtype=torch.float32, device="cuda")
        rstd = torch.empty((M,), dtype=torch.float32, device="cuda")

        # Less than 64KB per feature: enqueue fused kernel
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE_N:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")

        if not x_arg.is_contiguous() or not y.is_contiguous():
            global _triton_registered_warnings
            if not _triton_registered_warnings:
                logging.warning("Non-contiguous input tensor found. Making it contiguous,"
                                + " but could have perf or trainer implications")

                _triton_registered_warnings = True

            x_arg = x_arg.contiguous()
            y = y.contiguous()

        # heuristics for number of warps.
        num_warps = min(max(BLOCK_SIZE_N // 256, 1), 8)

        # enqueue kernel
        # fmt: off
        if weight is None:
            _layer_norm_non_affine_fw[(M,)](
                x_arg, y, mean, rstd,
                x_arg.stride(0),
                N,
                eps,
                num_warps=num_warps,
                BLOCK_SIZE_N=BLOCK_SIZE_N
            )
        else:
            _layer_norm_fw[(M,)](
                x_arg, y, weight, bias, mean, rstd,
                x_arg.stride(0),
                N,
                eps,
                num_warps=num_warps,
                BLOCK_SIZE_N=BLOCK_SIZE_N
            )
        # fmt: on

        ctx.save_for_backward(y, rstd, weight, bias)
        ctx.BLOCK_SIZE_N = BLOCK_SIZE_N
        ctx.num_warps = num_warps
        ctx.eps = eps
        ctx.N = N

        return y.reshape_as(x)