source/backend/opencl/execution/cl/conv_2d_mnn_cl.cpp (1,357 lines of code) (raw):
#include "opencl_source_map.hpp"
namespace MNN {
const char* conv_2d =
"#ifdef MNN_SUPPORT_FP16\n"
"#pragma OPENCL EXTENSION cl_khr_fp16 : enable\n"
"#endif\n"
"#define READ_INPUT_IMAGE(i, base) "" int in_width_value##i = in_width##i + base; "" in_width_value##i = "" select(in_idx + in_width_value##i, -1, (in_width_value##i < 0 || in_width_value##i >= input_shape.y)); "" in##i=RI_F(input,SAMPLER,(int2)(in_width_value##i,in_hb_value));\n"
"#define CALCULATE_OUTPUT(i) "" out##i = mad(in##i.x, weights0, out##i); "" out##i = mad(in##i.y, weights1, out##i); "" out##i = mad(in##i.z, weights2, out##i); "" out##i=mad(in##i.w,weights3,out##i); \n"
"#define CALCULATE_OUTPUT_WEIGHTS4(i, j) "" out##i = mad(in##j.x, weights4, out##i); "" out##i = mad(in##j.y, weights5, out##i); "" out##i = mad(in##j.z, weights6, out##i); "" out##i=mad(in##j.w,weights7,out##i);\n"
"#define CALCULATE_OUTPUT_OPT(i) "" out##i = mad(in_sm##i[local_idx].x, weights0, out##i); "" out##i = mad(in_sm##i[local_idx].y, weights1, out##i); "" out##i = mad(in_sm##i[local_idx].z, weights2, out##i); "" out##i=mad(in_sm##i[local_idx].w,weights3,out##i); \n"
"#define GLOBAL_SIZE_2_DIMS __private const int global_size_dim0,__private const int global_size_dim1,\n"
"__constant sampler_t SAMPLER=CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
"#define DEAL_NON_UNIFORM_DIM2(input1, input2) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1) { "" return; "" }\n"
"#define GLOBAL_SIZE_3_DIMS "" __private const int global_size_dim0,__private const int global_size_dim1,__private const int global_size_dim2,\n"
"#define DEAL_NON_UNIFORM_DIM3(input1, input2, input3) "" if (input1 >= global_size_dim0 || input2 >= global_size_dim1 || input3 >= global_size_dim2) { "" return; "" }\n"
"#define UNIT 4\n"
"#define MOD_NUM 15\n"
"#ifdef INPUT_CHANNEL_LEAVE\n"
" #define PADZEROSVEC(k, channel, data0, data1, data2, data3) "" data0 = (k << 2) < channel ? data0 : 0; "" data1 = (k << 2) + 1 < channel ? data1 : 0; "" data2 = (k << 2) + 2 < channel ? data2 : 0; "" data3=(k << 2)+3<channel ? data3 : 0;\n"
"#else\n"
" #define PADZEROSVEC(k,channel,data0,data1,data2,data3)\n"
"#endif\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_mali(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks,__read_only image2d_t input,\n"
" #ifdef BUFFER_INP_FP32\n"
" __global const float *kernel_ptr,\n"
" __global const float *bias_ptr,\n"
" #else\n"
" __global const FLOAT *kernel_ptr,\n"
" __global const FLOAT *bias_ptr,\n"
" #endif\n"
" __write_only image2d_t output,\n"
" __private const int in_c_block,__private const int out_h,\n"
" __private const int out_w) {\n"
" const int out_c_w_idx=get_global_id(0); //c/4 w\n"
" const int out_b_h_idx=get_global_id(1); //b h\n"
" DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n"
" const int out_c_idx=out_c_w_idx/out_w_blocks;\n"
" const int out_w_idx=out_c_w_idx % out_w_blocks;\n"
" const int out_w4_idx=mul24(out_w_idx,4);\n"
" #ifdef BUFFER_INP_FP32\n"
" FLOAT4 out0=CONVERT_FLOAT4(vload4(out_c_idx,(__global float *)bias_ptr));\n"
" #else\n"
" FLOAT4 out0=vload4(out_c_idx,(__global FLOAT *)bias_ptr);\n"
" #endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" FLOAT4 in0; \n"
" FLOAT4 in1; \n"
" FLOAT4 in2;\n"
" FLOAT4 in3; \n"
" FLOAT16 weight16;\n"
" const int intput_width_idx0=out_w4_idx;\n"
" const int intput_width_idx1=out_w4_idx+1;\n"
" const int intput_width_idx2=out_w4_idx+2;\n"
" const int intput_width_idx3=out_w4_idx+3;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_c_block; ++in_channel_block_idx) {\n"
" int input_width_base=mul24(in_channel_block_idx,out_w);\n"
" int offset=mad24(out_c_idx,in_c_block,in_channel_block_idx)*4;\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,out_b_h_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,out_b_h_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,out_b_h_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,out_b_h_idx));\n"
" #ifdef BUFFER_INP_FP32\n"
" weights0=CONVERT_FLOAT4(vload4(offset,(__global float *)kernel_ptr));\n"
" weights1=CONVERT_FLOAT4(vload4(offset+1,(__global float *)kernel_ptr));\n"
" weights2=CONVERT_FLOAT4(vload4(offset+2,(__global float *)kernel_ptr));\n"
" weights3=CONVERT_FLOAT4(vload4(offset+3,(__global float *)kernel_ptr));\n"
" #else\n"
" weights0=vload4(offset,(__global FLOAT *)kernel_ptr);\n"
" weights1=vload4(offset+1,(__global FLOAT *)kernel_ptr);\n"
" weights2=vload4(offset+2,(__global FLOAT *)kernel_ptr);\n"
" weights3=vload4(offset+3,(__global FLOAT *)kernel_ptr);\n"
" #endif\n"
" \n"
" out0.x += dot(weights0,in0);\n"
" out0.y += dot(weights1,in0);\n"
" out0.z += dot(weights2,in0);\n"
" out0.w += dot(weights3,in0);\n"
" out1.x += dot(weights0,in1);\n"
" out1.y += dot(weights1,in1);\n"
" out1.z += dot(weights2,in1);\n"
" out1.w += dot(weights3,in1);\n"
" out2.x += dot(weights0,in2);\n"
" out2.y += dot(weights1,in2);\n"
" out2.z += dot(weights2,in2);\n"
" out2.w += dot(weights3,in2);\n"
" out3.x += dot(weights0,in3);\n"
" out3.y += dot(weights1,in3);\n"
" out3.z += dot(weights2,in3);\n"
" out3.w += dot(weights3,in3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=out_c_idx*out_w;\n"
" const int remain=out_w-out_w4_idx;\n"
" int output_idx=out_x_base+out_w4_idx;\n"
" \n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,out_b_h_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,out_b_h_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,out_b_h_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,out_b_h_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block,__private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*8;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_block_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
"#ifdef MNN_CONV_S1D1\n"
" int intput_width_idx0=output_width_block_idx << 2;\n"
" int intput_width_idx1=intput_width_idx0+1;\n"
" int intput_width_idx2=intput_width_idx0+2;\n"
" int intput_width_idx3=intput_width_idx0+3;\n"
"#else\n"
" int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n"
" int intput_width_idx1=intput_width_idx0+stride_shape.y;\n"
" int intput_width_idx2=intput_width_idx1+stride_shape.y;\n"
" int intput_width_idx3=intput_width_idx2+stride_shape.y;\n"
" intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n"
" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n"
" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n"
" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n"
"#endif\n"
" int batch_index=output_batch_height_idx/output_shape.x;\n"
" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n"
" FLOAT4 in0;\n"
" FLOAT4 in1;\n"
" FLOAT4 in2;\n"
" FLOAT4 in3;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" int weight_offset=output_channel_block_idx*in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weights=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weights.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weights.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weights.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weights.scdef)*scale0+offset0;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar8 charWeightsInt4=vload8(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)- 8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
" weights3=vload4(weights_width_base+3,weights+weight_offset);\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_block_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_block_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_block_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_block_idx));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(output_channel_block_idx,output_shape.y);\n"
" int out_x_idx=output_width_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
" __read_only image2d_t bias,\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block,__private const int2 output_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int output_width_4,\n"
" __private const int out_channel_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int output_channel_block_idx=output_channel_width_idx/output_width_4;\n"
" const int output_width_block_idx=output_channel_width_idx % output_width_4;\n"
" const int output_channel_idx=output_channel_block_idx << 1;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int weight_ic_offset=output_channel_block_idx*16;\n"
" int weight_oc_offset=out_channel_blocks*8;\n"
"#else\n"
" int weight_ic_offset=output_channel_block_idx*32;\n"
" int weight_oc_offset=out_channel_blocks*16;\n"
"#endif\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(output_channel_idx,0));\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" \n"
" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(output_channel_idx+1,0));\n"
" FLOAT4 out5=out4;\n"
" FLOAT4 out6=out4;\n"
" FLOAT4 out7=out4;\n"
"#ifdef MNN_CONV_S1D1\n"
" int intput_width_idx0=output_width_block_idx << 2;\n"
" int intput_width_idx1=intput_width_idx0+1;\n"
" int intput_width_idx2=intput_width_idx0+2;\n"
" int intput_width_idx3=intput_width_idx0+3;\n"
"#else\n"
" int intput_width_idx0=mul24(output_width_block_idx,stride_shape.y*4);\n"
" int intput_width_idx1=intput_width_idx0+stride_shape.y;\n"
" int intput_width_idx2=intput_width_idx1+stride_shape.y;\n"
" int intput_width_idx3=intput_width_idx2+stride_shape.y;\n"
" intput_width_idx0=select(intput_width_idx0,INT_MIN,intput_width_idx0 >= input_shape.y);\n"
" intput_width_idx1=select(intput_width_idx1,INT_MIN,intput_width_idx1 >= input_shape.y);\n"
" intput_width_idx2=select(intput_width_idx2,INT_MIN,intput_width_idx2 >= input_shape.y);\n"
" intput_width_idx3=select(intput_width_idx3,INT_MIN,intput_width_idx3 >= input_shape.y);\n"
"#endif\n"
" int batch_index=output_batch_height_idx/output_shape.x;\n"
" int input_height_block_idx=mul24((output_batch_height_idx % output_shape.x),stride_shape.x)+batch_index*input_shape.x;\n"
" FLOAT4 in0;\n"
" FLOAT4 in1;\n"
" FLOAT4 in2;\n"
" FLOAT4 in3;\n"
" FLOAT4 weights0;\n"
" FLOAT4 weights1;\n"
" FLOAT4 weights2;\n"
" FLOAT4 weights3;\n"
" FLOAT4 weights4;\n"
" FLOAT4 weights5;\n"
" FLOAT4 weights6;\n"
" FLOAT4 weights7;\n"
" int weight_offset=output_channel_idx*in_channel_block*4*4;\n"
" int weight_offset1=weight_offset+in_channel_block*4*4;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" // already pack to 16,no need boundry protect\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(output_channel_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
"#endif\n"
" \n"
" int input_width_base=in_channel_block_idx*input_shape.y;\n"
" int weights_width_base=in_channel_block_idx << 2;\n"
" in0=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx0,input_height_block_idx));\n"
" in1=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx1,input_height_block_idx));\n"
" in2=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx2,input_height_block_idx));\n"
" in3=RI_F(input,SAMPLER,(int2)(input_width_base+intput_width_idx3,input_height_block_idx));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" FLOAT16 weightsInt80=CONVERT_FLOAT16(vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" FLOAT16 weightsInt81=output_channel_idx+1 >= out_channel_blocks ? (FLOAT16)0 : CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #else\n"
" FLOAT16 weightsInt81=CONVERT_FLOAT16(vload16(0,kernel_ptr+16+weight_ic_offset+in_channel_block_idx*weight_oc_offset));\n"
" #endif\n"
" FLOAT4 weights0=CONVERT_FLOAT4(weightsInt80.s0123)*scale0+offset0;\n"
" FLOAT4 weights1=CONVERT_FLOAT4(weightsInt80.s4567)*scale0+offset0;\n"
" FLOAT4 weights2=CONVERT_FLOAT4(weightsInt80.s89ab)*scale0+offset0;\n"
" FLOAT4 weights3=CONVERT_FLOAT4(weightsInt80.scdef)*scale0+offset0;\n"
" FLOAT4 weights4=CONVERT_FLOAT4(weightsInt81.s0123)*scale1+offset1;\n"
" FLOAT4 weights5=CONVERT_FLOAT4(weightsInt81.s4567)*scale1+offset1;\n"
" FLOAT4 weights6=CONVERT_FLOAT4(weightsInt81.s89ab)*scale1+offset1;\n"
" FLOAT4 weights7=CONVERT_FLOAT4(weightsInt81.scdef)*scale1+offset1;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar16 charWeightsInt4=vload16(0,kernel_ptr+weight_ic_offset+in_channel_block_idx*weight_oc_offset);\n"
" char4 charWeights0=(char4)(0,0,0,0);\n"
" char4 charWeights1=(char4)(0,0,0,0);\n"
" char4 charWeights2=(char4)(0,0,0,0);\n"
" char4 charWeights3=(char4)(0,0,0,0);\n"
" char4 charWeights4=(char4)(0,0,0,0);\n"
" char4 charWeights5=(char4)(0,0,0,0);\n"
" char4 charWeights6=(char4)(0,0,0,0);\n"
" char4 charWeights7=(char4)(0,0,0,0);\n"
" charWeights0.x=(charWeightsInt4.s0 >> 4)-8;\n"
" charWeights0.y=(charWeightsInt4.s0 & MOD_NUM)-8;\n"
" charWeights0.z=(charWeightsInt4.s1 >> 4)-8;\n"
" charWeights0.w=(charWeightsInt4.s1 & MOD_NUM)-8;\n"
" charWeights1.x=(charWeightsInt4.s2 >> 4)-8;\n"
" charWeights1.y=(charWeightsInt4.s2 & MOD_NUM)-8;\n"
" charWeights1.z=(charWeightsInt4.s3 >> 4)-8;\n"
" charWeights1.w=(charWeightsInt4.s3 & MOD_NUM)-8;\n"
" charWeights2.x=(charWeightsInt4.s4 >> 4)-8;\n"
" charWeights2.y=(charWeightsInt4.s4 & MOD_NUM)-8;\n"
" charWeights2.z=(charWeightsInt4.s5 >> 4)-8;\n"
" charWeights2.w=(charWeightsInt4.s5 & MOD_NUM)-8;\n"
" charWeights3.x=(charWeightsInt4.s6 >> 4)-8;\n"
" charWeights3.y=(charWeightsInt4.s6 & MOD_NUM)-8;\n"
" charWeights3.z=(charWeightsInt4.s7 >> 4)-8;\n"
" charWeights3.w=(charWeightsInt4.s7 & MOD_NUM)-8;\n"
" charWeights4.x=(charWeightsInt4.s8 >> 4)-8;\n"
" charWeights4.y=(charWeightsInt4.s8 & MOD_NUM)-8;\n"
" charWeights4.z=(charWeightsInt4.s9 >> 4)-8;\n"
" charWeights4.w=(charWeightsInt4.s9 & MOD_NUM)-8;\n"
" charWeights5.x=(charWeightsInt4.sa >> 4)-8;\n"
" charWeights5.y=(charWeightsInt4.sa & MOD_NUM)-8;\n"
" charWeights5.z=(charWeightsInt4.sb >> 4)-8;\n"
" charWeights5.w=(charWeightsInt4.sb & MOD_NUM)-8;\n"
" charWeights6.x=(charWeightsInt4.sc >> 4)-8;\n"
" charWeights6.y=(charWeightsInt4.sc & MOD_NUM)-8;\n"
" charWeights6.z=(charWeightsInt4.sd >> 4)-8;\n"
" charWeights6.w=(charWeightsInt4.sd & MOD_NUM)-8;\n"
" charWeights7.x=(charWeightsInt4.se >> 4)-8;\n"
" charWeights7.y=(charWeightsInt4.se & MOD_NUM)-8;\n"
" charWeights7.z=(charWeightsInt4.sf >> 4)-8;\n"
" charWeights7.w=(charWeightsInt4.sf & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeights0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeights1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeights2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeights3),scale0,offset0);\n"
" weights4=mad(CONVERT_FLOAT4(charWeights4),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeights5),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeights6),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeights7),scale1,offset1);\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(weights_width_base,weights+weight_offset);\n"
" weights1=vload4(weights_width_base+1,weights+weight_offset);\n"
" weights2=vload4(weights_width_base+2,weights+weight_offset);\n"
" weights3=vload4(weights_width_base+3,weights+weight_offset);\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" weights4=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base,weights+weight_offset1);\n"
" weights5=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+1,weights+weight_offset1);\n"
" weights6=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+2,weights+weight_offset1);\n"
" weights7=output_channel_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(weights_width_base+3,weights+weight_offset1);\n"
" #else\n"
" weights4=vload4(weights_width_base,weights+weight_offset1);\n"
" weights5=vload4(weights_width_base+1,weights+weight_offset1);\n"
" weights6=vload4(weights_width_base+2,weights+weight_offset1);\n"
" weights7=vload4(weights_width_base+3,weights+weight_offset1);\n"
" #endif\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx));\n"
" \n"
" weights4=RI_F(weights,SAMPLER,(int2)(weights_width_base+0,output_channel_idx+1));\n"
" weights5=RI_F(weights,SAMPLER,(int2)(weights_width_base+1,output_channel_idx+1));\n"
" weights6=RI_F(weights,SAMPLER,(int2)(weights_width_base+2,output_channel_idx+1));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_width_base+3,output_channel_idx+1));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" \n"
" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n"
" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n"
" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n"
" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
" out4=fmax(out4,(FLOAT4)0);\n"
" out5=fmax(out5,(FLOAT4)0);\n"
" out6=fmax(out6,(FLOAT4)0);\n"
" out7=fmax(out7,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n"
" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n"
" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n"
" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(output_channel_idx,output_shape.y);\n"
" int out_x_idx=output_width_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
" \n"
" if(output_channel_idx+1 >= out_channel_blocks)\n"
" return;\n"
" output_idx += output_shape.y;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out7);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out6);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out5);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out4);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h1w4(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n"
" const int out_height_block_idx=output_channel_width_idx % out_width_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" int in_width0=mad24(out_height_block_idx,stride_shape.y<<2,-padding_shape.y);\n"
" int in_width1=in_width0+stride_shape.y;\n"
" int in_width2=in_width0+stride_shape.y*2;\n"
" int in_width3=in_width0+stride_shape.y*3;\n"
" \n"
"#ifdef MNN_CONV_S1D1\n"
" const int height_start=mad24((output_batch_height_idx % output_shape.x),1,-padding_shape.x);\n"
" const int kh_start=select(0,(-height_start),height_start<0);\n"
" int in_height_start=kh_start+height_start;\n"
" int in_height_end=min(weights_shape.x+height_start,input_shape.x);\n"
" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n"
" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start),height_start<0),weights_shape.y);\n"
"#else\n"
" const int height_start=mad24((output_batch_height_idx % output_shape.x),stride_shape.x,-padding_shape.x);\n"
" const int kh_start=select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0);\n"
" int in_height_start=mad24(kh_start,dilation_shape.x,height_start);\n"
" int in_height_end=min(mad24(weights_shape.x,dilation_shape.x,height_start),input_shape.x);\n"
" const int batch_idx=mul24((output_batch_height_idx/output_shape.x),input_shape.x);\n"
" const int weights_h_idx=mul24(out_channel_block_idx,mul24(weights_shape.y,weights_shape.x))+mul24(select(0,(-height_start+dilation_shape.x-1)/dilation_shape.x,height_start<0),weights_shape.y);\n"
"#endif\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" \n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+kh_start)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=in_height_start; iy<in_height_end; iy += dilation_shape.x) {\n"
" int in_hb_value=iy+batch_idx;\n"
"#ifdef MNN_CONV_S1D1\n"
" {\n"
" READ_INPUT_IMAGE(0,0);\n"
" READ_INPUT_IMAGE(1,0);\n"
" READ_INPUT_IMAGE(2,0);\n"
" READ_INPUT_IMAGE(3,0);\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
" for (int w=1; w<weights_shape.y; w++){\n"
" in0=in1;\n"
" in1=in2;\n"
" in2=in3;\n"
" READ_INPUT_IMAGE(3,w);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#else\n"
" for (int w=0; w<weights_shape.y; w++) {\n"
" int input_width_base=mul24(w,dilation_shape.y);\n"
" READ_INPUT_IMAGE(0,input_width_base);\n"
" READ_INPUT_IMAGE(1,input_width_base);\n"
" READ_INPUT_IMAGE(2,input_width_base);\n"
" READ_INPUT_IMAGE(3,input_width_base);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx)); \n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx)); \n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx)); \n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
"#endif\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" int out_x_idx=out_height_block_idx << 2;\n"
" const int remain=output_shape.y-out_x_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" if (remain >= 4) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" WI_F(output,(int2)(output_idx+3,output_batch_height_idx),out3);\n"
" } else if (remain == 3) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" WI_F(output,(int2)(output_idx+2,output_batch_height_idx),out2);\n"
" } else if (remain == 2) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" WI_F(output,(int2)(output_idx+1,output_batch_height_idx),out1);\n"
" } else if (remain == 1) {\n"
" WI_F(output,(int2)(output_idx,output_batch_height_idx),out0);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c8h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=(output_channel_width_idx/out_width_blocks) << 1;\n"
" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n"
" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n"
" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
" FLOAT4 out4=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx+1,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
" FLOAT4 out4=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" FLOAT4 out5=out4;\n"
" FLOAT4 out6=out4;\n"
" FLOAT4 out7=out4;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=weights_shape.x*weights_shape.y*4;\n"
" const int weight_ic_offset=out_channel_blocks*weight_oc_offset;\n"
"#endif\n"
" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n"
" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n"
" int in_height1=in_height0+stride_shape.x;\n"
" int in_height2=in_height1+stride_shape.x;\n"
" int in_height3=in_height2+stride_shape.x;\n"
" int weight_size=mul24(weights_shape.y,weights_shape.x);\n"
" \n"
" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n"
" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n"
" \n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3,weights4,weights5,weights6,weights7;\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
" COMPUTE_FLOAT8 ScaleOffset1=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx+1,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale1=(COMPUTE_FLOAT4)(ScaleOffset1.s0,ScaleOffset1.s2,ScaleOffset1.s4,ScaleOffset1.s6);\n"
" COMPUTE_FLOAT4 offset1=(COMPUTE_FLOAT4)(ScaleOffset1.s1,ScaleOffset1.s3,ScaleOffset1.s5,ScaleOffset1.s7);\n"
" \n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n"
" int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n"
" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n"
" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n"
" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n"
" for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n"
" int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n"
" \n"
" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n"
" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n"
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_ic_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_ic_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" charWeight0=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" charWeight1=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=out_channel_block_idx+1 >= out_channel_blocks ? (char4)0 : vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" \n"
" #else\n"
" charWeight0=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset);\n"
" charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*2);\n"
" charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset+weight_ic_offset*3);\n"
" #endif\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_ic_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" charWeightInt40=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset/2);\n"
" charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*2/2);\n"
" charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2+weight_ic_offset*3/2);\n"
" charWeight0=(char4)(0,0,0,0);\n"
" charWeight1=(char4)(0,0,0,0);\n"
" charWeight2=(char4)(0,0,0,0);\n"
" charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)- 8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)- 8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights4=mad(CONVERT_FLOAT4(charWeight0),scale1,offset1);\n"
" weights5=mad(CONVERT_FLOAT4(charWeight1),scale1,offset1);\n"
" weights6=mad(CONVERT_FLOAT4(charWeight2),scale1,offset1);\n"
" weights7=mad(CONVERT_FLOAT4(charWeight3),scale1,offset1);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_ic_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_ic_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_ic_offset*3);\n"
" #ifdef CHANNEL_BOUNDARY_PROTECT\n"
" weights4=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights5=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n"
" weights6=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n"
" weights7=out_channel_block_idx+1 >= out_channel_blocks ? (FLOAT4)0 : vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n"
" #else\n"
" weights4=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights5=vload4(0,weights+weight_offset+weight_ic_offset+weight_oc_offset);\n"
" weights6=vload4(0,weights+weight_offset+weight_ic_offset*2+weight_oc_offset);\n"
" weights7=vload4(0,weights+weight_offset+weight_ic_offset*3+weight_oc_offset);\n"
" #endif\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx));\n"
" weights4=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weight_size+weights_y_idx));\n"
" weights5=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weight_size+weights_y_idx));\n"
" weights6=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weight_size+weights_y_idx));\n"
" weights7=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weight_size+weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights4,weights5,weights6,weights7);\n"
" \n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" CALCULATE_OUTPUT_WEIGHTS4(4,0);\n"
" CALCULATE_OUTPUT_WEIGHTS4(5,1);\n"
" CALCULATE_OUTPUT_WEIGHTS4(6,2);\n"
" CALCULATE_OUTPUT_WEIGHTS4(7,3);\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
" out4=fmax(out4,(FLOAT4)0);\n"
" out5=fmax(out5,(FLOAT4)0);\n"
" out6=fmax(out6,(FLOAT4)0);\n"
" out7=fmax(out7,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
" out4=clamp(out4,(FLOAT4)0,(FLOAT4)6);\n"
" out5=clamp(out5,(FLOAT4)0,(FLOAT4)6);\n"
" out6=clamp(out6,(FLOAT4)0,(FLOAT4)6);\n"
" out7=clamp(out7,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n"
" int out_x_idx=out_width_block_idx;\n"
" int out_y_idx=out_height_block_idx << 2;\n"
" const int remain_y=output_shape.x-out_y_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" int output_idy=out_y_base+out_y_idx;\n"
" \n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" }else if(remain_y == 1){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" }\n"
" \n"
" if(out_channel_block_idx+1 >= out_channel_blocks) {\n"
" return;\n"
" }\n"
" output_idx += output_shape.y;\n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out7);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out6);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out5);\n"
" }else if(remain_y == 1){\n"
" WI_F(output,(int2)(output_idx,output_idy),out4);\n"
" }\n"
"}\n"
"__kernel\n"
"#if SET_ATTRIBUTE\n"
"__attribute__((work_group_size_hint(16,16,1)))\n"
"#endif\n"
"void conv_2d_c4h4w1(GLOBAL_SIZE_2_DIMS __read_only image2d_t input,\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" __global const char *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" __global const uchar *kernel_ptr,\n"
" __global const float *dequantScaleOffset,\n"
"#elif (defined USE_BUFFER)\n"
" __global const FLOAT *weights,\n"
"#else\n"
" __read_only image2d_t weights,\n"
"#endif\n"
"#ifdef BIAS\n"
" __read_only image2d_t bias,\n"
"#endif\n"
" __write_only image2d_t output,\n"
" __private const int2 input_shape,\n"
" __private const int in_channel_block_length,\n"
" __private const int2 output_shape,\n"
" __private const int2 weights_shape,\n"
" __private const int2 stride_shape,\n"
" __private const int2 padding_shape,\n"
" __private const int2 dilation_shape,\n"
" __private const int out_width_blocks,\n"
" __private const int out_channel_blocks,\n"
" __private const int out_height_blocks\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" ,__private const int blockDim\n"
" ,__private const int inChannel\n"
"#endif\n"
") {\n"
" const int output_channel_width_idx=get_global_id(0);\n"
" const int output_batch_height_idx=get_global_id(1);\n"
" DEAL_NON_UNIFORM_DIM2(output_channel_width_idx,output_batch_height_idx);\n"
" const int out_channel_block_idx=output_channel_width_idx/out_width_blocks;\n"
" const int out_width_block_idx=output_channel_width_idx % out_width_blocks;\n"
" const int out_height_block_idx=(output_batch_height_idx % out_height_blocks);\n"
" const int out_batch_block_idx=output_batch_height_idx/out_height_blocks;\n"
"#ifdef BIAS\n"
" FLOAT4 out0=RI_F(bias,SAMPLER,(int2)(out_channel_block_idx,0));\n"
"#else\n"
" FLOAT4 out0=(FLOAT4)0;\n"
"#endif\n"
" FLOAT4 out1=out0;\n"
" FLOAT4 out2=out0;\n"
" FLOAT4 out3=out0;\n"
" int in_width0=mad24(out_width_block_idx,stride_shape.y,-padding_shape.y);\n"
" int in_height0=mad24(out_height_block_idx,stride_shape.x<<2,-padding_shape.x);\n"
" int in_height1=in_height0+stride_shape.x;\n"
" int in_height2=in_height1+stride_shape.x;\n"
" int in_height3=in_height2+stride_shape.x;\n"
" int weight_size=mul24(weights_shape.y,weights_shape.x);\n"
" \n"
" const int weights_h_idx=mul24(out_channel_block_idx,weight_size);\n"
" const int batch_idx=mul24(out_batch_block_idx,input_shape.x);\n"
" \n"
" FLOAT4 in0,in1,in2,in3;\n"
" FLOAT4 weights0,weights1,weights2,weights3;\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" const int weight_oc_offset=out_channel_blocks*weights_shape.x*weights_shape.y*4;\n"
"#endif\n"
" for (int in_channel_block_idx=0; in_channel_block_idx<in_channel_block_length; ++in_channel_block_idx) {\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" int kindex=(in_channel_block_idx*4)/blockDim*out_channel_blocks*8;\n"
" COMPUTE_FLOAT8 ScaleOffset0=CONVERT_COMPUTE_FLOAT8(vload8(out_channel_block_idx,dequantScaleOffset+kindex));\n"
" COMPUTE_FLOAT4 scale0=(COMPUTE_FLOAT4)(ScaleOffset0.s0,ScaleOffset0.s2,ScaleOffset0.s4,ScaleOffset0.s6);\n"
" COMPUTE_FLOAT4 offset0=(COMPUTE_FLOAT4)(ScaleOffset0.s1,ScaleOffset0.s3,ScaleOffset0.s5,ScaleOffset0.s7);\n"
"#endif\n"
" const int in_idx=mul24(in_channel_block_idx,input_shape.y);\n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8) || (defined USE_LOW_BIT_WEIGHT_INT4) || (defined USE_BUFFER)\n"
" int weight_offset=((((4*in_channel_block_idx+0)* out_channel_blocks+out_channel_block_idx) *weights_shape.x+0)*weights_shape.y+0)*4;\n"
"#else\n"
" int weights_x_idx=in_channel_block_idx << 2;\n"
" int weights_y_idx=weights_h_idx;\n"
"#endif\n"
" for (int iy=0; iy<weights_shape.x*dilation_shape.x; iy += dilation_shape.x) {\n"
" int h0=select(in_height0+iy+batch_idx,-1,(in_height0+iy<0 || in_height0+iy >= input_shape.x));\n"
" int h1=select(in_height1+iy+batch_idx,-1,(in_height1+iy<0 || in_height1+iy >= input_shape.x));\n"
" int h2=select(in_height2+iy+batch_idx,-1,(in_height2+iy<0 || in_height2+iy >= input_shape.x));\n"
" int h3=select(in_height3+iy+batch_idx,-1,(in_height3+iy<0 || in_height3+iy >= input_shape.x));\n"
" for (int ix=0; ix<weights_shape.y*dilation_shape.y; ix += dilation_shape.y) {\n"
" int w0=select(in_width0+ix+in_idx,-1,(in_width0+ix<0 || in_width0+ix >= input_shape.y));\n"
" \n"
" in0=RI_F(input,SAMPLER,(int2)(w0,h0));\n"
" in1=RI_F(input,SAMPLER,(int2)(w0,h1));\n"
" in2=RI_F(input,SAMPLER,(int2)(w0,h2));\n"
" in3=RI_F(input,SAMPLER,(int2)(w0,h3));\n"
" \n"
"#if (defined USE_LOW_BIT_WEIGHT_INT8)\n"
" char4 charWeight0=vload4(0,kernel_ptr+weight_offset);\n"
" char4 charWeight1=vload4(0,kernel_ptr+weight_offset+weight_oc_offset);\n"
" char4 charWeight2=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*2);\n"
" char4 charWeight3=vload4(0,kernel_ptr+weight_offset+weight_oc_offset*3);\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_LOW_BIT_WEIGHT_INT4)\n"
" uchar2 charWeightInt40=vload2(0,kernel_ptr+weight_offset/2);\n"
" uchar2 charWeightInt41=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset/2);\n"
" uchar2 charWeightInt42=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*2/2);\n"
" uchar2 charWeightInt43=vload2(0,kernel_ptr+weight_offset/2+weight_oc_offset*3/2);\n"
" char4 charWeight0=(char4)(0,0,0,0);\n"
" char4 charWeight1=(char4)(0,0,0,0);\n"
" char4 charWeight2=(char4)(0,0,0,0);\n"
" char4 charWeight3=(char4)(0,0,0,0);\n"
" charWeight0.x=(charWeightInt40.s0 >> 4)-8;\n"
" charWeight0.y=(charWeightInt40.s0 & MOD_NUM)-8;\n"
" charWeight0.z=(charWeightInt40.s1 >> 4)-8;\n"
" charWeight0.w=(charWeightInt40.s1 & MOD_NUM)-8;\n"
" charWeight1.x=(charWeightInt41.s0 >> 4)-8;\n"
" charWeight1.y=(charWeightInt41.s0 & MOD_NUM)-8;\n"
" charWeight1.z=(charWeightInt41.s1 >> 4)-8;\n"
" charWeight1.w=(charWeightInt41.s1 & MOD_NUM)-8;\n"
" charWeight2.x=(charWeightInt42.s0 >> 4)-8;\n"
" charWeight2.y=(charWeightInt42.s0 & MOD_NUM)-8;\n"
" charWeight2.z=(charWeightInt42.s1 >> 4)-8;\n"
" charWeight2.w=(charWeightInt42.s1 & MOD_NUM)-8;\n"
" charWeight3.x=(charWeightInt43.s0 >> 4)-8;\n"
" charWeight3.y=(charWeightInt43.s0 & MOD_NUM)-8;\n"
" charWeight3.z=(charWeightInt43.s1 >> 4)-8;\n"
" charWeight3.w=(charWeightInt43.s1 & MOD_NUM)-8;\n"
" weights0=mad(CONVERT_FLOAT4(charWeight0),scale0,offset0);\n"
" weights1=mad(CONVERT_FLOAT4(charWeight1),scale0,offset0);\n"
" weights2=mad(CONVERT_FLOAT4(charWeight2),scale0,offset0);\n"
" weights3=mad(CONVERT_FLOAT4(charWeight3),scale0,offset0);\n"
" weight_offset += 4;\n"
"#elif (defined USE_BUFFER)\n"
" weights0=vload4(0,weights+weight_offset);\n"
" weights1=vload4(0,weights+weight_offset+weight_oc_offset);\n"
" weights2=vload4(0,weights+weight_offset+weight_oc_offset*2);\n"
" weights3=vload4(0,weights+weight_offset+weight_oc_offset*3);\n"
" weight_offset += 4;\n"
"#else\n"
" weights0=RI_F(weights,SAMPLER,(int2)(weights_x_idx+0,weights_y_idx));\n"
" weights1=RI_F(weights,SAMPLER,(int2)(weights_x_idx+1,weights_y_idx));\n"
" weights2=RI_F(weights,SAMPLER,(int2)(weights_x_idx+2,weights_y_idx));\n"
" weights3=RI_F(weights,SAMPLER,(int2)(weights_x_idx+3,weights_y_idx++));\n"
"#endif\n"
" PADZEROSVEC(in_channel_block_idx,inChannel,weights0,weights1,weights2,weights3);\n"
" CALCULATE_OUTPUT(0);\n"
" CALCULATE_OUTPUT(1);\n"
" CALCULATE_OUTPUT(2);\n"
" CALCULATE_OUTPUT(3);\n"
" }\n"
" }\n"
" }\n"
"#ifdef RELU\n"
" out0=fmax(out0,(FLOAT4)0);\n"
" out1=fmax(out1,(FLOAT4)0);\n"
" out2=fmax(out2,(FLOAT4)0);\n"
" out3=fmax(out3,(FLOAT4)0);\n"
"#endif\n"
"#ifdef RELU6\n"
" out0=clamp(out0,(FLOAT4)0,(FLOAT4)6);\n"
" out1=clamp(out1,(FLOAT4)0,(FLOAT4)6);\n"
" out2=clamp(out2,(FLOAT4)0,(FLOAT4)6);\n"
" out3=clamp(out3,(FLOAT4)0,(FLOAT4)6);\n"
"#endif\n"
" const int out_x_base=mul24(out_channel_block_idx,output_shape.y);\n"
" const int out_y_base=mul24(out_batch_block_idx,output_shape.x);\n"
" int out_x_idx=out_width_block_idx;\n"
" int out_y_idx=out_height_block_idx << 2;\n"
" const int remain_y=output_shape.x-out_y_idx;\n"
" int output_idx=out_x_base+out_x_idx;\n"
" int output_idy=out_y_base+out_y_idx;\n"
" if(remain_y >= 4){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" WI_F(output,(int2)(output_idx,output_idy+3),out3);\n"
" }else if(remain_y == 3){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" WI_F(output,(int2)(output_idx,output_idy+2),out2);\n"
" }else if(remain_y == 2){\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" WI_F(output,(int2)(output_idx,output_idy+1),out1);\n"
" }else{\n"
" WI_F(output,(int2)(output_idx,output_idy),out0);\n"
" }\n"
"}\n"
;
}