source/backend/opencl/execution/cl/matmul_params_buf.cl (1,364 lines of code) (raw):
#ifdef MNN_SUPPORT_FP16
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#endif
// =================================================================================================
#define USE_INLINE_KEYWORD 1
#ifndef MWG
#define MWG 8 // Tile-size in dimension M (e.g. 64, 128)
#endif
#ifndef NWG
#define NWG 8 // Tile-size in dimension N (e.g. 64, 128)
#endif
#ifndef KWG
#define KWG 16 // Tile-size in dimension K (e.g. 8, 16)
#endif
#ifndef MDIMC
#define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32)
#endif
#ifndef NDIMC
#define NDIMC 8 // Threads per workgroup in N-dimension (e.g. 8, 16, 32)
#endif
#ifndef MDIMA
#define MDIMA 8 // Re-shaped tile dimension of matrix A: KDIMA * MDIMA (kernel 0 only)
#endif
#ifndef NDIMB
#define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB (kernel 0 only)
#endif
#ifndef KWI
#define KWI 2 // Unroll factor of the KWG loop (smaller or equal than KWG)
#endif
#ifndef VWM
#define VWM 1 // Vector width of matrices A and C
#endif
#ifndef VWN
#define VWN 1 // Vector width of matrix B
#endif
#ifndef STRM
#define STRM 0 // Use strided access within a thread in the M-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef STRN
#define STRN 0 // Use strided access within a thread in the N-dimension (1) or not (0) (kernel 0 only)
#endif
#ifndef SA
#define SA 0 // Use local/shared memory to cache matrix A (1) or not (0) (kernel 0 only)
#endif
#ifndef SB
#define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only)
#endif
// Helper parameters based on the above tuning parameters
#define MWI (MWG/MDIMC) // Work per work-item (M-dimension)
#define NWI (NWG/NDIMC) // Work per work-item (N-dimension)
#define KDIMA ((MDIMC*NDIMC)/(MDIMA)) // Re-shaped tile dimension of matrix A: KDIMA * MDIMA
#define KDIMB ((MDIMC*NDIMC)/(NDIMB)) // Re-shaped tile dimension of matrix B: KDIMB * NDIMB
#define MWA (MWG/MDIMA) // Amount of loads-per-thread for matrix A (M-dimension)
#define KWA (KWG/KDIMA) // Amount of loads-per-thread for matrix A (K-dimension)
#define KWB (KWG/KDIMB) // Amount of loads-per-thread for matrix B (K-dimension)
#define NWB (NWG/NDIMB) // Amount of loads-per-thread for matrix B (N-dimension)
// Settings
#ifndef USE_VECTOR_MAD
#define USE_VECTOR_MAD 0 // Unroll (0) or don't (1) unroll the vector MAD manually
#endif
#ifndef GLOBAL_MEM_FENCE
#define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance
#endif
// Pointers to local memory objects (using a define because CUDA doesn't need them)
#ifndef LOCAL_PTR
#define LOCAL_PTR __local
#endif
// Don't use the non-IEEE754 compliant OpenCL built-in mad() instruction per default. For specific
// devices, this is enabled (see src/routine.cpp).
#ifndef USE_CL_MAD
#define USE_CL_MAD 0
#endif
// BIAS_TYPE
// 0 -> without bias
// 1 -> with bias (add) [N]
// 2 -> with bias (eltwise_add) [M, N]
// 3 -> with bias (eltwise_sub) [M, N]
// 4 -> with bias (eltwise_sub and get negative) [M, N]
// 5 -> with bias (mask 0 for invalid) [M, N]
#ifndef BIAS_TYPE
#define BIAS_TYPE 0
#endif
#if BIAS_TYPE == 1
#define DEAL_BIAS(x, a) x = x + a
#elif BIAS_TYPE == 2
#define DEAL_BIAS(x, a) x = x + a
#elif BIAS_TYPE == 3
#define DEAL_BIAS(x, a) x = x - a
#elif BIAS_TYPE == 4
#define DEAL_BIAS(x, a) x = a - x
#elif BIAS_TYPE == 5
#define DEAL_BIAS(x, a) x = (a == 0 ? (FLOAT)(-FLT_MAX) : x)
#endif
// By default the workgroup size requirement is enabled. For Qualcomm devices the workgroup size
// requirement results in worse performance and is disabled (src/utilities/compile.cpp)
#ifndef RELAX_WORKGROUP_SIZE
#define RELAX_WORKGROUP_SIZE 0
#endif
typedef float real_arg;
#define GetRealArg(x) (FLOAT)x
typedef FLOAT real;
#ifndef PRECISION_COMPUTE
#define PRECISION_COMPUTE COMPUTE_FLOAT
#define CONVERT_PRECISION_COMPUTE(x) CONVERT_COMPUTE_FLOAT(x)
#endif
#ifndef PRECISION_COMPUTE2
#define PRECISION_COMPUTE2 COMPUTE_FLOAT2
#define CONVERT_PRECISION_COMPUTE2(x) CONVERT_COMPUTE_FLOAT2(x)
#endif
#ifndef PRECISION_COMPUTE4
#define PRECISION_COMPUTE4 COMPUTE_FLOAT4
#define CONVERT_PRECISION_COMPUTE4(x) CONVERT_COMPUTE_FLOAT4(x)
#endif
#ifndef PRECISION_COMPUTE8
#define PRECISION_COMPUTE8 COMPUTE_FLOAT8
#define CONVERT_PRECISION_COMPUTE8(x) CONVERT_COMPUTE_FLOAT8(x)
#endif
#ifndef PRECISION_COMPUTE16
#define PRECISION_COMPUTE16 COMPUTE_FLOAT16
#define CONVERT_PRECISION_COMPUTE16(x) CONVERT_COMPUTE_FLOAT16(x)
#endif
#define ZERO (PRECISION_COMPUTE)0.0f
// Sets a variable to zero
#define SetToZero(a) a = ZERO
#define IsZero(a) (a == ZERO)
#define Multiply(c,a,b) c = a * b
#if USE_CL_MAD == 1
#define MultiplyAdd(c,a,b) c = mad(a, b, c)
#else
#define MultiplyAdd(c,a,b) c += a * b
#endif
#define AXPBY(e,a,b,c,d) e = a*b + c*d
// Force inlining functions or not: some compilers don't support the inline keyword
#ifdef USE_INLINE_KEYWORD
#define INLINE_FUNC inline
#else
#define INLINE_FUNC
#endif
INLINE_FUNC int GetGroupID1() { return get_group_id(1); }
INLINE_FUNC int GetGroupID0() { return get_group_id(0); }
// =================================================================================================
// Data-widths in dimension M
#if VWM == 1
typedef FLOAT realM;
#define COMPUTE_FLOATM PRECISION_COMPUTE
#define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE(x)
#define CONVERT_FLOATM(x) CONVERT_FLOAT(x)
#elif VWM == 2
typedef FLOAT2 realM;
#define COMPUTE_FLOATM PRECISION_COMPUTE2
#define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE2(x)
#define CONVERT_FLOATM(x) CONVERT_FLOAT2(x)
#elif VWM == 4
typedef FLOAT4 realM;
#define COMPUTE_FLOATM PRECISION_COMPUTE4
#define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE4(x)
#define CONVERT_FLOATM(x) CONVERT_FLOAT4(x)
#elif VWM == 8
typedef FLOAT8 realM;
#define COMPUTE_FLOATM PRECISION_COMPUTE8
#define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE8(x)
#define CONVERT_FLOATM(x) CONVERT_FLOAT8(x)
#elif VWM == 16
typedef FLOAT16 realM;
#define COMPUTE_FLOATM PRECISION_COMPUTE16
#define CONVERT_COMPUTE_FLOATM(x) CONVERT_PRECISION_COMPUTE16(x)
#define CONVERT_FLOATM(x) CONVERT_FLOAT16(x)
#endif
// Data-widths in dimension N
#if VWN == 1
typedef FLOAT realN;
typedef int intN;
#define COMPUTE_FLOATN PRECISION_COMPUTE
#define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE(x)
#define CONVERT_FLOATN(x) CONVERT_FLOAT(x)
#elif VWN == 2
typedef FLOAT2 realN;
typedef int2 intN;
#define COMPUTE_FLOATN PRECISION_COMPUTE2
#define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE2(x)
#define CONVERT_FLOATN(x) CONVERT_FLOAT2(x)
#elif VWN == 4
typedef FLOAT4 realN;
typedef int4 intN;
#define COMPUTE_FLOATN PRECISION_COMPUTE4
#define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE4(x)
#define CONVERT_FLOATN(x) CONVERT_FLOAT4(x)
#elif VWN == 8
typedef FLOAT8 realN;
typedef int8 intN;
#define COMPUTE_FLOATN PRECISION_COMPUTE8
#define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE8(x)
#define CONVERT_FLOATN(x) CONVERT_FLOAT8(x)
#elif VWN == 16
typedef FLOAT16 realN;
typedef int16 intN;
#define COMPUTE_FLOATN PRECISION_COMPUTE16
#define CONVERT_COMPUTE_FLOATN(x) CONVERT_PRECISION_COMPUTE16(x)
#define CONVERT_FLOATN(x) CONVERT_FLOAT16(x)
#endif
// =================================================================================================
// Initializes the accumulation registers to zero
INLINE_FUNC COMPUTE_FLOATM InitAccRegisters() {
COMPUTE_FLOATM result;
#if VWM == 1
SetToZero(result);
#elif VWM == 2
SetToZero(result.x);
SetToZero(result.y);
#elif VWM == 4
SetToZero(result.x);
SetToZero(result.y);
SetToZero(result.z);
SetToZero(result.w);
#elif VWM == 8
SetToZero(result.s0);
SetToZero(result.s1);
SetToZero(result.s2);
SetToZero(result.s3);
SetToZero(result.s4);
SetToZero(result.s5);
SetToZero(result.s6);
SetToZero(result.s7);
#elif VWM == 16
SetToZero(result.s0);
SetToZero(result.s1);
SetToZero(result.s2);
SetToZero(result.s3);
SetToZero(result.s4);
SetToZero(result.s5);
SetToZero(result.s6);
SetToZero(result.s7);
SetToZero(result.s8);
SetToZero(result.s9);
SetToZero(result.sA);
SetToZero(result.sB);
SetToZero(result.sC);
SetToZero(result.sD);
SetToZero(result.sE);
SetToZero(result.sF);
#endif
return result;
}
INLINE_FUNC COMPUTE_FLOATN InitAccRegistersN() {
COMPUTE_FLOATN result;
#if VWN == 1
SetToZero(result);
#elif VWN == 2
SetToZero(result.x);
SetToZero(result.y);
#elif VWN == 4
SetToZero(result.x);
SetToZero(result.y);
SetToZero(result.z);
SetToZero(result.w);
#elif VWN == 8
SetToZero(result.s0);
SetToZero(result.s1);
SetToZero(result.s2);
SetToZero(result.s3);
SetToZero(result.s4);
SetToZero(result.s5);
SetToZero(result.s6);
SetToZero(result.s7);
#elif VWN == 16
SetToZero(result.s0);
SetToZero(result.s1);
SetToZero(result.s2);
SetToZero(result.s3);
SetToZero(result.s4);
SetToZero(result.s5);
SetToZero(result.s6);
SetToZero(result.s7);
SetToZero(result.s8);
SetToZero(result.s9);
SetToZero(result.sA);
SetToZero(result.sB);
SetToZero(result.sC);
SetToZero(result.sD);
SetToZero(result.sE);
SetToZero(result.sF);
#endif
return result;
}
// =================================================================================================
// Caches global off-chip memory into local (shared) memory on-chip. This function is specific for
// caching the A input matrix.
#if SA == 1
INLINE_FUNC void GlobalToLocalA(const __global realM* restrict agm, LOCAL_PTR realM* alm,
const int kSizeM, const int tid, const int kwg) {
const int la0 = tid % MDIMA;
const int la1 = tid / MDIMA;
#pragma unroll
for (int _mia = 0; _mia < MWA/VWM; _mia += 1) {
#pragma unroll
for (int _kia = 0; _kia < KWA; _kia += 1) {
// Computes the indices based on strided/non-strided access
#if STRM == 0
int mg = _mia + la0*(MWA/VWM);
#elif STRM == 1
int mg = la0 + _mia*MDIMA;
#endif
// Computes the indices for the global memory
int kg = _kia + la1*KWA;
int idm = mg + GetGroupID0() * (MWG/VWM);
int idk = kg + kwg;
// Loads the data from global memory (not transposed) into the local memory
alm[kg*(MWG/VWM) + mg] = agm[idk*(kSizeM/VWM) + idm];
}
}
}
#endif
// Same as above, but now for the B input matrix
#if SB == 1
INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR realN* blm,
const int kSizeN, const int tid, const int kwg) {
const int lb0 = tid % NDIMB;
const int lb1 = tid / NDIMB;
#pragma unroll
for (int _kib = 0; _kib < KWB; _kib += 1) {
#pragma unroll
for (int _nib = 0; _nib < NWB/VWN; _nib += 1) {
// Computes the indices based on strided/non-strided access
#if STRN == 0
int ng = _nib + lb0*(NWB/VWN);
#elif STRN == 1
int ng = lb0 + _nib*NDIMB;
#endif
// Computes the indices for the global memory
int kg = _kib + lb1*KWB;
int idn = ng + GetGroupID1() * (NWG/VWN);
int idk = kg + kwg;
// Loads the data from global memory (transposed) into the local memory
blm[kg*(NWG/VWN) + ng] = bgm[idk*(kSizeN/VWN) + idn];
}
}
}
#endif
// =================================================================================================
// Caches global off-chip memory directly into per-thread private memory (registers). This function
// is specific for caching the A input matrix.
#if SA == 0
INLINE_FUNC int GlobalIndexA() {
// Computes the indices based on strided/non-strided access
#if STRM == 0
// [MWG/MWI, MWI/VWM, VWM]
int mg = get_local_id(0)*(MWI/VWM);
#elif STRM == 1
// [MWI/VWM, MWG/MWI, VWM]
int mg = get_local_id(0);
#endif
// Computes the indices for the global memory
// [kSizeM/MWG, (MWG/VWM), VWM]
int idm = mg + GetGroupID0() * (MWG/VWM);
return idm;
}
INLINE_FUNC realM GlobalToPrivateOptA(const __global realM* restrict agm, const int base, const int _mi,
const int astride/*kSizeM*/, const int idk) {
// Computes the indices based on strided/non-strided access
#if STRM == 0
// [MWG/MWI, MWI/VWM, VWM]
int idm = base + _mi;
#elif STRM == 1
// [MWI/VWM, MWG/MWI, VWM]
int idm = base + _mi*MDIMC;
#endif
// Loads the data from global memory (not transposed) and stores into registers
// [kSizeK, kSizeM/VWM, VWM]
return agm[idk*(astride/VWM)+idm];
}
INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi,
const int kSizeM, const int idk) {
// Computes the indices based on strided/non-strided access
#if STRM == 0
// [MWG/MWI, MWI/VWM, VWM]
int mg = _mi + get_local_id(0)*(MWI/VWM);
#elif STRM == 1
// [MWI/VWM, MWG/MWI, VWM]
int mg = get_local_id(0) + _mi*MDIMC;
#endif
// Computes the indices for the global memory
// [kSizeM/MWG, (MWG/VWM), VWM]
int idm = mg + GetGroupID0() * (MWG/VWM);
// Loads the data from global memory (not transposed) and stores into registers
// [kSizeK, kSizeM/VWM, VWM]
return agm[idk*(kSizeM/VWM) + idm];
}
#endif
// Same as above, but now for the B input matrix
#if SB == 0
INLINE_FUNC int GlobalIndexB() {
// Computes the indices based on strided/non-strided access
#if STRN == 0
int ng = get_local_id(1)*(NWI/VWN);
#elif STRN == 1
int ng = get_local_id(1);
#endif
// Computes the indices for the global memory
int idn = ng + GetGroupID1() * (NWG/VWN);
return idn;
}
INLINE_FUNC realN GlobalToPrivateOptB(const __global realN* restrict bgm, const int base, const int _ni,
const int bstride/*kSizeN*/, const int idk) {
// Computes the indices based on strided/non-strided access
#if STRN == 0
int idn = base + _ni;
#elif STRN == 1
int idn = base + _ni*NDIMC;
#endif
// Loads the data from global memory (transposed) and stores into registers
return bgm[idk*(bstride/VWN)+idn];
}
INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni,
const int kSizeN, const int idk) {
// Computes the indices based on strided/non-strided access
#if STRN == 0
int ng = _ni + get_local_id(1)*(NWI/VWN);
#elif STRN == 1
int ng = get_local_id(1) + _ni*NDIMC;
#endif
// Computes the indices for the global memory
int idn = ng + GetGroupID1() * (NWG/VWN);
// Loads the data from global memory (transposed) and stores into registers
return bgm[idk*(kSizeN/VWN) + idn];
}
#endif
// =================================================================================================
// Caches on-chip local memory into per-thread private memory (registers). This function is specific
// for caching the A input matrix.
#if SA == 1
INLINE_FUNC realM LocalToPrivateA(LOCAL_PTR realM* alm, const int _mi, const int kg) {
#if STRM == 0
int mg = _mi + get_local_id(0)*(MWI/VWM);
#elif STRM == 1
int mg = get_local_id(0) + _mi*MDIMC;
#endif
return alm[kg*(MWG/VWM) + mg];
}
#endif
// Same as above, but now for the B input matrix
#if SB == 1
INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int kg) {
#if STRN == 0
int ng = _ni + get_local_id(1)*(NWI/VWN);
#elif STRN == 1
int ng = get_local_id(1) + _ni*NDIMC;
#endif
return blm[kg*(NWG/VWN) + ng];
}
#endif
// The vectorised multiply-add function
INLINE_FUNC COMPUTE_FLOATM MultiplyAddVector(COMPUTE_FLOATM cvec, COMPUTE_FLOATM avec, PRECISION_COMPUTE bval) {
#if USE_VECTOR_MAD == 1
#if USE_CL_MAD == 1
cvec = mad(avec, (COMPUTE_FLOATM)bval, cvec);
#else
cvec += avec * bval;
#endif
#else
#if VWM == 1
MultiplyAdd(cvec, avec, bval);
#elif VWM == 2
MultiplyAdd(cvec.x , avec.x, bval);
MultiplyAdd(cvec.y , avec.y, bval);
#elif VWM == 4
MultiplyAdd(cvec.x , avec.x, bval);
MultiplyAdd(cvec.y , avec.y, bval);
MultiplyAdd(cvec.z , avec.z, bval);
MultiplyAdd(cvec.w , avec.w, bval);
#elif VWM == 8
MultiplyAdd(cvec.s0, avec.s0, bval);
MultiplyAdd(cvec.s1, avec.s1, bval);
MultiplyAdd(cvec.s2, avec.s2, bval);
MultiplyAdd(cvec.s3, avec.s3, bval);
MultiplyAdd(cvec.s4, avec.s4, bval);
MultiplyAdd(cvec.s5, avec.s5, bval);
MultiplyAdd(cvec.s6, avec.s6, bval);
MultiplyAdd(cvec.s7, avec.s7, bval);
#elif VWM == 16
MultiplyAdd(cvec.s0, avec.s0, bval);
MultiplyAdd(cvec.s1, avec.s1, bval);
MultiplyAdd(cvec.s2, avec.s2, bval);
MultiplyAdd(cvec.s3, avec.s3, bval);
MultiplyAdd(cvec.s4, avec.s4, bval);
MultiplyAdd(cvec.s5, avec.s5, bval);
MultiplyAdd(cvec.s6, avec.s6, bval);
MultiplyAdd(cvec.s7, avec.s7, bval);
MultiplyAdd(cvec.s8, avec.s8, bval);
MultiplyAdd(cvec.s9, avec.s9, bval);
MultiplyAdd(cvec.sA, avec.sA, bval);
MultiplyAdd(cvec.sB, avec.sB, bval);
MultiplyAdd(cvec.sC, avec.sC, bval);
MultiplyAdd(cvec.sD, avec.sD, bval);
MultiplyAdd(cvec.sE, avec.sE, bval);
MultiplyAdd(cvec.sF, avec.sF, bval);
#endif
#endif
return cvec;
}
// The vectorised multiply-add function
INLINE_FUNC COMPUTE_FLOATN MultiplyAddVectorN(COMPUTE_FLOATN cvec, PRECISION_COMPUTE avec, COMPUTE_FLOATN bval) {
#if USE_VECTOR_MAD == 1
#if USE_CL_MAD == 1
cvec = mad((COMPUTE_FLOATN)avec, bval, cvec);
#else
cvec += avec * bval;
#endif
#else
#if VWN == 1
MultiplyAdd(cvec, avec, bval);
#elif VWN == 2
MultiplyAdd(cvec.x , avec, bval.x);
MultiplyAdd(cvec.y , avec, bval.y);
#elif VWN == 4
MultiplyAdd(cvec.x , avec, bval.x);
MultiplyAdd(cvec.y , avec, bval.y);
MultiplyAdd(cvec.z , avec, bval.z);
MultiplyAdd(cvec.w , avec, bval.w);
#elif VWN == 8
MultiplyAdd(cvec.s0, avec, bval.s0);
MultiplyAdd(cvec.s1, avec, bval.s1);
MultiplyAdd(cvec.s2, avec, bval.s2);
MultiplyAdd(cvec.s3, avec, bval.s3);
MultiplyAdd(cvec.s4, avec, bval.s4);
MultiplyAdd(cvec.s5, avec, bval.s5);
MultiplyAdd(cvec.s6, avec, bval.s6);
MultiplyAdd(cvec.s7, avec, bval.s7);
#elif VWN == 16
MultiplyAdd(cvec.s0, avec, bval.s0);
MultiplyAdd(cvec.s1, avec, bval.s1);
MultiplyAdd(cvec.s2, avec, bval.s2);
MultiplyAdd(cvec.s3, avec, bval.s3);
MultiplyAdd(cvec.s4, avec, bval.s4);
MultiplyAdd(cvec.s5, avec, bval.s5);
MultiplyAdd(cvec.s6, avec, bval.s6);
MultiplyAdd(cvec.s7, avec, bval.s7);
MultiplyAdd(cvec.s8, avec, bval.s8);
MultiplyAdd(cvec.s9, avec, bval.s9);
MultiplyAdd(cvec.sA, avec, bval.sA);
MultiplyAdd(cvec.sB, avec, bval.sB);
MultiplyAdd(cvec.sC, avec, bval.sC);
MultiplyAdd(cvec.sD, avec, bval.sD);
MultiplyAdd(cvec.sE, avec, bval.sE);
MultiplyAdd(cvec.sF, avec, bval.sF);
#endif
#endif
return cvec;
}
// =================================================================================================
// Merges the results in Cpm with the global array in Cgm. This also performs the multiplication
// with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm
typedef struct {
int index[2];
} INT2;
INLINE_FUNC INT2 StoreIndexM() {
INT2 res;
#if STRM == 0
int mg = get_local_id(0)*(MWI/VWM);
#elif STRM == 1
int mg = get_local_id(0);
#endif
#if STRN == 0
int ng = get_local_id(1)*NWI;
#elif STRN == 1
int ng = get_local_id(1)*VWN;
#endif
int idm = mg + GetGroupID0() * (MWG/VWM);
int idn = ng + GetGroupID1() * NWG;
res.index[0] = idm;
res.index[1] = idn;
return res;
}
// layout : [N, M]
INLINE_FUNC void StoreResultsM(__global realM* cgm, COMPUTE_FLOATM c_value, const INT2 baseOffset, const int _mi, const int _ni,
const int cstride/*kSizeM*/,
const PRECISION_COMPUTE alpha, const PRECISION_COMPUTE beta) {
#if STRM == 0
int idm = _mi + baseOffset.index[0];
#elif STRM == 1
int idm = baseOffset.index[0] + _mi*MDIMC;
#endif
#if STRN == 0
int idn = _ni + baseOffset.index[1];
#elif STRN == 1
int idn = _ni%VWN + baseOffset.index[1] + (_ni/VWN)*VWN*NDIMC;
#endif
int index = idn*(cstride/VWM) + idm;
COMPUTE_FLOATM result = c_value;
// The final multiplication with alpha (in case beta == 0)
#ifdef ONLY_HAVE_ALPHA
COMPUTE_FLOATM xval = c_value;
#if VWM == 1
Multiply(result, alpha, xval);
#elif VWM == 2
Multiply(result.x, alpha, xval.x);
Multiply(result.y, alpha, xval.y);
#elif VWM == 4
Multiply(result.x, alpha, xval.x);
Multiply(result.y, alpha, xval.y);
Multiply(result.z, alpha, xval.z);
Multiply(result.w, alpha, xval.w);
#elif VWM == 8
Multiply(result.s0, alpha, xval.s0);
Multiply(result.s1, alpha, xval.s1);
Multiply(result.s2, alpha, xval.s2);
Multiply(result.s3, alpha, xval.s3);
Multiply(result.s4, alpha, xval.s4);
Multiply(result.s5, alpha, xval.s5);
Multiply(result.s6, alpha, xval.s6);
Multiply(result.s7, alpha, xval.s7);
#elif VWM == 16
Multiply(result.s0, alpha, xval.s0);
Multiply(result.s1, alpha, xval.s1);
Multiply(result.s2, alpha, xval.s2);
Multiply(result.s3, alpha, xval.s3);
Multiply(result.s4, alpha, xval.s4);
Multiply(result.s5, alpha, xval.s5);
Multiply(result.s6, alpha, xval.s6);
Multiply(result.s7, alpha, xval.s7);
Multiply(result.s8, alpha, xval.s8);
Multiply(result.s9, alpha, xval.s9);
Multiply(result.sA, alpha, xval.sA);
Multiply(result.sB, alpha, xval.sB);
Multiply(result.sC, alpha, xval.sC);
Multiply(result.sD, alpha, xval.sD);
Multiply(result.sE, alpha, xval.sE);
Multiply(result.sF, alpha, xval.sF);
#endif
#endif
// The final multiplication with alpha and the addition with beta*C
#ifdef HAVE_ALPHA_BETA
COMPUTE_FLOATM xval = c_value;
COMPUTE_FLOATM yval = CONVERT_COMPUTE_FLOATM(cgm[index]);
#if VWM == 1
AXPBY(result, alpha, xval, beta, yval);
#elif VWM == 2
AXPBY(result.x, alpha, xval.x, beta, yval.x);
AXPBY(result.y, alpha, xval.y, beta, yval.y);
#elif VWM == 4
AXPBY(result.x, alpha, xval.x, beta, yval.x);
AXPBY(result.y, alpha, xval.y, beta, yval.y);
AXPBY(result.z, alpha, xval.z, beta, yval.z);
AXPBY(result.w, alpha, xval.w, beta, yval.w);
#elif VWM == 8
AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
#elif VWM == 16
AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
#endif
#endif
cgm[index] = CONVERT_FLOATM(result);
}
INLINE_FUNC INT2 StoreIndexN() {
INT2 res;
#if STRM == 0
int mg = get_local_id(0)*MWI;
#elif STRM == 1
int mg = get_local_id(0)*VWM;
#endif
#if STRN == 0
int ng = get_local_id(1)*(NWI/VWN);
#elif STRN == 1
int ng = get_local_id(1);
#endif
int idm = mg + GetGroupID0() * MWG;
int idn = ng + GetGroupID1() * (NWG/VWN);
res.index[0] = idm;
res.index[1] = idn;
return res;
}
// layout : [M, N]
INLINE_FUNC void StoreResultsN(__global realN* cgn, COMPUTE_FLOATN c_value,
const INT2 baseOffset,
#if BIAS_TYPE > 0
#if BIAS_TYPE > 1
__global realN* egm,
#else
realN* epm,
#endif
#endif
const int _mi, const int _ni,
const int cstride/*kSizeN*/, const int dstride/*kSizeN*/, const PRECISION_COMPUTE alpha, const PRECISION_COMPUTE beta) {
#if STRM == 0
int idm = _mi + baseOffset.index[0];
#elif STRM == 1
int idm = _mi%VWM + baseOffset.index[0] + (_mi/VWM)*VWM*MDIMC;
#endif
#if STRN == 0
int idn = _ni + baseOffset.index[1];
#elif STRN == 1
int idn = baseOffset.index[1] + _ni*NDIMC;
#endif
int index = idm * (cstride/VWN) + idn;
COMPUTE_FLOATN result = c_value;
// The final multiplication with alpha (in case beta == 0)
#ifdef ONLY_HAVE_ALPHA
COMPUTE_FLOATN xval = c_value;
#if VWN == 1
Multiply(result, alpha, xval);
#elif VWN == 2
Multiply(result.x, alpha, xval.x);
Multiply(result.y, alpha, xval.y);
#elif VWN == 4
Multiply(result.x, alpha, xval.x);
Multiply(result.y, alpha, xval.y);
Multiply(result.z, alpha, xval.z);
Multiply(result.w, alpha, xval.w);
#elif VWN == 8
Multiply(result.s0, alpha, xval.s0);
Multiply(result.s1, alpha, xval.s1);
Multiply(result.s2, alpha, xval.s2);
Multiply(result.s3, alpha, xval.s3);
Multiply(result.s4, alpha, xval.s4);
Multiply(result.s5, alpha, xval.s5);
Multiply(result.s6, alpha, xval.s6);
Multiply(result.s7, alpha, xval.s7);
#elif VWN == 16
Multiply(result.s0, alpha, xval.s0);
Multiply(result.s1, alpha, xval.s1);
Multiply(result.s2, alpha, xval.s2);
Multiply(result.s3, alpha, xval.s3);
Multiply(result.s4, alpha, xval.s4);
Multiply(result.s5, alpha, xval.s5);
Multiply(result.s6, alpha, xval.s6);
Multiply(result.s7, alpha, xval.s7);
Multiply(result.s8, alpha, xval.s8);
Multiply(result.s9, alpha, xval.s9);
Multiply(result.sA, alpha, xval.sA);
Multiply(result.sB, alpha, xval.sB);
Multiply(result.sC, alpha, xval.sC);
Multiply(result.sD, alpha, xval.sD);
Multiply(result.sE, alpha, xval.sE);
Multiply(result.sF, alpha, xval.sF);
#endif
#endif
// The final multiplication with alpha and the addition with beta*C
#ifdef HAVE_ALPHA_BETA
COMPUTE_FLOATN xval = c_value;
COMPUTE_FLOATN yval = CONVERT_COMPUTE_FLOATN(cgn[index]);
#if VWN == 1
AXPBY(result, alpha, xval, beta, yval);
#elif VWN == 2
AXPBY(result.x, alpha, xval.x, beta, yval.x);
AXPBY(result.y, alpha, xval.y, beta, yval.y);
#elif VWN == 4
AXPBY(result.x, alpha, xval.x, beta, yval.x);
AXPBY(result.y, alpha, xval.y, beta, yval.y);
AXPBY(result.z, alpha, xval.z, beta, yval.z);
AXPBY(result.w, alpha, xval.w, beta, yval.w);
#elif VWN == 8
AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
#elif VWN == 16
AXPBY(result.s0, alpha, xval.s0, beta, yval.s0);
AXPBY(result.s1, alpha, xval.s1, beta, yval.s1);
AXPBY(result.s2, alpha, xval.s2, beta, yval.s2);
AXPBY(result.s3, alpha, xval.s3, beta, yval.s3);
AXPBY(result.s4, alpha, xval.s4, beta, yval.s4);
AXPBY(result.s5, alpha, xval.s5, beta, yval.s5);
AXPBY(result.s6, alpha, xval.s6, beta, yval.s6);
AXPBY(result.s7, alpha, xval.s7, beta, yval.s7);
AXPBY(result.s8, alpha, xval.s8, beta, yval.s8);
AXPBY(result.s9, alpha, xval.s9, beta, yval.s9);
AXPBY(result.sA, alpha, xval.sA, beta, yval.sA);
AXPBY(result.sB, alpha, xval.sB, beta, yval.sB);
AXPBY(result.sC, alpha, xval.sC, beta, yval.sC);
AXPBY(result.sD, alpha, xval.sD, beta, yval.sD);
AXPBY(result.sE, alpha, xval.sE, beta, yval.sE);
AXPBY(result.sF, alpha, xval.sF, beta, yval.sF);
#endif
#endif
#if BIAS_TYPE > 0
#if BIAS_TYPE == 1
COMPUTE_FLOATN eval = CONVERT_COMPUTE_FLOATN(epm[_ni]);
#elif BIAS_TYPE == 5
int index_bias = idm * (dstride/VWN) + idn;
intN eval = ((__global intN*)egm)[index_bias];
#else
int index_bias = idm * (dstride/VWN) + idn;
COMPUTE_FLOATN eval = CONVERT_COMPUTE_FLOATN(egm[index_bias]);
#endif
#if VWN == 1
DEAL_BIAS(result, eval);
#ifdef RELU
result = fmax(result, (COMPUTE_FLOATN)0);
#endif
#ifdef RELU6
result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6);
#endif
#elif VWN == 2
DEAL_BIAS(result.x, eval.x);
DEAL_BIAS(result.y, eval.y);
#ifdef RELU
result = fmax(result, (COMPUTE_FLOATN)0);
#endif
#ifdef RELU6
result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6);
#endif
#elif VWN == 4
DEAL_BIAS(result.x, eval.x);
DEAL_BIAS(result.y, eval.y);
DEAL_BIAS(result.z, eval.z);
DEAL_BIAS(result.w, eval.w);
#ifdef RELU
result = fmax(result, (COMPUTE_FLOATN)0);
#endif
#ifdef RELU6
result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6);
#endif
#elif VWN == 8
DEAL_BIAS(result.s0, eval.s0);
DEAL_BIAS(result.s1, eval.s1);
DEAL_BIAS(result.s2, eval.s2);
DEAL_BIAS(result.s3, eval.s3);
DEAL_BIAS(result.s4, eval.s4);
DEAL_BIAS(result.s5, eval.s5);
DEAL_BIAS(result.s6, eval.s6);
DEAL_BIAS(result.s7, eval.s7);
#ifdef RELU
result = fmax(result, (COMPUTE_FLOATN)0);
#endif
#ifdef RELU6
result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6);
#endif
#elif VWN == 16
DEAL_BIAS(result.s0, eval.s0);
DEAL_BIAS(result.s1, eval.s1);
DEAL_BIAS(result.s2, eval.s2);
DEAL_BIAS(result.s3, eval.s3);
DEAL_BIAS(result.s4, eval.s4);
DEAL_BIAS(result.s5, eval.s5);
DEAL_BIAS(result.s6, eval.s6);
DEAL_BIAS(result.s7, eval.s7);
DEAL_BIAS(result.s8, eval.s8);
DEAL_BIAS(result.s9, eval.s9);
DEAL_BIAS(result.sA, eval.sA);
DEAL_BIAS(result.sB, eval.sB);
DEAL_BIAS(result.sC, eval.sC);
DEAL_BIAS(result.sD, eval.sD);
DEAL_BIAS(result.sE, eval.sE);
DEAL_BIAS(result.sF, eval.sF);
#ifdef RELU
result = fmax(result, (COMPUTE_FLOATN)0);
#endif
#ifdef RELU6
result = clamp(result, (COMPUTE_FLOATN)0, (COMPUTE_FLOATN)6);
#endif
#endif
#endif
cgn[index] = CONVERT_FLOATN(result);
}
// Main body of the matrix-multiplication algorithm. It calls various (inlined) functions.
INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, const int4 stride,
const __global realM* restrict agm, const __global realN* restrict bgm,
#if BIAS_TYPE > 0
__global realN* restrict egm,
#endif
__global realM* cgm, const real_arg alpha, const real_arg beta
#if SA == 1 && SB == 1
, LOCAL_PTR realM* alm, LOCAL_PTR realN* blm
#elif SA == 1
, LOCAL_PTR realM* alm
#elif SB == 1
, LOCAL_PTR realN* blm
#endif
) {
#ifdef OUTPUTMN
#pragma promote_to_registers
COMPUTE_FLOATN cpn[MWI*(NWI/VWN)]; // MWI * NWI
#else
#pragma promote_to_registers
COMPUTE_FLOATM cpm[NWI*(MWI/VWM)]; // NWI * MWI
#endif
// Combined thread identifier (volatile to disable caching)
#if SA == 1 || SB == 1
volatile int tid = get_local_id(0) + MDIMC*get_local_id(1);
#endif
// Initializes the accumulation registers
#ifdef OUTPUTMN
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI; _mi += 1) {
cpn[_mi * (NWI/VWN) + _ni] = InitAccRegistersN();
}
}
#else
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
cpm[_ni * (MWI/VWM) + _mi] = InitAccRegisters();
}
}
#endif
// Loops over all workgroup tiles
#if SA == 1 || SB == 1
// Allocates workitem-private memory (registers)
#pragma promote_to_registers
COMPUTE_FLOATM apm[MWI/VWM]; // MWI * 1
#pragma promote_to_registers
COMPUTE_FLOATN bpm[NWI/VWN]; // 1 * NWI
for (int kwg = 0; kwg < kSizeK; kwg += KWG) {
// Loads data: off-chip --> local (matrix A)
#if SA == 1
GlobalToLocalA(agm, alm, kSizeM, tid, kwg);
#endif
// Loads data: off-chip --> local (matrix B)
#if SB == 1
GlobalToLocalB(bgm, blm, kSizeN, tid, kwg);
#endif
barrier(CLK_LOCAL_MEM_FENCE);
// Loops over all workitem tiles, unrolled by a factor KWI
for (int pwi = 0; pwi < KWG; pwi += KWI) {
#pragma unroll
for (int _pit = 0; _pit < KWI; _pit += 1) {
#if SA == 0 || SB == 0
int idk = kwg + pwi + _pit;
#endif
int kg = pwi + _pit;
// Loads matrix A (kernel 0) or matrix B (kernel 1)
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
// Loads data: local --> private (matrix A)
#if SA == 1
apm[_mi] = CONVERT_COMPUTE_FLOATM(LocalToPrivateA(alm, _mi, kg));
// Loads data: off-chip --> private (matrix A)
#elif SA == 0
apm[_mi] = CONVERT_COMPUTE_FLOATM(GlobalToPrivateA(agm, _mi, kSizeM, idk));
#endif
}
// Loads matrix B (kernel 0) or matrix A (kernel 1)
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
// Loads data: local --> private (matrix B)
#if SB == 1
bpm[_ni] = CONVERT_COMPUTE_FLOATN(LocalToPrivateB(blm, _ni, kg));
// Loads data: off-chip --> private (matrix B)
#else
bpm[_ni] = CONVERT_COMPUTE_FLOATN(GlobalToPrivateB(bgm, _ni, kSizeN, idk));
#endif
}
// Performs the accumulation (Cpm += Apm * Bpm)
#ifdef OUTPUTMN
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
const COMPUTE_FLOATM aval = apm[_mi];
#if VWM == 1
// [MWI/VWM, VWM, NWI/VWN, VWN]
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval, bpm[_ni]);
#elif VWM == 2
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]);
#elif VWM == 4
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]);
cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.z, bpm[_ni]);
cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.w, bpm[_ni]);
#elif VWM == 8
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.s0, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.s1, bpm[_ni]);
cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.s2, bpm[_ni]);
cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.s3, bpm[_ni]);
cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni], aval.s4, bpm[_ni]);
cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni], aval.s5, bpm[_ni]);
cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni], aval.s6, bpm[_ni]);
cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni], aval.s7, bpm[_ni]);
#elif VWM == 16
cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni], aval.s0, bpm[_ni]);
cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni], aval.s1, bpm[_ni]);
cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni], aval.s2, bpm[_ni]);
cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni], aval.s3, bpm[_ni]);
cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni], aval.s4, bpm[_ni]);
cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni], aval.s5, bpm[_ni]);
cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni], aval.s6, bpm[_ni]);
cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni], aval.s7, bpm[_ni]);
cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni], aval.s8, bpm[_ni]);
cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni], aval.s9, bpm[_ni]);
cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni], aval.sA, bpm[_ni]);
cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni], aval.sB, bpm[_ni]);
cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni], aval.sC, bpm[_ni]);
cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni], aval.sD, bpm[_ni]);
cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni], aval.sE, bpm[_ni]);
cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni], aval.sF, bpm[_ni]);
#endif
}
}
#else
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const COMPUTE_FLOATM aval = apm[_mi];
#if VWN == 1
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]);
#elif VWN == 2
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
#elif VWN == 4
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w);
#elif VWN == 8
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
#elif VWN == 16
cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0);
cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1);
cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2);
cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3);
cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4);
cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5);
cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6);
cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7);
cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8);
cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9);
cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA);
cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB);
cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC);
cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD);
cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE);
cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF);
#endif
}
}
#endif
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
#else
// Allocates workitem-private memory (registers)
int baseIndexA = GlobalIndexA();
int baseIndexB = GlobalIndexB();
#pragma unroll
for (int _kj = 0; _kj < kSizeK; _kj += 4) {
#ifdef OUTPUTMN
#pragma promote_to_registers
COMPUTE_FLOATN bpm[NWI/VWN]; // 1 * NWI
#pragma unroll
for(int _ki = 0; _ki < 4; _ki += 1) {
int idk = _kj + _ki;
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
// Loads data: off-chip --> private (matrix B)
bpm[_ni] = CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk));
}
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const COMPUTE_FLOATM aval = CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk));
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#if VWM == 1
// [MWI/VWM, VWM, NWI/VWN, VWN]
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval, bpm[_ni]);
#elif VWM == 2
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]);
#elif VWM == 4
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]);
cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.z, bpm[_ni]);
cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.w, bpm[_ni]);
#elif VWM == 8
cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.s0, bpm[_ni]);
cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.s1, bpm[_ni]);
cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.s2, bpm[_ni]);
cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.s3, bpm[_ni]);
cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni], aval.s4, bpm[_ni]);
cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni], aval.s5, bpm[_ni]);
cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni], aval.s6, bpm[_ni]);
cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni], aval.s7, bpm[_ni]);
#elif VWM == 16
cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni], aval.s0, bpm[_ni]);
cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni], aval.s1, bpm[_ni]);
cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni], aval.s2, bpm[_ni]);
cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni], aval.s3, bpm[_ni]);
cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni], aval.s4, bpm[_ni]);
cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni], aval.s5, bpm[_ni]);
cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni], aval.s6, bpm[_ni]);
cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni], aval.s7, bpm[_ni]);
cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni], aval.s8, bpm[_ni]);
cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni], aval.s9, bpm[_ni]);
cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni], aval.sA, bpm[_ni]);
cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni], aval.sB, bpm[_ni]);
cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni], aval.sC, bpm[_ni]);
cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni], aval.sD, bpm[_ni]);
cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni], aval.sE, bpm[_ni]);
cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni], aval.sF, bpm[_ni]);
#endif
}
}
}
#else
#pragma promote_to_registers
COMPUTE_FLOATM apm[MWI/VWM]; // MWI * 1
#pragma unroll
for(int _ki = 0; _ki < 4; _ki += 1) {
int idk = _kj + _ki;
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
// Loads data: off-chip --> private (matrix B)
apm[_mi] = CONVERT_COMPUTE_FLOATM(GlobalToPrivateOptA(agm, baseIndexA, _mi, stride.s0/*kSizeM*/, idk));
}
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
const COMPUTE_FLOATN bval = CONVERT_COMPUTE_FLOATN(GlobalToPrivateOptB(bgm, baseIndexB, _ni, stride.s1/*kSizeN*/, idk));
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
const COMPUTE_FLOATM aval = apm[_mi];
#if VWN == 1
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval);
#elif VWN == 2
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.y);
#elif VWN == 4
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.x);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.y);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bval.z);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bval.w);
#elif VWN == 8
cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.s0);
cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.s1);
cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bval.s2);
cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bval.s3);
cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bval.s4);
cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bval.s5);
cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bval.s6);
cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bval.s7);
#elif VWN == 16
cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bval.s0);
cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bval.s1);
cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bval.s2);
cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bval.s3);
cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bval.s4);
cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bval.s5);
cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bval.s6);
cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bval.s7);
cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bval.s8);
cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bval.s9);
cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bval.sA);
cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bval.sB);
cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bval.sC);
cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bval.sD);
cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bval.sE);
cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bval.sF);
#endif
}
}
}
#endif
}
#endif
#if GLOBAL_MEM_FENCE == 1
barrier(CLK_GLOBAL_MEM_FENCE);
#endif
#ifdef OUTPUTMN
INT2 baseOffset = StoreIndexN();
#if BIAS_TYPE == 1
#pragma promote_to_registers
realN epm[NWI/VWN]; // MWI * 1
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
#if STRN == 0
int idn = _ni + baseOffset.index[1];
#elif STRN == 1
int idn = baseOffset.index[1] + _ni*NDIMC;
#endif
epm[_ni] = egm[idn];
}
#endif
#pragma unroll
for (int _mi = 0; _mi < MWI; _mi += 1) {
#pragma unroll
for (int _ni = 0; _ni < NWI/VWN; _ni += 1) {
StoreResultsN((__global realN* )cgm, cpn[_mi * (NWI/VWN) + _ni],
baseOffset,
#if BIAS_TYPE > 1
(__global realN*)egm,
#elif BIAS_TYPE == 1
(realN*)epm,
#endif
_mi, _ni, stride.s2, stride.s3, alpha, beta);
}
}
#else
INT2 baseOffset = StoreIndexM();
// Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta
const int cld = kSizeM;
#pragma unroll
for (int _ni = 0; _ni < NWI; _ni += 1) {
#pragma unroll
for (int _mi = 0; _mi < MWI/VWM; _mi += 1) {
StoreResultsM(cgm, cpm[_ni * (MWI/VWM) + _mi], baseOffset, _mi, _ni, stride.s2, alpha, beta);
}
}
#endif
}
// Main entry point of the kernel. This is the regular full version.
#if RELAX_WORKGROUP_SIZE == 1
__kernel
#else
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
#endif
void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK,
const real_arg arg_alpha,
const real_arg arg_beta,
const __global realM* restrict agm, // [K, M]
const __global realN* restrict bgm, // [K, N]
#if BIAS_TYPE > 0
__global realN* restrict egm, // [N]
#endif
__global realM* cgm,
__private const int4 offset,
__private const int4 stride
) {
// Adds the offsets (in case of use of a single temporary buffer for A, B, and C)
agm = (const __global realM*)((const __global real*)agm + offset.s0);
bgm = (const __global realN*)((const __global real*)bgm + offset.s1);
cgm = (__global realM*)((__global real*)cgm + offset.s2);
#if BIAS_TYPE > 0
egm = (__global realN*)((__global real*)egm + offset.s3);
#endif
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm,
#if BIAS_TYPE > 0
egm,
#endif
cgm, arg_alpha, arg_beta, alm, blm);
#elif SA == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm,
#if BIAS_TYPE > 0
egm,
#endif
cgm, arg_alpha, arg_beta, alm);
#elif SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm,
#if BIAS_TYPE > 0
egm,
#endif
cgm, arg_alpha, arg_beta, blm);
#else
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm, bgm,
#if BIAS_TYPE > 0
egm,
#endif
cgm, arg_alpha, arg_beta);
#endif
}
#if RELAX_WORKGROUP_SIZE == 1
__kernel
#else
__kernel __attribute__((reqd_work_group_size(MDIMC, NDIMC, 1)))
#endif
void XgemmBatched(const int kSizeM,
const int kSizeN,
const int kSizeK,
const real_arg arg_alpha,
const real_arg arg_beta,
const __global realM* restrict agm,
const __global realN* restrict bgm,
#if BIAS_TYPE > 0
__global realN* restrict egm,
#endif
__global realM* cgm,
const int4 batch_offset, // [batch_offset_a, batch_offset_b, batch_offset_c, batch_offset_e]
const int4 base_ptr_offset, // [base_ptr_offset_a, base_ptr_offset_b, base_ptr_offset_c, base_ptr_offset_e]
const int4 stride, // [stride_a, stride_b, stride_c, stride_e]
/*
total_batch -> [loop_y, loop_x]
with group batch -> [loop_y, loop_x/group_num]
group_size == loop_x/group_num
*/
const int4 group // [group_num_a, group_num_b, group_num_e, loop_x]
) {
const int batch = get_group_id(2);
// Sets the offsets
const int a_offset = base_ptr_offset.x + ((batch / group.w) * group.x + (batch % group.w) / group.x) * batch_offset.x;
const int b_offset = base_ptr_offset.y + ((batch / group.w) * group.y + (batch % group.w) / group.y) * batch_offset.y;
const int c_offset = base_ptr_offset.z + batch * batch_offset.z;
const __global realM* restrict agm_ = &agm[a_offset / VWM];
const __global realN* restrict bgm_ = &bgm[b_offset / VWN];
__global realM* restrict cgm_ = &cgm[c_offset / VWM];
#if BIAS_TYPE > 0
const int e_offset = base_ptr_offset.w + ((batch / group.w) * group.z + (batch % group.w) / group.z) * batch_offset.w;
__global realN* restrict egm_ = &egm[e_offset / VWN];
#endif
// Allocates workgroup-private memory (local memory)
#if SA == 1
__local realM alm[KWG * MWG/VWM];
#endif
#if SB == 1
__local realN blm[KWG * NWG/VWN];
#endif
// Computes the matrix-multiplication and stores the result in global memory
#if SA == 1 && SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_,
#if BIAS_TYPE > 0
egm_,
#endif
cgm_, arg_alpha, arg_beta, alm, blm);
#elif SA == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_,
#if BIAS_TYPE > 0
egm_,
#endif
cgm_, arg_alpha, arg_beta, alm);
#elif SB == 1
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_,
#if BIAS_TYPE > 0
egm_,
#endif
cgm_, arg_alpha, arg_beta, blm);
#else
XgemmBody(kSizeM, kSizeN, kSizeK, stride, agm_, bgm_,
#if BIAS_TYPE > 0
egm_,
#endif
cgm_, arg_alpha, arg_beta);
#endif
}