in kernels/fmha/utils.h [695:758]
inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {
// The number of complete bytes (where we use all the predicates in a byte).
enum { COMPLETE = N / PREDS_PER_BYTE };
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };
// Clear the fetch registers.
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for( int ii = 0; ii < COMPLETE; ++ii ) {
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if( REMAINDER > 0 ) {
// The mask to extract the predicates.
enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int ii = 0; ii < REMAINDER; ++ii ) {
fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}