source/backend/opengl/glsl/gemm16x16.glsl (48 lines of code) (raw):

layout(std430) buffer; layout(binding=0, FORMAT) writeonly mediump uniform image2D uOutput; layout(binding=1, FORMAT) readonly mediump uniform image2D uInput; layout(binding=2, FORMAT) readonly mediump uniform image2D uKernel; layout(location=3) uniform ivec2 outputSize; layout(location=4) uniform int ic_4; layout (local_size_x = XLOCAL, local_size_y = YLOCAL, local_size_z = ZLOCAL) in; //index : 1, oc/4, (ob*oh*ow)/4 //outputsize : oc/4, (ob*oh*ow)/4 //multiLength : ci/4 //kernel image : oc/4, ic/4 * ic4 * oc4 //input : temp image : (ib*oh*ow)/ 4, ic/4*(ib*oh*ow)%4*ic4 //output : temp image : oc/4 * (ob*oh*ow)%4, (ob*oh*ow)/4 * oc4 void main() { ivec3 pos = ivec3(gl_GlobalInvocationID); // 1, oc/4, (ob*oh*ow)/4 int oc_4 = pos.y; int obxohxow_4 = pos.x; if (obxohxow_4 < outputSize.x && oc_4 < outputSize.y) { vec4 o0 = vec4(0); vec4 o1 = vec4(0); vec4 o2 = vec4(0); vec4 o3 = vec4(0); for (int k=0; k<ic_4; ++k) { int k4 = k << 2; vec4 k0 = imageLoad(uKernel, ivec2(k4, oc_4)); vec4 s0 = imageLoad(uInput, ivec2(k4++, obxohxow_4)); vec4 k1 = imageLoad(uKernel, ivec2(k4, oc_4)); vec4 s1 = imageLoad(uInput, ivec2(k4++, obxohxow_4)); vec4 k2 = imageLoad(uKernel, ivec2(k4, oc_4)); vec4 s2 = imageLoad(uInput, ivec2(k4++, obxohxow_4)); vec4 k3 = imageLoad(uKernel, ivec2(k4, oc_4)); vec4 s3 = imageLoad(uInput, ivec2(k4, obxohxow_4)); mat4 kernel_mat = mat4(k0, k1, k2, k3); o0 += kernel_mat * s0; o1 += kernel_mat * s1; o2 += kernel_mat * s2; o3 += kernel_mat * s3; } int oc_44 = oc_4 << 2; imageStore(uOutput, ivec2(obxohxow_4, oc_44++), o0); imageStore(uOutput, ivec2(obxohxow_4, oc_44++), o1); imageStore(uOutput, ivec2(obxohxow_4, oc_44++), o2); imageStore(uOutput, ivec2(obxohxow_4, oc_44++), o3); } }