source/backend/metal/AllShader.cpp (2,305 lines of code) (raw):

#include "AllShader.hpp" const char* shader_MetalReLU6_metal = "struct Param {\n" " float minV;\n" " float maxV;\n" " int size;\n" " int remain;\n" "};\n" "kernel void relu6(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant Param &p [[buffer(2)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if (gid.x<p.size) {\n" " out[int(gid.x)]=clamp(in[int(gid.x)],(M4)p.minV,(M4)p.maxV);\n" " }\n" "}\n" "kernel void relu(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant Param &p [[buffer(2)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if (gid.x<p.size) {\n" " auto V=in[int(gid.x)];\n" " out[int(gid.x)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(p.minV);\n" " }\n" "}\n" ; const char* shader_MetalConvolutionDepthwise_metal = "struct conv_dw_cst {\n" " int input_width;\n" " int input_height;\n" " int input_size;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int slice;\n" " int batch;\n" " \n" " int kernel_x;\n" " int kernel_y;\n" " int kernel_size;\n" " int stride_x;\n" " int stride_y;\n" " int pad_x;\n" " int pad_y;\n" " int dilation_x;\n" " int dilation_y;\n" " conv_activation_type activation;\n" "};\n" "kernel void conv_depthwise(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_dw_cst& cst [[buffer(2)]],\n" " const device M4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n" " \n" " int oz=gid.z/cst.batch;\n" " int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n" " int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n" " int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n" " int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n" " int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n" " int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n" " offset_x += sx*cst.dilation_x;\n" " offset_y += sy*cst.dilation_y;\n" " auto z_wt=wt+(int)oz*cst.kernel_size;\n" " auto z_in=in+(int)gid.z*cst.input_size;\n" " auto z_out=out+(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n" " FLOAT4 result=FLOAT4(biasTerms[oz]);\n" " for (auto ky=sy,y=offset_y; ky<ey; ky++,y += cst.dilation_y) {\n" " for (auto kx=sx,x=offset_x; kx<ex; kx++,x += cst.dilation_x) {\n" " auto wt4=z_wt[ky*cst.kernel_x+kx];\n" " auto in4=z_in[ y*cst.input_width+x];\n" " result += FLOAT4(in4*wt4);\n" " }\n" " }\n" " *z_out=activate((M4)result,cst.activation);\n" "}\n" ; const char* shader_MetalConvolutionActivation_metal = "typedef enum : int {\n" " None=0,\n" " ReLU=1,\n" " ReLU6=2,\n" "} conv_activation_type;\n" "inline M4 activate(M4 V,conv_activation_type type) {\n" " switch (type) {\n" " case ReLU:\n" " return max(V,(M4)0);\n" " case ReLU6:\n" " return clamp(V,(M4)0,(M4)6);\n" " default: // None\n" " return V;\n" " }\n" "}\n" ; const char* shader_MetalConvolution_metal = "#define CONV_UNROLL (4)\n" "#define CONV_MUL_PACK_W2(x,y) " " x += FLOAT4(in00*k00);" " y += FLOAT4(in01*k00);" " x += FLOAT4(in01*k01);" " y += FLOAT4(in02*k01);" " x += FLOAT4(in02*k02);" " y += FLOAT4(in03*k02);" " " " x += FLOAT4(in10*k10);" " y += FLOAT4(in11*k10);" " x += FLOAT4(in11*k11);" " y += FLOAT4(in12*k11);" " x += FLOAT4(in12*k12);" " y += FLOAT4(in13*k12);" " " " x += FLOAT4(in20*k20);" " y += FLOAT4(in21*k20);" " x += FLOAT4(in21*k21);" " y += FLOAT4(in22*k21);" " x += FLOAT4(in22*k22);" " y += FLOAT4(in23*k22);\n" " \n" "#define CONV_NEXT_FLT " " z_wt += ws; " " " " k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];" " k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];" " k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n" "struct conv_constants {\n" " int input_width;\n" " int input_height;\n" " int input_size;\n" " int input_slice;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int output_slice;\n" " int batch;\n" " int oz_size;\n" " int threadgroup_input_slice;\n" " \n" " int kernel_x;\n" " int kernel_y;\n" " int kernel_size;\n" " int stride_x;\n" " int stride_y;\n" " int pad_x;\n" " int pad_y;\n" " int dilation_x;\n" " int dilation_y;\n" " conv_activation_type activation; \n" "};\n" "kernel void conv(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n" " \n" " int idx_w=gid.x;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " \n" " int offset_x=(int)idx_w*cst.stride_x-cst.pad_x;\n" " int offset_y=(int)idx_h*cst.stride_y-cst.pad_y;\n" " int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n" " int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n" " int kw=ex-sx;\n" " int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n" " int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n" " int kh=ey-sy;\n" " offset_x += sx*cst.dilation_x;\n" " offset_y += sy*cst.dilation_y;\n" " \n" " auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n" " auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n" " auto z_out=out+idx_b*cst.output_size+(int)idx_c*cst.batch*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n" " int dilation_h=cst.input_width*cst.dilation_y;\n" " FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " for (auto y=0; y<kh; y++) {\n" " for (auto x=0; x<kw; x++) {\n" " auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n" " auto in4=z_in[z*cst.input_size*cst.batch+y*dilation_h+x*cst.dilation_x];\n" " result += FLOAT4(in4*wt4);\n" " }\n" " }\n" " }\n" " *z_out=activate(M4(result),cst.activation);\n" "}\n" "kernel void convk3s1d1p1_w2z4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height) return;\n" " \n" " int idx_w=gid.x << 1;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n" " bool3 valids=uz.yzw<cst.output_slice;\n" " bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n" " int offset_x=(int)gid.x*2-cst.pad_x;\n" " int offset_y=(int)gid.y-cst.pad_y;\n" " auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n" " auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n" " auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n" " \n" " int ws=cst.input_slice*cst.kernel_size;\n" " FLOAT4 result0=0,result1=0,result2=0,result3=0;\n" " FLOAT4 result4=0,result5=0,result6=0,result7=0;\n" " for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += (cst.input_size*cst.batch)) {\n" " auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n" " auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n" " auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n" " auto in03=(offset_x+3>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+3);\n" " auto in10=(offset_x<0 || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+0);\n" " auto in11=(offset_x+1>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+1);\n" " auto in12=(offset_x+2>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+2);\n" " auto in13=(offset_x+3>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+3);\n" " \n" " auto in20=(offset_x<0 || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+0);\n" " auto in21=(offset_x+1>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+1);\n" " auto in22=(offset_x+2>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+2);\n" " auto in23=(offset_x+3>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+3);\n" " \n" " auto z_wt=z_flt;\n" " auto k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];\n" " auto k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];\n" " auto k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n" " CONV_MUL_PACK_W2(result0,result4);\n" " if (valids[0]) {\n" " CONV_NEXT_FLT;\n" " CONV_MUL_PACK_W2(result1,result5);\n" " }\n" " if (valids[1]) {\n" " CONV_NEXT_FLT;\n" " CONV_MUL_PACK_W2(result2,result6);\n" " }\n" " if (valids[2]) {\n" " CONV_NEXT_FLT;\n" " CONV_MUL_PACK_W2(result3,result7);\n" " }\n" " }\n" " /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n" " if(valid_x) {\n" " *(z_out+1)=activate(M4(result4+FLOAT4(biasTerms[uz[0]])),cst.activation);\n" " }\n" " if (valids[0]) {\n" " z_out += cst.output_size;\n" " *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation);\n" " if(valid_x) {\n" " *(z_out+1)=activate(M4(result5+FLOAT4(biasTerms[uz[1]])),cst.activation);\n" " }\n" " }\n" " if (valids[1]) {\n" " z_out += cst.output_size;\n" " *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation);\n" " if(valid_x) {\n" " *(z_out+1)=activate(M4(result6+FLOAT4(biasTerms[uz[2]])),cst.activation);\n" " }\n" " }\n" " if (valids[2]) {\n" " z_out += cst.output_size;\n" " *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation);\n" " if(valid_x) {\n" " *(z_out+1)=activate(M4(result7+FLOAT4(biasTerms[uz[3]])),cst.activation);\n" " }\n" " }\n" "}\n" "kernel void conv_s1d1p0_w2(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n" " \n" " int idx_w=gid.x << 1;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n" " bool valid=(idx_w+1<cst.output_width);\n" " \n" " auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n" " auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n" " auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n" " FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n" " FLOAT4 result1=result0;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " for (auto y=0; y<cst.kernel_y; y++) {\n" " auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x];\n" " auto in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width];\n" " result0 += FLOAT4(in4_0*wt4);\n" " for (auto x=1; x<cst.kernel_x; x++) {\n" " in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width+x];\n" " result1 += FLOAT4(in4_0*wt4);\n" " wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n" " result0 += FLOAT4(in4_0*wt4);\n" " }\n" " in4_0=z_in[z*cst.input_size*cst.batch+y*cst.input_width+cst.kernel_x];\n" " result1 += FLOAT4(in4_0*wt4);\n" " }\n" " }\n" " *z_out=activate(M4(result0),cst.activation);\n" " if(valid) { *(z_out+1)=activate(M4(result1),cst.activation);}\n" "}\n" "kernel void conv_s1d1p0_w4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*4 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n" " \n" " int idx_w=gid.x << 2;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " \n" " if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n" " int3 uz=idx_w+int3(1,2,3);\n" " bool3 valids=uz.xyz<cst.output_width;\n" " \n" " auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n" " auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n" " auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n" " FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n" " FLOAT4 result1=result0;\n" " FLOAT4 result2=result0;\n" " FLOAT4 result3=result0;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " for (auto y=0; y<cst.kernel_y; y++) {\n" " auto wt_base=z_wt+z*cst.kernel_size+y*cst.kernel_x;\n" " auto wt4_0=wt_base[0];\n" " auto wt4_1=wt_base[1];\n" " auto wt4_2=wt_base[2];\n" " auto z_in_base=z_in+z*cst.batch*cst.input_size+y*cst.input_width;\n" " auto in4_0=z_in_base[0];\n" " result0 += FLOAT4(in4_0*wt4_0);\n" " \n" " in4_0=z_in_base[1];\n" " result0 += FLOAT4(in4_0*wt4_1);\n" " result1 += FLOAT4(in4_0*wt4_0);\n" " in4_0=z_in_base[2];\n" " result0 += FLOAT4(in4_0*wt4_2);\n" " result1 += FLOAT4(in4_0*wt4_1);\n" " result2 += FLOAT4(in4_0*wt4_0);\n" " in4_0=z_in_base[3];\n" " result1 += FLOAT4(in4_0*wt4_2);\n" " result2 += FLOAT4(in4_0*wt4_1);\n" " result3 += FLOAT4(in4_0*wt4_0);\n" " \n" " in4_0=z_in_base[4];\n" " result2 += FLOAT4(in4_0*wt4_2);\n" " result3 += FLOAT4(in4_0*wt4_1);\n" " in4_0=z_in_base[5];\n" " result3 += FLOAT4(in4_0*wt4_2);\n" " }\n" " }\n" " *z_out=activate(M4(result0),cst.activation);\n" " if(valids[0]) { *(z_out+1)=activate(M4(result1),cst.activation);}\n" " if(valids[1]) { *(z_out+2)=activate(M4(result2),cst.activation);}\n" " if(valids[2]) { *(z_out+3)=activate(M4(result3),cst.activation);}\n" "}\n" "kernel void conv_z4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n" " \n" " int idx_w=gid.x;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " if (idx_b >= cst.batch || idx_c*4 >= cst.output_slice) return;\n" " int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n" " bool3 valids=uz.yzw<cst.output_slice;\n" " \n" " int offset_x=idx_w*cst.stride_x-cst.pad_x;\n" " int offset_y=idx_h*cst.stride_y-cst.pad_y;\n" " int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n" " int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n" " int kw=ex-sx;\n" " int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n" " int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n" " int kh=ey-sy;\n" " offset_x += sx*cst.dilation_x;\n" " offset_y += sy*cst.dilation_y;\n" " \n" " auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n" " auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n" " auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n" " \n" " int ws=cst.input_slice*cst.kernel_size;\n" " int dilation_h=cst.input_width*cst.dilation_y;\n" " FLOAT4 result0=0,result1=0,result2=0,result3=0;\n" " for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n" " for (auto y=0; y<kh; y++) {\n" " for (auto x=0; x<kw; x++) {\n" " auto x_wt=z_wt+y*cst.kernel_x+x;\n" " auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n" " /* true */ result0 += FLOAT4(in4**x_wt);\n" " if (valids.x) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n" " if (valids.y) { x_wt += ws; result2 += FLOAT4(in4**x_wt); }\n" " if (valids.z) { x_wt += ws; result3 += FLOAT4(in4**x_wt); }\n" " }\n" " }\n" " }\n" " /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n" " if (valids.x) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n" " if (valids.y) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation); }\n" " if (valids.z) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation); }\n" "}\n" "kernel void conv_z2(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n" " \n" " int idx_w=gid.x;\n" " int idx_h=gid.y;\n" " int idx_c=gid.z/cst.batch;\n" " int idx_b=gid.z % cst.batch;\n" " if (idx_b >= cst.batch || idx_c*2 >= cst.output_slice) return;\n" " int2 uz=idx_c*2+int2(0,1);\n" " bool valids=uz.y<cst.output_slice;\n" " \n" " int offset_x=idx_w*cst.stride_x-cst.pad_x;\n" " int offset_y=idx_h*cst.stride_y-cst.pad_y;\n" " int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n" " int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n" " int kw=ex-sx;\n" " int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n" " int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n" " int kh=ey-sy;\n" " offset_x += sx*cst.dilation_x;\n" " offset_y += sy*cst.dilation_y;\n" " \n" " auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n" " auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n" " auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n" " \n" " int ws=cst.input_slice*cst.kernel_size;\n" " int dilation_h=cst.input_width*cst.dilation_y;\n" " FLOAT4 result0=0,result1=0;\n" " for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n" " for (auto y=0; y<kh; y++) {\n" " for (auto x=0; x<kw; x++) {\n" " auto x_wt=z_wt+y*cst.kernel_x+x;\n" " auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n" " /* true */ result0 += FLOAT4(in4**x_wt);\n" " if (valids) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n" " }\n" " }\n" " }\n" " /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n" " if (valids) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n" "}\n" ; const char* shader_MetalReduction_metal = "struct reduce_shape {\n" " int outside_size;\n" " int axis_size;\n" " int inside_size;\n" " int outside_step;\n" "};\n" "template <typename M,typename T>\n" "static inline void reduce_mean(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n" " auto axis_in=in+gid.x*s.outside_step+gid.y;\n" " M summer=0;\n" " for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n" " summer += M(*axis_in);\n" " }\n" " out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer/s.axis_size);\n" "}\n" "template <typename M,typename T>\n" "static inline void reduce_sum(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n" " auto axis_in=in+gid.x*s.outside_step+gid.y;\n" " M summer=0;\n" " for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n" " summer += M(*axis_in);\n" " }\n" " out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n" "}\n" "template <typename M,typename T>\n" "static inline void reduce_min(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n" " auto axis_in=in+gid.x*s.outside_step+gid.y;\n" " T summer=*axis_in; axis_in += s.inside_size;\n" " for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n" " summer=min(summer,*axis_in);\n" " }\n" " out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n" "}\n" "template <typename M,typename T>\n" "static inline void reduce_max(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n" " auto axis_in=in+gid.x*s.outside_step+gid.y;\n" " T summer=*axis_in; axis_in += s.inside_size;\n" " for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n" " summer=max(summer,*axis_in);\n" " }\n" " out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n" "}\n" "template <typename M,typename T>\n" "static inline void reduce_prod(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n" " auto axis_in=in+gid.x*s.outside_step+gid.y;\n" " M summer=1;\n" " for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n" " summer *= M(*axis_in);\n" " }\n" " out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n" "}\n" "#define define_reduce(name) " "kernel void reduce_##name##_f(const device M *in [[buffer(0)]]," " device M *out [[buffer(1)]]," " constant reduce_shape &s [[buffer(2)]]," " uint2 gid [[thread_position_in_grid]]) { " " if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<FLOAT,M>(in,out,s,gid); " "} " "kernel void reduce_##name##_s(const device int *in [[buffer(0)]]," " device int *out [[buffer(1)]]," " constant reduce_shape &s [[buffer(2)]]," " uint2 gid [[thread_position_in_grid]]) { " " if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<int,int>(in,out,s,gid); " "}\n" "define_reduce(mean);\n" "define_reduce(sum);\n" "define_reduce(min);\n" "define_reduce(max);\n" "define_reduce(prod);\n" ; const char* shader_MetalSoftmax_metal = "struct softmax_shape {\n" " int inside_size;\n" " int axis_length;\n" " int outside_size;\n" " int flat_length;\n" "};\n" "static inline float softmax_max4(float4 V) {\n" " return max(max(V[0],V[1]),max(V[2],V[3]));\n" "}\n" "static inline float softmax_sum4(float4 V) {\n" " return V[0]+V[1]+V[2]+V[3];\n" "}\n" "static inline float4 softmax_filter(float4 V,int z,int limit) {\n" " return select(0,V,z*4+int4(0,1,2,3)<limit);\n" "}\n" "kernel void softmax_plane(const device M *in [[buffer(0)]],\n" " device M *out [[buffer(1)]],\n" " constant softmax_shape& s [[buffer(2)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n" " \n" " auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n" " auto axis_in=in+axis_off;\n" " auto axis_out=out+axis_off;\n" " \n" " // get max\n" " float max1=-INFINITY;\n" " for (int i=0; i<s.axis_length; i++) {\n" " max1=max(max1,float(axis_in[i*s.inside_size]));\n" " }\n" " \n" " // get sum\n" " float sum1=0;\n" " for (int i=0; i<s.axis_length; i++) {\n" " sum1 += float(exp(float(axis_in[i*s.inside_size])-float(max1)));\n" " }\n" " \n" " // output\n" " for (int i=0; i<s.axis_length; i++) {\n" " axis_out[i*s.inside_size]=M(exp(float(axis_in[i*s.inside_size])-float(max1))/sum1);\n" " }\n" "}\n" "kernel void softmax_on_reorder(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant softmax_shape& s [[buffer(2)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n" " \n" " auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n" " auto axis_in=in+axis_off;\n" " auto axis_out=out+axis_off;\n" " // get max\n" " auto max4=softmax_filter(float4(axis_in[0]),0,s.flat_length);\n" " for (int i=1; i<s.axis_length; i++) {\n" " max4=max(max4,softmax_filter(float4(axis_in[i*s.inside_size]),i,s.flat_length));\n" " }\n" " float max1=softmax_max4(max4);\n" " \n" " // get sum\n" " float4 sum4=0;\n" " for (int i=0; i<s.axis_length; i++) {\n" " sum4 += softmax_filter(exp(float4(axis_in[i*s.inside_size]-max1)),i,s.flat_length);\n" " }\n" " float sum1=softmax_sum4(sum4);\n" " \n" " // output\n" " for (int i=0; i<s.axis_length; i++) {\n" " axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size])-max1)/sum1);\n" " }\n" "}\n" "kernel void softmax_off_reorder(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant softmax_shape& s [[buffer(2)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n" " auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n" " auto axis_in=in+axis_off;\n" " auto axis_out=out+axis_off;\n" " // get max\n" " auto max4=axis_in[0];\n" " for (int i=1; i<s.axis_length; i++) {\n" " max4=max(max4,axis_in[i*s.inside_size]);\n" " }\n" " // get sum\n" " float4 sum4=0;\n" " for (int i=0; i<s.axis_length; i++) {\n" " sum4 += exp(float4(axis_in[i*s.inside_size]-max4));\n" " }\n" " // output\n" " for (int i=0; i<s.axis_length; i++) {\n" " axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size]-max4))/sum4);\n" " }\n" "}\n" ; const char* shader_MetalLayerNorm_metal = "struct layernorm_constants {\n" " int inside;\n" " int outside;\n" " float eps;\n" " int has_gamma_beta;\n" "};\n" "kernel void layernorm_x1(const device M *in [[buffer(0)]],\n" " device M *out [[buffer(1)]],\n" " constant layernorm_constants& cst [[buffer(2)]],\n" " const device float *gamma [[buffer(3)]],\n" " const device float *beta [[buffer(4)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n" " return;\n" " }\n" " auto in_data=in+gid.y*cst.inside;\n" " auto out_data=out+gid.y*cst.inside;\n" " float mean;\n" " float sum=0.0f;\n" " float square_sum=0.0f;\n" " \n" " for(int i=0; i<cst.inside; i++) {\n" " sum += in_data[i];\n" " }\n" " mean=sum/cst.inside;\n" " \n" " for(int i=0; i<cst.inside; i++) {\n" " float dis=(in_data[i]-mean);\n" " square_sum += dis*dis;\n" " }\n" " float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n" " \n" " float norm=var*((float)in_data[gid.x]-mean);\n" " if(cst.has_gamma_beta) {\n" " out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n" " } else {\n" " out_data[gid.x]=(M)(norm);\n" " }\n" "}\n" "kernel void layernorm_x4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant layernorm_constants& cst [[buffer(2)]],\n" " const device float4 *gamma [[buffer(3)]],\n" " const device float4 *beta [[buffer(4)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n" " return;\n" " }\n" " auto in_data=in+gid.y*cst.inside/4;\n" " auto out_data=out+gid.y*cst.inside/4;\n" " float mean;\n" " float sum=0.0f;\n" " float square_sum=0.0f;\n" " \n" " for(int i=0; i<cst.inside/4; i++) {\n" " sum += in_data[i].x;\n" " sum += in_data[i].y;\n" " sum += in_data[i].z;\n" " sum += in_data[i].w;\n" " }\n" " mean=sum/cst.inside;\n" " \n" " for(int i=0; i<cst.inside/4; i++) {\n" " float dis=(in_data[i].x-mean);\n" " square_sum += dis*dis;\n" " dis=(in_data[i].y-mean);\n" " square_sum += dis*dis;\n" " dis=(in_data[i].z-mean);\n" " square_sum += dis*dis;\n" " dis=(in_data[i].w-mean);\n" " square_sum += dis*dis;\n" " }\n" " float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n" " \n" " float4 norm=var*((float4)in_data[gid.x]-mean);\n" " if(cst.has_gamma_beta) {\n" " out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n" " } else {\n" " out_data[gid.x]=(M4)(norm);\n" " }\n" "}\n" "kernel void layernorm_x1_rms(const device M *in [[buffer(0)]],\n" " device M *out [[buffer(1)]],\n" " constant layernorm_constants& cst [[buffer(2)]],\n" " const device float *gamma [[buffer(3)]],\n" " const device float *beta [[buffer(4)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n" " return;\n" " }\n" " auto in_data=in+gid.y*cst.inside;\n" " auto out_data=out+gid.y*cst.inside;\n" " float square_sum=0.0f;\n" " \n" " for(int i=0; i<cst.inside; i++) {\n" " float dis=in_data[i];\n" " square_sum += dis*dis;\n" " }\n" " float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n" " \n" " float norm=var*((float)in_data[gid.x]);\n" " if(cst.has_gamma_beta) {\n" " out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n" " } else {\n" " out_data[gid.x]=(M)(norm);\n" " }\n" "}\n" "kernel void layernorm_x4_rms(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant layernorm_constants& cst [[buffer(2)]],\n" " const device float4 *gamma [[buffer(3)]],\n" " const device float4 *beta [[buffer(4)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n" " return;\n" " }\n" " auto in_data=in+gid.y*cst.inside/4;\n" " auto out_data=out+gid.y*cst.inside/4;\n" " float square_sum=0.0f;\n" " for(int i=0; i<cst.inside/4; i++) {\n" " float dis=in_data[i].x;\n" " square_sum += dis*dis;\n" " dis=in_data[i].y;\n" " square_sum += dis*dis;\n" " dis=in_data[i].z;\n" " square_sum += dis*dis;\n" " dis=in_data[i].w;\n" " square_sum += dis*dis;\n" " }\n" " float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n" " \n" " float4 norm=var*((float4)in_data[gid.x]);\n" " if(cst.has_gamma_beta) {\n" " out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n" " } else {\n" " out_data[gid.x]=(M4)(norm);\n" " }\n" "}\n" ; const char* shader_MetalConvolutionWinograd_metal = "struct winograd_constants {\n" " int4 input_shape;\n" " int4 output_shape;\n" " int pad_x;\n" " int pad_y;\n" " int unit_width;\n" " int unit_height;\n" " int unit;\n" " conv_activation_type activation;\n" "};\n" "static inline M4 get_input(const device M4 *input,int x,int y,constant winograd_constants &cst) {\n" " return x<cst.input_shape.x && y<cst.input_shape.y && x >= 0 && y >= 0 ? input[x+y*cst.input_shape.x] : 0;\n" "}\n" "kernel void winograd_transform_source2_5_1(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant winograd_constants &cst [[buffer(2)]],\n" " constant int &batch_idx [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " auto pos=int3(gid);\n" " if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n" " int ix=pos.x*cst.unit-cst.pad_x;\n" " int iy=pos.y*cst.unit-cst.pad_y;\n" " auto z_in=in+(pos.z*cst.output_shape.w+batch_idx)*cst.input_shape.x*cst.input_shape.y;\n" " auto S00=get_input(z_in,ix+0,iy+0,cst);\n" " auto S10=get_input(z_in,ix+1,iy+0,cst);\n" " auto S20=get_input(z_in,ix+2,iy+0,cst);\n" " auto S30=get_input(z_in,ix+3,iy+0,cst);\n" " auto S40=get_input(z_in,ix+4,iy+0,cst);\n" " auto S50=get_input(z_in,ix+5,iy+0,cst);\n" " auto S01=get_input(z_in,ix+0,iy+1,cst);\n" " auto S11=get_input(z_in,ix+1,iy+1,cst);\n" " auto S21=get_input(z_in,ix+2,iy+1,cst);\n" " auto S31=get_input(z_in,ix+3,iy+1,cst);\n" " auto S41=get_input(z_in,ix+4,iy+1,cst);\n" " auto S51=get_input(z_in,ix+5,iy+1,cst);\n" " auto S02=get_input(z_in,ix+0,iy+2,cst);\n" " auto S12=get_input(z_in,ix+1,iy+2,cst);\n" " auto S22=get_input(z_in,ix+2,iy+2,cst);\n" " auto S32=get_input(z_in,ix+3,iy+2,cst);\n" " auto S42=get_input(z_in,ix+4,iy+2,cst);\n" " auto S52=get_input(z_in,ix+5,iy+2,cst);\n" " auto S03=get_input(z_in,ix+0,iy+3,cst);\n" " auto S13=get_input(z_in,ix+1,iy+3,cst);\n" " auto S23=get_input(z_in,ix+2,iy+3,cst);\n" " auto S33=get_input(z_in,ix+3,iy+3,cst);\n" " auto S43=get_input(z_in,ix+4,iy+3,cst);\n" " auto S53=get_input(z_in,ix+5,iy+3,cst);\n" " auto S04=get_input(z_in,ix+0,iy+4,cst);\n" " auto S14=get_input(z_in,ix+1,iy+4,cst);\n" " auto S24=get_input(z_in,ix+2,iy+4,cst);\n" " auto S34=get_input(z_in,ix+3,iy+4,cst);\n" " auto S44=get_input(z_in,ix+4,iy+4,cst);\n" " auto S54=get_input(z_in,ix+5,iy+4,cst);\n" " auto S05=get_input(z_in,ix+0,iy+5,cst);\n" " auto S15=get_input(z_in,ix+1,iy+5,cst);\n" " auto S25=get_input(z_in,ix+2,iy+5,cst);\n" " auto S35=get_input(z_in,ix+3,iy+5,cst);\n" " auto S45=get_input(z_in,ix+4,iy+5,cst);\n" " auto S55=get_input(z_in,ix+5,iy+5,cst);\n" " auto m00=+S00-1.25*S02+0.25*S04;\n" " auto m10=+S10-1.25*S12+0.25*S14;\n" " auto m20=+S20-1.25*S22+0.25*S24;\n" " auto m30=+S30-1.25*S32+0.25*S34;\n" " auto m40=+S40-1.25*S42+0.25*S44;\n" " auto m50=+S50-1.25*S52+0.25*S54;\n" " auto m01=+0.666667*S01+0.666667*S02-0.166667*S03-0.166667*S04;\n" " auto m11=+0.666667*S11+0.666667*S12-0.166667*S13-0.166667*S14;\n" " auto m21=+0.666667*S21+0.666667*S22-0.166667*S23-0.166667*S24;\n" " auto m31=+0.666667*S31+0.666667*S32-0.166667*S33-0.166667*S34;\n" " auto m41=+0.666667*S41+0.666667*S42-0.166667*S43-0.166667*S44;\n" " auto m51=+0.666667*S51+0.666667*S52-0.166667*S53-0.166667*S54;\n" " auto m02=-0.666667*S01+0.666667*S02+0.166667*S03-0.166667*S04;\n" " auto m12=-0.666667*S11+0.666667*S12+0.166667*S13-0.166667*S14;\n" " auto m22=-0.666667*S21+0.666667*S22+0.166667*S23-0.166667*S24;\n" " auto m32=-0.666667*S31+0.666667*S32+0.166667*S33-0.166667*S34;\n" " auto m42=-0.666667*S41+0.666667*S42+0.166667*S43-0.166667*S44;\n" " auto m52=-0.666667*S51+0.666667*S52+0.166667*S53-0.166667*S54;\n" " auto m03=-0.0833333*S01-0.0416667*S02+0.0833333*S03+0.0416667*S04;\n" " auto m13=-0.0833333*S11-0.0416667*S12+0.0833333*S13+0.0416667*S14;\n" " auto m23=-0.0833333*S21-0.0416667*S22+0.0833333*S23+0.0416667*S24;\n" " auto m33=-0.0833333*S31-0.0416667*S32+0.0833333*S33+0.0416667*S34;\n" " auto m43=-0.0833333*S41-0.0416667*S42+0.0833333*S43+0.0416667*S44;\n" " auto m53=-0.0833333*S51-0.0416667*S52+0.0833333*S53+0.0416667*S54;\n" " auto m04=+0.0833333*S01-0.0416667*S02-0.0833333*S03+0.0416667*S04;\n" " auto m14=+0.0833333*S11-0.0416667*S12-0.0833333*S13+0.0416667*S14;\n" " auto m24=+0.0833333*S21-0.0416667*S22-0.0833333*S23+0.0416667*S24;\n" " auto m34=+0.0833333*S31-0.0416667*S32-0.0833333*S33+0.0416667*S34;\n" " auto m44=+0.0833333*S41-0.0416667*S42-0.0833333*S43+0.0416667*S44;\n" " auto m54=+0.0833333*S51-0.0416667*S52-0.0833333*S53+0.0416667*S54;\n" " auto m05=+4.0*S01-5.0*S03+S05;\n" " auto m15=+4.0*S11-5.0*S13+S15;\n" " auto m25=+4.0*S21-5.0*S23+S25;\n" " auto m35=+4.0*S31-5.0*S33+S35;\n" " auto m45=+4.0*S41-5.0*S43+S45;\n" " auto m55=+4.0*S51-5.0*S53+S55;\n" " int dst_x_origin=pos.z;\n" " int dst_y_origin=cst.unit_width*pos.y+pos.x;\n" " int dst_y_stride=cst.input_shape.z*4;\n" " int dst_y=dst_y_origin/4;\n" " int dst_x=dst_y_origin % 4+4*dst_x_origin;\n" " int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n" " int stride=src_height*dst_y_stride;\n" " auto xy_out=out+dst_y*dst_y_stride+dst_x;\n" " *xy_out=+m00-1.25*m20+0.25*m40;\n" " xy_out += stride; *xy_out=+0.666667*m10+0.666667*m20-0.166667*m30-0.166667*m40;\n" " xy_out += stride; *xy_out=-0.666667*m10+0.666667*m20+0.166667*m30-0.166667*m40;\n" " xy_out += stride; *xy_out=-0.0833333*m10-0.0416667*m20+0.0833333*m30+0.0416667*m40;\n" " xy_out += stride; *xy_out=+0.0833333*m10-0.0416667*m20-0.0833333*m30+0.0416667*m40;\n" " xy_out += stride; *xy_out=+4.0*m10-5.0*m30+m50;\n" " xy_out += stride; *xy_out=+m01-1.25*m21+0.25*m41;\n" " xy_out += stride; *xy_out=+0.666667*m11+0.666667*m21-0.166667*m31-0.166667*m41;\n" " xy_out += stride; *xy_out=-0.666667*m11+0.666667*m21+0.166667*m31-0.166667*m41;\n" " xy_out += stride; *xy_out=-0.0833333*m11-0.0416667*m21+0.0833333*m31+0.0416667*m41;\n" " xy_out += stride; *xy_out=+0.0833333*m11-0.0416667*m21-0.0833333*m31+0.0416667*m41;\n" " xy_out += stride; *xy_out=+4.0*m11-5.0*m31+m51;\n" " xy_out += stride; *xy_out=+m02-1.25*m22+0.25*m42;\n" " xy_out += stride; *xy_out=+0.666667*m12+0.666667*m22-0.166667*m32-0.166667*m42;\n" " xy_out += stride; *xy_out=-0.666667*m12+0.666667*m22+0.166667*m32-0.166667*m42;\n" " xy_out += stride; *xy_out=-0.0833333*m12-0.0416667*m22+0.0833333*m32+0.0416667*m42;\n" " xy_out += stride; *xy_out=+0.0833333*m12-0.0416667*m22-0.0833333*m32+0.0416667*m42;\n" " xy_out += stride; *xy_out=+4.0*m12-5.0*m32+m52;\n" " xy_out += stride; *xy_out=+m03-1.25*m23+0.25*m43;\n" " xy_out += stride; *xy_out=+0.666667*m13+0.666667*m23-0.166667*m33-0.166667*m43;\n" " xy_out += stride; *xy_out=-0.666667*m13+0.666667*m23+0.166667*m33-0.166667*m43;\n" " xy_out += stride; *xy_out=-0.0833333*m13-0.0416667*m23+0.0833333*m33+0.0416667*m43;\n" " xy_out += stride; *xy_out=+0.0833333*m13-0.0416667*m23-0.0833333*m33+0.0416667*m43;\n" " xy_out += stride; *xy_out=+4.0*m13-5.0*m33+m53;\n" " xy_out += stride; *xy_out=+m04-1.25*m24+0.25*m44;\n" " xy_out += stride; *xy_out=+0.666667*m14+0.666667*m24-0.166667*m34-0.166667*m44;\n" " xy_out += stride; *xy_out=-0.666667*m14+0.666667*m24+0.166667*m34-0.166667*m44;\n" " xy_out += stride; *xy_out=-0.0833333*m14-0.0416667*m24+0.0833333*m34+0.0416667*m44;\n" " xy_out += stride; *xy_out=+0.0833333*m14-0.0416667*m24-0.0833333*m34+0.0416667*m44;\n" " xy_out += stride; *xy_out=+4.0*m14-5.0*m34+m54;\n" " xy_out += stride; *xy_out=+m05-1.25*m25+0.25*m45;\n" " xy_out += stride; *xy_out=+0.666667*m15+0.666667*m25-0.166667*m35-0.166667*m45;\n" " xy_out += stride; *xy_out=-0.666667*m15+0.666667*m25+0.166667*m35-0.166667*m45;\n" " xy_out += stride; *xy_out=-0.0833333*m15-0.0416667*m25+0.0833333*m35+0.0416667*m45;\n" " xy_out += stride; *xy_out=+0.0833333*m15-0.0416667*m25-0.0833333*m35+0.0416667*m45;\n" " xy_out += stride; *xy_out=+4.0*m15-5.0*m35+m55;\n" " }\n" "}\n" "kernel void winograd_transform_source2_3_1(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant winograd_constants &cst [[buffer(2)]],\n" " constant int &batch_idx [[buffer(3)]],\n" " constant int &split_idx [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " auto pos=int3(gid);\n" " if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n" " int ix=pos.x*cst.unit-cst.pad_x;\n" " int iy=(pos.y+split_idx*cst.unit_height)*cst.unit-cst.pad_y;\n" " auto z_in=in+(pos.z*cst.output_shape.w+batch_idx)*cst.input_shape.x*cst.input_shape.y;\n" " auto S00=get_input(z_in,ix+0,iy+0,cst);\n" " auto S10=get_input(z_in,ix+1,iy+0,cst);\n" " auto S20=get_input(z_in,ix+2,iy+0,cst);\n" " auto S30=get_input(z_in,ix+3,iy+0,cst);\n" " auto S01=get_input(z_in,ix+0,iy+1,cst);\n" " auto S11=get_input(z_in,ix+1,iy+1,cst);\n" " auto S21=get_input(z_in,ix+2,iy+1,cst);\n" " auto S31=get_input(z_in,ix+3,iy+1,cst);\n" " auto S02=get_input(z_in,ix+0,iy+2,cst);\n" " auto S12=get_input(z_in,ix+1,iy+2,cst);\n" " auto S22=get_input(z_in,ix+2,iy+2,cst);\n" " auto S32=get_input(z_in,ix+3,iy+2,cst);\n" " auto S03=get_input(z_in,ix+0,iy+3,cst);\n" " auto S13=get_input(z_in,ix+1,iy+3,cst);\n" " auto S23=get_input(z_in,ix+2,iy+3,cst);\n" " auto S33=get_input(z_in,ix+3,iy+3,cst);\n" " auto m00=+S00-S02;\n" " auto m10=+S10-S12;\n" " auto m20=+S20-S22;\n" " auto m30=+S30-S32;\n" " auto m01=+0.5*S01+0.5*S02;\n" " auto m11=+0.5*S11+0.5*S12;\n" " auto m21=+0.5*S21+0.5*S22;\n" " auto m31=+0.5*S31+0.5*S32;\n" " auto m02=-0.5*S01+0.5*S02;\n" " auto m12=-0.5*S11+0.5*S12;\n" " auto m22=-0.5*S21+0.5*S22;\n" " auto m32=-0.5*S31+0.5*S32;\n" " auto m03=-S01+S03;\n" " auto m13=-S11+S13;\n" " auto m23=-S21+S23;\n" " auto m33=-S31+S33;\n" " int dst_x_origin=pos.z;\n" " int dst_y_origin=cst.unit_width*pos.y+pos.x;\n" " int dst_y_stride=cst.input_shape.z*4;\n" " int dst_y=dst_y_origin/4;\n" " int dst_x=dst_y_origin % 4+4*dst_x_origin;\n" " int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n" " int stride=src_height*dst_y_stride;\n" " // [mSrcUnit*mSrcUnit,UP_DIV(uw*wh,4),UP_DIV(ci,4),(uw*wh)_4,ci_4]\n" " auto xy_out=out+dst_y*dst_y_stride+dst_x;\n" " *xy_out=+m00-m20;\n" " xy_out += stride; *xy_out=+0.5*m10+0.5*m20;\n" " xy_out += stride; *xy_out=-0.5*m10+0.5*m20;\n" " xy_out += stride; *xy_out=-m10+m30;\n" " xy_out += stride; *xy_out=+m01-m21;\n" " xy_out += stride; *xy_out=+0.5*m11+0.5*m21;\n" " xy_out += stride; *xy_out=-0.5*m11+0.5*m21;\n" " xy_out += stride; *xy_out=-m11+m31;\n" " xy_out += stride; *xy_out=+m02-m22;\n" " xy_out += stride; *xy_out= +0.5*m12+0.5*m22;\n" " xy_out += stride; *xy_out=-0.5*m12+0.5*m22;\n" " xy_out += stride; *xy_out=-m12+m32;\n" " xy_out += stride; *xy_out=+m03-m23;\n" " xy_out += stride; *xy_out=+0.5*m13+0.5*m23;\n" " xy_out += stride; *xy_out=-0.5*m13+0.5*m23;\n" " xy_out += stride; *xy_out=-m13+m33;\n" " }\n" "}\n" "static inline void set_output(constant winograd_constants &cst,device M4 *output,int x,int y,M4 V) {\n" " output[y*cst.output_shape.x+x]=activate(V,cst.activation);\n" "}\n" "kernel void winograd_transform_dest2_5_1(const device M4 *in [[buffer(0)]],\n" " const device M4 *biasTerms [[buffer(1)]],\n" " device M4 *out [[buffer(2)]],\n" " constant winograd_constants &cst [[buffer(3)]],\n" " constant int &batch_idx [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " auto pos=int3(gid);\n" " if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n" " int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n" " int dst_x_origin=cst.unit_width*pos.y+pos.x;\n" " int dst_x=dst_x_origin/4;\n" " int dst_y=4*pos.z+dst_x_origin % 4;\n" " int dst_y_stride=dst_w*36;\n" " auto xy_in=in+dst_y*dst_y_stride+dst_x;\n" " auto S00=*xy_in; xy_in += dst_w;\n" " auto S10=*xy_in; xy_in += dst_w;\n" " auto S20=*xy_in; xy_in += dst_w;\n" " auto S30=*xy_in; xy_in += dst_w;\n" " auto S40=*xy_in; xy_in += dst_w;\n" " auto S50=*xy_in; xy_in += dst_w;\n" " auto S01=*xy_in; xy_in += dst_w;\n" " auto S11=*xy_in; xy_in += dst_w;\n" " auto S21=*xy_in; xy_in += dst_w;\n" " auto S31=*xy_in; xy_in += dst_w;\n" " auto S41=*xy_in; xy_in += dst_w;\n" " auto S51=*xy_in; xy_in += dst_w;\n" " auto S02=*xy_in; xy_in += dst_w;\n" " auto S12=*xy_in; xy_in += dst_w;\n" " auto S22=*xy_in; xy_in += dst_w;\n" " auto S32=*xy_in; xy_in += dst_w;\n" " auto S42=*xy_in; xy_in += dst_w;\n" " auto S52=*xy_in; xy_in += dst_w;\n" " auto S03=*xy_in; xy_in += dst_w;\n" " auto S13=*xy_in; xy_in += dst_w;\n" " auto S23=*xy_in; xy_in += dst_w;\n" " auto S33=*xy_in; xy_in += dst_w;\n" " auto S43=*xy_in; xy_in += dst_w;\n" " auto S53=*xy_in; xy_in += dst_w;\n" " auto S04=*xy_in; xy_in += dst_w;\n" " auto S14=*xy_in; xy_in += dst_w;\n" " auto S24=*xy_in; xy_in += dst_w;\n" " auto S34=*xy_in; xy_in += dst_w;\n" " auto S44=*xy_in; xy_in += dst_w;\n" " auto S54=*xy_in; xy_in += dst_w;\n" " auto S05=*xy_in; xy_in += dst_w;\n" " auto S15=*xy_in; xy_in += dst_w;\n" " auto S25=*xy_in; xy_in += dst_w;\n" " auto S35=*xy_in; xy_in += dst_w;\n" " auto S45=*xy_in; xy_in += dst_w;\n" " auto S55=*xy_in;\n" " auto m00=+S00+S01+S02+S03+S04;\n" " auto m10=+S10+S11+S12+S13+S14;\n" " auto m20=+S20+S21+S22+S23+S24;\n" " auto m30=+S30+S31+S32+S33+S34;\n" " auto m40=+S40+S41+S42+S43+S44;\n" " auto m50=+S50+S51+S52+S53+S54;\n" " auto m01=+S01-S02+2.0*S03-2.0*S04+S05;\n" " auto m11=+S11-S12+2.0*S13-2.0*S14+S15;\n" " auto m21=+S21-S22+2.0*S23-2.0*S24+S25;\n" " auto m31=+S31-S32+2.0*S33-2.0*S34+S35;\n" " auto m41=+S41-S42+2.0*S43-2.0*S44+S45;\n" " auto m51=+S51-S52+2.0*S53-2.0*S54+S55;\n" " // write output\n" " auto b4=biasTerms[int(pos.z)];\n" " int oy=pos.y*cst.unit;\n" " int ox=pos.x*cst.unit;\n" " auto z_out=out+(pos.z*cst.output_shape.w+batch_idx)*cst.output_shape.x*cst.output_shape.y;\n" " \n" " /* if true */ {\n" " set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20+m30+m40);\n" " }\n" " if (ox+1<cst.output_shape.x) {\n" " set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+2.0*m30-2.0*m40+m50);\n" " }\n" " if (oy+1<cst.output_shape.y) {\n" " set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21+m31+m41);\n" " }\n" " if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n" " set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+2.0*m31-2.0*m41+m51);\n" " }\n" " }\n" "}\n" "kernel void winograd_transform_dest2_3_1(const device M4 *in [[buffer(0)]],\n" " const device M4 *biasTerms [[buffer(1)]],\n" " device M4 *out [[buffer(2)]],\n" " constant winograd_constants &cst [[buffer(3)]],\n" " constant int &batch_idx [[buffer(4)]],\n" " constant int &split_idx [[buffer(5)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " auto pos=int3(gid);\n" " if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n" " int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n" " int dst_x_origin=cst.unit_width*pos.y+pos.x;\n" " int dst_x=dst_x_origin/4;\n" " int dst_y=4*pos.z+dst_x_origin % 4;\n" " int dst_y_stride=dst_w*16;\n" " auto xy_in=in+dst_y*dst_y_stride+dst_x;\n" " auto S00=*xy_in; xy_in += dst_w;\n" " auto S10=*xy_in; xy_in += dst_w;\n" " auto S20=*xy_in; xy_in += dst_w;\n" " auto S30=*xy_in; xy_in += dst_w;\n" " auto S01=*xy_in; xy_in += dst_w;\n" " auto S11=*xy_in; xy_in += dst_w;\n" " auto S21=*xy_in; xy_in += dst_w;\n" " auto S31=*xy_in; xy_in += dst_w;\n" " auto S02=*xy_in; xy_in += dst_w;\n" " auto S12=*xy_in; xy_in += dst_w;\n" " auto S22=*xy_in; xy_in += dst_w;\n" " auto S32=*xy_in; xy_in += dst_w;\n" " auto S03=*xy_in; xy_in += dst_w;\n" " auto S13=*xy_in; xy_in += dst_w;\n" " auto S23=*xy_in; xy_in += dst_w;\n" " auto S33=*xy_in;\n" " auto m00=+S00+S01+S02;\n" " auto m10=+S10+S11+S12;\n" " auto m20=+S20+S21+S22;\n" " auto m30=+S30+S31+S32;\n" " auto m01=+S01-S02+S03;\n" " auto m11=+S11-S12+S13;\n" " auto m21=+S21-S22+S23;\n" " auto m31=+S31-S32+S33;\n" " // write output\n" " auto b4=biasTerms[int(pos.z)];\n" " int oy=(pos.y+split_idx*cst.unit_height)*cst.unit;\n" " int ox=pos.x*cst.unit;\n" " auto z_out=out+(pos.z*cst.output_shape.w+batch_idx)*cst.output_shape.x*cst.output_shape.y;\n" " \n" " /* if true */ {\n" " set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20);\n" " }\n" " if (ox+1<cst.output_shape.x) {\n" " set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+m30);\n" " }\n" " if (oy+1<cst.output_shape.y) {\n" " set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21);\n" " }\n" " if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n" " set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+m31);\n" " }\n" " }\n" "}\n" ; const char* shader_MetalMatMul_metal = "struct matmul_shape {\n" " int4 mat_size;\n" " int4 in_stride;\n" "};\n" "kernel void matmul(const device M *in0 [[buffer(0)]],\n" " const device M *in1 [[buffer(1)]],\n" " device M *out [[buffer(2)]],\n" " constant matmul_shape &s [[buffer(3)]],\n" " uint2 gid[[thread_position_in_grid]]) {\n" " if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n" " auto off_in0=in0+int(gid.y)*s.in_stride.x;\n" " auto off_in1=in1+int(gid.x)*s.in_stride.z;\n" " FLOAT V=0.f;\n" " for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n" " V += FLOAT(*off_in0)*FLOAT(*off_in1);\n" " }\n" " out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V);\n" " }\n" "}\n" "kernel void matmul_bias(const device M *in0 [[buffer(0)]],\n" " const device M *in1 [[buffer(1)]],\n" " const device M *biasValue [[buffer(2)]],\n" " device M *out [[buffer(3)]],\n" " constant matmul_shape &s [[buffer(4)]],\n" " uint2 gid[[thread_position_in_grid]]) {\n" " if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n" " auto off_in0=in0+int(gid.y)*s.in_stride.x;\n" " auto off_in1=in1+int(gid.x)*s.in_stride.z;\n" " FLOAT V=0.f;\n" " for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n" " V += FLOAT(*off_in0)*FLOAT(*off_in1);\n" " }\n" " out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V)+biasValue[(int)(gid.x)];\n" " }\n" "}\n" ; const char* shader_MetalScale_metal = "struct scale_shape {\n" " int size;\n" " int steps;\n" " int batch;\n" " int offset;\n" "};\n" "kernel void scale_ca(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant scale_shape &s [[buffer(2)]],\n" " const device float4 *scalesbias[[buffer(3)]],\n" " uint2 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n" " int z=gid.y/s.batch;\n" " int offset=s.offset;\n" " float4 scale=scalesbias[z];\n" " float4 bias=scalesbias[z+offset];\n" " out[int(gid.y)*s.size+int(gid.x)] =\n" " (M4)((float4)in[int(gid.y)*s.size+int(gid.x)]*scale+bias);\n" "}\n" ; const char* shader_MetalDeconvolution_metal = "struct deconv_constants {\n" " int input_width;\n" " int input_height;\n" " int input_size;\n" " int input_slice;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int output_slice;\n" " int kernel_x;\n" " int kernel_y;\n" " int kernel_size;\n" " int stride_x;\n" " int stride_y;\n" " int pad_x;\n" " int pad_y;\n" " int dilation_x;\n" " int dilation_y;\n" " int delta_ky;\n" " int delta_kx;\n" " int delta_iy;\n" " int delta_ix;\n" " int batch;\n" " conv_activation_type activation;\n" "};\n" "kernel void deconv(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant deconv_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n" " \n" " int b=gid.z % cst.batch;\n" " int o=gid.z/cst.batch;\n" " FLOAT4 result=FLOAT4(biasTerms[o]);\n" " int oy=(int)gid.y+cst.pad_y;\n" " int ox=(int)gid.x+cst.pad_x;\n" " int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n" " int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n" " int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n" " int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n" " \n" " if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n" " int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n" " int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n" " int max_ky=(oy-min_sy)/cst.dilation_y;\n" " int max_kx=(ox-min_sx)/cst.dilation_x;\n" " int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n" " int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n" " \n" " auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n" " auto b_in=in+b*cst.input_size;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n" " for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n" " auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n" " auto in4=b_in[z*cst.input_size*cst.batch+iy*cst.input_width+ix];\n" " result += FLOAT4(in4*wt4);\n" " }\n" " }\n" " }\n" " }\n" " out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n" "}\n" "kernel void deconv_depthwise(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant deconv_constants& cst [[buffer(2)]],\n" " const device M4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n" " int oz=(int)gid.z/cst.batch;\n" " FLOAT4 result=FLOAT4(biasTerms[oz]);\n" " \n" " int oy=(int)gid.y+cst.pad_y;\n" " int ox=(int)gid.x+cst.pad_x;\n" " int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n" " int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n" " int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n" " int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n" " \n" " if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n" " int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n" " int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n" " int max_ky=(oy-min_sy)/cst.dilation_y;\n" " int max_kx=(ox-min_sx)/cst.dilation_x;\n" " int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n" " int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n" " \n" " auto z_wt=wt+oz*cst.kernel_size;\n" " auto z_in=in+(int)gid.z*cst.input_size;\n" " for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n" " for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n" " auto wt4=z_wt[ky*cst.kernel_x+kx];\n" " auto in4=z_in[iy*cst.input_width+ix];\n" " result += FLOAT4(in4*wt4);\n" " }\n" " }\n" " }\n" " out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n" "}\n" ; const char* shader_MetalPooling_metal = "struct pooling_sizes {\n" " int input_width;\n" " int input_height;\n" " int output_width;\n" " int output_height;\n" " int slice;\n" " int kernel_width;\n" " int kernel_height;\n" " int stride_width;\n" " int stride_height;\n" " int pad_width;\n" " int pad_height;\n" "};\n" "kernel void pooling_max(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant pooling_sizes& s [[buffer(2)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n" " \n" " int off_x=gid.x*s.stride_width-s.pad_width;\n" " int off_y=gid.y*s.stride_height-s.pad_height;\n" " int x_max=s.input_width-1;\n" " int y_max=s.input_height-1;\n" " int ex=off_x+s.kernel_width;\n" " int ey=off_y+s.kernel_height;\n" " \n" " auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n" " auto result=M4(z_in[clamp(off_y,0,y_max)*s.input_width+clamp(off_x,0,x_max)]);\n" " for (int y=off_y; y<ey; y++) {\n" " auto y_in=z_in+clamp(y,0,y_max)*s.input_width;\n" " for (int x=off_x; x<ex; x++) {\n" " result=max(result,y_in[clamp(x,0,x_max)]);\n" " }\n" " }\n" " out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=result;\n" "}\n" "kernel void pooling_avg(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant pooling_sizes& s [[buffer(2)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n" " \n" " int off_x=gid.x*s.stride_width-s.pad_width;\n" " int off_y=gid.y*s.stride_height-s.pad_height;\n" " int sx=off_x+max(0,-off_x);\n" " int sy=off_y+max(0,-off_y);\n" " int ex=off_x+min(s.kernel_width,s.input_width-off_x);\n" " int ey=off_y+min(s.kernel_height,s.input_height-off_y);\n" " \n" " FLOAT4 result=0;\n" " auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n" " for (int y=sy; y<ey; y++) {\n" " for (int x=sx; x<ex; x++) {\n" " result += FLOAT4(z_in[y*s.input_width+x]);\n" " }\n" " }\n" " int count=(ey-sy)*(ex-sx);\n" " FLOAT4 div=count>0 ? 1.f/count : 1;\n" " out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=M4(result*div);\n" "}\n" ; const char* shader_MetalROIPooling_metal = "struct ROI_shape {\n" " int input_width;\n" " int input_height;\n" " int input_size;\n" " int input_batch;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int batch;\n" " float spatial_scale;\n" "};\n" "kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n" " const device M *roi [[buffer(1)]],\n" " device M4 *out [[buffer(2)]],\n" " constant ROI_shape &s [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n" " \n" " int ob=gid.z % s.batch;\n" " int iz=gid.z/s.batch;\n" " \n" " auto b_roi=roi+ob*5;\n" " int ib=int(b_roi[0]);\n" " int x1=round(float(b_roi[1])*s.spatial_scale);\n" " int y1=round(float(b_roi[2])*s.spatial_scale);\n" " int x2=round(float(b_roi[3])*s.spatial_scale);\n" " int y2=round(float(b_roi[4])*s.spatial_scale);\n" " \n" " int roi_w=max(x2-x1+1,1);\n" " int roi_h=max(y2-y1+1,1);\n" " float bin_size_w=(float)roi_w/(float)s.output_width;\n" " float bin_size_h=(float)roi_h/(float)s.output_height;\n" " \n" " int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n" " int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n" " int h_start=clamp(y1+(int)floor(gid.y*bin_size_h) ,0,s.input_height);\n" " int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n" " \n" " int is_empty=(h_end <= h_start) || (w_end <= w_start);\n" " auto z_in=in+(ib+iz*s.input_batch)*s.input_size;\n" " auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n" " for (int y=h_start; y<h_end; y++) {\n" " auto y_in=z_in+y*s.input_width;\n" " for (int x=w_start; x<w_end; x++) {\n" " max4=max(max4,y_in[x]);\n" " }\n" " }\n" " out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n" "}\n" ; const char* shader_MetalConvolution1x1_metal = "#define CONV_UNROLL (4)\n" "#define CONV_UNROLL_L (8)\n" "struct conv1x1_constants {\n" " int input_size;\n" " int input_slice;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int output_slice;\n" " int output_channel;\n" " int batch;\n" " int block_size;\n" " conv_activation_type activation;\n" " float scale_coef;\n" "};\n" "kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " \n" " int rx=gid.x*CONV_UNROLL;\n" " int uz=gid.y;\n" " auto xy_wt=wt+uz*cst.input_slice;\n" " auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n" " auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n" " auto biasValue=FLOAT4(biasTerms[uz]);\n" " FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n" " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " auto in40=*xy_in0;\n" " auto in41=*(xy_in0+1);\n" " auto in42=*(xy_in0+2);\n" " auto in43=*(xy_in0+3);\n" " auto w=xy_wt[z];\n" " \n" " result0 += FLOAT4(in40*w);\n" " result1 += FLOAT4(in41*w);\n" " result2 += FLOAT4(in42*w);\n" " result3 += FLOAT4(in43*w);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " \n" " /* true */ *xy_out=activate(M4(result0),cst.activation);\n" " if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n" " if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n" "}\n" "kernel void conv1x1_g1z4_w8(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device MNN::char4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " const device M4 *dequantScale [[buffer(5)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " int rx=gid.x*CONV_UNROLL;\n" " int uz=gid.y;\n" " auto xy_wt=wt+uz*cst.input_slice;\n" " auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n" " auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n" " auto biasValue=FLOAT4(biasTerms[uz]);\n" " FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n" " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n" " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n" " for (int bi=0; bi<cst.block_size; ++bi) {\n" " FLOAT4 bs0=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n" " FLOAT4 bs1=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n" " FLOAT4 scale=bs0;\n" " FLOAT4 dequant_bias=bs1;\n" " int zmin=bi*block;\n" " int zmax=min(zmin+block,cst.input_slice);\n" " for (int z=zmin; z<zmax; z++) {\n" " auto in40=(FLOAT4)*xy_in0;\n" " auto in41=computeSize>1 ? (FLOAT4)*(xy_in0+1) : (FLOAT4)0.0;\n" " auto in42=computeSize>2 ? (FLOAT4)*(xy_in0+2) : (FLOAT4)0.0;\n" " auto in43=computeSize>3 ? (FLOAT4)*(xy_in0+3) : (FLOAT4)0.0;\n" " auto w=xy_wt[z];\n" " FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n" " FLOAT4x4 w_dequant;\n" " for (int i=0; i<4; ++i) {\n" " w_dequant[i]=w_fp32[i]*scale[i]+dequant_bias[i];\n" " }\n" " result0 += FLOAT4(in40*w_dequant);\n" " result1 += FLOAT4(in41*w_dequant);\n" " result2 += FLOAT4(in42*w_dequant);\n" " result3 += FLOAT4(in43*w_dequant);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " }\n" " /* true */ \n" " xy_out[0]=activate(M4(result0),cst.activation);\n" " if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n" " if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n" "}\n" "kernel void conv1x1_g1z4_w4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device MNN::uchar4x2 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " const device M4 *dequantScale [[buffer(5)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " int rx=gid.x*CONV_UNROLL;\n" " int uz=gid.y;\n" " auto xy_wt=wt+uz*cst.input_slice;\n" " auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n" " auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n" " auto biasValue=FLOAT4(biasTerms[uz]);\n" " FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n" " int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n" " int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n" " for (int bi=0; bi<cst.block_size; ++bi) {\n" " FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n" " FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n" " int zmin=bi*block;\n" " int zmax=min(zmin+block,cst.input_slice);\n" " for (int z=zmin; z<zmax; z++) {\n" " auto in40=(FLOAT4)*xy_in0;\n" " auto in41=(FLOAT4)*(xy_in0+1);\n" " auto in42=(FLOAT4)*(xy_in0+2);\n" " auto in43=(FLOAT4)*(xy_in0+3);\n" " MNN::uchar4x2 w_int4=xy_wt[z];\n" " // MNN::char4x4 w_int8(char4(0));\n" " /* weight int4->float */\n" " //FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n" " FLOAT4x4 w_dequant;\n" " for (int i=0; i<4; ++i) {\n" " // M4 w4=M4(w_fp32[i]);\n" " FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n" " FLOAT4 res=w4*scale[i]+dequant_bias[i];\n" " w_dequant[i]=res;\n" " }\n" " result0 += FLOAT4(in40*w_dequant);\n" " result1 += FLOAT4(in41*w_dequant);\n" " result2 += FLOAT4(in42*w_dequant);\n" " result3 += FLOAT4(in43*w_dequant);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " }\n" " \n" " /* true */ \n" " xy_out[0]=activate(M4(result0),cst.activation);\n" " if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n" " if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n" "}\n" "kernel void conv1x1_g1z8(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*CONV_UNROLL_L >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n" " int rx=gid.x*CONV_UNROLL_L;\n" " int uz=gid.y;\n" " auto xy_wt=wt+uz*cst.input_slice;\n" " auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n" " auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.batch*cst.output_size+rx;\n" " auto biasValue=FLOAT4(biasTerms[uz]);\n" " FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n" " FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n" " int computeSize=min(cst.output_size-rx,CONV_UNROLL_L);\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " auto in40=xy_in0[0];\n" " auto in41=xy_in0[1];\n" " auto in42=xy_in0[2];\n" " auto in43=xy_in0[3];\n" " auto in44=xy_in0[4];\n" " auto in45=xy_in0[5];\n" " auto in46=xy_in0[6];\n" " auto in47=xy_in0[7];\n" " auto w=xy_wt[z];\n" " result0 += FLOAT4(in40*w);\n" " result1 += FLOAT4(in41*w);\n" " result2 += FLOAT4(in42*w);\n" " result3 += FLOAT4(in43*w);\n" " result4 += FLOAT4(in44*w);\n" " result5 += FLOAT4(in45*w);\n" " result6 += FLOAT4(in46*w);\n" " result7 += FLOAT4(in47*w);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " /* true */ *xy_out=activate(M4(result0),cst.activation);\n" " if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n" " if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n" " if (computeSize>4) {xy_out[4]=activate(M4(result4),cst.activation); }\n" " if (computeSize>5) {xy_out[5]=activate(M4(result5),cst.activation); }\n" " if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n" " if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n" "}\n" "kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*16 >= cst.output_width || (int)gid.y >= cst.batch*cst.output_slice) return;\n" " int idx_w=gid.x << 4;\n" " int idx_h=0;\n" " int idx_c=gid.y/cst.batch;\n" " int idx_b=gid.y % cst.batch;\n" " auto xy_wt=wt+idx_c*cst.input_slice;\n" " auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n" " auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n" " auto biasValue=FLOAT4(biasTerms[idx_c]);\n" " FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n" " FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n" " FLOAT4 result20=biasValue,result21=biasValue,result22=biasValue,result23=biasValue;\n" " FLOAT4 result30=biasValue,result31=biasValue,result32=biasValue,result33=biasValue;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " auto in00=xy_in0[0];\n" " auto in01=xy_in0[1];\n" " auto in02=xy_in0[2];\n" " auto in03=xy_in0[3];\n" " auto in10=xy_in0[4];\n" " auto in11=xy_in0[5];\n" " auto in12=xy_in0[6];\n" " auto in13=xy_in0[7];\n" " \n" " auto in20=xy_in0[8];\n" " auto in21=xy_in0[9];\n" " auto in22=xy_in0[10];\n" " auto in23=xy_in0[11];\n" " auto in30=xy_in0[12];\n" " auto in31=xy_in0[13];\n" " auto in32=xy_in0[14];\n" " auto in33=xy_in0[15];\n" " auto w=xy_wt[z];\n" " result00 += FLOAT4(in00*w);\n" " result01 += FLOAT4(in01*w);\n" " result02 += FLOAT4(in02*w);\n" " result03 += FLOAT4(in03*w);\n" " result10 += FLOAT4(in10*w);\n" " result11 += FLOAT4(in11*w);\n" " result12 += FLOAT4(in12*w);\n" " result13 += FLOAT4(in13*w);\n" " \n" " result20 += FLOAT4(in20*w);\n" " result21 += FLOAT4(in21*w);\n" " result22 += FLOAT4(in22*w);\n" " result23 += FLOAT4(in23*w);\n" " result30 += FLOAT4(in30*w);\n" " result31 += FLOAT4(in31*w);\n" " result32 += FLOAT4(in32*w);\n" " result33 += FLOAT4(in33*w);\n" " \n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " int widthSize=min(cst.output_width-idx_w,16);\n" " /* true */ *xy_out=activate(M4(result00),cst.activation);\n" " if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n" " if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n" " if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n" " if (widthSize>4) {xy_out[4]=activate(M4(result10),cst.activation); }\n" " if (widthSize>5) {xy_out[5]=activate(M4(result11),cst.activation); }\n" " if (widthSize>6) {xy_out[6]=activate(M4(result12),cst.activation); }\n" " if (widthSize>7) {xy_out[7]=activate(M4(result13),cst.activation); }\n" " if (widthSize>8) {xy_out[8]=activate(M4(result20),cst.activation); }\n" " if (widthSize>9) {xy_out[9]=activate(M4(result21),cst.activation); }\n" " if (widthSize>10) {xy_out[10]=activate(M4(result22),cst.activation); }\n" " if (widthSize>11) {xy_out[11]=activate(M4(result23),cst.activation); }\n" " if (widthSize>12) {xy_out[12]=activate(M4(result30),cst.activation); }\n" " if (widthSize>13) {xy_out[13]=activate(M4(result31),cst.activation); }\n" " if (widthSize>14) {xy_out[14]=activate(M4(result32),cst.activation); }\n" " if (widthSize>15) {xy_out[15]=activate(M4(result33),cst.activation); }\n" "}\n" "kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n" " int channel_pack=(cst.output_channel+7) >> 3;\n" " int idx_w=gid.x << 1;\n" " int idx_h=0;\n" " int idx_c=(gid.y % channel_pack) << 1;\n" " int idx_b=gid.y/channel_pack;\n" " \n" " if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n" " auto xy_wt=wt+idx_c*cst.input_slice;\n" " auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n" " auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n" " auto biasValue0=FLOAT4(biasTerms[idx_c]);\n" " auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n" " FLOAT4 result0=biasValue0,result1=biasValue0;\n" " FLOAT4 result4=biasValue1,result5=biasValue1;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " auto in40=xy_in0[0];\n" " auto in41=xy_in0[1];\n" " auto w0=xy_wt[z];\n" " auto w1=xy_wt[cst.input_slice+z];\n" " result0 += FLOAT4(in40*w0);\n" " result1 += FLOAT4(in41*w0);\n" " result4 += FLOAT4(in40*w1);\n" " result5 += FLOAT4(in41*w1);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " int widthSize=min(cst.output_width-idx_w,2);\n" " /* true */ *xy_out=activate(M4(result0),cst.activation);\n" " if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " \n" " int channelSize=min(cst.output_slice-idx_c,2);\n" " if(channelSize>1) {\n" " /* true */ {xy_out[cst.output_size*cst.batch +0]=activate(M4(result4),cst.activation); }\n" " if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result5),cst.activation); }\n" " }\n" "}\n" "kernel void conv1x1_w4c2(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant conv1x1_constants& cst [[buffer(2)]],\n" " const device M4x4 *wt [[buffer(3)]],\n" " const device M4 *biasTerms [[buffer(4)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n" " int channel_pack=(cst.output_channel+7) >> 3;\n" " int idx_w=gid.x << 2;\n" " int idx_h=0;\n" " int idx_c=(gid.y % channel_pack) << 1;\n" " int idx_b=gid.y/channel_pack;\n" " if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n" " auto xy_wt=wt+idx_c*cst.input_slice;\n" " auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n" " auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n" " auto biasValue0=FLOAT4(biasTerms[idx_c]);\n" " auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n" " FLOAT4 result0=biasValue0,result1=biasValue0;\n" " FLOAT4 result4=biasValue0,result5=biasValue0;\n" " FLOAT4 result2=biasValue1,result3=biasValue1;\n" " FLOAT4 result6=biasValue1,result7=biasValue1;\n" " for (auto z=0; z<cst.input_slice; z++) {\n" " auto in40=xy_in0[0];\n" " auto in41=xy_in0[1];\n" " auto in44=xy_in0[2];\n" " auto in45=xy_in0[3];\n" " auto w0=xy_wt[z];\n" " auto w1=xy_wt[cst.input_slice+z];\n" " result0 += FLOAT4(in40*w0);\n" " result1 += FLOAT4(in41*w0);\n" " result4 += FLOAT4(in44*w0);\n" " result5 += FLOAT4(in45*w0);\n" " result2 += FLOAT4(in40*w1);\n" " result3 += FLOAT4(in41*w1);\n" " result6 += FLOAT4(in44*w1);\n" " result7 += FLOAT4(in45*w1);\n" " xy_in0 += cst.input_size*cst.batch;\n" " }\n" " int widthSize=min(cst.output_width-idx_w,4);\n" " /* true */ *xy_out=activate(M4(result0),cst.activation);\n" " if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n" " if (widthSize>2) {xy_out[2]=activate(M4(result4),cst.activation); }\n" " if (widthSize>3) {xy_out[3]=activate(M4(result5),cst.activation); }\n" " \n" " int channelSize=min(cst.output_slice-idx_c,2);\n" " if(channelSize>1) {\n" " /* true */ xy_out[cst.output_size*cst.batch]=activate(M4(result2),cst.activation);\n" " if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result3),cst.activation); }\n" " if (widthSize>2) {xy_out[cst.output_size*cst.batch +2]=activate(M4(result6),cst.activation); }\n" " if (widthSize>3) {xy_out[cst.output_size*cst.batch +3]=activate(M4(result7),cst.activation); }\n" " }\n" "}\n" ; const char* shader_MetalConvolutionGEMM_metal = "struct matmul4x4_const {\n" " int output_width;\n" " int output_height;\n" " int multi_length;\n" " int group;\n" "};\n" "template <typename IType,typename OType>\n" "static inline void matmul4x4_template(const device IType *in,\n" " device OType *out,\n" " const device IType *kt,\n" " constant matmul4x4_const &cst,\n" " uint3 gid) {\n" " if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height) {\n" " auto ky=(int)gid.y+(int)gid.z*cst.output_height;\n" " auto iy=(int)gid.x+(int)gid.z*cst.output_width;\n" " auto off_in=in+iy*cst.multi_length;\n" " auto off_wt=kt+ky*cst.multi_length;\n" " auto off_out=out+iy+4*(int)gid.y*cst.output_width*cst.group;\n" " \n" " FLOAT4 result0=0,result1=0,result2=0,result3=0;\n" " for (int k=0; k<cst.multi_length; ++k) {\n" " auto w4x4=off_wt[k];\n" " auto i4x4=off_in[k];\n" " result0 += FLOAT4(w4x4*i4x4[0]);\n" " result1 += FLOAT4(w4x4*i4x4[1]);\n" " result2 += FLOAT4(w4x4*i4x4[2]);\n" " result3 += FLOAT4(w4x4*i4x4[3]);\n" " }\n" " *off_out=OType(result0); off_out += cst.output_width*cst.group;\n" " *off_out=OType(result1); off_out += cst.output_width*cst.group;\n" " *off_out=OType(result2); off_out += cst.output_width*cst.group;\n" " *off_out=OType(result3);\n" " }\n" "}\n" "kernel void matmul4x4(const device M4x4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " const device M4x4 *kt [[buffer(2)]],\n" " constant matmul4x4_const &cst [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " matmul4x4_template<M4x4,M4>(in,out,kt,cst,gid);\n" "}\n" ; const char* shader_MetalResize_metal = "struct resize_shape {\n" " int input_width;\n" " int input_height;\n" " int input_size;\n" " int output_width;\n" " int output_height;\n" " int output_size;\n" " int sliceMap;\n" "};\n" "kernel void resize_nearest(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant resize_shape &c [[buffer(2)]],\n" " constant float4& s [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n" " \n" " float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n" " int left=floor(srcX);\n" " int top=floor(srcY);\n" " \n" " auto in_z=in+gid.z*c.input_size;\n" " auto in_top=in_z+top*c.input_width;\n" " out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=in_top[left];\n" "}\n" "kernel void resize_bilinear(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant resize_shape &c [[buffer(2)]],\n" " constant float4& s [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n" " \n" " float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n" " int srcXInt=int(floor(srcX));\n" " int srcYInt=int(floor(srcY));\n" " int left=clamp(srcXInt,0,c.input_width-1);\n" " int right=clamp(srcXInt+1,0,c.input_width-1);\n" " int top=clamp(srcYInt,0,c.input_height-1);\n" " int bottom=clamp(srcYInt+1,0,c.input_height-1);\n" " float x2_factor=srcX-float(srcXInt);\n" " float y2_factor=srcY-float(srcYInt);\n" " float x1_factor=1-x2_factor;\n" " float y1_factor=1-y2_factor;\n" " \n" " auto in_z=in+gid.z*c.input_size;\n" " auto in_top=in_z+top*c.input_width;\n" " auto in_bottom=in_z+bottom*c.input_width;\n" " auto tl=float4(in_top[left])*x1_factor*y1_factor;\n" " auto tr=float4(in_top[right])*x2_factor*y1_factor;\n" " auto bl=float4(in_bottom[left])*x1_factor*y2_factor;\n" " auto br=float4(in_bottom[right])*x2_factor*y2_factor;\n" " out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=M4(tl+tr+bl+br);\n" "}\n" "static inline float4 resize_cubic_interpolation(float4 A,float4 B,float4 C,float4 D,float factor) {\n" " float4 a=(B-C)+0.5f*(B-A)+(D-C)*0.5f;\n" " float4 b=C-((B-A)+(B-C))-(B+D)*0.5f;\n" " float4 c=(C-A)*0.5f;\n" " float4 d=B;\n" " return ((a*factor+b)*factor+c)*factor+d;\n" "}\n" "kernel void resize_cubic(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant resize_shape &c [[buffer(2)]],\n" " constant float4& s [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) {\n" " if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n" " float x=gid.x*s.x+s.y,y=gid.y*s.z+s.w;\n" " \n" " float x_factor=x-floor(x);\n" " float y_factor=y-floor(y);\n" " \n" " int4 xp=int4(int(x)-1,int(x)+0,int(x)+1,int(x)+2);\n" " xp=clamp(xp,0,c.input_width-1);\n" " \n" " int4 yp=int4(int(y)-1,int(y)+0,int(y)+1,int(y)+2);\n" " yp=clamp(yp,0,c.input_height-1);\n" " \n" " auto in_z=in+gid.z*c.input_size;\n" " float4x4 ABCD;\n" " for (int i=0; i<4; i++) {\n" " auto in_y=in_z+yp[i]*c.input_width;\n" " float4 A=float4(in_y[xp[0]]);\n" " float4 B=float4(in_y[xp[1]]);\n" " float4 C=float4(in_y[xp[2]]);\n" " float4 D=float4(in_y[xp[3]]);\n" " ABCD[i]=resize_cubic_interpolation(A,B,C,D,x_factor);\n" " }\n" " \n" " auto val=M4(resize_cubic_interpolation(ABCD[0],ABCD[1],ABCD[2],ABCD[3],y_factor));\n" " out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=val;\n" "}\n" ; const char* shader_MetalPReLU_metal = "struct prelu_shape {\n" " int size;\n" " int slice;\n" " int batch;\n" "};\n" "kernel void prelu(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " constant float &slope [[buffer(2)]],\n" " uint gid [[thread_position_in_grid]]) {\n" " auto v4=in[int(gid)];\n" " out[int(gid)]=select(v4,M4(slope)*v4,signbit(v4));\n" "}\n" "kernel void prelu_slopes(const device M4 *in [[buffer(0)]],\n" " device M4 *out [[buffer(1)]],\n" " const device float4 *slope [[buffer(2)]],\n" " constant prelu_shape& s [[buffer(3)]],\n" " uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n" " if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n" " \n" " int z=gid.z+gid.y*s.batch;\n" " auto v4=in[z*s.size+int(gid.x)];\n" " out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n" "}\n" ; const char* shader_MetalDefine_metal = "using namespace metal;\n" "// –––––––––––––––––––––––––––––––––––––––––––––––––––\n" "// Macro\n" "// –––––––––––––––––––––––––––––––––––––––––––––––––––\n" "#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n" "#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n" "// whether computer with float32 when store with float16\n" "#define MNN_METAL_FLOAT32_COMPUTER 1 //\n" "#if MNN_METAL_FULL_PRECISION\n" "typedef float M;\n" "typedef float2 M2;\n" "typedef float3 M3;\n" "typedef float4 M4;\n" "typedef float2x2 M2x2;\n" "typedef float2x3 M2x3;\n" "typedef float2x4 M2x4;\n" "typedef float3x2 M3x2;\n" "typedef float3x3 M3x3;\n" "typedef float3x4 M3x4;\n" "typedef float4x2 M4x2;\n" "typedef float4x3 M4x3;\n" "typedef float4x4 M4x4;\n" "#else\n" "typedef half M;\n" "typedef half2 M2;\n" "typedef half3 M3;\n" "typedef half4 M4;\n" "typedef half2x2 M2x2;\n" "typedef half2x3 M2x3;\n" "typedef half2x4 M2x4;\n" "typedef half3x2 M3x2;\n" "typedef half3x3 M3x3;\n" "typedef half3x4 M3x4;\n" "typedef half4x2 M4x2;\n" "typedef half4x3 M4x3;\n" "typedef half4x4 M4x4;\n" "#endif\n" "#if MNN_METAL_FLOAT32_COMPUTER\n" "typedef float FLOAT;\n" "typedef float2 FLOAT2;\n" "typedef float3 FLOAT3;\n" "typedef float4 FLOAT4;\n" "typedef float2x2 FLOAT2x2;\n" "typedef float2x3 FLOAT2x3;\n" "typedef float2x4 FLOAT2x4;\n" "typedef float3x2 FLOAT3x2;\n" "typedef float3x3 FLOAT3x3;\n" "typedef float3x4 FLOAT3x4;\n" "typedef float4x2 FLOAT4x2;\n" "typedef float4x3 FLOAT4x3;\n" "typedef float4x4 FLOAT4x4;\n" "#else\n" "typedef half FLOAT;\n" "typedef half2 FLOAT2;\n" "typedef half3 FLOAT3;\n" "typedef half4 FLOAT4;\n" "typedef half2x2 FLOAT2x2;\n" "typedef half2x3 FLOAT2x3;\n" "typedef half2x4 FLOAT2x4;\n" "typedef half3x2 FLOAT3x2;\n" "typedef half3x3 FLOAT3x3;\n" "typedef half3x4 FLOAT3x4;\n" "typedef half4x2 FLOAT4x2;\n" "typedef half4x3 FLOAT4x3;\n" "typedef half4x4 FLOAT4x4;\n" "#endif\n" "namespace MNN {\n" " \n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" " // Number Limit\n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" "#define INT8_MAX 127\n" "#define INT8_MIN -128\n" "#define INT16_MAX 32767\n" "#define INT16_MIN -32768\n" "#define INT32_MAX 2147483647\n" "#define INT32_MIN -2147483648\n" "#define UINT8_MAX 255\n" "#define UINT16_MAX 65535\n" "#define UINT32_MAX 4294967295U\n" " \n" " template<typename T> struct num_limits {\n" " static int max() { return 0; };\n" " static int min() { return 0; };\n" " };\n" " template<> struct num_limits<char> {\n" " static int max() { return INT8_MAX; };\n" " static int min() { return INT8_MIN; };\n" " };\n" " template<> struct num_limits<uchar> {\n" " static int max() { return UINT8_MAX; };\n" " static int min() { return 0; };\n" " };\n" " template<> struct num_limits<short> {\n" " static int max() { return INT16_MAX; };\n" " static int min() { return INT16_MIN; };\n" " };\n" " template<> struct num_limits<ushort> {\n" " static int max() { return UINT16_MAX; };\n" " static int min() { return 0; };\n" " };\n" " template<> struct num_limits<int> {\n" " static int max() { return INT32_MAX; };\n" " static int min() { return INT32_MIN; };\n" " };\n" " template<> struct num_limits<uint> {\n" " static int max() { return UINT32_MAX; };\n" " static int min() { return 0; };\n" " };\n" " \n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" " // Function\n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" " inline int dot(int4 i4,int4 w4) {\n" " return i4[0]*w4[0]+i4[1]*w4[1]+i4[2]*w4[2]+i4[3]*w4[3];\n" " }\n" " \n" " template <typename T>\n" " inline T saturate_round_x2_high_mul(T a,int b) {\n" " return mulhi(a,b)*2;\n" " }\n" " \n" " template <typename T>\n" " inline T round_divide_by_pot(T x,int exponent) {\n" " int mask=(1 << exponent)-1;\n" " T remainder=x & mask;\n" " T threshold=(mask >> 1)+T(x<0);\n" " return (x >> exponent)+T(remainder>threshold);\n" " }\n" " \n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" " // Typedef\n" " // –––––––––––––––––––––––––––––––––––––––––––––––––––\n" " \n" " typedef struct short4x4 {\n" " private:\n" " short4 v[4];\n" " public:\n" " short4x4(short4 a) {\n" " v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n" " }\n" " short4x4(short4 a,short4 b,short4 c,short4 d) {\n" " v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n" " }\n" " \n" " inline thread short4& operator[] (const int index) {\n" " return v[index];\n" " }\n" " inline device short4& operator[] (const int index) device {\n" " return v[index];\n" " }\n" " inline threadgroup short4& operator[] (const int index) threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline const thread short4& operator[] (const int index) const {\n" " return v[index];\n" " }\n" " inline const device short4& operator[] (const int index) const device {\n" " return v[index];\n" " }\n" " inline const threadgroup short4& operator[] (const int index) const threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline explicit operator half4x4() const {\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " inline explicit operator half4x4() const device{\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " inline explicit operator half4x4() const threadgroup {\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " \n" " inline explicit operator float4x4() const {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " inline explicit operator float4x4() const device {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " inline explicit operator float4x4() const threadgroup {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " } short4x4;\n" " \n" " typedef struct char4x4 {\n" " private:\n" " char4 v[4];\n" " public:\n" " char4x4(char4 a) {\n" " v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n" " }\n" " char4x4(char4 a,char4 b,char4 c,char4 d) {\n" " v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n" " }\n" " \n" " inline thread char4& operator[] (const int index) {\n" " return v[index];\n" " }\n" " inline device char4& operator[] (const int index) device {\n" " return v[index];\n" " }\n" " inline threadgroup char4& operator[] (const int index) threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline const thread char4& operator[] (const int index) const {\n" " return v[index];\n" " }\n" " inline const device char4& operator[] (const int index) const device {\n" " return v[index];\n" " }\n" " inline const threadgroup char4& operator[] (const int index) const threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline explicit operator half4x4() const {\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " inline explicit operator half4x4() const device {\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " inline explicit operator half4x4() const threadgroup {\n" " return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n" " }\n" " \n" " inline explicit operator float4x4() const {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " inline explicit operator float4x4() const device {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " inline explicit operator float4x4() const threadgroup {\n" " return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n" " }\n" " } char4x4;\n" " typedef struct char4x2 {\n" " private:\n" " char2 v[4];\n" " public:\n" " char4x2(char2 a) {\n" " v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n" " }\n" " char4x2(char2 a,char2 b,char2 c,char2 d) {\n" " v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n" " }\n" " \n" " inline thread char2& operator[] (const int index) {\n" " return v[index];\n" " }\n" " inline device char2& operator[] (const int index) device {\n" " return v[index];\n" " }\n" " inline threadgroup char2& operator[] (const int index) threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline const thread char2& operator[] (const int index) const {\n" " return v[index];\n" " }\n" " inline const device char2& operator[] (const int index) const device {\n" " return v[index];\n" " }\n" " inline const threadgroup char2& operator[] (const int index) const threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline explicit operator half4x2() const {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " inline explicit operator half4x2() const device {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " inline explicit operator half4x2() const threadgroup {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " \n" " inline explicit operator float4x2() const {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " inline explicit operator float4x2() const device {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " inline explicit operator float4x2() const threadgroup {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " } char4x2;\n" " typedef struct uchar4x2 {\n" " private:\n" " uchar2 v[4];\n" " public:\n" " uchar4x2(uchar2 a) {\n" " v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n" " }\n" " uchar4x2(uchar2 a,uchar2 b,uchar2 c,uchar2 d) {\n" " v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n" " }\n" " \n" " inline thread uchar2& operator[] (const int index) {\n" " return v[index];\n" " }\n" " inline device uchar2& operator[] (const int index) device {\n" " return v[index];\n" " }\n" " inline threadgroup uchar2& operator[] (const int index) threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline const thread uchar2& operator[] (const int index) const {\n" " return v[index];\n" " }\n" " inline const device uchar2& operator[] (const int index) const device {\n" " return v[index];\n" " }\n" " inline const threadgroup uchar2& operator[] (const int index) const threadgroup {\n" " return v[index];\n" " }\n" " \n" " inline explicit operator half4x2() const {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " inline explicit operator half4x2() const device {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " inline explicit operator half4x2() const threadgroup {\n" " return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n" " }\n" " \n" " inline explicit operator float4x2() const {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " inline explicit operator float4x2() const device {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " inline explicit operator float4x2() const threadgroup {\n" " return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n" " }\n" " } uchar4x2;\n" "}\n" ; const char* shader_MetalEltwise_metal = "kernel void eltwise_prod(device const M *in0 [[buffer(0)]],\n" " device const M *in1 [[buffer(1)]],\n" " device M *out [[buffer(2)]],\n" " constant int4& shape [[buffer(3)]],\n" " uint gid [[thread_position_in_grid]]) {\n" " if ((int)gid<shape.x) {\n" " out[(int)gid]=in0[(int)gid]*in1[(int)gid];\n" " }\n" "}\n" "kernel void eltwise_max(device const M *in0 [[buffer(0)]],\n" " device const M *in1 [[buffer(1)]],\n" " device M *out [[buffer(2)]],\n" " constant int4& shape [[buffer(3)]],\n" " uint gid [[thread_position_in_grid]]) {\n" " if ((int)gid<shape.x) {\n" " out[(int)gid]=max(in0[(int)gid],in1[(int)gid]);\n" " }\n" "}\n" "kernel void eltwise_add(device const M *in0 [[buffer(0)]],\n" " device const M *in1 [[buffer(1)]],\n" " device M *out [[buffer(2)]],\n" " constant int4& shape [[buffer(3)]],\n" " uint gid [[thread_position_in_grid]]) {\n" " if ((int)gid<shape.x) {\n" " out[(int)gid]=in0[(int)gid]+in1[(int)gid];\n" " }\n" "}\n" ;