def median_filter_cuda()

in whisper/triton_ops.py [0:0]


def median_filter_cuda(x: torch.Tensor, filter_width: int):
    """Apply a median filter of given width along the last dimension of x"""
    slices = x.contiguous().unfold(-1, filter_width, 1)
    grid = np.prod(slices.shape[:-2])

    kernel = median_kernel(filter_width)
    y = torch.empty_like(slices[..., 0])

    BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
    kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)

    return y