source/backend/opencl/execution/cl/matmul_params_buf_mnn_cl.cpp (1,393 lines of code) (raw):

#include "opencl_source_map.hpp" namespace MNN { #ifndef MNN_OPENCL_BUFFER_CLOSED const char* matmul_params_buf = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" "// =================================================================================================\n" "#define USE_INLINE_KEYWORD 1\n" "#ifndef MWG\n" " #define MWG 8 // Tile-size in dimension M (e.g. 64,128)\n" "#endif\n" "#ifndef NWG\n" " #define NWG 8 // Tile-size in dimension N (e.g. 64,128)\n" "#endif\n" "#ifndef KWG\n" " #define KWG 16 // Tile-size in dimension K (e.g. 8,16)\n" "#endif\n" "#ifndef MDIMC\n" " #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8,16,32)\n" "#endif\n" "#ifndef NDIMC\n" " #define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8,16,32)\n" "#endif\n" "#ifndef MDIMA\n" " #define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA*MDIMA (kernel 0 only)\n" "#endif\n" "#ifndef NDIMB\n" " #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB*NDIMB (kernel 0 only)\n" "#endif\n" "#ifndef KWI\n" " #define KWI 2 // Unroll factor of the KWG loop (smaller or equal than KWG)\n" "#endif\n" "#ifndef VWM\n" " #define VWM 1 // Vector width of matrices A and C\n" "#endif\n" "#ifndef VWN\n" " #define VWN 1 // Vector width of matrix B\n" "#endif\n" "#ifndef STRM\n" " #define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)\n" "#endif\n" "#ifndef STRN\n" " #define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)\n" "#endif\n" "#ifndef SA\n" " #define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)\n" "#endif\n" "#ifndef SB\n" " #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)\n" "#endif\n" "// Helper parameters based on the above tuning parameters\n" "#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)\n" "#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)\n" "#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA*MDIMA\n" "#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB*NDIMB\n" "#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)\n" "#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)\n" "#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)\n" "#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)\n" "// Settings\n" "#ifndef USE_VECTOR_MAD\n" " #define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually\n" "#endif\n" "#ifndef GLOBAL_MEM_FENCE\n" " #define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance\n" "#endif\n" "// Pointers to local memory objects (using a define because CUDA doesn't need them)\n" "#ifndef LOCAL_PTR\n" " #define LOCAL_PTR __local\n" "#endif\n" "// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific\n" "// devices,this is enabled (see src/routine.cpp).\n" "#ifndef USE_CL_MAD\n" " #define USE_CL_MAD 0\n" "#endif\n" "// BIAS_TYPE\n" "// 0 -> without bias\n" "// 1 -> with bias (add) [N]\n" "// 2 -> with bias (eltwise_add) [M,N]\n" "// 3 -> with bias (eltwise_sub) [M,N]\n" "// 4 -> with bias (eltwise_sub and get negative) [M,N]\n" "// 5 -> with bias (mask 0 for invalid) [M,N]\n" "#ifndef BIAS_TYPE\n" " #define BIAS_TYPE 0\n" "#endif\n" "#if BIAS_TYPE == 1\n" "#define DEAL_BIAS(x,a) x=x+a\n" "#elif BIAS_TYPE == 2\n" "#define DEAL_BIAS(x,a) x=x+a\n" "#elif BIAS_TYPE == 3\n" "#define DEAL_BIAS(x,a) x=x-a\n" "#elif BIAS_TYPE == 4\n" "#define DEAL_BIAS(x,a) x=a-x\n" "#elif BIAS_TYPE == 5\n" "#define DEAL_BIAS(x,a) x=(a == 0 ? (FLOAT)(-FLT_MAX) : x)\n" "#endif\n" "// By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size\n" "// requirement results in worse performance and is disabled (src/utilities/compile.cpp)\n" "#ifndef RELAX_WORKGROUP_SIZE\n" " #define RELAX_WORKGROUP_SIZE 0\n" "#endif\n" "typedef float real_arg;\n" "#define GetRealArg(x) (FLOAT)x\n" "typedef FLOAT real;\n" "#ifndef PRECISION_COMPUTE\n" "#define PRECISION_COMPUTE COMPUTE_FLOAT\n" "#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x)\n" "#endif\n" "#ifndef PRECISION_COMPUTE2\n" "#define PRECISION_COMPUTE2 COMPUTE_FLOAT2\n" "#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x)\n" "#endif\n" "#ifndef PRECISION_COMPUTE4\n" "#define PRECISION_COMPUTE4 COMPUTE_FLOAT4\n" "#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x)\n" "#endif\n" "#ifndef PRECISION_COMPUTE8\n" "#define PRECISION_COMPUTE8 COMPUTE_FLOAT8\n" "#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x)\n" "#endif\n" "#ifndef PRECISION_COMPUTE16\n" "#define PRECISION_COMPUTE16 COMPUTE_FLOAT16\n" "#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x)\n" "#endif\n" "#define ZERO (PRECISION_COMPUTE)0.0f\n" "// Sets a variable to zero\n" "#define SetToZero(a) a=ZERO\n" "#define IsZero(a) (a == ZERO)\n" "#define Multiply(c,a,b) c=a*b\n" "#if USE_CL_MAD == 1\n" "#define MultiplyAdd(c,a,b) c=mad(a,b,c)\n" "#else\n" "#define MultiplyAdd(c,a,b) c += a*b\n" "#endif\n" "#define AXPBY(e,a,b,c,d) e=a*b+c*d\n" "// Force inlining functions or not: some compilers don't support the inline keyword\n" "#ifdef USE_INLINE_KEYWORD\n" " #define INLINE_FUNC inline\n" "#else\n" " #define INLINE_FUNC\n" "#endif\n" "INLINE_FUNC int GetGroupID1() { return get_group_id(1); }\n" "INLINE_FUNC int GetGroupID0() { return get_group_id(0); }\n" "// =================================================================================================\n" "// Data-widths in dimension M\n" "#if VWM == 1\n" " typedef FLOAT realM;\n" " #define COMPUTE_FLOATM PRECISION_COMPUTE\n" " #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x)\n" " #define CONVERT_FLOATM(x) CONVERT_FLOAT(x)\n" "#elif VWM == 2\n" " typedef FLOAT2 realM;\n" " #define COMPUTE_FLOATM PRECISION_COMPUTE2\n" " #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x)\n" " #define CONVERT_FLOATM(x) CONVERT_FLOAT2(x)\n" "#elif VWM == 4\n" " typedef FLOAT4 realM;\n" " #define COMPUTE_FLOATM PRECISION_COMPUTE4\n" " #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x)\n" " #define CONVERT_FLOATM(x) CONVERT_FLOAT4(x)\n" "#elif VWM == 8\n" " typedef FLOAT8 realM;\n" " #define COMPUTE_FLOATM PRECISION_COMPUTE8\n" " #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x)\n" " #define CONVERT_FLOATM(x) CONVERT_FLOAT8(x)\n" "#elif VWM == 16\n" " typedef FLOAT16 realM;\n" " #define COMPUTE_FLOATM PRECISION_COMPUTE16\n" " #define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x)\n" " #define CONVERT_FLOATM(x) CONVERT_FLOAT16(x)\n" "#endif\n" "// Data-widths in dimension N\n" "#if VWN == 1\n" " typedef FLOAT realN;\n" " typedef int intN;\n" " #define COMPUTE_FLOATN PRECISION_COMPUTE\n" " #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x)\n" " #define CONVERT_FLOATN(x) CONVERT_FLOAT(x)\n" "#elif VWN == 2\n" " typedef FLOAT2 realN;\n" " typedef int2 intN;\n" " #define COMPUTE_FLOATN PRECISION_COMPUTE2\n" " #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x)\n" " #define CONVERT_FLOATN(x) CONVERT_FLOAT2(x)\n" "#elif VWN == 4\n" " typedef FLOAT4 realN;\n" " typedef int4 intN;\n" " #define COMPUTE_FLOATN PRECISION_COMPUTE4\n" " #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x)\n" " #define CONVERT_FLOATN(x) CONVERT_FLOAT4(x)\n" "#elif VWN == 8\n" " typedef FLOAT8 realN;\n" " typedef int8 intN;\n" " #define COMPUTE_FLOATN PRECISION_COMPUTE8\n" " #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x)\n" " #define CONVERT_FLOATN(x) CONVERT_FLOAT8(x)\n" "#elif VWN == 16\n" " typedef FLOAT16 realN;\n" " typedef int16 intN;\n" " #define COMPUTE_FLOATN PRECISION_COMPUTE16\n" " #define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x)\n" " #define CONVERT_FLOATN(x) CONVERT_FLOAT16(x)\n" "#endif\n" "// =================================================================================================\n" "// Initializes the accumulation registers to zero\n" "INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() {\n" " COMPUTE_FLOATM result;\n" " #if VWM == 1\n" " SetToZero(result);\n" " #elif VWM == 2\n" " SetToZero(result.x);\n" " SetToZero(result.y);\n" " #elif VWM == 4\n" " SetToZero(result.x);\n" " SetToZero(result.y);\n" " SetToZero(result.z);\n" " SetToZero(result.w);\n" " #elif VWM == 8\n" " SetToZero(result.s0);\n" " SetToZero(result.s1);\n" " SetToZero(result.s2);\n" " SetToZero(result.s3);\n" " SetToZero(result.s4);\n" " SetToZero(result.s5);\n" " SetToZero(result.s6);\n" " SetToZero(result.s7);\n" " #elif VWM == 16\n" " SetToZero(result.s0);\n" " SetToZero(result.s1);\n" " SetToZero(result.s2);\n" " SetToZero(result.s3);\n" " SetToZero(result.s4);\n" " SetToZero(result.s5);\n" " SetToZero(result.s6);\n" " SetToZero(result.s7);\n" " SetToZero(result.s8);\n" " SetToZero(result.s9);\n" " SetToZero(result.sA);\n" " SetToZero(result.sB);\n" " SetToZero(result.sC);\n" " SetToZero(result.sD);\n" " SetToZero(result.sE);\n" " SetToZero(result.sF);\n" " #endif\n" " return result;\n" "}\n" "INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() {\n" " COMPUTE_FLOATN result;\n" " #if VWN == 1\n" " SetToZero(result);\n" " #elif VWN == 2\n" " SetToZero(result.x);\n" " SetToZero(result.y);\n" " #elif VWN == 4\n" " SetToZero(result.x);\n" " SetToZero(result.y);\n" " SetToZero(result.z);\n" " SetToZero(result.w);\n" " #elif VWN == 8\n" " SetToZero(result.s0);\n" " SetToZero(result.s1);\n" " SetToZero(result.s2);\n" " SetToZero(result.s3);\n" " SetToZero(result.s4);\n" " SetToZero(result.s5);\n" " SetToZero(result.s6);\n" " SetToZero(result.s7);\n" " #elif VWN == 16\n" " SetToZero(result.s0);\n" " SetToZero(result.s1);\n" " SetToZero(result.s2);\n" " SetToZero(result.s3);\n" " SetToZero(result.s4);\n" " SetToZero(result.s5);\n" " SetToZero(result.s6);\n" " SetToZero(result.s7);\n" " SetToZero(result.s8);\n" " SetToZero(result.s9);\n" " SetToZero(result.sA);\n" " SetToZero(result.sB);\n" " SetToZero(result.sC);\n" " SetToZero(result.sD);\n" " SetToZero(result.sE);\n" " SetToZero(result.sF);\n" " #endif\n" " return result;\n" "}\n" "// =================================================================================================\n" "// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for\n" "// caching the A input matrix.\n" "#if SA == 1\n" "INLINE_FUNC void GlobalToLocalA(const __global realM* restrict agm,LOCAL_PTR realM* alm,\n" " const int kSizeM,const int tid,const int kwg) {\n" " const int la0=tid % MDIMA;\n" " const int la1=tid/MDIMA;\n" " #pragma unroll\n" " for (int _mia=0; _mia<MWA/VWM; _mia += 1) {\n" " #pragma unroll\n" " for (int _kia=0; _kia<KWA; _kia += 1) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRM == 0\n" " int mg=_mia+la0*(MWA/VWM);\n" " #elif STRM == 1\n" " int mg=la0+_mia*MDIMA;\n" " #endif\n" " // Computes the indices for the global memory\n" " int kg=_kia+la1*KWA;\n" " int idm=mg+GetGroupID0()*(MWG/VWM);\n" " int idk=kg+kwg;\n" " // Loads the data from global memory (not transposed) into the local memory\n" " alm[kg*(MWG/VWM)+mg]=agm[idk*(kSizeM/VWM)+idm];\n" " }\n" " }\n" "}\n" "#endif\n" "// Same as above,but now for the B input matrix\n" "#if SB == 1\n" "INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm,LOCAL_PTR realN* blm,\n" " const int kSizeN,const int tid,const int kwg) {\n" " const int lb0=tid % NDIMB;\n" " const int lb1=tid/NDIMB;\n" " #pragma unroll\n" " for (int _kib=0; _kib<KWB; _kib += 1) {\n" " #pragma unroll\n" " for (int _nib=0; _nib<NWB/VWN; _nib += 1) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRN == 0\n" " int ng=_nib+lb0*(NWB/VWN);\n" " #elif STRN == 1\n" " int ng=lb0+_nib*NDIMB;\n" " #endif\n" " // Computes the indices for the global memory\n" " int kg=_kib+lb1*KWB;\n" " int idn=ng+GetGroupID1()*(NWG/VWN);\n" " int idk=kg+kwg;\n" " // Loads the data from global memory (transposed) into the local memory\n" " blm[kg*(NWG/VWN)+ng]=bgm[idk*(kSizeN/VWN)+idn];\n" " }\n" " }\n" "}\n" "#endif\n" "// =================================================================================================\n" "// Caches global off-chip memory directly into per-thread private memory (registers). This function\n" "// is specific for caching the A input matrix.\n" "#if SA == 0\n" "INLINE_FUNC int GlobalIndexA() {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRM == 0\n" " // [MWG/MWI,MWI/VWM,VWM]\n" " int mg=get_local_id(0)*(MWI/VWM);\n" " #elif STRM == 1\n" " // [MWI/VWM,MWG/MWI,VWM]\n" " int mg=get_local_id(0);\n" " #endif\n" " // Computes the indices for the global memory\n" " // [kSizeM/MWG,(MWG/VWM),VWM]\n" " int idm=mg+GetGroupID0()*(MWG/VWM);\n" " return idm;\n" "}\n" "INLINE_FUNC realM GlobalToPrivateOptA(const __global realM* restrict agm,const int base,const int _mi,\n" " const int astride/*kSizeM*/,const int idk) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRM == 0\n" " // [MWG/MWI,MWI/VWM,VWM]\n" " int idm=base+_mi;\n" " #elif STRM == 1\n" " // [MWI/VWM,MWG/MWI,VWM]\n" " int idm=base+_mi*MDIMC;\n" " #endif\n" " // Loads the data from global memory (not transposed) and stores into registers\n" " // [kSizeK,kSizeM/VWM,VWM]\n" " return agm[idk*(astride/VWM)+idm];\n" "}\n" "INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm,const int _mi,\n" " const int kSizeM,const int idk) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRM == 0\n" " // [MWG/MWI,MWI/VWM,VWM]\n" " int mg=_mi+get_local_id(0)*(MWI/VWM);\n" " #elif STRM == 1\n" " // [MWI/VWM,MWG/MWI,VWM]\n" " int mg=get_local_id(0)+_mi*MDIMC;\n" " #endif\n" " // Computes the indices for the global memory\n" " // [kSizeM/MWG,(MWG/VWM),VWM]\n" " int idm=mg+GetGroupID0()*(MWG/VWM);\n" " // Loads the data from global memory (not transposed) and stores into registers\n" " // [kSizeK,kSizeM/VWM,VWM]\n" " return agm[idk*(kSizeM/VWM)+idm];\n" "}\n" "#endif\n" "// Same as above,but now for the B input matrix\n" "#if SB == 0\n" "INLINE_FUNC int GlobalIndexB() {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRN == 0\n" " int ng=get_local_id(1)*(NWI/VWN);\n" " #elif STRN == 1\n" " int ng=get_local_id(1);\n" " #endif\n" " // Computes the indices for the global memory\n" " int idn=ng+GetGroupID1()*(NWG/VWN);\n" " return idn;\n" "}\n" "INLINE_FUNC realN GlobalToPrivateOptB(const __global realN* restrict bgm,const int base,const int _ni,\n" " const int bstride/*kSizeN*/,const int idk) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRN == 0\n" " int idn=base+_ni;\n" " #elif STRN == 1\n" " int idn=base+_ni*NDIMC;\n" " #endif\n" " // Loads the data from global memory (transposed) and stores into registers\n" " return bgm[idk*(bstride/VWN)+idn];\n" "}\n" "INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm,const int _ni,\n" " const int kSizeN,const int idk) {\n" " // Computes the indices based on strided/non-strided access\n" " #if STRN == 0\n" " int ng=_ni+get_local_id(1)*(NWI/VWN);\n" " #elif STRN == 1\n" " int ng=get_local_id(1)+_ni*NDIMC;\n" " #endif\n" " // Computes the indices for the global memory\n" " int idn=ng+GetGroupID1()*(NWG/VWN);\n" " // Loads the data from global memory (transposed) and stores into registers\n" " return bgm[idk*(kSizeN/VWN)+idn];\n" "}\n" "#endif\n" "// =================================================================================================\n" "// Caches on-chip local memory into per-thread private memory (registers). This function is specific\n" "// for caching the A input matrix.\n" "#if SA == 1\n" "INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm,const int _mi,const int kg) {\n" " #if STRM == 0\n" " int mg=_mi+get_local_id(0)*(MWI/VWM);\n" " #elif STRM == 1\n" " int mg=get_local_id(0)+_mi*MDIMC;\n" " #endif\n" " return alm[kg*(MWG/VWM)+mg];\n" "}\n" "#endif\n" "// Same as above,but now for the B input matrix\n" "#if SB == 1\n" "INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm,const int _ni,const int kg) {\n" " #if STRN == 0\n" " int ng=_ni+get_local_id(1)*(NWI/VWN);\n" " #elif STRN == 1\n" " int ng=get_local_id(1)+_ni*NDIMC;\n" " #endif\n" " return blm[kg*(NWG/VWN)+ng];\n" "}\n" "#endif\n" "// The vectorised multiply-add function\n" "INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec,COMPUTE_FLOATM avec,PRECISION_COMPUTE bval) {\n" " #if USE_VECTOR_MAD == 1\n" " #if USE_CL_MAD == 1\n" " cvec=mad(avec,(COMPUTE_FLOATM)bval,cvec);\n" " #else\n" " cvec += avec*bval;\n" " #endif\n" " #else\n" " #if VWM == 1\n" " MultiplyAdd(cvec,avec,bval);\n" " #elif VWM == 2\n" " MultiplyAdd(cvec.x ,avec.x,bval);\n" " MultiplyAdd(cvec.y ,avec.y,bval);\n" " #elif VWM == 4\n" " MultiplyAdd(cvec.x ,avec.x,bval);\n" " MultiplyAdd(cvec.y ,avec.y,bval);\n" " MultiplyAdd(cvec.z ,avec.z,bval);\n" " MultiplyAdd(cvec.w ,avec.w,bval);\n" " #elif VWM == 8\n" " MultiplyAdd(cvec.s0,avec.s0,bval);\n" " MultiplyAdd(cvec.s1,avec.s1,bval);\n" " MultiplyAdd(cvec.s2,avec.s2,bval);\n" " MultiplyAdd(cvec.s3,avec.s3,bval);\n" " MultiplyAdd(cvec.s4,avec.s4,bval);\n" " MultiplyAdd(cvec.s5,avec.s5,bval);\n" " MultiplyAdd(cvec.s6,avec.s6,bval);\n" " MultiplyAdd(cvec.s7,avec.s7,bval);\n" " #elif VWM == 16\n" " MultiplyAdd(cvec.s0,avec.s0,bval);\n" " MultiplyAdd(cvec.s1,avec.s1,bval);\n" " MultiplyAdd(cvec.s2,avec.s2,bval);\n" " MultiplyAdd(cvec.s3,avec.s3,bval);\n" " MultiplyAdd(cvec.s4,avec.s4,bval);\n" " MultiplyAdd(cvec.s5,avec.s5,bval);\n" " MultiplyAdd(cvec.s6,avec.s6,bval);\n" " MultiplyAdd(cvec.s7,avec.s7,bval);\n" " MultiplyAdd(cvec.s8,avec.s8,bval);\n" " MultiplyAdd(cvec.s9,avec.s9,bval);\n" " MultiplyAdd(cvec.sA,avec.sA,bval);\n" " MultiplyAdd(cvec.sB,avec.sB,bval);\n" " MultiplyAdd(cvec.sC,avec.sC,bval);\n" " MultiplyAdd(cvec.sD,avec.sD,bval);\n" " MultiplyAdd(cvec.sE,avec.sE,bval);\n" " MultiplyAdd(cvec.sF,avec.sF,bval);\n" " #endif\n" " #endif\n" " return cvec;\n" "}\n" "// The vectorised multiply-add function\n" "INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec,PRECISION_COMPUTE avec,COMPUTE_FLOATN bval) {\n" " #if USE_VECTOR_MAD == 1\n" " #if USE_CL_MAD == 1\n" " cvec=mad((COMPUTE_FLOATN)avec,bval,cvec);\n" " #else\n" " cvec += avec*bval;\n" " #endif\n" " #else\n" " #if VWN == 1\n" " MultiplyAdd(cvec,avec,bval);\n" " #elif VWN == 2\n" " MultiplyAdd(cvec.x ,avec,bval.x);\n" " MultiplyAdd(cvec.y ,avec,bval.y);\n" " #elif VWN == 4\n" " MultiplyAdd(cvec.x ,avec,bval.x);\n" " MultiplyAdd(cvec.y ,avec,bval.y);\n" " MultiplyAdd(cvec.z ,avec,bval.z);\n" " MultiplyAdd(cvec.w ,avec,bval.w);\n" " #elif VWN == 8\n" " MultiplyAdd(cvec.s0,avec,bval.s0);\n" " MultiplyAdd(cvec.s1,avec,bval.s1);\n" " MultiplyAdd(cvec.s2,avec,bval.s2);\n" " MultiplyAdd(cvec.s3,avec,bval.s3);\n" " MultiplyAdd(cvec.s4,avec,bval.s4);\n" " MultiplyAdd(cvec.s5,avec,bval.s5);\n" " MultiplyAdd(cvec.s6,avec,bval.s6);\n" " MultiplyAdd(cvec.s7,avec,bval.s7);\n" " #elif VWN == 16\n" " MultiplyAdd(cvec.s0,avec,bval.s0);\n" " MultiplyAdd(cvec.s1,avec,bval.s1);\n" " MultiplyAdd(cvec.s2,avec,bval.s2);\n" " MultiplyAdd(cvec.s3,avec,bval.s3);\n" " MultiplyAdd(cvec.s4,avec,bval.s4);\n" " MultiplyAdd(cvec.s5,avec,bval.s5);\n" " MultiplyAdd(cvec.s6,avec,bval.s6);\n" " MultiplyAdd(cvec.s7,avec,bval.s7);\n" " MultiplyAdd(cvec.s8,avec,bval.s8);\n" " MultiplyAdd(cvec.s9,avec,bval.s9);\n" " MultiplyAdd(cvec.sA,avec,bval.sA);\n" " MultiplyAdd(cvec.sB,avec,bval.sB);\n" " MultiplyAdd(cvec.sC,avec,bval.sC);\n" " MultiplyAdd(cvec.sD,avec,bval.sD);\n" " MultiplyAdd(cvec.sE,avec,bval.sE);\n" " MultiplyAdd(cvec.sF,avec,bval.sF);\n" " #endif\n" " #endif\n" " return cvec;\n" "}\n" "// =================================================================================================\n" "// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication\n" "// with the constants: Cgm=alpha*A*B+beta*Cgm=alpha*Cpm+beta*Cgm\n" "typedef struct {\n" " int index[2];\n" "} INT2;\n" "INLINE_FUNC INT2 StoreIndexM() {\n" " INT2 res;\n" " #if STRM == 0\n" " int mg=get_local_id(0)*(MWI/VWM);\n" " #elif STRM == 1\n" " int mg=get_local_id(0);\n" " #endif\n" " #if STRN == 0\n" " int ng=get_local_id(1)*NWI;\n" " #elif STRN == 1\n" " int ng=get_local_id(1)*VWN;\n" " #endif\n" " int idm=mg+GetGroupID0()*(MWG/VWM);\n" " int idn=ng+GetGroupID1()*NWG;\n" " res.index[0]=idm;\n" " res.index[1]=idn;\n" " return res;\n" "}\n" "// layout : [N,M]\n" "INLINE_FUNC void StoreResultsM(__global realM* cgm,COMPUTE_FLOATM c_value,const INT2 baseOffset,const int _mi,const int _ni,\n" " const int cstride/*kSizeM*/,\n" " const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n" " #if STRM == 0\n" " int idm=_mi+baseOffset.index[0];\n" " #elif STRM == 1\n" " int idm=baseOffset.index[0]+_mi*MDIMC;\n" " #endif\n" " #if STRN == 0\n" " int idn=_ni+baseOffset.index[1];\n" " #elif STRN == 1\n" " int idn=_ni%VWN+baseOffset.index[1]+(_ni/VWN)*VWN*NDIMC;\n" " #endif\n" " \n" " int index=idn*(cstride/VWM)+idm;\n" " COMPUTE_FLOATM result=c_value;\n" " // The final multiplication with alpha (in case beta == 0)\n" " #ifdef ONLY_HAVE_ALPHA\n" " COMPUTE_FLOATM xval=c_value;\n" " #if VWM == 1\n" " Multiply(result,alpha,xval);\n" " #elif VWM == 2\n" " Multiply(result.x,alpha,xval.x);\n" " Multiply(result.y,alpha,xval.y);\n" " #elif VWM == 4\n" " Multiply(result.x,alpha,xval.x);\n" " Multiply(result.y,alpha,xval.y);\n" " Multiply(result.z,alpha,xval.z);\n" " Multiply(result.w,alpha,xval.w);\n" " #elif VWM == 8\n" " Multiply(result.s0,alpha,xval.s0);\n" " Multiply(result.s1,alpha,xval.s1);\n" " Multiply(result.s2,alpha,xval.s2);\n" " Multiply(result.s3,alpha,xval.s3);\n" " Multiply(result.s4,alpha,xval.s4);\n" " Multiply(result.s5,alpha,xval.s5);\n" " Multiply(result.s6,alpha,xval.s6);\n" " Multiply(result.s7,alpha,xval.s7);\n" " #elif VWM == 16\n" " Multiply(result.s0,alpha,xval.s0);\n" " Multiply(result.s1,alpha,xval.s1);\n" " Multiply(result.s2,alpha,xval.s2);\n" " Multiply(result.s3,alpha,xval.s3);\n" " Multiply(result.s4,alpha,xval.s4);\n" " Multiply(result.s5,alpha,xval.s5);\n" " Multiply(result.s6,alpha,xval.s6);\n" " Multiply(result.s7,alpha,xval.s7);\n" " Multiply(result.s8,alpha,xval.s8);\n" " Multiply(result.s9,alpha,xval.s9);\n" " Multiply(result.sA,alpha,xval.sA);\n" " Multiply(result.sB,alpha,xval.sB);\n" " Multiply(result.sC,alpha,xval.sC);\n" " Multiply(result.sD,alpha,xval.sD);\n" " Multiply(result.sE,alpha,xval.sE);\n" " Multiply(result.sF,alpha,xval.sF);\n" " #endif\n" " #endif\n" " // The final multiplication with alpha and the addition with beta*C\n" " #ifdef HAVE_ALPHA_BETA\n" " COMPUTE_FLOATM xval=c_value;\n" " COMPUTE_FLOATM yval=CONVERT_COMPUTE_FLOATM(cgm[index]);\n" " #if VWM == 1\n" " AXPBY(result,alpha,xval,beta,yval);\n" " #elif VWM == 2\n" " AXPBY(result.x,alpha,xval.x,beta,yval.x);\n" " AXPBY(result.y,alpha,xval.y,beta,yval.y);\n" " #elif VWM == 4\n" " AXPBY(result.x,alpha,xval.x,beta,yval.x);\n" " AXPBY(result.y,alpha,xval.y,beta,yval.y);\n" " AXPBY(result.z,alpha,xval.z,beta,yval.z);\n" " AXPBY(result.w,alpha,xval.w,beta,yval.w);\n" " #elif VWM == 8\n" " AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n" " AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n" " AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n" " AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n" " AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n" " AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n" " AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n" " AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n" " #elif VWM == 16\n" " AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n" " AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n" " AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n" " AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n" " AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n" " AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n" " AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n" " AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n" " AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n" " AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n" " AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n" " AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n" " AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n" " AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n" " AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n" " AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n" " #endif\n" " #endif\n" " cgm[index]=CONVERT_FLOATM(result);\n" "}\n" "INLINE_FUNC INT2 StoreIndexN() {\n" " INT2 res;\n" " #if STRM == 0\n" " int mg=get_local_id(0)*MWI;\n" " #elif STRM == 1\n" " int mg=get_local_id(0)*VWM;\n" " #endif\n" " #if STRN == 0\n" " int ng=get_local_id(1)*(NWI/VWN);\n" " #elif STRN == 1\n" " int ng=get_local_id(1);\n" " #endif\n" " int idm=mg+GetGroupID0()*MWG;\n" " int idn=ng+GetGroupID1()*(NWG/VWN);\n" " \n" " res.index[0]=idm;\n" " res.index[1]=idn;\n" " return res;\n" "}\n" "// layout : [M,N]\n" "INLINE_FUNC void StoreResultsN(__global realN* cgn,COMPUTE_FLOATN c_value,\n" " const INT2 baseOffset,\n" " #if BIAS_TYPE>0\n" " #if BIAS_TYPE>1\n" " __global realN* egm,\n" " #else\n" " realN* epm,\n" " #endif\n" " #endif\n" " const int _mi,const int _ni,\n" " const int cstride/*kSizeN*/,const int dstride/*kSizeN*/,const PRECISION_COMPUTE alpha,const PRECISION_COMPUTE beta) {\n" " #if STRM == 0\n" " int idm=_mi+baseOffset.index[0];\n" " #elif STRM == 1\n" " int idm=_mi%VWM+baseOffset.index[0]+(_mi/VWM)*VWM*MDIMC;\n" " #endif\n" " #if STRN == 0\n" " int idn=_ni+baseOffset.index[1];\n" " #elif STRN == 1\n" " int idn=baseOffset.index[1]+_ni*NDIMC;\n" " #endif\n" " int index=idm*(cstride/VWN)+idn;\n" " \n" " COMPUTE_FLOATN result=c_value;\n" " \n" " // The final multiplication with alpha (in case beta == 0)\n" " #ifdef ONLY_HAVE_ALPHA\n" " COMPUTE_FLOATN xval=c_value;\n" " #if VWN == 1\n" " Multiply(result,alpha,xval);\n" " #elif VWN == 2\n" " Multiply(result.x,alpha,xval.x);\n" " Multiply(result.y,alpha,xval.y);\n" " #elif VWN == 4\n" " Multiply(result.x,alpha,xval.x);\n" " Multiply(result.y,alpha,xval.y);\n" " Multiply(result.z,alpha,xval.z);\n" " Multiply(result.w,alpha,xval.w);\n" " #elif VWN == 8\n" " Multiply(result.s0,alpha,xval.s0);\n" " Multiply(result.s1,alpha,xval.s1);\n" " Multiply(result.s2,alpha,xval.s2);\n" " Multiply(result.s3,alpha,xval.s3);\n" " Multiply(result.s4,alpha,xval.s4);\n" " Multiply(result.s5,alpha,xval.s5);\n" " Multiply(result.s6,alpha,xval.s6);\n" " Multiply(result.s7,alpha,xval.s7);\n" " #elif VWN == 16\n" " Multiply(result.s0,alpha,xval.s0);\n" " Multiply(result.s1,alpha,xval.s1);\n" " Multiply(result.s2,alpha,xval.s2);\n" " Multiply(result.s3,alpha,xval.s3);\n" " Multiply(result.s4,alpha,xval.s4);\n" " Multiply(result.s5,alpha,xval.s5);\n" " Multiply(result.s6,alpha,xval.s6);\n" " Multiply(result.s7,alpha,xval.s7);\n" " Multiply(result.s8,alpha,xval.s8);\n" " Multiply(result.s9,alpha,xval.s9);\n" " Multiply(result.sA,alpha,xval.sA);\n" " Multiply(result.sB,alpha,xval.sB);\n" " Multiply(result.sC,alpha,xval.sC);\n" " Multiply(result.sD,alpha,xval.sD);\n" " Multiply(result.sE,alpha,xval.sE);\n" " Multiply(result.sF,alpha,xval.sF);\n" " #endif\n" " #endif\n" " // The final multiplication with alpha and the addition with beta*C\n" " #ifdef HAVE_ALPHA_BETA\n" " COMPUTE_FLOATN xval=c_value;\n" " COMPUTE_FLOATN yval=CONVERT_COMPUTE_FLOATN(cgn[index]);\n" " #if VWN == 1\n" " AXPBY(result,alpha,xval,beta,yval);\n" " #elif VWN == 2\n" " AXPBY(result.x,alpha,xval.x,beta,yval.x);\n" " AXPBY(result.y,alpha,xval.y,beta,yval.y);\n" " #elif VWN == 4\n" " AXPBY(result.x,alpha,xval.x,beta,yval.x);\n" " AXPBY(result.y,alpha,xval.y,beta,yval.y);\n" " AXPBY(result.z,alpha,xval.z,beta,yval.z);\n" " AXPBY(result.w,alpha,xval.w,beta,yval.w);\n" " #elif VWN == 8\n" " AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n" " AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n" " AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n" " AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n" " AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n" " AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n" " AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n" " AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n" " #elif VWN == 16\n" " AXPBY(result.s0,alpha,xval.s0,beta,yval.s0);\n" " AXPBY(result.s1,alpha,xval.s1,beta,yval.s1);\n" " AXPBY(result.s2,alpha,xval.s2,beta,yval.s2);\n" " AXPBY(result.s3,alpha,xval.s3,beta,yval.s3);\n" " AXPBY(result.s4,alpha,xval.s4,beta,yval.s4);\n" " AXPBY(result.s5,alpha,xval.s5,beta,yval.s5);\n" " AXPBY(result.s6,alpha,xval.s6,beta,yval.s6);\n" " AXPBY(result.s7,alpha,xval.s7,beta,yval.s7);\n" " AXPBY(result.s8,alpha,xval.s8,beta,yval.s8);\n" " AXPBY(result.s9,alpha,xval.s9,beta,yval.s9);\n" " AXPBY(result.sA,alpha,xval.sA,beta,yval.sA);\n" " AXPBY(result.sB,alpha,xval.sB,beta,yval.sB);\n" " AXPBY(result.sC,alpha,xval.sC,beta,yval.sC);\n" " AXPBY(result.sD,alpha,xval.sD,beta,yval.sD);\n" " AXPBY(result.sE,alpha,xval.sE,beta,yval.sE);\n" " AXPBY(result.sF,alpha,xval.sF,beta,yval.sF);\n" " #endif\n" " #endif\n" " \n" " \n" "#if BIAS_TYPE>0\n" " #if BIAS_TYPE == 1\n" " COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(epm[_ni]);\n" " #elif BIAS_TYPE == 5\n" " int index_bias=idm*(dstride/VWN)+idn;\n" " intN eval=((__global intN*)egm)[index_bias];\n" " #else\n" " int index_bias=idm*(dstride/VWN)+idn;\n" " COMPUTE_FLOATN eval=CONVERT_COMPUTE_FLOATN(egm[index_bias]);\n" " #endif\n" " \n" " #if VWN == 1\n" " DEAL_BIAS(result,eval);\n" " #ifdef RELU\n" " result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" " result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 2\n" " DEAL_BIAS(result.x,eval.x);\n" " DEAL_BIAS(result.y,eval.y);\n" " #ifdef RELU\n" " result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" " result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 4\n" " DEAL_BIAS(result.x,eval.x);\n" " DEAL_BIAS(result.y,eval.y);\n" " DEAL_BIAS(result.z,eval.z);\n" " DEAL_BIAS(result.w,eval.w);\n" " #ifdef RELU\n" " result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" " result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 8\n" " DEAL_BIAS(result.s0,eval.s0);\n" " DEAL_BIAS(result.s1,eval.s1);\n" " DEAL_BIAS(result.s2,eval.s2);\n" " DEAL_BIAS(result.s3,eval.s3);\n" " DEAL_BIAS(result.s4,eval.s4);\n" " DEAL_BIAS(result.s5,eval.s5);\n" " DEAL_BIAS(result.s6,eval.s6);\n" " DEAL_BIAS(result.s7,eval.s7);\n" " #ifdef RELU\n" " result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" " result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #elif VWN == 16\n" " DEAL_BIAS(result.s0,eval.s0);\n" " DEAL_BIAS(result.s1,eval.s1);\n" " DEAL_BIAS(result.s2,eval.s2);\n" " DEAL_BIAS(result.s3,eval.s3);\n" " DEAL_BIAS(result.s4,eval.s4);\n" " DEAL_BIAS(result.s5,eval.s5);\n" " DEAL_BIAS(result.s6,eval.s6);\n" " DEAL_BIAS(result.s7,eval.s7);\n" " DEAL_BIAS(result.s8,eval.s8);\n" " DEAL_BIAS(result.s9,eval.s9);\n" " DEAL_BIAS(result.sA,eval.sA);\n" " DEAL_BIAS(result.sB,eval.sB);\n" " DEAL_BIAS(result.sC,eval.sC);\n" " DEAL_BIAS(result.sD,eval.sD);\n" " DEAL_BIAS(result.sE,eval.sE);\n" " DEAL_BIAS(result.sF,eval.sF);\n" " #ifdef RELU\n" " result=fmax(result,(COMPUTE_FLOATN)0);\n" " #endif\n" " #ifdef RELU6\n" " result=clamp(result,(COMPUTE_FLOATN)0,(COMPUTE_FLOATN)6);\n" " #endif\n" " #endif\n" "#endif\n" " cgn[index]=CONVERT_FLOATN(result);\n" "}\n" "// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.\n" "INLINE_FUNC void XgemmBody(const int kSizeM,const int kSizeN,const int kSizeK,const int4 stride,\n" " const __global realM* restrict agm,const __global realN* restrict bgm,\n" " #if BIAS_TYPE>0\n" " __global realN* restrict egm,\n" " #endif\n" " __global realM* cgm,const real_arg alpha,const real_arg beta\n" " #if SA == 1 && SB == 1\n" " ,LOCAL_PTR realM* alm,LOCAL_PTR realN* blm\n" " #elif SA == 1\n" " ,LOCAL_PTR realM* alm\n" " #elif SB == 1\n" " ,LOCAL_PTR realN* blm\n" " #endif\n" " ) {\n" " #ifdef OUTPUTMN\n" " #pragma promote_to_registers\n" " COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI*NWI\n" " #else\n" " #pragma promote_to_registers\n" " COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI*MWI\n" " #endif\n" " // Combined thread identifier (volatile to disable caching)\n" " #if SA == 1 || SB == 1\n" " volatile int tid=get_local_id(0)+MDIMC*get_local_id(1);\n" " #endif\n" " // Initializes the accumulation registers\n" " #ifdef OUTPUTMN\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI; _mi += 1) {\n" " cpn[_mi*(NWI/VWN)+_ni]=InitAccRegistersN();\n" " }\n" " }\n" " #else\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI; _ni += 1) {\n" " cpm[_ni*(MWI/VWM)+_mi]=InitAccRegisters();\n" " }\n" " }\n" " #endif\n" " // Loops over all workgroup tiles\n" " #if SA == 1 || SB == 1\n" " // Allocates workitem-private memory (registers)\n" " #pragma promote_to_registers\n" " COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n" " #pragma promote_to_registers\n" " COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n" " \n" " for (int kwg=0; kwg<kSizeK; kwg += KWG) {\n" " // Loads data: off-chip --> local (matrix A)\n" " #if SA == 1\n" " GlobalToLocalA(agm,alm,kSizeM,tid,kwg);\n" " #endif\n" " // Loads data: off-chip --> local (matrix B)\n" " #if SB == 1\n" " GlobalToLocalB(bgm,blm,kSizeN,tid,kwg);\n" " #endif\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" " // Loops over all workitem tiles,unrolled by a factor KWI\n" " for (int pwi=0; pwi<KWG; pwi += KWI) {\n" " #pragma unroll\n" " for (int _pit=0; _pit<KWI; _pit += 1) {\n" " #if SA == 0 || SB == 0\n" " int idk=kwg+pwi+_pit;\n" " #endif\n" " int kg=pwi+_pit;\n" " // Loads matrix A (kernel 0) or matrix B (kernel 1)\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " // Loads data: local --> private (matrix A)\n" " #if SA == 1\n" " apm[_mi]=CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm,_mi,kg));\n" " // Loads data: off-chip --> private (matrix A)\n" " #elif SA == 0\n" " apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm,_mi,kSizeM,idk));\n" " #endif\n" " }\n" " // Loads matrix B (kernel 0) or matrix A (kernel 1)\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " // Loads data: local --> private (matrix B)\n" " #if SB == 1\n" " bpm[_ni]=CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm,_ni,kg));\n" " // Loads data: off-chip --> private (matrix B)\n" " #else\n" " bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm,_ni,kSizeN,idk));\n" " #endif\n" " }\n" " // Performs the accumulation (Cpm += Apm*Bpm)\n" " #ifdef OUTPUTMN\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " const COMPUTE_FLOATM aval=apm[_mi];\n" " #if VWM == 1\n" " // [MWI/VWM,VWM,NWI/VWN,VWN]\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n" " #elif VWM == 2\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n" " #elif VWM == 4\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n" " cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n" " cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n" " #elif VWM == 8\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n" " cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n" " cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n" " cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n" " cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n" " cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n" " cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n" " #elif VWM == 16\n" " cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n" " cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n" " cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n" " cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n" " cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n" " cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n" " cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n" " cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n" " cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n" " cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n" " cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n" " cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n" " cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n" " cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n" " cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n" " cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n" " #endif\n" " }\n" " }\n" " #else\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " const COMPUTE_FLOATM aval=apm[_mi];\n" " #if VWN == 1\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni]);\n" " #elif VWN == 2\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n" " #elif VWN == 4\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].x);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].y);\n" " cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].z);\n" " cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].w);\n" " #elif VWN == 8\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n" " cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n" " cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n" " cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n" " cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n" " cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n" " cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n" " #elif VWN == 16\n" " cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bpm[_ni].s0);\n" " cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bpm[_ni].s1);\n" " cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bpm[_ni].s2);\n" " cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bpm[_ni].s3);\n" " cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bpm[_ni].s4);\n" " cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bpm[_ni].s5);\n" " cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bpm[_ni].s6);\n" " cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bpm[_ni].s7);\n" " cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bpm[_ni].s8);\n" " cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bpm[_ni].s9);\n" " cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bpm[_ni].sA);\n" " cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bpm[_ni].sB);\n" " cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bpm[_ni].sC);\n" " cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bpm[_ni].sD);\n" " cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bpm[_ni].sE);\n" " cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bpm[_ni].sF);\n" " #endif\n" " }\n" " }\n" " #endif\n" " }\n" " }\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" " }\n" " #else\n" " // Allocates workitem-private memory (registers)\n" " int baseIndexA=GlobalIndexA();\n" " int baseIndexB=GlobalIndexB();\n" " #pragma unroll\n" " for (int _kj=0; _kj<kSizeK; _kj += 4) {\n" " #ifdef OUTPUTMN\n" " #pragma promote_to_registers\n" " COMPUTE_FLOATN bpm[NWI/VWN]; // 1*NWI\n" " \n" " #pragma unroll\n" " for(int _ki=0; _ki<4; _ki += 1) {\n" " int idk=_kj+_ki;\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " // Loads data: off-chip --> private (matrix B)\n" " bpm[_ni]=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n" " }\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " const COMPUTE_FLOATM aval=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " #if VWM == 1\n" " // [MWI/VWM,VWM,NWI/VWN,VWN]\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval,bpm[_ni]);\n" " #elif VWM == 2\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n" " #elif VWM == 4\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.x,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.y,bpm[_ni]);\n" " cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.z,bpm[_ni]);\n" " cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.w,bpm[_ni]);\n" " #elif VWM == 8\n" " cpn[(_mi*VWM+0)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0)*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n" " cpn[(_mi*VWM+1)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1)*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n" " cpn[(_mi*VWM+2)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2)*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n" " cpn[(_mi*VWM+3)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3)*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n" " cpn[(_mi*VWM+4)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4)*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n" " cpn[(_mi*VWM+5)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5)*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n" " cpn[(_mi*VWM+6)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6)*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n" " cpn[(_mi*VWM+7)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7)*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n" " #elif VWM == 16\n" " cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+0 )*(NWI/VWN)+_ni],aval.s0,bpm[_ni]);\n" " cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+1 )*(NWI/VWN)+_ni],aval.s1,bpm[_ni]);\n" " cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+2 )*(NWI/VWN)+_ni],aval.s2,bpm[_ni]);\n" " cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+3 )*(NWI/VWN)+_ni],aval.s3,bpm[_ni]);\n" " cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+4 )*(NWI/VWN)+_ni],aval.s4,bpm[_ni]);\n" " cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+5 )*(NWI/VWN)+_ni],aval.s5,bpm[_ni]);\n" " cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+6 )*(NWI/VWN)+_ni],aval.s6,bpm[_ni]);\n" " cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+7 )*(NWI/VWN)+_ni],aval.s7,bpm[_ni]);\n" " cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+8 )*(NWI/VWN)+_ni],aval.s8,bpm[_ni]);\n" " cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+9 )*(NWI/VWN)+_ni],aval.s9,bpm[_ni]);\n" " cpn[(_mi*VWM+10)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+10)*(NWI/VWN)+_ni],aval.sA,bpm[_ni]);\n" " cpn[(_mi*VWM+11)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+11)*(NWI/VWN)+_ni],aval.sB,bpm[_ni]);\n" " cpn[(_mi*VWM+12)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+12)*(NWI/VWN)+_ni],aval.sC,bpm[_ni]);\n" " cpn[(_mi*VWM+13)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+13)*(NWI/VWN)+_ni],aval.sD,bpm[_ni]);\n" " cpn[(_mi*VWM+14)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+14)*(NWI/VWN)+_ni],aval.sE,bpm[_ni]);\n" " cpn[(_mi*VWM+15)*(NWI/VWN)+_ni]=MultiplyAddVectorN(cpn[(_mi*VWM+15)*(NWI/VWN)+_ni],aval.sF,bpm[_ni]);\n" " #endif\n" " }\n" " }\n" " }\n" " #else\n" " \n" " #pragma promote_to_registers\n" " COMPUTE_FLOATM apm[MWI/VWM]; // MWI*1\n" " #pragma unroll\n" " for(int _ki=0; _ki<4; _ki += 1) {\n" " int idk=_kj+_ki;\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " // Loads data: off-chip --> private (matrix B)\n" " apm[_mi]=CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm,baseIndexA,_mi,stride.s0/*kSizeM*/,idk));\n" " }\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " const COMPUTE_FLOATN bval=CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm,baseIndexB,_ni,stride.s1/*kSizeN*/,idk));\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " const COMPUTE_FLOATM aval=apm[_mi];\n" " #if VWN == 1\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval);\n" " #elif VWN == 2\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n" " #elif VWN == 4\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.x);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.y);\n" " cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.z);\n" " cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.w);\n" " #elif VWN == 8\n" " cpm[(_ni*VWN+0)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0)*(MWI/VWM)+_mi],aval,bval.s0);\n" " cpm[(_ni*VWN+1)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1)*(MWI/VWM)+_mi],aval,bval.s1);\n" " cpm[(_ni*VWN+2)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2)*(MWI/VWM)+_mi],aval,bval.s2);\n" " cpm[(_ni*VWN+3)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3)*(MWI/VWM)+_mi],aval,bval.s3);\n" " cpm[(_ni*VWN+4)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4)*(MWI/VWM)+_mi],aval,bval.s4);\n" " cpm[(_ni*VWN+5)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5)*(MWI/VWM)+_mi],aval,bval.s5);\n" " cpm[(_ni*VWN+6)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6)*(MWI/VWM)+_mi],aval,bval.s6);\n" " cpm[(_ni*VWN+7)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7)*(MWI/VWM)+_mi],aval,bval.s7);\n" " #elif VWN == 16\n" " cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+0 )*(MWI/VWM)+_mi],aval,bval.s0);\n" " cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+1 )*(MWI/VWM)+_mi],aval,bval.s1);\n" " cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+2 )*(MWI/VWM)+_mi],aval,bval.s2);\n" " cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+3 )*(MWI/VWM)+_mi],aval,bval.s3);\n" " cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+4 )*(MWI/VWM)+_mi],aval,bval.s4);\n" " cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+5 )*(MWI/VWM)+_mi],aval,bval.s5);\n" " cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+6 )*(MWI/VWM)+_mi],aval,bval.s6);\n" " cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+7 )*(MWI/VWM)+_mi],aval,bval.s7);\n" " cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+8 )*(MWI/VWM)+_mi],aval,bval.s8);\n" " cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+9 )*(MWI/VWM)+_mi],aval,bval.s9);\n" " cpm[(_ni*VWN+10)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+10)*(MWI/VWM)+_mi],aval,bval.sA);\n" " cpm[(_ni*VWN+11)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+11)*(MWI/VWM)+_mi],aval,bval.sB);\n" " cpm[(_ni*VWN+12)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+12)*(MWI/VWM)+_mi],aval,bval.sC);\n" " cpm[(_ni*VWN+13)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+13)*(MWI/VWM)+_mi],aval,bval.sD);\n" " cpm[(_ni*VWN+14)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+14)*(MWI/VWM)+_mi],aval,bval.sE);\n" " cpm[(_ni*VWN+15)*(MWI/VWM)+_mi]=MultiplyAddVector(cpm[(_ni*VWN+15)*(MWI/VWM)+_mi],aval,bval.sF);\n" " #endif\n" " }\n" " }\n" " }\n" " #endif\n" " }\n" " #endif\n" " \n" " #if GLOBAL_MEM_FENCE == 1\n" " barrier(CLK_GLOBAL_MEM_FENCE);\n" " #endif\n" " #ifdef OUTPUTMN\n" " INT2 baseOffset=StoreIndexN();\n" " #if BIAS_TYPE == 1\n" " #pragma promote_to_registers\n" " realN epm[NWI/VWN]; // MWI*1\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " #if STRN == 0\n" " int idn=_ni+baseOffset.index[1];\n" " #elif STRN == 1\n" " int idn=baseOffset.index[1]+_ni*NDIMC;\n" " #endif\n" " epm[_ni]=egm[idn];\n" " }\n" " #endif\n" " \n" " \n" " \n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI; _mi += 1) {\n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI/VWN; _ni += 1) {\n" " StoreResultsN((__global realN* )cgm,cpn[_mi*(NWI/VWN)+_ni],\n" " baseOffset,\n" " #if BIAS_TYPE>1\n" " (__global realN*)egm,\n" " #elif BIAS_TYPE == 1\n" " (realN*)epm,\n" " #endif\n" " _mi,_ni,stride.s2,stride.s3,alpha,beta);\n" " }\n" " }\n" " \n" " #else\n" " INT2 baseOffset=StoreIndexM();\n" " // Stores an MWG*NWG tile of results and performs the multiplication with alpha and beta\n" " const int cld=kSizeM;\n" " \n" " #pragma unroll\n" " for (int _ni=0; _ni<NWI; _ni += 1) {\n" " #pragma unroll\n" " for (int _mi=0; _mi<MWI/VWM; _mi += 1) {\n" " StoreResultsM(cgm,cpm[_ni*(MWI/VWM)+_mi],baseOffset,_mi,_ni,stride.s2,alpha,beta);\n" " }\n" " }\n" " #endif\n" "}\n" "// Main entry point of the kernel. This is the regular full version.\n" "#if RELAX_WORKGROUP_SIZE == 1\n" " __kernel\n" "#else\n" " __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n" "#endif\n" "void Xgemm(const int kSizeM,const int kSizeN,const int kSizeK,\n" " const real_arg arg_alpha,\n" " const real_arg arg_beta,\n" " const __global realM* restrict agm,// [K,M]\n" " const __global realN* restrict bgm,// [K,N]\n" " #if BIAS_TYPE>0\n" " __global realN* restrict egm,// [N]\n" " #endif\n" " __global realM* cgm,\n" " __private const int4 offset,\n" " __private const int4 stride\n" ") {\n" " \n" " // Adds the offsets (in case of use of a single temporary buffer for A,B,and C)\n" " agm=(const __global realM*)((const __global real*)agm+offset.s0);\n" " bgm=(const __global realN*)((const __global real*)bgm+offset.s1);\n" " cgm=(__global realM*)((__global real*)cgm+offset.s2);\n" " \n" " #if BIAS_TYPE>0\n" " egm=(__global realN*)((__global real*)egm+offset.s3);\n" " #endif\n" " // Allocates workgroup-private memory (local memory)\n" " #if SA == 1\n" " __local realM alm[KWG*MWG/VWM];\n" " #endif\n" " #if SB == 1\n" " __local realN blm[KWG*NWG/VWN];\n" " #endif\n" " \n" " // Computes the matrix-multiplication and stores the result in global memory\n" " #if SA == 1 && SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" " cgm,arg_alpha,arg_beta,alm,blm);\n" " #elif SA == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" " cgm,arg_alpha,arg_beta,alm);\n" " #elif SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" " cgm,arg_alpha,arg_beta,blm);\n" " #else\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm,bgm,\n" " #if BIAS_TYPE>0\n" " egm,\n" " #endif\n" " cgm,arg_alpha,arg_beta);\n" " #endif\n" "}\n" "#if RELAX_WORKGROUP_SIZE == 1\n" " __kernel\n" "#else\n" " __kernel __attribute__((reqd_work_group_size(MDIMC,NDIMC,1)))\n" "#endif\n" "void XgemmBatched(const int kSizeM,\n" " const int kSizeN,\n" " const int kSizeK,\n" " const real_arg arg_alpha,\n" " const real_arg arg_beta,\n" " const __global realM* restrict agm,\n" " const __global realN* restrict bgm,\n" " #if BIAS_TYPE>0\n" " __global realN* restrict egm,\n" " #endif\n" " __global realM* cgm,\n" " const int4 batch_offset,// [batch_offset_a,batch_offset_b,batch_offset_c,batch_offset_e]\n" " const int4 base_ptr_offset,// [base_ptr_offset_a,base_ptr_offset_b,base_ptr_offset_c,base_ptr_offset_e]\n" " const int4 stride,// [stride_a,stride_b,stride_c,stride_e]\n" " /*\n" " total_batch -> [loop_y,loop_x]\n" " with group batch -> [loop_y,loop_x/group_num]\n" " group_size == loop_x/group_num\n" " */\n" " const int4 group // [group_num_a,group_num_b,group_num_e,loop_x]\n" ") {\n" " const int batch=get_group_id(2);\n" " \n" " // Sets the offsets\n" " const int a_offset=base_ptr_offset.x+((batch/group.w)*group.x+(batch % group.w)/group.x)*batch_offset.x;\n" " const int b_offset=base_ptr_offset.y+((batch/group.w)*group.y+(batch % group.w)/group.y)*batch_offset.y;\n" " const int c_offset=base_ptr_offset.z+batch*batch_offset.z;\n" " const __global realM* restrict agm_=&agm[a_offset/VWM];\n" " const __global realN* restrict bgm_=&bgm[b_offset/VWN];\n" " __global realM* restrict cgm_=&cgm[c_offset/VWM];\n" " \n" " #if BIAS_TYPE>0\n" " const int e_offset=base_ptr_offset.w+((batch/group.w)*group.z+(batch % group.w)/group.z)*batch_offset.w;\n" " __global realN* restrict egm_=&egm[e_offset/VWN];\n" " #endif\n" " \n" " // Allocates workgroup-private memory (local memory)\n" " #if SA == 1\n" " __local realM alm[KWG*MWG/VWM];\n" " #endif\n" " #if SB == 1\n" " __local realN blm[KWG*NWG/VWN];\n" " #endif\n" " // Computes the matrix-multiplication and stores the result in global memory\n" " #if SA == 1 && SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" " cgm_,arg_alpha,arg_beta,alm,blm);\n" " #elif SA == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" " cgm_,arg_alpha,arg_beta,alm);\n" " #elif SB == 1\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" " cgm_,arg_alpha,arg_beta,blm);\n" " #else\n" " XgemmBody(kSizeM,kSizeN,kSizeK,stride,agm_,bgm_,\n" " #if BIAS_TYPE>0\n" " egm_,\n" " #endif\n" " cgm_,arg_alpha,arg_beta);\n" " #endif\n" "}\n" ; #endif }