inline __device__ void store()

in candle-flash-attn-v1/kernels/fmha/smem_tile.h [1139:1217]


    inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
        // uint32_t smem_write_og = this->smem_write_;
        static constexpr int M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA;
        #pragma unroll
        for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {

            // The number of MMAs that are stored per loop iteration.
            static constexpr int MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS;

            // Store 1st column of the different MMAs.
            #pragma unroll
            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
                // Precompute the immediates to jump between rows.
                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
                uint2 tmp0, tmp1;
                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);

                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);

                // Store.
                fmha::sts(this->smem_write_ + row_0, tmp0);
                fmha::sts(this->smem_write_ + row_1, tmp1);
            }
            // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
            // }

            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     uint4 read_tmp;
            //     fmha::lds(read_tmp, this->smem_read_);
            //     printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
            // }
            // Swizzle the write pointer using a XOR of 16B.
            this->smem_write_ ^= 32;

            // Store 2nd column of the different MMAs.
            #pragma unroll
            for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
                // Precompute the immediates to jump between rows.
                int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
                int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;

                uint2 tmp0, tmp1;
                tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
                tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);

                tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
                tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
                // Store.
                fmha::sts(this->smem_write_ + row_0, tmp0);
                fmha::sts(this->smem_write_ + row_1, tmp1);
            }

            // if ((threadIdx.x == 16) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     printf("smem_write diff = %d\n", this->smem_write_ - smem_write_og);
            // }

            // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
            static_assert(Mma_tile::MMAS_N <= 8, "Not implemented");
            if(        Mma_tile::MMAS_N >= 8 && ni % 4 == 3 ) {
                this->smem_write_ ^= 15 * 32;
            } else if( Mma_tile::MMAS_N >= 4 && ni % 2 == 1 ) {
                this->smem_write_ ^= 7 * 32;
            } else if( Mma_tile::MMAS_N >= 2 ) {
                this->smem_write_ ^= 3 * 32;
            } else {
                this->smem_write_ ^= 3 * 32;
            }
            // this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
            // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0))  {
            //     uint4 read_tmp;
            //     fmha::lds(read_tmp, this->smem_read_);
            //     printf("smem_o = %.6f\n", reinterpret_cast<float (&)[4]>(read_tmp)[0]);
            // }
        }
    }