def get_block_n_padding_for_smem_d()

in deep_gemm/jit_kernels/gemm.py [0:0]


def get_block_n_padding_for_smem_d(block_n: int) -> int:
    # NOTES: padding is for solving bank conflicts, but wastes shared memory space
    elem_size, requirement = 2, (4, 8)
    bank_stride = (block_n * elem_size) // 4
    padding = (requirement[0] - bank_stride) % requirement[1]
    return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size