source/backend/opencl/execution/cl/matmul_params_buf_mnn_cl.cpp (1,393 lines of code) (raw):
#include "opencl_source_map.hpp"
namespace MNN {
#ifndef MNN_OPENCL_BUFFER_CLOSED
const char* matmul_params_buf =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"// =================================================================================================\n"
"#define USE_INLINE_KEYWORD 1\n"
"#ifndef MWG\n"
" #define MWG 8 // Tile-size in dimension M (e.g. 64,128)\n"
"#endif\n"
"#ifndef NWG\n"
" #define NWG 8 // Tile-size in dimension N (e.g. 64,128)\n"
"#endif\n"
"#ifndef KWG\n"
" #define KWG 16 // Tile-size in dimension K (e.g. 8,16)\n"
"#endif\n"
"#ifndef MDIMC\n"
" #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8,16,32)\n"
"#endif\n"
"#ifndef NDIMC\n"
" #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8,16,32)\n"
"#endif\n"
"#ifndef MDIMA\n"
" #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA*MDIMA (kernel 0 only)\n"
"#endif\n"
"#ifndef NDIMB\n"
" #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB*NDIMB (kernel 0 only)\n"
"#endif\n"
"#ifndef KWI\n"
" #define KWI 2 // Unroll factor of the KWG loop (smaller or equal than KWG)\n"
"#endif\n"
"#ifndef VWM\n"
" #define VWM 1 // Vector width of matrices A and C\n"
"#endif\n"
"#ifndef VWN\n"
" #define VWN 1 // Vector width of matrix B\n"
"#endif\n"
"#ifndef STRM\n"
" #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef STRN\n"
" #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef SA\n"
" #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"#ifndef SB\n"
" #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)\n"
"#endif\n"
"// Helper parameters based on the above tuning parameters\n"
"#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)\n"
"#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)\n"
"#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA*MDIMA\n"
"#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB*NDIMB\n"
"#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)\n"
"#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)\n"
"#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)\n"
"#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)\n"
"// Settings\n"
"#ifndef USE_VECTOR_MAD\n"
" #define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually\n"
"#endif\n"
"#ifndef GLOBAL_MEM_FENCE\n"
" #define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance\n"
"#endif\n"
"// Pointers to local memory objects (using a define because CUDA doesn't need them)\n"
"#ifndef LOCAL_PTR\n"
" #define LOCAL_PTR __local\n"
"#endif\n"
"// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific\n"
"// devices,this is enabled (see src/routine.cpp).\n"
"#ifndef USE_CL_MAD\n"
" #define USE_CL_MAD 0\n"
"#endif\n"
"// BIAS_TYPE\n"
"// 0 -> without bias\n"
"// 1 -> with bias (add) [N]\n"
"// 2 -> with bias (eltwise_add) [M,N]\n"
"// 3 -> with bias (eltwise_sub) [M,N]\n"
"// 4 -> with bias (eltwise_sub and get negative) [M,N]\n"
"// 5 -> with bias (mask 0 for invalid) [M,N]\n"
"#ifndef BIAS_TYPE\n"
" #define BIAS_TYPE 0\n"
"#endif\n"
"#if BIAS_TYPE == 1\n"
"#define DEAL_BIAS(x,a) x=x+a\n"
"#elif BIAS_TYPE == 2\n"
"#define DEAL_BIAS(x,a) x=x+a\n"
"#elif BIAS_TYPE == 3\n"
"#define DEAL_BIAS(x,a) x=x-a\n"
"#elif BIAS_TYPE == 4\n"
"#define DEAL_BIAS(x,a) x=a-x\n"
"#elif BIAS_TYPE == 5\n"
"#define DEAL_BIAS(x,a) x=(a == 0 ? (FLOAT)(-FLT_MAX) : x)\n"
"#endif\n"
"// By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size\n"
"// requirement results in worse performance and is disabled (src/utilities/compile.cpp)\n"
"#ifndef RELAX_WORKGROUP_SIZE\n"
" #define RELAX_WORKGROUP_SIZE 0\n"
"#endif\n"
"typedef float real_arg;\n"
"#define GetRealArg(x) (FLOAT)x\n"
"typedef FLOAT real;\n"
"#ifndef PRECISION_COMPUTE\n"
"#define PRECISION_COMPUTE COMPUTE_FLOAT\n"
"#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE2\n"
"#define PRECISION_COMPUTE2 COMPUTE_FLOAT2\n"
"#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE4\n"
"#define PRECISION_COMPUTE4 COMPUTE_FLOAT4\n"
"#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE8\n"
"#define PRECISION_COMPUTE8 COMPUTE_FLOAT8\n"
"#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x)\n"
"#endif\n"
"#ifndef PRECISION_COMPUTE16\n"
"#define PRECISION_COMPUTE16 COMPUTE_FLOAT16\n"
"#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x)\n"
"#endif\n"
"#define ZERO (PRECISION_COMPUTE)0.0f\n"
"// Sets a variable to zero\n"
"#define SetToZero(a) a=ZERO\n"
"#define IsZero(a) (a == ZERO)\n"
"#define Multiply(c,a,b) c=a*b\n"
"#if USE_CL_MAD == 1\n"
"#define MultiplyAdd(c,a,b) c=mad(a,b,c)\n"
"#else\n"
"#define MultiplyAdd(c,a,b) c += a*b\n"
"#endif\n"
"#define AXPBY(e,a,b,c,d) e=a*b+c*d\n"
"// Force inlining functions or not: some compilers don't support the inline keyword\n"
"#ifdef USE_INLINE_KEYWORD\n"
" #define INLINE_FUNC inline\n"
"#else\n"
" #define INLINE_FUNC\n"
"#endif\n"
"INLINE_FUNC int GetGroupID1() { return get_group_id(1); }\n"
"INLINE_FUNC int GetGroupID0() { return get_group_id(0); }\n"
"// =================================================================================================\n"
"// Data-widths in dimension M\n"
"#if VWM == 1\n"
" typedef FLOAT realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT(x)\n"
"#elif VWM == 2\n"
" typedef FLOAT2 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE2\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT2(x)\n"
"#elif VWM == 4\n"
" typedef FLOAT4 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE4\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT4(x)\n"
"#elif VWM == 8\n"
" typedef FLOAT8 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE8\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT8(x)\n"
"#elif VWM == 16\n"
" typedef FLOAT16 realM;\n"
" #define COMPUTE_FLOATM PRECISION_COMPUTE16\n"
" #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x)\n"
" #define CONVERT_FLOATM(x) CONVERT_FLOAT16(x)\n"
"#endif\n"
"// Data-widths in dimension N\n"
"#if VWN == 1\n"
" typedef FLOAT realN;\n"
" typedef int intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT(x)\n"
"#elif VWN == 2\n"
" typedef FLOAT2 realN;\n"
" typedef int2 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE2\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT2(x)\n"
"#elif VWN == 4\n"
" typedef FLOAT4 realN;\n"
" typedef int4 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE4\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT4(x)\n"
"#elif VWN == 8\n"
" typedef FLOAT8 realN;\n"
" typedef int8 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE8\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT8(x)\n"
"#elif VWN == 16\n"
" typedef FLOAT16 realN;\n"
" typedef int16 intN;\n"
" #define COMPUTE_FLOATN PRECISION_COMPUTE16\n"
" #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x)\n"
" #define CONVERT_FLOATN(x) CONVERT_FLOAT16(x)\n"
"#endif\n"
"// =================================================================================================\n"
"// Initializes the accumulation registers to zero\n"
"INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() {\n"
" COMPUTE_FLOATM result;\n"
" #if VWM == 1\n"
" SetToZero(result);\n"
" #elif VWM == 2\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" #elif VWM == 4\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" SetToZero(result.z);\n"
" SetToZero(result.w);\n"
" #elif VWM == 8\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" #elif VWM == 16\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" SetToZero(result.s8);\n"
" SetToZero(result.s9);\n"
" SetToZero(result.sA);\n"
" SetToZero(result.sB);\n"
" SetToZero(result.sC);\n"
" SetToZero(result.sD);\n"
" SetToZero(result.sE);\n"
" SetToZero(result.sF);\n"
" #endif\n"
" return result;\n"
"}\n"
"INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() {\n"
" COMPUTE_FLOATN result;\n"
" #if VWN == 1\n"
" SetToZero(result);\n"
" #elif VWN == 2\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" #elif VWN == 4\n"
" SetToZero(result.x);\n"
" SetToZero(result.y);\n"
" SetToZero(result.z);\n"
" SetToZero(result.w);\n"
" #elif VWN == 8\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" #elif VWN == 16\n"
" SetToZero(result.s0);\n"
" SetToZero(result.s1);\n"
" SetToZero(result.s2);\n"
" SetToZero(result.s3);\n"
" SetToZero(result.s4);\n"
" SetToZero(result.s5);\n"
" SetToZero(result.s6);\n"
" SetToZero(result.s7);\n"
" SetToZero(result.s8);\n"
" SetToZero(result.s9);\n"
" SetToZero(result.sA);\n"
" SetToZero(result.sB);\n"
" SetToZero(result.sC);\n"
" SetToZero(result.sD);\n"
" SetToZero(result.sE);\n"
" SetToZero(result.sF);\n"
" #endif\n"
" return result;\n"
"}\n"
"// =================================================================================================\n"
"// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for\n"
"// caching the A input matrix.\n"
"#if SA == 1\n"
"INLINE_FUNC void GlobalToLocalA(const __global realM* restrict agm,LOCAL_PTR realM* alm,\n"
" const int kSizeM,const int tid,const int kwg) {\n"
" const int la0=tid % MDIMA;\n"
" const int la1=tid/MDIMA;\n"
" #pragma unroll\n"
" for (int _mia=0; _mia<MWA/VWM; _mia += 1) {\n"
" #pragma unroll\n"
" for (int _kia=0; _kia<KWA; _kia += 1) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" int mg=_mia+la0*(MWA/VWM);\n"
" #elif STRM == 1\n"
" int mg=la0+_mia*MDIMA;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int kg=_kia+la1*KWA;\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" int idk=kg+kwg;\n"
" // Loads the data from global memory (not transposed) into the local memory\n"
" alm[kg*(MWG/VWM)+mg]=agm[idk*(kSizeM/VWM)+idm];\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 1\n"
"INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm,LOCAL_PTR realN* blm,\n"
" const int kSizeN,const int tid,const int kwg) {\n"
" const int lb0=tid % NDIMB;\n"
" const int lb1=tid/NDIMB;\n"
" #pragma unroll\n"
" for (int _kib=0; _kib<KWB; _kib += 1) {\n"
" #pragma unroll\n"
" for (int _nib=0; _nib<NWB/VWN; _nib += 1) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=_nib+lb0*(NWB/VWN);\n"
" #elif STRN == 1\n"
" int ng=lb0+_nib*NDIMB;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int kg=_kib+lb1*KWB;\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" int idk=kg+kwg;\n"
" // Loads the data from global memory (transposed) into the local memory\n"
" blm[kg*(NWG/VWN)+ng]=bgm[idk*(kSizeN/VWN)+idn];\n"
" }\n"
" }\n"
"}\n"
"#endif\n"
"// =================================================================================================\n"
"// Caches global off-chip memory directly into per-thread private memory (registers). This function\n"
"// is specific for caching the A input matrix.\n"
"#if SA == 0\n"
"INLINE_FUNC int GlobalIndexA() {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int mg=get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int mg=get_local_id(0);\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" // [kSizeM/MWG,(MWG/VWM),VWM]\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" return idm;\n"
"}\n"
"INLINE_FUNC realM GlobalToPrivateOptA(const __global realM* restrict agm,const int base,const int _mi,\n"
" const int astride/*kSizeM*/,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int idm=base+_mi;\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int idm=base+_mi*MDIMC;\n"
" #endif\n"
" // Loads the data from global memory (not transposed) and stores into registers\n"
" // [kSizeK,kSizeM/VWM,VWM]\n"
" return agm[idk*(astride/VWM)+idm];\n"
"}\n"
"INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm,const int _mi,\n"
" const int kSizeM,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRM == 0\n"
" // [MWG/MWI,MWI/VWM,VWM]\n"
" int mg=_mi+get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" // [MWI/VWM,MWG/MWI,VWM]\n"
" int mg=get_local_id(0)+_mi*MDIMC;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" // [kSizeM/MWG,(MWG/VWM),VWM]\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" // Loads the data from global memory (not transposed) and stores into registers\n"
" // [kSizeK,kSizeM/VWM,VWM]\n"
" return agm[idk*(kSizeM/VWM)+idm];\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 0\n"
"INLINE_FUNC int GlobalIndexB() {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1);\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" return idn;\n"
"}\n"
"INLINE_FUNC realN GlobalToPrivateOptB(const __global realN* restrict bgm,const int base,const int _ni,\n"
" const int bstride/*kSizeN*/,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int idn=base+_ni;\n"
" #elif STRN == 1\n"
" int idn=base+_ni*NDIMC;\n"
" #endif\n"
" // Loads the data from global memory (transposed) and stores into registers\n"
" return bgm[idk*(bstride/VWN)+idn];\n"
"}\n"
"INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm,const int _ni,\n"
" const int kSizeN,const int idk) {\n"
" // Computes the indices based on strided/non-strided access\n"
" #if STRN == 0\n"
" int ng=_ni+get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)+_ni*NDIMC;\n"
" #endif\n"
" // Computes the indices for the global memory\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" // Loads the data from global memory (transposed) and stores into registers\n"
" return bgm[idk*(kSizeN/VWN)+idn];\n"
"}\n"
"#endif\n"
"// =================================================================================================\n"
"// Caches on-chip local memory into per-thread private memory (registers). This function is specific\n"
"// for caching the A input matrix.\n"
"#if SA == 1\n"
"INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm,const int _mi,const int kg) {\n"
" #if STRM == 0\n"
" int mg=_mi+get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0)+_mi*MDIMC;\n"
" #endif\n"
" return alm[kg*(MWG/VWM)+mg];\n"
"}\n"
"#endif\n"
"// Same as above,but now for the B input matrix\n"
"#if SB == 1\n"
"INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm,const int _ni,const int kg) {\n"
" #if STRN == 0\n"
" int ng=_ni+get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)+_ni*NDIMC;\n"
" #endif\n"
" return blm[kg*(NWG/VWN)+ng];\n"
"}\n"
"#endif\n"
"// The vectorised multiply-add function\n"
"INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec,COMPUTE_FLOATM avec,PRECISION_COMPUTE bval) {\n"
" #if USE_VECTOR_MAD == 1\n"
" #if USE_CL_MAD == 1\n"
" cvec=mad(avec,(COMPUTE_FLOATM)bval,cvec);\n"
" #else\n"
" cvec += avec*bval;\n"
" #endif\n"
" #else\n"
" #if VWM == 1\n"
" MultiplyAdd(cvec,avec,bval);\n"
" #elif VWM == 2\n"
" MultiplyAdd(cvec.x ,avec.x,bval);\n"
" MultiplyAdd(cvec.y ,avec.y,bval);\n"
" #elif VWM == 4\n"
" MultiplyAdd(cvec.x ,avec.x,bval);\n"
" MultiplyAdd(cvec.y ,avec.y,bval);\n"
" MultiplyAdd(cvec.z ,avec.z,bval);\n"
" MultiplyAdd(cvec.w ,avec.w,bval);\n"
" #elif VWM == 8\n"
" MultiplyAdd(cvec.s0,avec.s0,bval);\n"
" MultiplyAdd(cvec.s1,avec.s1,bval);\n"
" MultiplyAdd(cvec.s2,avec.s2,bval);\n"
" MultiplyAdd(cvec.s3,avec.s3,bval);\n"
" MultiplyAdd(cvec.s4,avec.s4,bval);\n"
" MultiplyAdd(cvec.s5,avec.s5,bval);\n"
" MultiplyAdd(cvec.s6,avec.s6,bval);\n"
" MultiplyAdd(cvec.s7,avec.s7,bval);\n"
" #elif VWM == 16\n"
" MultiplyAdd(cvec.s0,avec.s0,bval);\n"
" MultiplyAdd(cvec.s1,avec.s1,bval);\n"
" MultiplyAdd(cvec.s2,avec.s2,bval);\n"
" MultiplyAdd(cvec.s3,avec.s3,bval);\n"
" MultiplyAdd(cvec.s4,avec.s4,bval);\n"
" MultiplyAdd(cvec.s5,avec.s5,bval);\n"
" MultiplyAdd(cvec.s6,avec.s6,bval);\n"
" MultiplyAdd(cvec.s7,avec.s7,bval);\n"
" MultiplyAdd(cvec.s8,avec.s8,bval);\n"
" MultiplyAdd(cvec.s9,avec.s9,bval);\n"
" MultiplyAdd(cvec.sA,avec.sA,bval);\n"
" MultiplyAdd(cvec.sB,avec.sB,bval);\n"
" MultiplyAdd(cvec.sC,avec.sC,bval);\n"
" MultiplyAdd(cvec.sD,avec.sD,bval);\n"
" MultiplyAdd(cvec.sE,avec.sE,bval);\n"
" MultiplyAdd(cvec.sF,avec.sF,bval);\n"
" #endif\n"
" #endif\n"
" return cvec;\n"
"}\n"
"// The vectorised multiply-add function\n"
"INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec,PRECISION_COMPUTE avec,COMPUTE_FLOATN bval) {\n"
" #if USE_VECTOR_MAD == 1\n"
" #if USE_CL_MAD == 1\n"
" cvec=mad((COMPUTE_FLOATN)avec,bval,cvec);\n"
" #else\n"
" cvec += avec*bval;\n"
" #endif\n"
" #else\n"
" #if VWN == 1\n"
" MultiplyAdd(cvec,avec,bval);\n"
" #elif VWN == 2\n"
" MultiplyAdd(cvec.x ,avec,bval.x);\n"
" MultiplyAdd(cvec.y ,avec,bval.y);\n"
" #elif VWN == 4\n"
" MultiplyAdd(cvec.x ,avec,bval.x);\n"
" MultiplyAdd(cvec.y ,avec,bval.y);\n"
" MultiplyAdd(cvec.z ,avec,bval.z);\n"
" MultiplyAdd(cvec.w ,avec,bval.w);\n"
" #elif VWN == 8\n"
" MultiplyAdd(cvec.s0,avec,bval.s0);\n"
" MultiplyAdd(cvec.s1,avec,bval.s1);\n"
" MultiplyAdd(cvec.s2,avec,bval.s2);\n"
" MultiplyAdd(cvec.s3,avec,bval.s3);\n"
" MultiplyAdd(cvec.s4,avec,bval.s4);\n"
" MultiplyAdd(cvec.s5,avec,bval.s5);\n"
" MultiplyAdd(cvec.s6,avec,bval.s6);\n"
" MultiplyAdd(cvec.s7,avec,bval.s7);\n"
" #elif VWN == 16\n"
" MultiplyAdd(cvec.s0,avec,bval.s0);\n"
" MultiplyAdd(cvec.s1,avec,bval.s1);\n"
" MultiplyAdd(cvec.s2,avec,bval.s2);\n"
" MultiplyAdd(cvec.s3,avec,bval.s3);\n"
" MultiplyAdd(cvec.s4,avec,bval.s4);\n"
" MultiplyAdd(cvec.s5,avec,bval.s5);\n"
" MultiplyAdd(cvec.s6,avec,bval.s6);\n"
" MultiplyAdd(cvec.s7,avec,bval.s7);\n"
" MultiplyAdd(cvec.s8,avec,bval.s8);\n"
" MultiplyAdd(cvec.s9,avec,bval.s9);\n"
" MultiplyAdd(cvec.sA,avec,bval.sA);\n"
" MultiplyAdd(cvec.sB,avec,bval.sB);\n"
" MultiplyAdd(cvec.sC,avec,bval.sC);\n"
" MultiplyAdd(cvec.sD,avec,bval.sD);\n"
" MultiplyAdd(cvec.sE,avec,bval.sE);\n"
" MultiplyAdd(cvec.sF,avec,bval.sF);\n"
" #endif\n"
" #endif\n"
" return cvec;\n"
"}\n"
"// =================================================================================================\n"
"// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication\n"
"// with the constants: Cgm=alpha*A*B+beta*Cgm=alpha*Cpm+beta*Cgm\n"
"typedef struct {\n"
" int index[2];\n"
"} INT2;\n"
"INLINE_FUNC INT2 StoreIndexM() {\n"
" INT2 res;\n"
" #if STRM == 0\n"
" int mg=get_local_id(0)*(MWI/VWM);\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0);\n"
" #endif\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*NWI;\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1)*VWN;\n"
" #endif\n"
" int idm=mg+GetGroupID0()*(MWG/VWM);\n"
" int idn=ng+GetGroupID1()*NWG;\n"
" res.index[0]=idm;\n"
" res.index[1]=idn;\n"
" return res;\n"
"}\n"
"// layout : [N,M]\n"
"INLINE_FUNC void StoreResultsM(__global realM* cgm,COMPUTE_FLOATM c_value,const INT2 baseOffset,const int _mi,const int _ni,\n"
" const int cstride/*kSizeM*/,\n"
" const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n"
" #if STRM == 0\n"
" int idm=_mi+baseOffset.index[0];\n"
" #elif STRM == 1\n"
" int idm=baseOffset.index[0]+_mi*MDIMC;\n"
" #endif\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=_ni%VWN+baseOffset.index[1]+(_ni/VWN)*VWN*NDIMC;\n"
" #endif\n"
" \n"
" int index=idn*(cstride/VWM)+idm;\n"
" COMPUTE_FLOATM result=c_value;\n"
" // The final multiplication with alpha (in case beta == 0)\n"
" #ifdef ONLY_HAVE_ALPHA\n"
" COMPUTE_FLOATM xval=c_value;\n"
" #if VWM == 1\n"
" Multiply(result,alpha,xval);\n"
" #elif VWM == 2\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" #elif VWM == 4\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" Multiply(result.z,alpha,xval.z);\n"
" Multiply(result.w,alpha,xval.w);\n"
" #elif VWM == 8\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" #elif VWM == 16\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" Multiply(result.s8,alpha,xval.s8);\n"
" Multiply(result.s9,alpha,xval.s9);\n"
" Multiply(result.sA,alpha,xval.sA);\n"
" Multiply(result.sB,alpha,xval.sB);\n"
" Multiply(result.sC,alpha,xval.sC);\n"
" Multiply(result.sD,alpha,xval.sD);\n"
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATM xval=c_value;\n"
" COMPUTE_FLOATM yval=CONVERT_COMPUTE_FLOATM(cgm[index]);\n"
" #if VWM == 1\n"
" AXPBY(result,alpha,xval,beta,yval);\n"
" #elif VWM == 2\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" #elif VWM == 4\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" AXPBY(result.z,alpha,xval.z,beta,yval.z);\n"
" AXPBY(result.w,alpha,xval.w,beta,yval.w);\n"
" #elif VWM == 8\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" #elif VWM == 16\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n"
" AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n"
" AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n"
" AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n"
" AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n"
" AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n"
" AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n"
" AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n"
" #endif\n"
" #endif\n"
" cgm[index]=CONVERT_FLOATM(result);\n"
"}\n"
"INLINE_FUNC INT2 StoreIndexN() {\n"
" INT2 res;\n"
" #if STRM == 0\n"
" int mg=get_local_id(0)*MWI;\n"
" #elif STRM == 1\n"
" int mg=get_local_id(0)*VWM;\n"
" #endif\n"
" #if STRN == 0\n"
" int ng=get_local_id(1)*(NWI/VWN);\n"
" #elif STRN == 1\n"
" int ng=get_local_id(1);\n"
" #endif\n"
" int idm=mg+GetGroupID0()*MWG;\n"
" int idn=ng+GetGroupID1()*(NWG/VWN);\n"
" \n"
" res.index[0]=idm;\n"
" res.index[1]=idn;\n"
" return res;\n"
"}\n"
"// layout : [M,N]\n"
"INLINE_FUNC void StoreResultsN(__global realN* cgn,COMPUTE_FLOATN c_value,\n"
" const INT2 baseOffset,\n"
" #if BIAS_TYPE>0\n"
" #if BIAS_TYPE>1\n"
" __global realN* egm,\n"
" #else\n"
" realN* epm,\n"
" #endif\n"
" #endif\n"
" const int _mi,const int _ni,\n"
" const int cstride/*kSizeN*/,const int dstride/*kSizeN*/,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n"
" #if STRM == 0\n"
" int idm=_mi+baseOffset.index[0];\n"
" #elif STRM == 1\n"
" int idm=_mi%VWM+baseOffset.index[0]+(_mi/VWM)*VWM*MDIMC;\n"
" #endif\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=baseOffset.index[1]+_ni*NDIMC;\n"
" #endif\n"
" int index=idm*(cstride/VWN)+idn;\n"
" \n"
" COMPUTE_FLOATN result=c_value;\n"
" \n"
" // The final multiplication with alpha (in case beta == 0)\n"
" #ifdef ONLY_HAVE_ALPHA\n"
" COMPUTE_FLOATN xval=c_value;\n"
" #if VWN == 1\n"
" Multiply(result,alpha,xval);\n"
" #elif VWN == 2\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" #elif VWN == 4\n"
" Multiply(result.x,alpha,xval.x);\n"
" Multiply(result.y,alpha,xval.y);\n"
" Multiply(result.z,alpha,xval.z);\n"
" Multiply(result.w,alpha,xval.w);\n"
" #elif VWN == 8\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" #elif VWN == 16\n"
" Multiply(result.s0,alpha,xval.s0);\n"
" Multiply(result.s1,alpha,xval.s1);\n"
" Multiply(result.s2,alpha,xval.s2);\n"
" Multiply(result.s3,alpha,xval.s3);\n"
" Multiply(result.s4,alpha,xval.s4);\n"
" Multiply(result.s5,alpha,xval.s5);\n"
" Multiply(result.s6,alpha,xval.s6);\n"
" Multiply(result.s7,alpha,xval.s7);\n"
" Multiply(result.s8,alpha,xval.s8);\n"
" Multiply(result.s9,alpha,xval.s9);\n"
" Multiply(result.sA,alpha,xval.sA);\n"
" Multiply(result.sB,alpha,xval.sB);\n"
" Multiply(result.sC,alpha,xval.sC);\n"
" Multiply(result.sD,alpha,xval.sD);\n"
" Multiply(result.sE,alpha,xval.sE);\n"
" Multiply(result.sF,alpha,xval.sF);\n"
" #endif\n"
" #endif\n"
" // The final multiplication with alpha and the addition with beta*C\n"
" #ifdef HAVE_ALPHA_BETA\n"
" COMPUTE_FLOATN xval=c_value;\n"
" COMPUTE_FLOATN yval=CONVERT_COMPUTE_FLOATN(cgn[index]);\n"
" #if VWN == 1\n"
" AXPBY(result,alpha,xval,beta,yval);\n"
" #elif VWN == 2\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" #elif VWN == 4\n"
" AXPBY(result.x,alpha,xval.x,beta,yval.x);\n"
" AXPBY(result.y,alpha,xval.y,beta,yval.y);\n"
" AXPBY(result.z,alpha,xval.z,beta,yval.z);\n"
" AXPBY(result.w,alpha,xval.w,beta,yval.w);\n"
" #elif VWN == 8\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" #elif VWN == 16\n"
" AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n"
" AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n"
" AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n"
" AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n"
" AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n"
" AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n"
" AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n"
" AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n"
" AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n"
" AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n"
" AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n"
" AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n"
" AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n"
" AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n"
" AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n"
" AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n"
" #endif\n"
" #endif\n"
" \n"
" \n"
"#if BIAS_TYPE>0\n"
" #if BIAS_TYPE == 1\n"
" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(epm[_ni]);\n"
" #elif BIAS_TYPE == 5\n"
" int index_bias=idm*(dstride/VWN)+idn;\n"
" intN eval=((__global intN*)egm)[index_bias];\n"
" #else\n"
" int index_bias=idm*(dstride/VWN)+idn;\n"
" COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(egm[index_bias]);\n"
" #endif\n"
" \n"
" #if VWN == 1\n"
" DEAL_BIAS(result,eval);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 2\n"
" DEAL_BIAS(result.x,eval.x);\n"
" DEAL_BIAS(result.y,eval.y);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 4\n"
" DEAL_BIAS(result.x,eval.x);\n"
" DEAL_BIAS(result.y,eval.y);\n"
" DEAL_BIAS(result.z,eval.z);\n"
" DEAL_BIAS(result.w,eval.w);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 8\n"
" DEAL_BIAS(result.s0,eval.s0);\n"
" DEAL_BIAS(result.s1,eval.s1);\n"
" DEAL_BIAS(result.s2,eval.s2);\n"
" DEAL_BIAS(result.s3,eval.s3);\n"
" DEAL_BIAS(result.s4,eval.s4);\n"
" DEAL_BIAS(result.s5,eval.s5);\n"
" DEAL_BIAS(result.s6,eval.s6);\n"
" DEAL_BIAS(result.s7,eval.s7);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #elif VWN == 16\n"
" DEAL_BIAS(result.s0,eval.s0);\n"
" DEAL_BIAS(result.s1,eval.s1);\n"
" DEAL_BIAS(result.s2,eval.s2);\n"
" DEAL_BIAS(result.s3,eval.s3);\n"
" DEAL_BIAS(result.s4,eval.s4);\n"
" DEAL_BIAS(result.s5,eval.s5);\n"
" DEAL_BIAS(result.s6,eval.s6);\n"
" DEAL_BIAS(result.s7,eval.s7);\n"
" DEAL_BIAS(result.s8,eval.s8);\n"
" DEAL_BIAS(result.s9,eval.s9);\n"
" DEAL_BIAS(result.sA,eval.sA);\n"
" DEAL_BIAS(result.sB,eval.sB);\n"
" DEAL_BIAS(result.sC,eval.sC);\n"
" DEAL_BIAS(result.sD,eval.sD);\n"
" DEAL_BIAS(result.sE,eval.sE);\n"
" DEAL_BIAS(result.sF,eval.sF);\n"
" #ifdef RELU\n"
" result=fmax(result,(COMPUTE_FLOATN)0);\n"
" #endif\n"
" #ifdef RELU6\n"
" result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n"
" #endif\n"
" #endif\n"
"#endif\n"
" cgn[index]=CONVERT_FLOATN(result);\n"
"}\n"
"// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.\n"
"INLINE_FUNC void XgemmBody(const int kSizeM,const int kSizeN,const int kSizeK,const int4 stride,\n"
" const __global realM* restrict agm,const __global realN* restrict bgm,\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,\n"
" #endif\n"
" __global realM* cgm,const real_arg alpha,const real_arg beta\n"
" #if SA == 1 && SB == 1\n"
" ,LOCAL_PTR realM* alm,LOCAL_PTR realN* blm\n"
" #elif SA == 1\n"
" ,LOCAL_PTR realM* alm\n"
" #elif SB == 1\n"
" ,LOCAL_PTR realN* blm\n"
" #endif\n"
" ) {\n"
" #ifdef OUTPUTMN\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI*NWI\n"
" #else\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI*MWI\n"
" #endif\n"
" // Combined thread identifier (volatile to disable caching)\n"
" #if SA == 1 || SB == 1\n"
" volatile int tid=get_local_id(0)+MDIMC*get_local_id(1);\n"
" #endif\n"
" // Initializes the accumulation registers\n"
" #ifdef OUTPUTMN\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI; _mi += 1) {\n"
" cpn[_mi*(NWI/VWN)+_ni]=InitAccRegistersN();\n"
" }\n"
" }\n"
" #else\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI; _ni += 1) {\n"
" cpm[_ni*(MWI/VWM)+_mi]=InitAccRegisters();\n"
" }\n"
" }\n"
" #endif\n"
" // Loops over all workgroup tiles\n"
" #if SA == 1 || SB == 1\n"
" // Allocates workitem-private memory (registers)\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n"
" \n"
" for (int kwg=0; kwg<kSizeK; kwg += KWG) {\n"
" // Loads data: off-chip --> local (matrix A)\n"
" #if SA == 1\n"
" GlobalToLocalA(agm,alm,kSizeM,tid,kwg);\n"
" #endif\n"
" // Loads data: off-chip --> local (matrix B)\n"
" #if SB == 1\n"
" GlobalToLocalB(bgm,blm,kSizeN,tid,kwg);\n"
" #endif\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" // Loops over all workitem tiles,unrolled by a factor KWI\n"
" for (int pwi=0; pwi<KWG; pwi += KWI) {\n"
" #pragma unroll\n"
" for (int _pit=0; _pit<KWI; _pit += 1) {\n"
" #if SA == 0 || SB == 0\n"
" int idk=kwg+pwi+_pit;\n"
" #endif\n"
" int kg=pwi+_pit;\n"
" // Loads matrix A (kernel 0) or matrix B (kernel 1)\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" // Loads data: local --> private (matrix A)\n"
" #if SA == 1\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm,_mi,kg));\n"
" // Loads data: off-chip --> private (matrix A)\n"
" #elif SA == 0\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm,_mi,kSizeM,idk));\n"
" #endif\n"
" }\n"
" // Loads matrix B (kernel 0) or matrix A (kernel 1)\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" // Loads data: local --> private (matrix B)\n"
" #if SB == 1\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm,_ni,kg));\n"
" // Loads data: off-chip --> private (matrix B)\n"
" #else\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm,_ni,kSizeN,idk));\n"
" #endif\n"
" }\n"
" // Performs the accumulation (Cpm += Apm*Bpm)\n"
" #ifdef OUTPUTMN\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWM == 1\n"
" // [MWI/VWM,VWM,NWI/VWN,VWN]\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n"
" #elif VWM == 2\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" #elif VWM == 4\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n"
" #elif VWM == 8\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" #elif VWM == 16\n"
" cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n"
" cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n"
" cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n"
" cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n"
" cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n"
" cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n"
" cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n"
" cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n"
" #endif\n"
" }\n"
" }\n"
" #else\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWN == 1\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni]);\n"
" #elif VWN == 2\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n"
" #elif VWN == 4\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].z);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].w);\n"
" #elif VWN == 8\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n"
" cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n"
" cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n"
" cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n"
" cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n"
" #elif VWN == 16\n"
" cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n"
" cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n"
" cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n"
" cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n"
" cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n"
" cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n"
" cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n"
" cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n"
" cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bpm[_ni].s8);\n"
" cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bpm[_ni].s9);\n"
" cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bpm[_ni].sA);\n"
" cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bpm[_ni].sB);\n"
" cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bpm[_ni].sC);\n"
" cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bpm[_ni].sD);\n"
" cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bpm[_ni].sE);\n"
" cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bpm[_ni].sF);\n"
" #endif\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
" }\n"
" barrier(CLK_LOCAL_MEM_FENCE);\n"
" }\n"
" #else\n"
" // Allocates workitem-private memory (registers)\n"
" int baseIndexA=GlobalIndexA();\n"
" int baseIndexB=GlobalIndexB();\n"
" #pragma unroll\n"
" for (int _kj=0; _kj<kSizeK; _kj += 4) {\n"
" #ifdef OUTPUTMN\n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n"
" \n"
" #pragma unroll\n"
" for(int _ki=0; _ki<4; _ki += 1) {\n"
" int idk=_kj+_ki;\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" // Loads data: off-chip --> private (matrix B)\n"
" bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n"
" }\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #if VWM == 1\n"
" // [MWI/VWM,VWM,NWI/VWN,VWN]\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n"
" #elif VWM == 2\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" #elif VWM == 4\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n"
" #elif VWM == 8\n"
" cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" #elif VWM == 16\n"
" cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n"
" cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n"
" cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n"
" cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n"
" cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n"
" cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n"
" cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n"
" cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n"
" cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n"
" cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n"
" cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n"
" cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n"
" cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n"
" cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n"
" cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n"
" cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n"
" #endif\n"
" }\n"
" }\n"
" }\n"
" #else\n"
" \n"
" #pragma promote_to_registers\n"
" COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n"
" #pragma unroll\n"
" for(int _ki=0; _ki<4; _ki += 1) {\n"
" int idk=_kj+_ki;\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" // Loads data: off-chip --> private (matrix B)\n"
" apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n"
" }\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" const COMPUTE_FLOATN bval=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" const COMPUTE_FLOATM aval=apm[_mi];\n"
" #if VWN == 1\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval);\n"
" #elif VWN == 2\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n"
" #elif VWN == 4\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.z);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.w);\n"
" #elif VWN == 8\n"
" cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.s0);\n"
" cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.s1);\n"
" cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.s2);\n"
" cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.s3);\n"
" cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bval.s4);\n"
" cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bval.s5);\n"
" cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bval.s6);\n"
" cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bval.s7);\n"
" #elif VWN == 16\n"
" cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bval.s0);\n"
" cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bval.s1);\n"
" cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bval.s2);\n"
" cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bval.s3);\n"
" cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bval.s4);\n"
" cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bval.s5);\n"
" cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bval.s6);\n"
" cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bval.s7);\n"
" cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bval.s8);\n"
" cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bval.s9);\n"
" cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bval.sA);\n"
" cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bval.sB);\n"
" cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bval.sC);\n"
" cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bval.sD);\n"
" cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bval.sE);\n"
" cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bval.sF);\n"
" #endif\n"
" }\n"
" }\n"
" }\n"
" #endif\n"
" }\n"
" #endif\n"
" \n"
" #if GLOBAL_MEM_FENCE == 1\n"
" barrier(CLK_GLOBAL_MEM_FENCE);\n"
" #endif\n"
" #ifdef OUTPUTMN\n"
" INT2 baseOffset=StoreIndexN();\n"
" #if BIAS_TYPE == 1\n"
" #pragma promote_to_registers\n"
" realN epm[NWI/VWN]; // MWI*1\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" #if STRN == 0\n"
" int idn=_ni+baseOffset.index[1];\n"
" #elif STRN == 1\n"
" int idn=baseOffset.index[1]+_ni*NDIMC;\n"
" #endif\n"
" epm[_ni]=egm[idn];\n"
" }\n"
" #endif\n"
" \n"
" \n"
" \n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI; _mi += 1) {\n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n"
" StoreResultsN((__global realN* )cgm,cpn[_mi*(NWI/VWN)+_ni],\n"
" baseOffset,\n"
" #if BIAS_TYPE>1\n"
" (__global realN*)egm,\n"
" #elif BIAS_TYPE == 1\n"
" (realN*)epm,\n"
" #endif\n"
" _mi,_ni,stride.s2,stride.s3,alpha,beta);\n"
" }\n"
" }\n"
" \n"
" #else\n"
" INT2 baseOffset=StoreIndexM();\n"
" // Stores an MWG*NWG tile of results and performs the multiplication with alpha and beta\n"
" const int cld=kSizeM;\n"
" \n"
" #pragma unroll\n"
" for (int _ni=0; _ni<NWI; _ni += 1) {\n"
" #pragma unroll\n"
" for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n"
" StoreResultsM(cgm,cpm[_ni*(MWI/VWM)+_mi],baseOffset,_mi,_ni,stride.s2,alpha,beta);\n"
" }\n"
" }\n"
" #endif\n"
"}\n"
"// Main entry point of the kernel. This is the regular full version.\n"
"#if RELAX_WORKGROUP_SIZE == 1\n"
" __kernel\n"
"#else\n"
" __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n"
"#endif\n"
"void Xgemm(const int kSizeM,const int kSizeN,const int kSizeK,\n"
" const real_arg arg_alpha,\n"
" const real_arg arg_beta,\n"
" const __global realM* restrict agm,// [K,M]\n"
" const __global realN* restrict bgm,// [K,N]\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,// [N]\n"
" #endif\n"
" __global realM* cgm,\n"
" __private const int4 offset,\n"
" __private const int4 stride\n"
") {\n"
" \n"
" // Adds the offsets (in case of use of a single temporary buffer for A,B,and C)\n"
" agm=(const __global realM*)((const __global real*)agm+offset.s0);\n"
" bgm=(const __global realN*)((const __global real*)bgm+offset.s1);\n"
" cgm=(__global realM*)((__global real*)cgm+offset.s2);\n"
" \n"
" #if BIAS_TYPE>0\n"
" egm=(__global realN*)((__global real*)egm+offset.s3);\n"
" #endif\n"
" // Allocates workgroup-private memory (local memory)\n"
" #if SA == 1\n"
" __local realM alm[KWG*MWG/VWM];\n"
" #endif\n"
" #if SB == 1\n"
" __local realN blm[KWG*NWG/VWN];\n"
" #endif\n"
" \n"
" // Computes the matrix-multiplication and stores the result in global memory\n"
" #if SA == 1 && SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,alm,blm);\n"
" #elif SA == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,alm);\n"
" #elif SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta,blm);\n"
" #else\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n"
" #if BIAS_TYPE>0\n"
" egm,\n"
" #endif\n"
" cgm,arg_alpha,arg_beta);\n"
" #endif\n"
"}\n"
"#if RELAX_WORKGROUP_SIZE == 1\n"
" __kernel\n"
"#else\n"
" __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n"
"#endif\n"
"void XgemmBatched(const int kSizeM,\n"
" const int kSizeN,\n"
" const int kSizeK,\n"
" const real_arg arg_alpha,\n"
" const real_arg arg_beta,\n"
" const __global realM* restrict agm,\n"
" const __global realN* restrict bgm,\n"
" #if BIAS_TYPE>0\n"
" __global realN* restrict egm,\n"
" #endif\n"
" __global realM* cgm,\n"
" const int4 batch_offset,// [batch_offset_a,batch_offset_b,batch_offset_c,batch_offset_e]\n"
" const int4 base_ptr_offset,// [base_ptr_offset_a,base_ptr_offset_b,base_ptr_offset_c,base_ptr_offset_e]\n"
" const int4 stride,// [stride_a,stride_b,stride_c,stride_e]\n"
" /*\n"
" total_batch -> [loop_y,loop_x]\n"
" with group batch -> [loop_y,loop_x/group_num]\n"
" group_size == loop_x/group_num\n"
" */\n"
" const int4 group // [group_num_a,group_num_b,group_num_e,loop_x]\n"
") {\n"
" const int batch=get_group_id(2);\n"
" \n"
" // Sets the offsets\n"
" const int a_offset=base_ptr_offset.x+((batch/group.w)*group.x+(batch % group.w)/group.x)*batch_offset.x;\n"
" const int b_offset=base_ptr_offset.y+((batch/group.w)*group.y+(batch % group.w)/group.y)*batch_offset.y;\n"
" const int c_offset=base_ptr_offset.z+batch*batch_offset.z;\n"
" const __global realM* restrict agm_=&agm[a_offset/VWM];\n"
" const __global realN* restrict bgm_=&bgm[b_offset/VWN];\n"
" __global realM* restrict cgm_=&cgm[c_offset/VWM];\n"
" \n"
" #if BIAS_TYPE>0\n"
" const int e_offset=base_ptr_offset.w+((batch/group.w)*group.z+(batch % group.w)/group.z)*batch_offset.w;\n"
" __global realN* restrict egm_=&egm[e_offset/VWN];\n"
" #endif\n"
" \n"
" // Allocates workgroup-private memory (local memory)\n"
" #if SA == 1\n"
" __local realM alm[KWG*MWG/VWM];\n"
" #endif\n"
" #if SB == 1\n"
" __local realN blm[KWG*NWG/VWN];\n"
" #endif\n"
" // Computes the matrix-multiplication and stores the result in global memory\n"
" #if SA == 1 && SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,alm,blm);\n"
" #elif SA == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,alm);\n"
" #elif SB == 1\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta,blm);\n"
" #else\n"
" XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n"
" #if BIAS_TYPE>0\n"
" egm_,\n"
" #endif\n"
" cgm_,arg_alpha,arg_beta);\n"
" #endif\n"
"}\n"
;
#endif
}