source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp (1,357 lines of code) (raw):

#include "opencl_source_map.hpp" namespace MNN { const char* conv_2d = "#ifdef MNN_SUPPORT_FP16\n" "#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n" "#endif\n" "#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=RI_F(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n" "#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i); \n" "#define CALCULATE_OUTPUT_WEIGHTS4(i, j) "" out##i = mad(in##j.x, weights4, out##i); "" out##i = mad(in##j.y, weights5, out##i); "" out##i = mad(in##j.z, weights6, out##i); "" out##i=mad(in##j.w,weights7,out##i);\n" "#define CALCULATE_OUTPUT_OPT(i) "" out##i = mad(in_sm##i[local_idx].x, weights0, out##i); "" out##i = mad(in_sm##i[local_idx].y, weights1, out##i); "" out##i = mad(in_sm##i[local_idx].z, weights2, out##i); "" out##i=mad(in_sm##i[local_idx].w,weights3,out##i); \n" "#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n" "__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" "#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n" "#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n" "#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n" "#define UNIT 4\n" "#define MOD_NUM 15\n" "#ifdef INPUT_CHANNEL_LEAVE\n" " #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n" "#else\n" " #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n" "#endif\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,__read_only image2d_t input,\n" " #ifdef BUFFER_INP_FP32\n" " __global const float *kernel_ptr,\n" " __global const float *bias_ptr,\n" " #else\n" " __global const FLOAT *kernel_ptr,\n" " __global const FLOAT *bias_ptr,\n" " #endif\n" " __write_only image2d_t output,\n" " __private const int in_c_block,__private const int out_h,\n" " __private const int out_w) {\n" " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" " const int out_c_idx=out_c_w_idx/out_w_blocks;\n" " const int out_w_idx=out_c_w_idx % out_w_blocks;\n" " const int out_w4_idx=mul24(out_w_idx,4);\n" " #ifdef BUFFER_INP_FP32\n" " FLOAT4 out0=CONVERT_FLOAT4(vload4(out_c_idx,(__global float *)bias_ptr));\n" " #else\n" " FLOAT4 out0=vload4(out_c_idx,(__global FLOAT *)bias_ptr);\n" " #endif\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" " FLOAT4 weights0;\n" " FLOAT4 weights1;\n" " FLOAT4 weights2;\n" " FLOAT4 weights3;\n" " FLOAT4 in0; \n" " FLOAT4 in1; \n" " FLOAT4 in2;\n" " FLOAT4 in3; \n" " FLOAT16 weight16;\n" " const int intput_width_idx0=out_w4_idx;\n" " const int intput_width_idx1=out_w4_idx+1;\n" " const int intput_width_idx2=out_w4_idx+2;\n" " const int intput_width_idx3=out_w4_idx+3;\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n" " int input_width_base=mul24(in_channel_block_idx,out_w);\n" " int offset=mad24(out_c_idx,in_c_block,in_channel_block_idx)*4;\n" " in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,out_b_h_idx));\n" " in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,out_b_h_idx));\n" " in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,out_b_h_idx));\n" " in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,out_b_h_idx));\n" " #ifdef BUFFER_INP_FP32\n" " weights0=CONVERT_FLOAT4(vload4(offset,(__global float *)kernel_ptr));\n" " weights1=CONVERT_FLOAT4(vload4(offset+1,(__global float *)kernel_ptr));\n" " weights2=CONVERT_FLOAT4(vload4(offset+2,(__global float *)kernel_ptr));\n" " weights3=CONVERT_FLOAT4(vload4(offset+3,(__global float *)kernel_ptr));\n" " #else\n" " weights0=vload4(offset,(__global FLOAT *)kernel_ptr);\n" " weights1=vload4(offset+1,(__global FLOAT *)kernel_ptr);\n" " weights2=vload4(offset+2,(__global FLOAT *)kernel_ptr);\n" " weights3=vload4(offset+3,(__global FLOAT *)kernel_ptr);\n" " #endif\n" " \n" " out0.x += dot(weights0,in0);\n" " out0.y += dot(weights1,in0);\n" " out0.z += dot(weights2,in0);\n" " out0.w += dot(weights3,in0);\n" " out1.x += dot(weights0,in1);\n" " out1.y += dot(weights1,in1);\n" " out1.z += dot(weights2,in1);\n" " out1.w += dot(weights3,in1);\n" " out2.x += dot(weights0,in2);\n" " out2.y += dot(weights1,in2);\n" " out2.z += dot(weights2,in2);\n" " out2.w += dot(weights3,in2);\n" " out3.x += dot(weights0,in3);\n" " out3.y += dot(weights1,in3);\n" " out3.z += dot(weights2,in3);\n" " out3.w += dot(weights3,in3);\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=out_c_idx*out_w;\n" " const int remain=out_w-out_w4_idx;\n" " int output_idx=out_x_base+out_w4_idx;\n" " \n" " if (remain >= 4) {\n" " WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n" " WI_F(output,(int2)(output_idx+3,out_b_h_idx),out3);\n" " } else if (remain == 3) {\n" " WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n" " } else if (remain == 2) {\n" " WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n" " } else if (remain == 1) {\n" " WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n" " }\n" "}\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_BUFFER)\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" "#endif\n" " __read_only image2d_t bias,\n" " __write_only image2d_t output,\n" " __private const int2 input_shape,\n" " __private const int in_channel_block,__private const int2 output_shape,\n" " __private const int2 stride_shape,\n" " __private const int output_width_4,\n" " __private const int out_channel_blocks\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " ,__private const int blockDim\n" " ,__private const int inChannel\n" "#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" " DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" " const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n" " const int output_width_block_idx=output_channel_width_idx % output_width_4;\n" "#if (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int weight_ic_offset=output_channel_block_idx*8;\n" " int weight_oc_offset=out_channel_blocks*8;\n" "#else\n" " int weight_ic_offset=output_channel_block_idx*16;\n" " int weight_oc_offset=out_channel_blocks*16;\n" "#endif\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_block_idx,0));\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" "#ifdef MNN_CONV_S1D1\n" " int intput_width_idx0=output_width_block_idx << 2;\n" " int intput_width_idx1=intput_width_idx0+1;\n" " int intput_width_idx2=intput_width_idx0+2;\n" " int intput_width_idx3=intput_width_idx0+3;\n" "#else\n" " int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n" " int intput_width_idx1=intput_width_idx0+stride_shape.y;\n" " int intput_width_idx2=intput_width_idx1+stride_shape.y;\n" " int intput_width_idx3=intput_width_idx2+stride_shape.y;\n" " intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n" " intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n" " intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n" " intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n" "#endif\n" " int batch_index=output_batch_height_idx/output_shape.x;\n" " int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n" " FLOAT4 in0;\n" " FLOAT4 in1;\n" " FLOAT4 in2;\n" " FLOAT4 in3;\n" " FLOAT4 weights0;\n" " FLOAT4 weights1;\n" " FLOAT4 weights2;\n" " FLOAT4 weights3;\n" " int weight_offset=output_channel_block_idx*in_channel_block*4*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n" " COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" "#endif\n" " int input_width_base=in_channel_block_idx*input_shape.y;\n" " int weights_width_base=in_channel_block_idx << 2;\n" " \n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " FLOAT16 weights=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" " FLOAT4 weights0=CONVERT_FLOAT4(weights.s0123)*scale0+offset0;\n" " FLOAT4 weights1=CONVERT_FLOAT4(weights.s4567)*scale0+offset0;\n" " FLOAT4 weights2=CONVERT_FLOAT4(weights.s89ab)*scale0+offset0;\n" " FLOAT4 weights3=CONVERT_FLOAT4(weights.scdef)*scale0+offset0;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar8 charWeightsInt4=vload8(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n" " char4 charWeights0=(char4)(0,0,0,0);\n" " char4 charWeights1=(char4)(0,0,0,0);\n" " char4 charWeights2=(char4)(0,0,0,0);\n" " char4 charWeights3=(char4)(0,0,0,0);\n" " charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n" " charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" " charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" " charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" " charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" " charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" " charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" " charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n" " charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" " charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" " charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" " charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" " charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" " charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" " charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" " charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(weights_width_base,weights+weight_offset);\n" " weights1=vload4(weights_width_base+1,weights+weight_offset);\n" " weights2=vload4(weights_width_base+2,weights+weight_offset);\n" " weights3=vload4(weights_width_base+3,weights+weight_offset);\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_block_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_block_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_block_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_block_idx));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n" " in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n" " in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n" " in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=mul24(output_channel_block_idx,output_shape.y);\n" " int out_x_idx=output_width_block_idx << 2;\n" " const int remain=output_shape.y-out_x_idx;\n" " int output_idx=out_x_base+out_x_idx;\n" " if (remain >= 4) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" " } else if (remain == 3) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " } else if (remain == 2) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " } else if (remain == 1) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " }\n" "}\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_BUFFER)\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" "#endif\n" " __read_only image2d_t bias,\n" " __write_only image2d_t output,\n" " __private const int2 input_shape,\n" " __private const int in_channel_block,__private const int2 output_shape,\n" " __private const int2 stride_shape,\n" " __private const int output_width_4,\n" " __private const int out_channel_blocks\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " ,__private const int blockDim\n" " ,__private const int inChannel\n" "#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" " DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" " const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n" " const int output_width_block_idx=output_channel_width_idx % output_width_4;\n" " const int output_channel_idx=output_channel_block_idx << 1;\n" "#if (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int weight_ic_offset=output_channel_block_idx*16;\n" " int weight_oc_offset=out_channel_blocks*8;\n" "#else\n" " int weight_ic_offset=output_channel_block_idx*32;\n" " int weight_oc_offset=out_channel_blocks*16;\n" "#endif\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" " \n" " FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(output_channel_idx+1,0));\n" " FLOAT4 out5=out4;\n" " FLOAT4 out6=out4;\n" " FLOAT4 out7=out4;\n" "#ifdef MNN_CONV_S1D1\n" " int intput_width_idx0=output_width_block_idx << 2;\n" " int intput_width_idx1=intput_width_idx0+1;\n" " int intput_width_idx2=intput_width_idx0+2;\n" " int intput_width_idx3=intput_width_idx0+3;\n" "#else\n" " int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n" " int intput_width_idx1=intput_width_idx0+stride_shape.y;\n" " int intput_width_idx2=intput_width_idx1+stride_shape.y;\n" " int intput_width_idx3=intput_width_idx2+stride_shape.y;\n" " intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n" " intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n" " intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n" " intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n" "#endif\n" " int batch_index=output_batch_height_idx/output_shape.x;\n" " int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n" " FLOAT4 in0;\n" " FLOAT4 in1;\n" " FLOAT4 in2;\n" " FLOAT4 in3;\n" " FLOAT4 weights0;\n" " FLOAT4 weights1;\n" " FLOAT4 weights2;\n" " FLOAT4 weights3;\n" " FLOAT4 weights4;\n" " FLOAT4 weights5;\n" " FLOAT4 weights6;\n" " FLOAT4 weights7;\n" " int weight_offset=output_channel_idx*in_channel_block*4*4;\n" " int weight_offset1=weight_offset+in_channel_block*4*4;\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n" " // already pack to 16,no need boundry protect\n" " COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" " COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx+1,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n" " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" "#endif\n" " \n" " int input_width_base=in_channel_block_idx*input_shape.y;\n" " int weights_width_base=in_channel_block_idx << 2;\n" " in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n" " in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n" " in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n" " in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " FLOAT16 weightsInt80=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" " #ifdef CHANNEL_BOUNDARY_PROTECT\n" " FLOAT16 weightsInt81=output_channel_idx+1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" " #else\n" " FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n" " #endif\n" " FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n" " FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n" " FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n" " FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n" " FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n" " FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n" " FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n" " FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n" " char4 charWeights0=(char4)(0,0,0,0);\n" " char4 charWeights1=(char4)(0,0,0,0);\n" " char4 charWeights2=(char4)(0,0,0,0);\n" " char4 charWeights3=(char4)(0,0,0,0);\n" " char4 charWeights4=(char4)(0,0,0,0);\n" " char4 charWeights5=(char4)(0,0,0,0);\n" " char4 charWeights6=(char4)(0,0,0,0);\n" " char4 charWeights7=(char4)(0,0,0,0);\n" " charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n" " charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n" " charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n" " charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n" " charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n" " charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n" " charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n" " charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n" " charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n" " charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n" " charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n" " charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n" " charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n" " charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n" " charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n" " charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n" " charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n" " charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n" " charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n" " charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n" " charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n" " charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n" " charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n" " charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n" " charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n" " charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n" " charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n" " charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n" " charWeights7.x=(charWeightsInt4.se >> 4)-8;\n" " charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n" " charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n" " charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n" " weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n" " weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n" " weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n" " weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(weights_width_base,weights+weight_offset);\n" " weights1=vload4(weights_width_base+1,weights+weight_offset);\n" " weights2=vload4(weights_width_base+2,weights+weight_offset);\n" " weights3=vload4(weights_width_base+3,weights+weight_offset);\n" " #ifdef CHANNEL_BOUNDARY_PROTECT\n" " weights4=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base,weights+weight_offset1);\n" " weights5=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+1,weights+weight_offset1);\n" " weights6=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+2,weights+weight_offset1);\n" " weights7=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+3,weights+weight_offset1);\n" " #else\n" " weights4=vload4(weights_width_base,weights+weight_offset1);\n" " weights5=vload4(weights_width_base+1,weights+weight_offset1);\n" " weights6=vload4(weights_width_base+2,weights+weight_offset1);\n" " weights7=vload4(weights_width_base+3,weights+weight_offset1);\n" " #endif\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx));\n" " \n" " weights4=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx+1));\n" " weights5=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx+1));\n" " weights6=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx+1));\n" " weights7=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx+1));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " \n" " CALCULATE_OUTPUT_WEIGHTS4(4,0);\n" " CALCULATE_OUTPUT_WEIGHTS4(5,1);\n" " CALCULATE_OUTPUT_WEIGHTS4(6,2);\n" " CALCULATE_OUTPUT_WEIGHTS4(7,3);\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" " out4=fmax(out4,(FLOAT4)0);\n" " out5=fmax(out5,(FLOAT4)0);\n" " out6=fmax(out6,(FLOAT4)0);\n" " out7=fmax(out7,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" " out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n" " out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n" " out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n" " out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=mul24(output_channel_idx,output_shape.y);\n" " int out_x_idx=output_width_block_idx << 2;\n" " const int remain=output_shape.y-out_x_idx;\n" " int output_idx=out_x_base+out_x_idx;\n" " if (remain >= 4) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" " } else if (remain == 3) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " } else if (remain == 2) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " } else if (remain == 1) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " }\n" " \n" " if(output_channel_idx+1 >= out_channel_blocks)\n" " return;\n" " output_idx += output_shape.y;\n" " if (remain >= 4) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n" " WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out7);\n" " } else if (remain == 3) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n" " } else if (remain == 2) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n" " } else if (remain == 1) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n" " }\n" "}\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_BUFFER)\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" "#endif\n" "#ifdef BIAS\n" " __read_only image2d_t bias,\n" "#endif\n" " __write_only image2d_t output,\n" " __private const int2 input_shape,\n" " __private const int in_channel_block_length,\n" " __private const int2 output_shape,\n" " __private const int2 weights_shape,\n" " __private const int2 stride_shape,\n" " __private const int2 padding_shape,\n" " __private const int2 dilation_shape,\n" " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " ,__private const int blockDim\n" " ,__private const int inChannel\n" "#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" " DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" " const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n" " const int out_height_block_idx=output_channel_width_idx % out_width_blocks;\n" "#ifdef BIAS\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" "#else\n" " FLOAT4 out0=(FLOAT4)0;\n" "#endif\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" " int in_width0=mad24(out_height_block_idx,stride_shape.y<<2,-padding_shape.y);\n" " int in_width1=in_width0+stride_shape.y;\n" " int in_width2=in_width0+stride_shape.y*2;\n" " int in_width3=in_width0+stride_shape.y*3;\n" " \n" "#ifdef MNN_CONV_S1D1\n" " const int height_start=mad24((output_batch_height_idx % output_shape.x),1,-padding_shape.x);\n" " const int kh_start=select(0,(-height_start),height_start<0);\n" " int in_height_start=kh_start+height_start;\n" " int in_height_end=min(weights_shape.x+height_start,input_shape.x);\n" " const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n" " const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start),height_start<0),weights_shape.y);\n" "#else\n" " const int height_start=mad24((output_batch_height_idx % output_shape.x),stride_shape.x,-padding_shape.x);\n" " const int kh_start=select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0);\n" " int in_height_start=mad24(kh_start,dilation_shape.x,height_start);\n" " int in_height_end=min(mad24(weights_shape.x,dilation_shape.x,height_start),input_shape.x);\n" " const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n" " const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n" "#endif\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" "#endif\n" " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3;\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n" " \n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n" " COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" "#endif\n" " \n" " const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+kh_start)*weights_shape.y+0)*4;\n" "#else\n" " int weights_x_idx=in_channel_block_idx << 2;\n" " int weights_y_idx=weights_h_idx;\n" "#endif\n" " for (int iy=in_height_start; iy<in_height_end; iy += dilation_shape.x) {\n" " int in_hb_value=iy+batch_idx;\n" "#ifdef MNN_CONV_S1D1\n" " {\n" " READ_INPUT_IMAGE(0,0);\n" " READ_INPUT_IMAGE(1,0);\n" " READ_INPUT_IMAGE(2,0);\n" " READ_INPUT_IMAGE(3,0);\n" " \n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" " char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" " char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" " uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" " uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" " uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" " char4 charWeight0=(char4)(0,0,0,0);\n" " char4 charWeight1=(char4)(0,0,0,0);\n" " char4 charWeight2=(char4)(0,0,0,0);\n" " char4 charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " }\n" " for (int w=1; w<weights_shape.y; w++){\n" " in0=in1;\n" " in1=in2;\n" " in2=in3;\n" " READ_INPUT_IMAGE(3,w);\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" " char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" " char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" " uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" " uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" " uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" " char4 charWeight0=(char4)(0,0,0,0);\n" " char4 charWeight1=(char4)(0,0,0,0);\n" " char4 charWeight2=(char4)(0,0,0,0);\n" " char4 charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " }\n" "#else\n" " for (int w=0; w<weights_shape.y; w++) {\n" " int input_width_base=mul24(w,dilation_shape.y);\n" " READ_INPUT_IMAGE(0,input_width_base);\n" " READ_INPUT_IMAGE(1,input_width_base);\n" " READ_INPUT_IMAGE(2,input_width_base);\n" " READ_INPUT_IMAGE(3,input_width_base);\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" " char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" " char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" " uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" " uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" " uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" " char4 charWeight0=(char4)(0,0,0,0);\n" " char4 charWeight1=(char4)(0,0,0,0);\n" " char4 charWeight2=(char4)(0,0,0,0);\n" " char4 charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx)); \n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx)); \n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx)); \n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " }\n" "#endif\n" " }\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" " int out_x_idx=out_height_block_idx << 2;\n" " const int remain=output_shape.y-out_x_idx;\n" " int output_idx=out_x_base+out_x_idx;\n" " if (remain >= 4) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n" " } else if (remain == 3) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n" " } else if (remain == 2) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n" " } else if (remain == 1) {\n" " WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n" " }\n" "}\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_BUFFER)\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" "#endif\n" "#ifdef BIAS\n" " __read_only image2d_t bias,\n" "#endif\n" " __write_only image2d_t output,\n" " __private const int2 input_shape,\n" " __private const int in_channel_block_length,\n" " __private const int2 output_shape,\n" " __private const int2 weights_shape,\n" " __private const int2 stride_shape,\n" " __private const int2 padding_shape,\n" " __private const int2 dilation_shape,\n" " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " ,__private const int blockDim\n" " ,__private const int inChannel\n" "#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" " DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" " const int out_channel_block_idx=(output_channel_width_idx/out_width_blocks) << 1;\n" " const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n" " const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n" " const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n" "#ifdef BIAS\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" " FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx+1,0));\n" "#else\n" " FLOAT4 out0=(FLOAT4)0;\n" " FLOAT4 out4=(FLOAT4)0;\n" "#endif\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" " FLOAT4 out5=out4;\n" " FLOAT4 out6=out4;\n" " FLOAT4 out7=out4;\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n" " const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n" "#endif\n" " int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n" " int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n" " int in_height1=in_height0+stride_shape.x;\n" " int in_height2=in_height1+stride_shape.x;\n" " int in_height3=in_height2+stride_shape.x;\n" " int weight_size=mul24(weights_shape.y,weights_shape.x);\n" " \n" " const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n" " const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n" " \n" " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n" " COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" " COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx+1,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n" " COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n" " \n" "#endif\n" " const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n" "#else\n" " int weights_x_idx=in_channel_block_idx << 2;\n" " int weights_y_idx=weights_h_idx;\n" "#endif\n" " for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n" " int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n" " int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n" " int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n" " int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n" " for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n" " int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n" " \n" " in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n" " in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n" " in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n" " in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" " char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_ic_offset);\n" " char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*2);\n" " char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*3);\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " #ifdef CHANNEL_BOUNDARY_PROTECT\n" " charWeight0=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" " \n" " #else\n" " charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n" " charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n" " charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n" " #endif\n" " weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" " weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" " weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" " weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" " weight_offset += 4;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" " uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n" " uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n" " uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n" " char4 charWeight0=(char4)(0,0,0,0);\n" " char4 charWeight1=(char4)(0,0,0,0);\n" " char4 charWeight2=(char4)(0,0,0,0);\n" " char4 charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" " charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n" " charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n" " charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n" " charWeight0=(char4)(0,0,0,0);\n" " charWeight1=(char4)(0,0,0,0);\n" " charWeight2=(char4)(0,0,0,0);\n" " charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n" " weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n" " weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n" " weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n" " weight_offset += 4;\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_ic_offset*3);\n" " #ifdef CHANNEL_BOUNDARY_PROTECT\n" " weights4=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights5=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n" " weights6=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n" " weights7=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n" " #else\n" " weights4=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights5=vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n" " weights6=vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n" " weights7=vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n" " #endif\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx));\n" " weights4=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weight_size+weights_y_idx));\n" " weights5=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weight_size+weights_y_idx));\n" " weights6=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weight_size+weights_y_idx));\n" " weights7=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weight_size+weights_y_idx++));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n" " \n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " CALCULATE_OUTPUT_WEIGHTS4(4,0);\n" " CALCULATE_OUTPUT_WEIGHTS4(5,1);\n" " CALCULATE_OUTPUT_WEIGHTS4(6,2);\n" " CALCULATE_OUTPUT_WEIGHTS4(7,3);\n" " }\n" " }\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" " out4=fmax(out4,(FLOAT4)0);\n" " out5=fmax(out5,(FLOAT4)0);\n" " out6=fmax(out6,(FLOAT4)0);\n" " out7=fmax(out7,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" " out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n" " out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n" " out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n" " out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" " const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n" " int out_x_idx=out_width_block_idx;\n" " int out_y_idx=out_height_block_idx << 2;\n" " const int remain_y=output_shape.x-out_y_idx;\n" " int output_idx=out_x_base+out_x_idx;\n" " int output_idy=out_y_base+out_y_idx;\n" " \n" " if(remain_y >= 4){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" " WI_F(output,(int2)(output_idx,output_idy+3),out3);\n" " }else if(remain_y == 3){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" " }else if(remain_y == 2){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " }else if(remain_y == 1){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " }\n" " \n" " if(out_channel_block_idx+1 >= out_channel_blocks) {\n" " return;\n" " }\n" " output_idx += output_shape.y;\n" " if(remain_y >= 4){\n" " WI_F(output,(int2)(output_idx,output_idy),out4);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out6);\n" " WI_F(output,(int2)(output_idx,output_idy+3),out7);\n" " }else if(remain_y == 3){\n" " WI_F(output,(int2)(output_idx,output_idy),out4);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out6);\n" " }else if(remain_y == 2){\n" " WI_F(output,(int2)(output_idx,output_idy),out4);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out5);\n" " }else if(remain_y == 1){\n" " WI_F(output,(int2)(output_idx,output_idy),out4);\n" " }\n" "}\n" "__kernel\n" "#if SET_ATTRIBUTE\n" "__attribute__((work_group_size_hint(16,16,1)))\n" "#endif\n" "void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " __global const char *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " __global const uchar *kernel_ptr,\n" " __global const float *dequantScaleOffset,\n" "#elif (defined USE_BUFFER)\n" " __global const FLOAT *weights,\n" "#else\n" " __read_only image2d_t weights,\n" "#endif\n" "#ifdef BIAS\n" " __read_only image2d_t bias,\n" "#endif\n" " __write_only image2d_t output,\n" " __private const int2 input_shape,\n" " __private const int in_channel_block_length,\n" " __private const int2 output_shape,\n" " __private const int2 weights_shape,\n" " __private const int2 stride_shape,\n" " __private const int2 padding_shape,\n" " __private const int2 dilation_shape,\n" " __private const int out_width_blocks,\n" " __private const int out_channel_blocks,\n" " __private const int out_height_blocks\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " ,__private const int blockDim\n" " ,__private const int inChannel\n" "#endif\n" ") {\n" " const int output_channel_width_idx=get_global_id(0);\n" " const int output_batch_height_idx=get_global_id(1);\n" " DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n" " const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n" " const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n" " const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n" " const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n" "#ifdef BIAS\n" " FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n" "#else\n" " FLOAT4 out0=(FLOAT4)0;\n" "#endif\n" " FLOAT4 out1=out0;\n" " FLOAT4 out2=out0;\n" " FLOAT4 out3=out0;\n" " int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n" " int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n" " int in_height1=in_height0+stride_shape.x;\n" " int in_height2=in_height1+stride_shape.x;\n" " int in_height3=in_height2+stride_shape.x;\n" " int weight_size=mul24(weights_shape.y,weights_shape.x);\n" " \n" " const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n" " const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n" " \n" " FLOAT4 in0,in1,in2,in3;\n" " FLOAT4 weights0,weights1,weights2,weights3;\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n" "#endif\n" " for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n" " int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n" " COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n" " COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n" " COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n" "#endif\n" " const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n" "#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n" " int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n" "#else\n" " int weights_x_idx=in_channel_block_idx << 2;\n" " int weights_y_idx=weights_h_idx;\n" "#endif\n" " for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n" " int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n" " int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n" " int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n" " int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n" " for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n" " int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n" " \n" " in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n" " in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n" " in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n" " in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n" " \n" "#if (defined USE_LOW_BIT_WEIGHT_INT8)\n" " char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n" " char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n" " char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n" " char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n" " uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n" " uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n" " uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n" " uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n" " char4 charWeight0=(char4)(0,0,0,0);\n" " char4 charWeight1=(char4)(0,0,0,0);\n" " char4 charWeight2=(char4)(0,0,0,0);\n" " char4 charWeight3=(char4)(0,0,0,0);\n" " charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n" " charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n" " charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n" " charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n" " charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n" " charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n" " charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n" " charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n" " charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n" " charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n" " charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n" " charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n" " charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n" " charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n" " charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n" " charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n" " weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n" " weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n" " weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n" " weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n" " weight_offset += 4;\n" "#elif (defined USE_BUFFER)\n" " weights0=vload4(0,weights+weight_offset);\n" " weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n" " weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n" " weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n" " weight_offset += 4;\n" "#else\n" " weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n" " weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n" " weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n" " weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n" "#endif\n" " PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n" " CALCULATE_OUTPUT(0);\n" " CALCULATE_OUTPUT(1);\n" " CALCULATE_OUTPUT(2);\n" " CALCULATE_OUTPUT(3);\n" " }\n" " }\n" " }\n" "#ifdef RELU\n" " out0=fmax(out0,(FLOAT4)0);\n" " out1=fmax(out1,(FLOAT4)0);\n" " out2=fmax(out2,(FLOAT4)0);\n" " out3=fmax(out3,(FLOAT4)0);\n" "#endif\n" "#ifdef RELU6\n" " out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n" " out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n" " out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n" " out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n" "#endif\n" " const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n" " const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n" " int out_x_idx=out_width_block_idx;\n" " int out_y_idx=out_height_block_idx << 2;\n" " const int remain_y=output_shape.x-out_y_idx;\n" " int output_idx=out_x_base+out_x_idx;\n" " int output_idy=out_y_base+out_y_idx;\n" " if(remain_y >= 4){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" " WI_F(output,(int2)(output_idx,output_idy+3),out3);\n" " }else if(remain_y == 3){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " WI_F(output,(int2)(output_idx,output_idy+2),out2);\n" " }else if(remain_y == 2){\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " WI_F(output,(int2)(output_idx,output_idy+1),out1);\n" " }else{\n" " WI_F(output,(int2)(output_idx,output_idy),out0);\n" " }\n" "}\n" ; }