inline __device__ void load_()

in kernels/fmha/utils.h [695:758]


inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {

    // The number of complete bytes (where we use all the predicates in a byte).
    enum { COMPLETE = N / PREDS_PER_BYTE };
    // Make sure we did allocate enough predicates.
    static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
    // The remainder.
    enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
    // Make sure we got the math right and the remainder is between 0 and 3.
    static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
    // The mask to extract the predicates.
    enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };

    // Clear the fetch registers.
    #pragma unroll
    for( int ii = 0; ii < N; ++ii ) {
        fct.clear(ii);
    }

    // Run complete steps.
    bool p[PREDS_PER_BYTE];
    #pragma unroll
    for( int ii = 0; ii < COMPLETE; ++ii ) {

        // The predicate.
        uint32_t reg = preds[ii / BYTES_PER_REG];

        // Extract the predicates.
        #pragma unroll
        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
            uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
            p[jj] = (reg & mask) != 0u;
        }

        // Issue the loads.
        #pragma unroll
        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
            fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
        }
    }

    // Skip the rest of the code if we do not have a remainder.
    if( REMAINDER > 0 ) {

        // The mask to extract the predicates.
        enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };

        // The predicate register.
        uint32_t reg = preds[COMPLETE / BYTES_PER_REG];

        // Extract the predicates.
        #pragma unroll
        for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
            uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
            p[jj] = (reg & mask) != 0u;
        }

        // Issue the loads.
        #pragma unroll
        for( int ii = 0; ii < REMAINDER; ++ii ) {
            fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
        }
    }
}