struct alignas()

in kernels/fmha/gemm.h [80:152]


struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {

    // The size of a load/store.
    static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);

    // Clear the fragment. Using PTX in that code seems to produce better SASS...
    inline __device__ void clear() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
        }
    }

    // Immutable access to a register.
    inline __device__ const uint32_t& reg(int ii) const {
        return this->regs_[ii];
    }

    // Mutable access to a register.
    inline __device__ uint32_t& reg(int ii) {
        return this->regs_[ii];
    }

    uint32_t regs_[Base_::NUM_REGS];

    // Immutable access to the elements.
    inline __device__ const Data_type_& elt(int ii) const {
        return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    inline __device__ Data_type_& elt(int ii) {
        return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
    }

    // Immutable access to the elements with a cast.
    template< typename Cast_type >
    inline __device__ const Cast_type& elt_as(int ii) const {
        return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
    }

    // Mutable access to the elements.
    template< typename Cast_type >
    inline __device__ Cast_type& elt_as(int ii) {
        return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
    }

    // Add another fragment.
    inline __device__ void add(const Fragment &other) {
        // TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
        // Also are we doing int addition or __half2 addition?
        #pragma unroll
        for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
            this->elt(ii) += other.elt(ii);
        }
    }

    // Multiply by another fragment.
    inline __device__ void hmul(const Fragment &other) {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
        }
    }

    template <typename elem_type>
    inline __device__ void hrelu_() {
        #pragma unroll
        for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
            this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
        }
    }
};