source/backend/metal/AllShader.cpp (2,305 lines of code) (raw):
#include "AllShader.hpp"
const char* shader_MetalReLU6_metal =
"struct Param {\n"
" float minV;\n"
" float maxV;\n"
" int size;\n"
" int remain;\n"
"};\n"
"kernel void relu6(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" out[int(gid.x)]=clamp(in[int(gid.x)],(M4)p.minV,(M4)p.maxV);\n"
" }\n"
"}\n"
"kernel void relu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant Param &p [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (gid.x<p.size) {\n"
" auto V=in[int(gid.x)];\n"
" out[int(gid.x)]=fmax(V,(M4)0)+fmin(V,(M4)0)*M4(p.minV);\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionDepthwise_metal =
"struct conv_dw_cst {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int slice;\n"
" int batch;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void conv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_dw_cst& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.slice*cst.batch) return;\n"
" \n"
" int oz=gid.z/cst.batch;\n"
" int offset_x=(int)gid.x*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)gid.y*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" auto z_wt=wt+(int)oz*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" auto z_out=out+(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" FLOAT4 result=FLOAT4(biasTerms[oz]);\n"
" for (auto ky=sy,y=offset_y; ky<ey; ky++,y += cst.dilation_y) {\n"
" for (auto kx=sx,x=offset_x; kx<ex; kx++,x += cst.dilation_x) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[ y*cst.input_width+x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" *z_out=activate((M4)result,cst.activation);\n"
"}\n"
;
const char* shader_MetalConvolutionActivation_metal =
"typedef enum : int {\n"
" None=0,\n"
" ReLU=1,\n"
" ReLU6=2,\n"
"} conv_activation_type;\n"
"inline M4 activate(M4 V,conv_activation_type type) {\n"
" switch (type) {\n"
" case ReLU:\n"
" return max(V,(M4)0);\n"
" case ReLU6:\n"
" return clamp(V,(M4)0,(M4)6);\n"
" default: // None\n"
" return V;\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolution_metal =
"#define CONV_UNROLL (4)\n"
"#define CONV_MUL_PACK_W2(x,y) "
" x += FLOAT4(in00*k00);"
" y += FLOAT4(in01*k00);"
" x += FLOAT4(in01*k01);"
" y += FLOAT4(in02*k01);"
" x += FLOAT4(in02*k02);"
" y += FLOAT4(in03*k02);"
" "
" x += FLOAT4(in10*k10);"
" y += FLOAT4(in11*k10);"
" x += FLOAT4(in11*k11);"
" y += FLOAT4(in12*k11);"
" x += FLOAT4(in12*k12);"
" y += FLOAT4(in13*k12);"
" "
" x += FLOAT4(in20*k20);"
" y += FLOAT4(in21*k20);"
" x += FLOAT4(in21*k21);"
" y += FLOAT4(in22*k21);"
" x += FLOAT4(in22*k22);"
" y += FLOAT4(in23*k22);\n"
" \n"
"#define CONV_NEXT_FLT "
" z_wt += ws; "
" "
" k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];"
" k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];"
" k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
"struct conv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int batch;\n"
" int oz_size;\n"
" int threadgroup_input_slice;\n"
" \n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" conv_activation_type activation; \n"
"};\n"
"kernel void conv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" int offset_x=(int)idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=(int)idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+(int)idx_c*cst.batch*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result=FLOAT4(biasTerms[idx_c]);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" auto in4=z_in[z*cst.input_size*cst.batch+y*dilation_h+x*cst.dilation_x];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" *z_out=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void convk3s1d1p1_w2z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" bool valid_x=(int)(gid.x*2+1)<cst.output_width;\n"
" int offset_x=(int)gid.x*2-cst.pad_x;\n"
" int offset_y=(int)gid.y-cst.pad_y;\n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_flt=wt+uz[0]*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" FLOAT4 result4=0,result5=0,result6=0,result7=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_flt += cst.kernel_size,z_in += (cst.input_size*cst.batch)) {\n"
" auto in00=(offset_x<0 || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+0);\n"
" auto in01=(offset_x+1>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+1);\n"
" auto in02=(offset_x+2>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+2);\n"
" auto in03=(offset_x+3>=cst.input_width || offset_y<0) ? (M4)0.f : *(z_in+0*cst.input_width+3);\n"
" auto in10=(offset_x<0 || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+0);\n"
" auto in11=(offset_x+1>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+1);\n"
" auto in12=(offset_x+2>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+2);\n"
" auto in13=(offset_x+3>=cst.input_width || offset_y+1>=cst.input_height) ? (M4)0.f : *(z_in+1*cst.input_width+3);\n"
" \n"
" auto in20=(offset_x<0 || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+0);\n"
" auto in21=(offset_x+1>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+1);\n"
" auto in22=(offset_x+2>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+2);\n"
" auto in23=(offset_x+3>=cst.input_width || offset_y+2>=cst.input_height) ? (M4)0.f : *(z_in+2*cst.input_width+3);\n"
" \n"
" auto z_wt=z_flt;\n"
" auto k00=z_wt[0],k01=z_wt[1],k02=z_wt[2];\n"
" auto k10=z_wt[3],k11=z_wt[4],k12=z_wt[5];\n"
" auto k20=z_wt[6],k21=z_wt[7],k22=z_wt[8];\n"
" CONV_MUL_PACK_W2(result0,result4);\n"
" if (valids[0]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result1,result5);\n"
" }\n"
" if (valids[1]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result2,result6);\n"
" }\n"
" if (valids[2]) {\n"
" CONV_NEXT_FLT;\n"
" CONV_MUL_PACK_W2(result3,result7);\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result4+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" }\n"
" if (valids[0]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result5+FLOAT4(biasTerms[uz[1]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[1]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result6+FLOAT4(biasTerms[uz[2]])),cst.activation);\n"
" }\n"
" }\n"
" if (valids[2]) {\n"
" z_out += cst.output_size;\n"
" *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation);\n"
" if(valid_x) {\n"
" *(z_out+1)=activate(M4(result7+FLOAT4(biasTerms[uz[3]])),cst.activation);\n"
" }\n"
" }\n"
"}\n"
"kernel void conv_s1d1p0_w2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x << 1;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" bool valid=(idx_w+1<cst.output_width);\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<cst.kernel_y; y++) {\n"
" auto wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x];\n"
" auto in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
" for (auto x=1; x<cst.kernel_x; x++) {\n"
" in4_0=z_in[z*cst.batch*cst.input_size+y*cst.input_width+x];\n"
" result1 += FLOAT4(in4_0*wt4);\n"
" wt4=z_wt[z*cst.kernel_size+y*cst.kernel_x+x];\n"
" result0 += FLOAT4(in4_0*wt4);\n"
" }\n"
" in4_0=z_in[z*cst.input_size*cst.batch+y*cst.input_width+cst.kernel_x];\n"
" result1 += FLOAT4(in4_0*wt4);\n"
" }\n"
" }\n"
" *z_out=activate(M4(result0),cst.activation);\n"
" if(valid) { *(z_out+1)=activate(M4(result1),cst.activation);}\n"
"}\n"
"kernel void conv_s1d1p0_w4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.oz_size) return;\n"
" \n"
" int idx_w=gid.x << 2;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" \n"
" if (idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" int3 uz=idx_w+int3(1,2,3);\n"
" bool3 valids=uz.xyz<cst.output_width;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+idx_h*cst.input_width+idx_w;\n"
" auto z_wt=wt+idx_c*cst.input_slice*cst.kernel_size;\n"
" auto z_out=out+idx_b*cst.output_size+idx_c*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" FLOAT4 result0=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result1=result0;\n"
" FLOAT4 result2=result0;\n"
" FLOAT4 result3=result0;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto y=0; y<cst.kernel_y; y++) {\n"
" auto wt_base=z_wt+z*cst.kernel_size+y*cst.kernel_x;\n"
" auto wt4_0=wt_base[0];\n"
" auto wt4_1=wt_base[1];\n"
" auto wt4_2=wt_base[2];\n"
" auto z_in_base=z_in+z*cst.batch*cst.input_size+y*cst.input_width;\n"
" auto in4_0=z_in_base[0];\n"
" result0 += FLOAT4(in4_0*wt4_0);\n"
" \n"
" in4_0=z_in_base[1];\n"
" result0 += FLOAT4(in4_0*wt4_1);\n"
" result1 += FLOAT4(in4_0*wt4_0);\n"
" in4_0=z_in_base[2];\n"
" result0 += FLOAT4(in4_0*wt4_2);\n"
" result1 += FLOAT4(in4_0*wt4_1);\n"
" result2 += FLOAT4(in4_0*wt4_0);\n"
" in4_0=z_in_base[3];\n"
" result1 += FLOAT4(in4_0*wt4_2);\n"
" result2 += FLOAT4(in4_0*wt4_1);\n"
" result3 += FLOAT4(in4_0*wt4_0);\n"
" \n"
" in4_0=z_in_base[4];\n"
" result2 += FLOAT4(in4_0*wt4_2);\n"
" result3 += FLOAT4(in4_0*wt4_1);\n"
" in4_0=z_in_base[5];\n"
" result3 += FLOAT4(in4_0*wt4_2);\n"
" }\n"
" }\n"
" *z_out=activate(M4(result0),cst.activation);\n"
" if(valids[0]) { *(z_out+1)=activate(M4(result1),cst.activation);}\n"
" if(valids[1]) { *(z_out+2)=activate(M4(result2),cst.activation);}\n"
" if(valids[2]) { *(z_out+3)=activate(M4(result3),cst.activation);}\n"
"}\n"
"kernel void conv_z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c*4 >= cst.output_slice) return;\n"
" int4 uz=idx_c*CONV_UNROLL+int4(0,1,2,3);\n"
" bool3 valids=uz.yzw<cst.output_slice;\n"
" \n"
" int offset_x=idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
" auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n"
" /* true */ result0 += FLOAT4(in4**x_wt);\n"
" if (valids.x) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n"
" if (valids.y) { x_wt += ws; result2 += FLOAT4(in4**x_wt); }\n"
" if (valids.z) { x_wt += ws; result3 += FLOAT4(in4**x_wt); }\n"
" }\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids.x) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
" if (valids.y) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result2+FLOAT4(biasTerms[uz[2]])),cst.activation); }\n"
" if (valids.z) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result3+FLOAT4(biasTerms[uz[3]])),cst.activation); }\n"
"}\n"
"kernel void conv_z2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height) return;\n"
" \n"
" int idx_w=gid.x;\n"
" int idx_h=gid.y;\n"
" int idx_c=gid.z/cst.batch;\n"
" int idx_b=gid.z % cst.batch;\n"
" if (idx_b >= cst.batch || idx_c*2 >= cst.output_slice) return;\n"
" int2 uz=idx_c*2+int2(0,1);\n"
" bool valids=uz.y<cst.output_slice;\n"
" \n"
" int offset_x=idx_w*cst.stride_x-cst.pad_x;\n"
" int offset_y=idx_h*cst.stride_y-cst.pad_y;\n"
" int sx=max(0,(UP_DIV(-offset_x,cst.dilation_x)));\n"
" int ex=min(cst.kernel_x,UP_DIV(cst.input_width-offset_x,cst.dilation_x));\n"
" int kw=ex-sx;\n"
" int sy=max(0,(UP_DIV(-offset_y,cst.dilation_y)));\n"
" int ey=min(cst.kernel_y,UP_DIV(cst.input_height-offset_y,cst.dilation_y));\n"
" int kh=ey-sy;\n"
" offset_x += sx*cst.dilation_x;\n"
" offset_y += sy*cst.dilation_y;\n"
" \n"
" auto z_in=in+idx_b*cst.input_size+offset_y*cst.input_width+offset_x;\n"
" auto z_wt=wt+uz[0]*cst.input_slice*cst.kernel_size+sy*cst.kernel_x+sx;\n"
" auto z_out=out+idx_b*cst.output_size+uz[0]*cst.batch*cst.output_size+idx_h*cst.output_width+idx_w;\n"
" \n"
" int ws=cst.input_slice*cst.kernel_size;\n"
" int dilation_h=cst.input_width*cst.dilation_y;\n"
" FLOAT4 result0=0,result1=0;\n"
" for (auto z=0; z<cst.input_slice; z++,z_wt += cst.kernel_size,z_in += cst.input_size*cst.batch) {\n"
" for (auto y=0; y<kh; y++) {\n"
" for (auto x=0; x<kw; x++) {\n"
" auto x_wt=z_wt+y*cst.kernel_x+x;\n"
" auto in4=z_in[ y*dilation_h+x*cst.dilation_x];\n"
" /* true */ result0 += FLOAT4(in4**x_wt);\n"
" if (valids) { x_wt += ws; result1 += FLOAT4(in4**x_wt); }\n"
" }\n"
" }\n"
" }\n"
" /* true */ *z_out=activate(M4(result0+FLOAT4(biasTerms[uz[0]])),cst.activation);\n"
" if (valids) { z_out += cst.output_size*cst.batch; *z_out=activate(M4(result1+FLOAT4(biasTerms[uz[1]])),cst.activation); }\n"
"}\n"
;
const char* shader_MetalReduction_metal =
"struct reduce_shape {\n"
" int outside_size;\n"
" int axis_size;\n"
" int inside_size;\n"
" int outside_step;\n"
"};\n"
"template <typename M,typename T>\n"
"static inline void reduce_mean(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer/s.axis_size);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_sum(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=0;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer += M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_min(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=min(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_max(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" T summer=*axis_in; axis_in += s.inside_size;\n"
" for (int i=1; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer=max(summer,*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=summer;\n"
"}\n"
"template <typename M,typename T>\n"
"static inline void reduce_prod(const device T *in,device T *out,constant reduce_shape &s,uint2 gid) {\n"
" auto axis_in=in+gid.x*s.outside_step+gid.y;\n"
" M summer=1;\n"
" for (int i=0; i<s.axis_size; i++,axis_in += s.inside_size) {\n"
" summer *= M(*axis_in);\n"
" }\n"
" out[int(gid.x)*s.inside_size+int(gid.y)]=T(summer);\n"
"}\n"
"#define define_reduce(name) "
"kernel void reduce_##name##_f(const device M *in [[buffer(0)]],"
" device M *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<FLOAT,M>(in,out,s,gid); "
"} "
"kernel void reduce_##name##_s(const device int *in [[buffer(0)]],"
" device int *out [[buffer(1)]],"
" constant reduce_shape &s [[buffer(2)]],"
" uint2 gid [[thread_position_in_grid]]) { "
" if (gid.x<(uint)s.outside_size && gid.y<(uint)s.inside_size) reduce_##name<int,int>(in,out,s,gid); "
"}\n"
"define_reduce(mean);\n"
"define_reduce(sum);\n"
"define_reduce(min);\n"
"define_reduce(max);\n"
"define_reduce(prod);\n"
;
const char* shader_MetalSoftmax_metal =
"struct softmax_shape {\n"
" int inside_size;\n"
" int axis_length;\n"
" int outside_size;\n"
" int flat_length;\n"
"};\n"
"static inline float softmax_max4(float4 V) {\n"
" return max(max(V[0],V[1]),max(V[2],V[3]));\n"
"}\n"
"static inline float softmax_sum4(float4 V) {\n"
" return V[0]+V[1]+V[2]+V[3];\n"
"}\n"
"static inline float4 softmax_filter(float4 V,int z,int limit) {\n"
" return select(0,V,z*4+int4(0,1,2,3)<limit);\n"
"}\n"
"kernel void softmax_plane(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" \n"
" // get max\n"
" float max1=-INFINITY;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" max1=max(max1,float(axis_in[i*s.inside_size]));\n"
" }\n"
" \n"
" // get sum\n"
" float sum1=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum1 += float(exp(float(axis_in[i*s.inside_size])-float(max1)));\n"
" }\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M(exp(float(axis_in[i*s.inside_size])-float(max1))/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_on_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" \n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=softmax_filter(float4(axis_in[0]),0,s.flat_length);\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,softmax_filter(float4(axis_in[i*s.inside_size]),i,s.flat_length));\n"
" }\n"
" float max1=softmax_max4(max4);\n"
" \n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += softmax_filter(exp(float4(axis_in[i*s.inside_size]-max1)),i,s.flat_length);\n"
" }\n"
" float sum1=softmax_sum4(sum4);\n"
" \n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size])-max1)/sum1);\n"
" }\n"
"}\n"
"kernel void softmax_off_reorder(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant softmax_shape& s [[buffer(2)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.inside_size || (int)gid.y >= s.outside_size) return;\n"
" auto axis_off=gid.y*s.axis_length*s.inside_size+gid.x;\n"
" auto axis_in=in+axis_off;\n"
" auto axis_out=out+axis_off;\n"
" // get max\n"
" auto max4=axis_in[0];\n"
" for (int i=1; i<s.axis_length; i++) {\n"
" max4=max(max4,axis_in[i*s.inside_size]);\n"
" }\n"
" // get sum\n"
" float4 sum4=0;\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" sum4 += exp(float4(axis_in[i*s.inside_size]-max4));\n"
" }\n"
" // output\n"
" for (int i=0; i<s.axis_length; i++) {\n"
" axis_out[i*s.inside_size]=M4(exp(float4(axis_in[i*s.inside_size]-max4))/sum4);\n"
" }\n"
"}\n"
;
const char* shader_MetalLayerNorm_metal =
"struct layernorm_constants {\n"
" int inside;\n"
" int outside;\n"
" float eps;\n"
" int has_gamma_beta;\n"
"};\n"
"kernel void layernorm_x1(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float *gamma [[buffer(3)]],\n"
" const device float *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside;\n"
" auto out_data=out+gid.y*cst.inside;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" sum += in_data[i];\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" float dis=(in_data[i]-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float norm=var*((float)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside/4;\n"
" auto out_data=out+gid.y*cst.inside/4;\n"
" float mean;\n"
" float sum=0.0f;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" sum += in_data[i].x;\n"
" sum += in_data[i].y;\n"
" sum += in_data[i].z;\n"
" sum += in_data[i].w;\n"
" }\n"
" mean=sum/cst.inside;\n"
" \n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" float dis=(in_data[i].x-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].y-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].z-mean);\n"
" square_sum += dis*dis;\n"
" dis=(in_data[i].w-mean);\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[gid.x]-mean);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M4)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x1_rms(const device M *in [[buffer(0)]],\n"
" device M *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float *gamma [[buffer(3)]],\n"
" const device float *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside;\n"
" auto out_data=out+gid.y*cst.inside;\n"
" float square_sum=0.0f;\n"
" \n"
" for(int i=0; i<cst.inside; i++) {\n"
" float dis=in_data[i];\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float norm=var*((float)in_data[gid.x]);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M)(norm);\n"
" }\n"
"}\n"
"kernel void layernorm_x4_rms(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant layernorm_constants& cst [[buffer(2)]],\n"
" const device float4 *gamma [[buffer(3)]],\n"
" const device float4 *beta [[buffer(4)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.inside/4 || (int)gid.y >= cst.outside) {\n"
" return;\n"
" }\n"
" auto in_data=in+gid.y*cst.inside/4;\n"
" auto out_data=out+gid.y*cst.inside/4;\n"
" float square_sum=0.0f;\n"
" for(int i=0; i<cst.inside/4; i++) {\n"
" float dis=in_data[i].x;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].y;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].z;\n"
" square_sum += dis*dis;\n"
" dis=in_data[i].w;\n"
" square_sum += dis*dis;\n"
" }\n"
" float var=1.0/sqrt(square_sum/cst.inside+cst.eps);\n"
" \n"
" float4 norm=var*((float4)in_data[gid.x]);\n"
" if(cst.has_gamma_beta) {\n"
" out_data[gid.x]=(M4)(norm*gamma[gid.x]+beta[gid.x]);\n"
" } else {\n"
" out_data[gid.x]=(M4)(norm);\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionWinograd_metal =
"struct winograd_constants {\n"
" int4 input_shape;\n"
" int4 output_shape;\n"
" int pad_x;\n"
" int pad_y;\n"
" int unit_width;\n"
" int unit_height;\n"
" int unit;\n"
" conv_activation_type activation;\n"
"};\n"
"static inline M4 get_input(const device M4 *input,int x,int y,constant winograd_constants &cst) {\n"
" return x<cst.input_shape.x && y<cst.input_shape.y && x >= 0 && y >= 0 ? input[x+y*cst.input_shape.x] : 0;\n"
"}\n"
"kernel void winograd_transform_source2_5_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" constant int &batch_idx [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=pos.y*cst.unit-cst.pad_y;\n"
" auto z_in=in+(pos.z*cst.output_shape.w+batch_idx)*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S40=get_input(z_in,ix+4,iy+0,cst);\n"
" auto S50=get_input(z_in,ix+5,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S41=get_input(z_in,ix+4,iy+1,cst);\n"
" auto S51=get_input(z_in,ix+5,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S42=get_input(z_in,ix+4,iy+2,cst);\n"
" auto S52=get_input(z_in,ix+5,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto S43=get_input(z_in,ix+4,iy+3,cst);\n"
" auto S53=get_input(z_in,ix+5,iy+3,cst);\n"
" auto S04=get_input(z_in,ix+0,iy+4,cst);\n"
" auto S14=get_input(z_in,ix+1,iy+4,cst);\n"
" auto S24=get_input(z_in,ix+2,iy+4,cst);\n"
" auto S34=get_input(z_in,ix+3,iy+4,cst);\n"
" auto S44=get_input(z_in,ix+4,iy+4,cst);\n"
" auto S54=get_input(z_in,ix+5,iy+4,cst);\n"
" auto S05=get_input(z_in,ix+0,iy+5,cst);\n"
" auto S15=get_input(z_in,ix+1,iy+5,cst);\n"
" auto S25=get_input(z_in,ix+2,iy+5,cst);\n"
" auto S35=get_input(z_in,ix+3,iy+5,cst);\n"
" auto S45=get_input(z_in,ix+4,iy+5,cst);\n"
" auto S55=get_input(z_in,ix+5,iy+5,cst);\n"
" auto m00=+S00-1.25*S02+0.25*S04;\n"
" auto m10=+S10-1.25*S12+0.25*S14;\n"
" auto m20=+S20-1.25*S22+0.25*S24;\n"
" auto m30=+S30-1.25*S32+0.25*S34;\n"
" auto m40=+S40-1.25*S42+0.25*S44;\n"
" auto m50=+S50-1.25*S52+0.25*S54;\n"
" auto m01=+0.666667*S01+0.666667*S02-0.166667*S03-0.166667*S04;\n"
" auto m11=+0.666667*S11+0.666667*S12-0.166667*S13-0.166667*S14;\n"
" auto m21=+0.666667*S21+0.666667*S22-0.166667*S23-0.166667*S24;\n"
" auto m31=+0.666667*S31+0.666667*S32-0.166667*S33-0.166667*S34;\n"
" auto m41=+0.666667*S41+0.666667*S42-0.166667*S43-0.166667*S44;\n"
" auto m51=+0.666667*S51+0.666667*S52-0.166667*S53-0.166667*S54;\n"
" auto m02=-0.666667*S01+0.666667*S02+0.166667*S03-0.166667*S04;\n"
" auto m12=-0.666667*S11+0.666667*S12+0.166667*S13-0.166667*S14;\n"
" auto m22=-0.666667*S21+0.666667*S22+0.166667*S23-0.166667*S24;\n"
" auto m32=-0.666667*S31+0.666667*S32+0.166667*S33-0.166667*S34;\n"
" auto m42=-0.666667*S41+0.666667*S42+0.166667*S43-0.166667*S44;\n"
" auto m52=-0.666667*S51+0.666667*S52+0.166667*S53-0.166667*S54;\n"
" auto m03=-0.0833333*S01-0.0416667*S02+0.0833333*S03+0.0416667*S04;\n"
" auto m13=-0.0833333*S11-0.0416667*S12+0.0833333*S13+0.0416667*S14;\n"
" auto m23=-0.0833333*S21-0.0416667*S22+0.0833333*S23+0.0416667*S24;\n"
" auto m33=-0.0833333*S31-0.0416667*S32+0.0833333*S33+0.0416667*S34;\n"
" auto m43=-0.0833333*S41-0.0416667*S42+0.0833333*S43+0.0416667*S44;\n"
" auto m53=-0.0833333*S51-0.0416667*S52+0.0833333*S53+0.0416667*S54;\n"
" auto m04=+0.0833333*S01-0.0416667*S02-0.0833333*S03+0.0416667*S04;\n"
" auto m14=+0.0833333*S11-0.0416667*S12-0.0833333*S13+0.0416667*S14;\n"
" auto m24=+0.0833333*S21-0.0416667*S22-0.0833333*S23+0.0416667*S24;\n"
" auto m34=+0.0833333*S31-0.0416667*S32-0.0833333*S33+0.0416667*S34;\n"
" auto m44=+0.0833333*S41-0.0416667*S42-0.0833333*S43+0.0416667*S44;\n"
" auto m54=+0.0833333*S51-0.0416667*S52-0.0833333*S53+0.0416667*S54;\n"
" auto m05=+4.0*S01-5.0*S03+S05;\n"
" auto m15=+4.0*S11-5.0*S13+S15;\n"
" auto m25=+4.0*S21-5.0*S23+S25;\n"
" auto m35=+4.0*S31-5.0*S33+S35;\n"
" auto m45=+4.0*S41-5.0*S43+S45;\n"
" auto m55=+4.0*S51-5.0*S53+S55;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-1.25*m20+0.25*m40;\n"
" xy_out += stride; *xy_out=+0.666667*m10+0.666667*m20-0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.666667*m10+0.666667*m20+0.166667*m30-0.166667*m40;\n"
" xy_out += stride; *xy_out=-0.0833333*m10-0.0416667*m20+0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+0.0833333*m10-0.0416667*m20-0.0833333*m30+0.0416667*m40;\n"
" xy_out += stride; *xy_out=+4.0*m10-5.0*m30+m50;\n"
" xy_out += stride; *xy_out=+m01-1.25*m21+0.25*m41;\n"
" xy_out += stride; *xy_out=+0.666667*m11+0.666667*m21-0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.666667*m11+0.666667*m21+0.166667*m31-0.166667*m41;\n"
" xy_out += stride; *xy_out=-0.0833333*m11-0.0416667*m21+0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+0.0833333*m11-0.0416667*m21-0.0833333*m31+0.0416667*m41;\n"
" xy_out += stride; *xy_out=+4.0*m11-5.0*m31+m51;\n"
" xy_out += stride; *xy_out=+m02-1.25*m22+0.25*m42;\n"
" xy_out += stride; *xy_out=+0.666667*m12+0.666667*m22-0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.666667*m12+0.666667*m22+0.166667*m32-0.166667*m42;\n"
" xy_out += stride; *xy_out=-0.0833333*m12-0.0416667*m22+0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+0.0833333*m12-0.0416667*m22-0.0833333*m32+0.0416667*m42;\n"
" xy_out += stride; *xy_out=+4.0*m12-5.0*m32+m52;\n"
" xy_out += stride; *xy_out=+m03-1.25*m23+0.25*m43;\n"
" xy_out += stride; *xy_out=+0.666667*m13+0.666667*m23-0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.666667*m13+0.666667*m23+0.166667*m33-0.166667*m43;\n"
" xy_out += stride; *xy_out=-0.0833333*m13-0.0416667*m23+0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+0.0833333*m13-0.0416667*m23-0.0833333*m33+0.0416667*m43;\n"
" xy_out += stride; *xy_out=+4.0*m13-5.0*m33+m53;\n"
" xy_out += stride; *xy_out=+m04-1.25*m24+0.25*m44;\n"
" xy_out += stride; *xy_out=+0.666667*m14+0.666667*m24-0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.666667*m14+0.666667*m24+0.166667*m34-0.166667*m44;\n"
" xy_out += stride; *xy_out=-0.0833333*m14-0.0416667*m24+0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+0.0833333*m14-0.0416667*m24-0.0833333*m34+0.0416667*m44;\n"
" xy_out += stride; *xy_out=+4.0*m14-5.0*m34+m54;\n"
" xy_out += stride; *xy_out=+m05-1.25*m25+0.25*m45;\n"
" xy_out += stride; *xy_out=+0.666667*m15+0.666667*m25-0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.666667*m15+0.666667*m25+0.166667*m35-0.166667*m45;\n"
" xy_out += stride; *xy_out=-0.0833333*m15-0.0416667*m25+0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+0.0833333*m15-0.0416667*m25-0.0833333*m35+0.0416667*m45;\n"
" xy_out += stride; *xy_out=+4.0*m15-5.0*m35+m55;\n"
" }\n"
"}\n"
"kernel void winograd_transform_source2_3_1(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant winograd_constants &cst [[buffer(2)]],\n"
" constant int &batch_idx [[buffer(3)]],\n"
" constant int &split_idx [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int ix=pos.x*cst.unit-cst.pad_x;\n"
" int iy=(pos.y+split_idx*cst.unit_height)*cst.unit-cst.pad_y;\n"
" auto z_in=in+(pos.z*cst.output_shape.w+batch_idx)*cst.input_shape.x*cst.input_shape.y;\n"
" auto S00=get_input(z_in,ix+0,iy+0,cst);\n"
" auto S10=get_input(z_in,ix+1,iy+0,cst);\n"
" auto S20=get_input(z_in,ix+2,iy+0,cst);\n"
" auto S30=get_input(z_in,ix+3,iy+0,cst);\n"
" auto S01=get_input(z_in,ix+0,iy+1,cst);\n"
" auto S11=get_input(z_in,ix+1,iy+1,cst);\n"
" auto S21=get_input(z_in,ix+2,iy+1,cst);\n"
" auto S31=get_input(z_in,ix+3,iy+1,cst);\n"
" auto S02=get_input(z_in,ix+0,iy+2,cst);\n"
" auto S12=get_input(z_in,ix+1,iy+2,cst);\n"
" auto S22=get_input(z_in,ix+2,iy+2,cst);\n"
" auto S32=get_input(z_in,ix+3,iy+2,cst);\n"
" auto S03=get_input(z_in,ix+0,iy+3,cst);\n"
" auto S13=get_input(z_in,ix+1,iy+3,cst);\n"
" auto S23=get_input(z_in,ix+2,iy+3,cst);\n"
" auto S33=get_input(z_in,ix+3,iy+3,cst);\n"
" auto m00=+S00-S02;\n"
" auto m10=+S10-S12;\n"
" auto m20=+S20-S22;\n"
" auto m30=+S30-S32;\n"
" auto m01=+0.5*S01+0.5*S02;\n"
" auto m11=+0.5*S11+0.5*S12;\n"
" auto m21=+0.5*S21+0.5*S22;\n"
" auto m31=+0.5*S31+0.5*S32;\n"
" auto m02=-0.5*S01+0.5*S02;\n"
" auto m12=-0.5*S11+0.5*S12;\n"
" auto m22=-0.5*S21+0.5*S22;\n"
" auto m32=-0.5*S31+0.5*S32;\n"
" auto m03=-S01+S03;\n"
" auto m13=-S11+S13;\n"
" auto m23=-S21+S23;\n"
" auto m33=-S31+S33;\n"
" int dst_x_origin=pos.z;\n"
" int dst_y_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_y_stride=cst.input_shape.z*4;\n"
" int dst_y=dst_y_origin/4;\n"
" int dst_x=dst_y_origin % 4+4*dst_x_origin;\n"
" int src_height=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int stride=src_height*dst_y_stride;\n"
" // [mSrcUnit*mSrcUnit,UP_DIV(uw*wh,4),UP_DIV(ci,4),(uw*wh)_4,ci_4]\n"
" auto xy_out=out+dst_y*dst_y_stride+dst_x;\n"
" *xy_out=+m00-m20;\n"
" xy_out += stride; *xy_out=+0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-0.5*m10+0.5*m20;\n"
" xy_out += stride; *xy_out=-m10+m30;\n"
" xy_out += stride; *xy_out=+m01-m21;\n"
" xy_out += stride; *xy_out=+0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-0.5*m11+0.5*m21;\n"
" xy_out += stride; *xy_out=-m11+m31;\n"
" xy_out += stride; *xy_out=+m02-m22;\n"
" xy_out += stride; *xy_out= +0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-0.5*m12+0.5*m22;\n"
" xy_out += stride; *xy_out=-m12+m32;\n"
" xy_out += stride; *xy_out=+m03-m23;\n"
" xy_out += stride; *xy_out=+0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-0.5*m13+0.5*m23;\n"
" xy_out += stride; *xy_out=-m13+m33;\n"
" }\n"
"}\n"
"static inline void set_output(constant winograd_constants &cst,device M4 *output,int x,int y,M4 V) {\n"
" output[y*cst.output_shape.x+x]=activate(V,cst.activation);\n"
"}\n"
"kernel void winograd_transform_dest2_5_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" constant int &batch_idx [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*36;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S40=*xy_in; xy_in += dst_w;\n"
" auto S50=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S41=*xy_in; xy_in += dst_w;\n"
" auto S51=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S42=*xy_in; xy_in += dst_w;\n"
" auto S52=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in; xy_in += dst_w;\n"
" auto S43=*xy_in; xy_in += dst_w;\n"
" auto S53=*xy_in; xy_in += dst_w;\n"
" auto S04=*xy_in; xy_in += dst_w;\n"
" auto S14=*xy_in; xy_in += dst_w;\n"
" auto S24=*xy_in; xy_in += dst_w;\n"
" auto S34=*xy_in; xy_in += dst_w;\n"
" auto S44=*xy_in; xy_in += dst_w;\n"
" auto S54=*xy_in; xy_in += dst_w;\n"
" auto S05=*xy_in; xy_in += dst_w;\n"
" auto S15=*xy_in; xy_in += dst_w;\n"
" auto S25=*xy_in; xy_in += dst_w;\n"
" auto S35=*xy_in; xy_in += dst_w;\n"
" auto S45=*xy_in; xy_in += dst_w;\n"
" auto S55=*xy_in;\n"
" auto m00=+S00+S01+S02+S03+S04;\n"
" auto m10=+S10+S11+S12+S13+S14;\n"
" auto m20=+S20+S21+S22+S23+S24;\n"
" auto m30=+S30+S31+S32+S33+S34;\n"
" auto m40=+S40+S41+S42+S43+S44;\n"
" auto m50=+S50+S51+S52+S53+S54;\n"
" auto m01=+S01-S02+2.0*S03-2.0*S04+S05;\n"
" auto m11=+S11-S12+2.0*S13-2.0*S14+S15;\n"
" auto m21=+S21-S22+2.0*S23-2.0*S24+S25;\n"
" auto m31=+S31-S32+2.0*S33-2.0*S34+S35;\n"
" auto m41=+S41-S42+2.0*S43-2.0*S44+S45;\n"
" auto m51=+S51-S52+2.0*S53-2.0*S54+S55;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=pos.y*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+(pos.z*cst.output_shape.w+batch_idx)*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20+m30+m40);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+2.0*m30-2.0*m40+m50);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21+m31+m41);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+2.0*m31-2.0*m41+m51);\n"
" }\n"
" }\n"
"}\n"
"kernel void winograd_transform_dest2_3_1(const device M4 *in [[buffer(0)]],\n"
" const device M4 *biasTerms [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant winograd_constants &cst [[buffer(3)]],\n"
" constant int &batch_idx [[buffer(4)]],\n"
" constant int &split_idx [[buffer(5)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" auto pos=int3(gid);\n"
" if (pos.x<cst.unit_width && pos.y<cst.unit_height) {\n"
" int dst_w=UP_DIV(cst.unit_width*cst.unit_height,4);\n"
" int dst_x_origin=cst.unit_width*pos.y+pos.x;\n"
" int dst_x=dst_x_origin/4;\n"
" int dst_y=4*pos.z+dst_x_origin % 4;\n"
" int dst_y_stride=dst_w*16;\n"
" auto xy_in=in+dst_y*dst_y_stride+dst_x;\n"
" auto S00=*xy_in; xy_in += dst_w;\n"
" auto S10=*xy_in; xy_in += dst_w;\n"
" auto S20=*xy_in; xy_in += dst_w;\n"
" auto S30=*xy_in; xy_in += dst_w;\n"
" auto S01=*xy_in; xy_in += dst_w;\n"
" auto S11=*xy_in; xy_in += dst_w;\n"
" auto S21=*xy_in; xy_in += dst_w;\n"
" auto S31=*xy_in; xy_in += dst_w;\n"
" auto S02=*xy_in; xy_in += dst_w;\n"
" auto S12=*xy_in; xy_in += dst_w;\n"
" auto S22=*xy_in; xy_in += dst_w;\n"
" auto S32=*xy_in; xy_in += dst_w;\n"
" auto S03=*xy_in; xy_in += dst_w;\n"
" auto S13=*xy_in; xy_in += dst_w;\n"
" auto S23=*xy_in; xy_in += dst_w;\n"
" auto S33=*xy_in;\n"
" auto m00=+S00+S01+S02;\n"
" auto m10=+S10+S11+S12;\n"
" auto m20=+S20+S21+S22;\n"
" auto m30=+S30+S31+S32;\n"
" auto m01=+S01-S02+S03;\n"
" auto m11=+S11-S12+S13;\n"
" auto m21=+S21-S22+S23;\n"
" auto m31=+S31-S32+S33;\n"
" // write output\n"
" auto b4=biasTerms[int(pos.z)];\n"
" int oy=(pos.y+split_idx*cst.unit_height)*cst.unit;\n"
" int ox=pos.x*cst.unit;\n"
" auto z_out=out+(pos.z*cst.output_shape.w+batch_idx)*cst.output_shape.x*cst.output_shape.y;\n"
" \n"
" /* if true */ {\n"
" set_output(cst,z_out,ox+0,oy+0,b4+m00+m10+m20);\n"
" }\n"
" if (ox+1<cst.output_shape.x) {\n"
" set_output(cst,z_out,ox+1,oy+0,b4+m10-m20+m30);\n"
" }\n"
" if (oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+0,oy+1,b4+m01+m11+m21);\n"
" }\n"
" if (ox+1<cst.output_shape.x && oy+1<cst.output_shape.y) {\n"
" set_output(cst,z_out,ox+1,oy+1,b4+m11-m21+m31);\n"
" }\n"
" }\n"
"}\n"
;
const char* shader_MetalMatMul_metal =
"struct matmul_shape {\n"
" int4 mat_size;\n"
" int4 in_stride;\n"
"};\n"
"kernel void matmul(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant matmul_shape &s [[buffer(3)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V);\n"
" }\n"
"}\n"
"kernel void matmul_bias(const device M *in0 [[buffer(0)]],\n"
" const device M *in1 [[buffer(1)]],\n"
" const device M *biasValue [[buffer(2)]],\n"
" device M *out [[buffer(3)]],\n"
" constant matmul_shape &s [[buffer(4)]],\n"
" uint2 gid[[thread_position_in_grid]]) {\n"
" if ((int)gid.x<s.mat_size.x && (int)gid.y<s.mat_size.y) {\n"
" auto off_in0=in0+int(gid.y)*s.in_stride.x;\n"
" auto off_in1=in1+int(gid.x)*s.in_stride.z;\n"
" FLOAT V=0.f;\n"
" for (int i=0; i<s.mat_size.z; i++,off_in0 += s.in_stride.y,off_in1 += s.in_stride.w) {\n"
" V += FLOAT(*off_in0)*FLOAT(*off_in1);\n"
" }\n"
" out[int(gid.y)*s.mat_size.x+int(gid.x)]=M(V)+biasValue[(int)(gid.x)];\n"
" }\n"
"}\n"
;
const char* shader_MetalScale_metal =
"struct scale_shape {\n"
" int size;\n"
" int steps;\n"
" int batch;\n"
" int offset;\n"
"};\n"
"kernel void scale_ca(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant scale_shape &s [[buffer(2)]],\n"
" const device float4 *scalesbias[[buffer(3)]],\n"
" uint2 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.steps*s.batch) return;\n"
" int z=gid.y/s.batch;\n"
" int offset=s.offset;\n"
" float4 scale=scalesbias[z];\n"
" float4 bias=scalesbias[z+offset];\n"
" out[int(gid.y)*s.size+int(gid.x)] =\n"
" (M4)((float4)in[int(gid.y)*s.size+int(gid.x)]*scale+bias);\n"
"}\n"
;
const char* shader_MetalDeconvolution_metal =
"struct deconv_constants {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int kernel_x;\n"
" int kernel_y;\n"
" int kernel_size;\n"
" int stride_x;\n"
" int stride_y;\n"
" int pad_x;\n"
" int pad_y;\n"
" int dilation_x;\n"
" int dilation_y;\n"
" int delta_ky;\n"
" int delta_kx;\n"
" int delta_iy;\n"
" int delta_ix;\n"
" int batch;\n"
" conv_activation_type activation;\n"
"};\n"
"kernel void deconv(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" \n"
" int b=gid.z % cst.batch;\n"
" int o=gid.z/cst.batch;\n"
" FLOAT4 result=FLOAT4(biasTerms[o]);\n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto o_wt=wt+o*cst.input_slice*cst.kernel_size;\n"
" auto b_in=in+b*cst.input_size;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=o_wt[z*cst.kernel_size+ky*cst.kernel_x+kx];\n"
" auto in4=b_in[z*cst.input_size*cst.batch+iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
"kernel void deconv_depthwise(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant deconv_constants& cst [[buffer(2)]],\n"
" const device M4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= cst.output_width || (int)gid.y >= cst.output_height || (int)gid.z >= cst.batch*cst.output_slice) return;\n"
" int oz=(int)gid.z/cst.batch;\n"
" FLOAT4 result=FLOAT4(biasTerms[oz]);\n"
" \n"
" int oy=(int)gid.y+cst.pad_y;\n"
" int ox=(int)gid.x+cst.pad_x;\n"
" int max_sy=min((cst.input_height-1)*cst.stride_y,oy/cst.stride_y*cst.stride_y);\n"
" int max_sx=min((cst.input_width-1)*cst.stride_x,ox/cst.stride_x*cst.stride_x);\n"
" int min_ky=UP_DIV(oy-max_sy,cst.dilation_y);\n"
" int min_kx=UP_DIV(ox-max_sx,cst.dilation_x);\n"
" \n"
" if ((oy-min_ky*cst.dilation_y) % cst.stride_y == 0 && (ox-min_kx*cst.dilation_x) % cst.stride_x == 0) {\n"
" int min_sy=max(0,ROUND_UP(oy+cst.dilation_y-cst.kernel_y*cst.dilation_y,cst.stride_y));\n"
" int min_sx=max(0,ROUND_UP(ox+cst.dilation_x-cst.kernel_x*cst.dilation_x,cst.stride_x));\n"
" int max_ky=(oy-min_sy)/cst.dilation_y;\n"
" int max_kx=(ox-min_sx)/cst.dilation_x;\n"
" int min_iy=(oy-max_ky*cst.dilation_y)/cst.stride_y;\n"
" int min_ix=(ox-max_kx*cst.dilation_x)/cst.stride_x;\n"
" \n"
" auto z_wt=wt+oz*cst.kernel_size;\n"
" auto z_in=in+(int)gid.z*cst.input_size;\n"
" for (auto ky=max_ky,iy=min_iy; ky >= min_ky; ky -= cst.delta_ky,iy += cst.delta_iy) {\n"
" for (auto kx=max_kx,ix=min_ix; kx >= min_kx; kx -= cst.delta_kx,ix += cst.delta_ix) {\n"
" auto wt4=z_wt[ky*cst.kernel_x+kx];\n"
" auto in4=z_in[iy*cst.input_width+ix];\n"
" result += FLOAT4(in4*wt4);\n"
" }\n"
" }\n"
" }\n"
" out[(int)gid.z*cst.output_size+(int)gid.y*cst.output_width+(int)gid.x]=activate(M4(result),cst.activation);\n"
"}\n"
;
const char* shader_MetalPooling_metal =
"struct pooling_sizes {\n"
" int input_width;\n"
" int input_height;\n"
" int output_width;\n"
" int output_height;\n"
" int slice;\n"
" int kernel_width;\n"
" int kernel_height;\n"
" int stride_width;\n"
" int stride_height;\n"
" int pad_width;\n"
" int pad_height;\n"
"};\n"
"kernel void pooling_max(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int x_max=s.input_width-1;\n"
" int y_max=s.input_height-1;\n"
" int ex=off_x+s.kernel_width;\n"
" int ey=off_y+s.kernel_height;\n"
" \n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" auto result=M4(z_in[clamp(off_y,0,y_max)*s.input_width+clamp(off_x,0,x_max)]);\n"
" for (int y=off_y; y<ey; y++) {\n"
" auto y_in=z_in+clamp(y,0,y_max)*s.input_width;\n"
" for (int x=off_x; x<ex; x++) {\n"
" result=max(result,y_in[clamp(x,0,x_max)]);\n"
" }\n"
" }\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=result;\n"
"}\n"
"kernel void pooling_avg(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant pooling_sizes& s [[buffer(2)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if (any(gid >= uint3(s.output_width,s.output_height,s.slice))) return;\n"
" \n"
" int off_x=gid.x*s.stride_width-s.pad_width;\n"
" int off_y=gid.y*s.stride_height-s.pad_height;\n"
" int sx=off_x+max(0,-off_x);\n"
" int sy=off_y+max(0,-off_y);\n"
" int ex=off_x+min(s.kernel_width,s.input_width-off_x);\n"
" int ey=off_y+min(s.kernel_height,s.input_height-off_y);\n"
" \n"
" FLOAT4 result=0;\n"
" auto z_in=in+(int)gid.z*s.input_width*s.input_height;\n"
" for (int y=sy; y<ey; y++) {\n"
" for (int x=sx; x<ex; x++) {\n"
" result += FLOAT4(z_in[y*s.input_width+x]);\n"
" }\n"
" }\n"
" int count=(ey-sy)*(ex-sx);\n"
" FLOAT4 div=count>0 ? 1.f/count : 1;\n"
" out[(int)gid.z*s.output_width*s.output_height+(int)gid.y*s.output_width+(int)gid.x]=M4(result*div);\n"
"}\n"
;
const char* shader_MetalROIPooling_metal =
"struct ROI_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int input_batch;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int batch;\n"
" float spatial_scale;\n"
"};\n"
"kernel void ROI_pooling(const device M4 *in [[buffer(0)]],\n"
" const device M *roi [[buffer(1)]],\n"
" device M4 *out [[buffer(2)]],\n"
" constant ROI_shape &s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= s.output_width || (int)gid.y >= s.output_height) return;\n"
" \n"
" int ob=gid.z % s.batch;\n"
" int iz=gid.z/s.batch;\n"
" \n"
" auto b_roi=roi+ob*5;\n"
" int ib=int(b_roi[0]);\n"
" int x1=round(float(b_roi[1])*s.spatial_scale);\n"
" int y1=round(float(b_roi[2])*s.spatial_scale);\n"
" int x2=round(float(b_roi[3])*s.spatial_scale);\n"
" int y2=round(float(b_roi[4])*s.spatial_scale);\n"
" \n"
" int roi_w=max(x2-x1+1,1);\n"
" int roi_h=max(y2-y1+1,1);\n"
" float bin_size_w=(float)roi_w/(float)s.output_width;\n"
" float bin_size_h=(float)roi_h/(float)s.output_height;\n"
" \n"
" int w_start=clamp(x1+(int)floor(gid.x*bin_size_w) ,0,s.input_width);\n"
" int w_end=clamp(x1+(int)ceil((gid.x+1)*bin_size_w),0,s.input_width);\n"
" int h_start=clamp(y1+(int)floor(gid.y*bin_size_h) ,0,s.input_height);\n"
" int h_end=clamp(y1+(int)ceil((gid.y+1)*bin_size_h),0,s.input_height);\n"
" \n"
" int is_empty=(h_end <= h_start) || (w_end <= w_start);\n"
" auto z_in=in+(ib+iz*s.input_batch)*s.input_size;\n"
" auto max4=is_empty ? 0 : z_in[h_start*s.input_width+w_start];\n"
" for (int y=h_start; y<h_end; y++) {\n"
" auto y_in=z_in+y*s.input_width;\n"
" for (int x=w_start; x<w_end; x++) {\n"
" max4=max(max4,y_in[x]);\n"
" }\n"
" }\n"
" out[int(gid.z)*s.output_size+int(gid.y)*s.output_width+int(gid.x)]=max4;\n"
"}\n"
;
const char* shader_MetalConvolution1x1_metal =
"#define CONV_UNROLL (4)\n"
"#define CONV_UNROLL_L (8)\n"
"struct conv1x1_constants {\n"
" int input_size;\n"
" int input_slice;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int output_slice;\n"
" int output_channel;\n"
" int batch;\n"
" int block_size;\n"
" conv_activation_type activation;\n"
" float scale_coef;\n"
"};\n"
"kernel void conv1x1_g1z4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" \n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=*xy_in0;\n"
" auto in41=*(xy_in0+1);\n"
" auto in42=*(xy_in0+2);\n"
" auto in43=*(xy_in0+3);\n"
" auto w=xy_wt[z];\n"
" \n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" \n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z4_w8(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device MNN::char4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device M4 *dequantScale [[buffer(5)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
" for (int bi=0; bi<cst.block_size; ++bi) {\n"
" FLOAT4 bs0=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
" FLOAT4 bs1=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
" FLOAT4 scale=bs0;\n"
" FLOAT4 dequant_bias=bs1;\n"
" int zmin=bi*block;\n"
" int zmax=min(zmin+block,cst.input_slice);\n"
" for (int z=zmin; z<zmax; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=computeSize>1 ? (FLOAT4)*(xy_in0+1) : (FLOAT4)0.0;\n"
" auto in42=computeSize>2 ? (FLOAT4)*(xy_in0+2) : (FLOAT4)0.0;\n"
" auto in43=computeSize>3 ? (FLOAT4)*(xy_in0+3) : (FLOAT4)0.0;\n"
" auto w=xy_wt[z];\n"
" FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" w_dequant[i]=w_fp32[i]*scale[i]+dequant_bias[i];\n"
" }\n"
" result0 += FLOAT4(in40*w_dequant);\n"
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" }\n"
" /* true */ \n"
" xy_out[0]=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z4_w4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device MNN::uchar4x2 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" const device M4 *dequantScale [[buffer(5)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.output_size*cst.batch+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL);\n"
" int block=(cst.input_slice+cst.block_size-1)/cst.block_size;\n"
" for (int bi=0; bi<cst.block_size; ++bi) {\n"
" FLOAT4 scale=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+0])/(FLOAT)cst.scale_coef;\n"
" FLOAT4 dequant_bias=FLOAT4(dequantScale[2*(uz*cst.block_size+bi)+1])/(FLOAT)cst.scale_coef;\n"
" int zmin=bi*block;\n"
" int zmax=min(zmin+block,cst.input_slice);\n"
" for (int z=zmin; z<zmax; z++) {\n"
" auto in40=(FLOAT4)*xy_in0;\n"
" auto in41=(FLOAT4)*(xy_in0+1);\n"
" auto in42=(FLOAT4)*(xy_in0+2);\n"
" auto in43=(FLOAT4)*(xy_in0+3);\n"
" MNN::uchar4x2 w_int4=xy_wt[z];\n"
" // MNN::char4x4 w_int8(char4(0));\n"
" /* weight int4->float */\n"
" //FLOAT4x4 w_fp32=FLOAT4x4(FLOAT4(w[0]),FLOAT4(w[1]),FLOAT4(w[2]),FLOAT4(w[3]));\n"
" FLOAT4x4 w_dequant;\n"
" for (int i=0; i<4; ++i) {\n"
" // M4 w4=M4(w_fp32[i]);\n"
" FLOAT4 w4=FLOAT4((float)(w_int4[i][0] >> 4)-8,(float)(w_int4[i][0] & 15)-8,(float)(w_int4[i][1] >> 4)-8,(float)(w_int4[i][1] & 15)-8);\n"
" FLOAT4 res=w4*scale[i]+dequant_bias[i];\n"
" w_dequant[i]=res;\n"
" }\n"
" result0 += FLOAT4(in40*w_dequant);\n"
" result1 += FLOAT4(in41*w_dequant);\n"
" result2 += FLOAT4(in42*w_dequant);\n"
" result3 += FLOAT4(in43*w_dequant);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" }\n"
" \n"
" /* true */ \n"
" xy_out[0]=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
"}\n"
"kernel void conv1x1_g1z8(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*CONV_UNROLL_L >= cst.output_size || (int)gid.y >= cst.output_slice || (int)gid.z >= cst.batch) return;\n"
" int rx=gid.x*CONV_UNROLL_L;\n"
" int uz=gid.y;\n"
" auto xy_wt=wt+uz*cst.input_slice;\n"
" auto xy_in0=in+(int)gid.z*cst.input_size+rx+0;\n"
" auto xy_out=out+(int)gid.z*cst.output_size+uz*cst.batch*cst.output_size+rx;\n"
" auto biasValue=FLOAT4(biasTerms[uz]);\n"
" FLOAT4 result0=biasValue,result1=biasValue,result2=biasValue,result3=biasValue;\n"
" FLOAT4 result4=biasValue,result5=biasValue,result6=biasValue,result7=biasValue;\n"
" int computeSize=min(cst.output_size-rx,CONV_UNROLL_L);\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in42=xy_in0[2];\n"
" auto in43=xy_in0[3];\n"
" auto in44=xy_in0[4];\n"
" auto in45=xy_in0[5];\n"
" auto in46=xy_in0[6];\n"
" auto in47=xy_in0[7];\n"
" auto w=xy_wt[z];\n"
" result0 += FLOAT4(in40*w);\n"
" result1 += FLOAT4(in41*w);\n"
" result2 += FLOAT4(in42*w);\n"
" result3 += FLOAT4(in43*w);\n"
" result4 += FLOAT4(in44*w);\n"
" result5 += FLOAT4(in45*w);\n"
" result6 += FLOAT4(in46*w);\n"
" result7 += FLOAT4(in47*w);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (computeSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (computeSize>2) {xy_out[2]=activate(M4(result2),cst.activation); }\n"
" if (computeSize>3) {xy_out[3]=activate(M4(result3),cst.activation); }\n"
" if (computeSize>4) {xy_out[4]=activate(M4(result4),cst.activation); }\n"
" if (computeSize>5) {xy_out[5]=activate(M4(result5),cst.activation); }\n"
" if (computeSize>6) {xy_out[6]=activate(M4(result6),cst.activation); }\n"
" if (computeSize>7) {xy_out[7]=activate(M4(result7),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w4h4(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*16 >= cst.output_width || (int)gid.y >= cst.batch*cst.output_slice) return;\n"
" int idx_w=gid.x << 4;\n"
" int idx_h=0;\n"
" int idx_c=gid.y/cst.batch;\n"
" int idx_b=gid.y % cst.batch;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue=FLOAT4(biasTerms[idx_c]);\n"
" FLOAT4 result00=biasValue,result01=biasValue,result02=biasValue,result03=biasValue;\n"
" FLOAT4 result10=biasValue,result11=biasValue,result12=biasValue,result13=biasValue;\n"
" FLOAT4 result20=biasValue,result21=biasValue,result22=biasValue,result23=biasValue;\n"
" FLOAT4 result30=biasValue,result31=biasValue,result32=biasValue,result33=biasValue;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in00=xy_in0[0];\n"
" auto in01=xy_in0[1];\n"
" auto in02=xy_in0[2];\n"
" auto in03=xy_in0[3];\n"
" auto in10=xy_in0[4];\n"
" auto in11=xy_in0[5];\n"
" auto in12=xy_in0[6];\n"
" auto in13=xy_in0[7];\n"
" \n"
" auto in20=xy_in0[8];\n"
" auto in21=xy_in0[9];\n"
" auto in22=xy_in0[10];\n"
" auto in23=xy_in0[11];\n"
" auto in30=xy_in0[12];\n"
" auto in31=xy_in0[13];\n"
" auto in32=xy_in0[14];\n"
" auto in33=xy_in0[15];\n"
" auto w=xy_wt[z];\n"
" result00 += FLOAT4(in00*w);\n"
" result01 += FLOAT4(in01*w);\n"
" result02 += FLOAT4(in02*w);\n"
" result03 += FLOAT4(in03*w);\n"
" result10 += FLOAT4(in10*w);\n"
" result11 += FLOAT4(in11*w);\n"
" result12 += FLOAT4(in12*w);\n"
" result13 += FLOAT4(in13*w);\n"
" \n"
" result20 += FLOAT4(in20*w);\n"
" result21 += FLOAT4(in21*w);\n"
" result22 += FLOAT4(in22*w);\n"
" result23 += FLOAT4(in23*w);\n"
" result30 += FLOAT4(in30*w);\n"
" result31 += FLOAT4(in31*w);\n"
" result32 += FLOAT4(in32*w);\n"
" result33 += FLOAT4(in33*w);\n"
" \n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,16);\n"
" /* true */ *xy_out=activate(M4(result00),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result01),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result02),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result03),cst.activation); }\n"
" if (widthSize>4) {xy_out[4]=activate(M4(result10),cst.activation); }\n"
" if (widthSize>5) {xy_out[5]=activate(M4(result11),cst.activation); }\n"
" if (widthSize>6) {xy_out[6]=activate(M4(result12),cst.activation); }\n"
" if (widthSize>7) {xy_out[7]=activate(M4(result13),cst.activation); }\n"
" if (widthSize>8) {xy_out[8]=activate(M4(result20),cst.activation); }\n"
" if (widthSize>9) {xy_out[9]=activate(M4(result21),cst.activation); }\n"
" if (widthSize>10) {xy_out[10]=activate(M4(result22),cst.activation); }\n"
" if (widthSize>11) {xy_out[11]=activate(M4(result23),cst.activation); }\n"
" if (widthSize>12) {xy_out[12]=activate(M4(result30),cst.activation); }\n"
" if (widthSize>13) {xy_out[13]=activate(M4(result31),cst.activation); }\n"
" if (widthSize>14) {xy_out[14]=activate(M4(result32),cst.activation); }\n"
" if (widthSize>15) {xy_out[15]=activate(M4(result33),cst.activation); }\n"
"}\n"
"kernel void conv1x1_w2c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*2 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 1;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" \n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue1,result5=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in40*w1);\n"
" result5 += FLOAT4(in41*w1);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,2);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ {xy_out[cst.output_size*cst.batch +0]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result5),cst.activation); }\n"
" }\n"
"}\n"
"kernel void conv1x1_w4c2(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant conv1x1_constants& cst [[buffer(2)]],\n"
" const device M4x4 *wt [[buffer(3)]],\n"
" const device M4 *biasTerms [[buffer(4)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x*4 >= cst.output_width || (int)gid.y*2 >= cst.batch*cst.output_slice) return;\n"
" int channel_pack=(cst.output_channel+7) >> 3;\n"
" int idx_w=gid.x << 2;\n"
" int idx_h=0;\n"
" int idx_c=(gid.y % channel_pack) << 1;\n"
" int idx_b=gid.y/channel_pack;\n"
" if(idx_b >= cst.batch || idx_c >= cst.output_slice) return;\n"
" auto xy_wt=wt+idx_c*cst.input_slice;\n"
" auto xy_in0=in+(int)idx_b*cst.input_size+idx_h*cst.output_width+idx_w;\n"
" auto xy_out=out+(int)idx_b*cst.output_size+idx_c*cst.output_size*cst.batch+idx_h*cst.output_width+idx_w;\n"
" auto biasValue0=FLOAT4(biasTerms[idx_c]);\n"
" auto biasValue1=FLOAT4(biasTerms[idx_c+1]);\n"
" FLOAT4 result0=biasValue0,result1=biasValue0;\n"
" FLOAT4 result4=biasValue0,result5=biasValue0;\n"
" FLOAT4 result2=biasValue1,result3=biasValue1;\n"
" FLOAT4 result6=biasValue1,result7=biasValue1;\n"
" for (auto z=0; z<cst.input_slice; z++) {\n"
" auto in40=xy_in0[0];\n"
" auto in41=xy_in0[1];\n"
" auto in44=xy_in0[2];\n"
" auto in45=xy_in0[3];\n"
" auto w0=xy_wt[z];\n"
" auto w1=xy_wt[cst.input_slice+z];\n"
" result0 += FLOAT4(in40*w0);\n"
" result1 += FLOAT4(in41*w0);\n"
" result4 += FLOAT4(in44*w0);\n"
" result5 += FLOAT4(in45*w0);\n"
" result2 += FLOAT4(in40*w1);\n"
" result3 += FLOAT4(in41*w1);\n"
" result6 += FLOAT4(in44*w1);\n"
" result7 += FLOAT4(in45*w1);\n"
" xy_in0 += cst.input_size*cst.batch;\n"
" }\n"
" int widthSize=min(cst.output_width-idx_w,4);\n"
" /* true */ *xy_out=activate(M4(result0),cst.activation);\n"
" if (widthSize>1) {xy_out[1]=activate(M4(result1),cst.activation); }\n"
" if (widthSize>2) {xy_out[2]=activate(M4(result4),cst.activation); }\n"
" if (widthSize>3) {xy_out[3]=activate(M4(result5),cst.activation); }\n"
" \n"
" int channelSize=min(cst.output_slice-idx_c,2);\n"
" if(channelSize>1) {\n"
" /* true */ xy_out[cst.output_size*cst.batch]=activate(M4(result2),cst.activation);\n"
" if (widthSize>1) {xy_out[cst.output_size*cst.batch +1]=activate(M4(result3),cst.activation); }\n"
" if (widthSize>2) {xy_out[cst.output_size*cst.batch +2]=activate(M4(result6),cst.activation); }\n"
" if (widthSize>3) {xy_out[cst.output_size*cst.batch +3]=activate(M4(result7),cst.activation); }\n"
" }\n"
"}\n"
;
const char* shader_MetalConvolutionGEMM_metal =
"struct matmul4x4_const {\n"
" int output_width;\n"
" int output_height;\n"
" int multi_length;\n"
" int group;\n"
"};\n"
"template <typename IType,typename OType>\n"
"static inline void matmul4x4_template(const device IType *in,\n"
" device OType *out,\n"
" const device IType *kt,\n"
" constant matmul4x4_const &cst,\n"
" uint3 gid) {\n"
" if ((int)gid.x<cst.output_width && (int)gid.y<cst.output_height) {\n"
" auto ky=(int)gid.y+(int)gid.z*cst.output_height;\n"
" auto iy=(int)gid.x+(int)gid.z*cst.output_width;\n"
" auto off_in=in+iy*cst.multi_length;\n"
" auto off_wt=kt+ky*cst.multi_length;\n"
" auto off_out=out+iy+4*(int)gid.y*cst.output_width*cst.group;\n"
" \n"
" FLOAT4 result0=0,result1=0,result2=0,result3=0;\n"
" for (int k=0; k<cst.multi_length; ++k) {\n"
" auto w4x4=off_wt[k];\n"
" auto i4x4=off_in[k];\n"
" result0 += FLOAT4(w4x4*i4x4[0]);\n"
" result1 += FLOAT4(w4x4*i4x4[1]);\n"
" result2 += FLOAT4(w4x4*i4x4[2]);\n"
" result3 += FLOAT4(w4x4*i4x4[3]);\n"
" }\n"
" *off_out=OType(result0); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result1); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result2); off_out += cst.output_width*cst.group;\n"
" *off_out=OType(result3);\n"
" }\n"
"}\n"
"kernel void matmul4x4(const device M4x4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device M4x4 *kt [[buffer(2)]],\n"
" constant matmul4x4_const &cst [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" matmul4x4_template<M4x4,M4>(in,out,kt,cst,gid);\n"
"}\n"
;
const char* shader_MetalResize_metal =
"struct resize_shape {\n"
" int input_width;\n"
" int input_height;\n"
" int input_size;\n"
" int output_width;\n"
" int output_height;\n"
" int output_size;\n"
" int sliceMap;\n"
"};\n"
"kernel void resize_nearest(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int left=floor(srcX);\n"
" int top=floor(srcY);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=in_top[left];\n"
"}\n"
"kernel void resize_bilinear(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" \n"
" float srcX=gid.x*s.x+s.y,srcY=gid.y*s.z+s.w;\n"
" int srcXInt=int(floor(srcX));\n"
" int srcYInt=int(floor(srcY));\n"
" int left=clamp(srcXInt,0,c.input_width-1);\n"
" int right=clamp(srcXInt+1,0,c.input_width-1);\n"
" int top=clamp(srcYInt,0,c.input_height-1);\n"
" int bottom=clamp(srcYInt+1,0,c.input_height-1);\n"
" float x2_factor=srcX-float(srcXInt);\n"
" float y2_factor=srcY-float(srcYInt);\n"
" float x1_factor=1-x2_factor;\n"
" float y1_factor=1-y2_factor;\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" auto in_top=in_z+top*c.input_width;\n"
" auto in_bottom=in_z+bottom*c.input_width;\n"
" auto tl=float4(in_top[left])*x1_factor*y1_factor;\n"
" auto tr=float4(in_top[right])*x2_factor*y1_factor;\n"
" auto bl=float4(in_bottom[left])*x1_factor*y2_factor;\n"
" auto br=float4(in_bottom[right])*x2_factor*y2_factor;\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=M4(tl+tr+bl+br);\n"
"}\n"
"static inline float4 resize_cubic_interpolation(float4 A,float4 B,float4 C,float4 D,float factor) {\n"
" float4 a=(B-C)+0.5f*(B-A)+(D-C)*0.5f;\n"
" float4 b=C-((B-A)+(B-C))-(B+D)*0.5f;\n"
" float4 c=(C-A)*0.5f;\n"
" float4 d=B;\n"
" return ((a*factor+b)*factor+c)*factor+d;\n"
"}\n"
"kernel void resize_cubic(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant resize_shape &c [[buffer(2)]],\n"
" constant float4& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) {\n"
" if ((int)gid.x >= c.output_width || (int)gid.y >= c.output_height || (int)gid.z >= c.sliceMap) return;\n"
" float x=gid.x*s.x+s.y,y=gid.y*s.z+s.w;\n"
" \n"
" float x_factor=x-floor(x);\n"
" float y_factor=y-floor(y);\n"
" \n"
" int4 xp=int4(int(x)-1,int(x)+0,int(x)+1,int(x)+2);\n"
" xp=clamp(xp,0,c.input_width-1);\n"
" \n"
" int4 yp=int4(int(y)-1,int(y)+0,int(y)+1,int(y)+2);\n"
" yp=clamp(yp,0,c.input_height-1);\n"
" \n"
" auto in_z=in+gid.z*c.input_size;\n"
" float4x4 ABCD;\n"
" for (int i=0; i<4; i++) {\n"
" auto in_y=in_z+yp[i]*c.input_width;\n"
" float4 A=float4(in_y[xp[0]]);\n"
" float4 B=float4(in_y[xp[1]]);\n"
" float4 C=float4(in_y[xp[2]]);\n"
" float4 D=float4(in_y[xp[3]]);\n"
" ABCD[i]=resize_cubic_interpolation(A,B,C,D,x_factor);\n"
" }\n"
" \n"
" auto val=M4(resize_cubic_interpolation(ABCD[0],ABCD[1],ABCD[2],ABCD[3],y_factor));\n"
" out[int(gid.z)*c.output_size+int(gid.y)*c.output_width+int(gid.x)]=val;\n"
"}\n"
;
const char* shader_MetalPReLU_metal =
"struct prelu_shape {\n"
" int size;\n"
" int slice;\n"
" int batch;\n"
"};\n"
"kernel void prelu(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" constant float &slope [[buffer(2)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" auto v4=in[int(gid)];\n"
" out[int(gid)]=select(v4,M4(slope)*v4,signbit(v4));\n"
"}\n"
"kernel void prelu_slopes(const device M4 *in [[buffer(0)]],\n"
" device M4 *out [[buffer(1)]],\n"
" const device float4 *slope [[buffer(2)]],\n"
" constant prelu_shape& s [[buffer(3)]],\n"
" uint3 gid [[thread_position_in_grid]]) { // size,slice,batch\n"
" if ((int)gid.x >= s.size || (int)gid.y >= s.slice) return;\n"
" \n"
" int z=gid.z+gid.y*s.batch;\n"
" auto v4=in[z*s.size+int(gid.x)];\n"
" out[z*s.size+int(gid.x)]=select(v4,M4(slope[int(gid.y)])*v4,signbit(v4));\n"
"}\n"
;
const char* shader_MetalDefine_metal =
"using namespace metal;\n"
"// –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"// Macro\n"
"// –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"#define UP_DIV(x,y) ( ((x)+(y)-1)/(y) )\n"
"#define ROUND_UP(x,y) ( ((x)+(y)-1)/(y)*(y) )\n"
"// whether computer with float32 when store with float16\n"
"#define MNN_METAL_FLOAT32_COMPUTER 1 //\n"
"#if MNN_METAL_FULL_PRECISION\n"
"typedef float M;\n"
"typedef float2 M2;\n"
"typedef float3 M3;\n"
"typedef float4 M4;\n"
"typedef float2x2 M2x2;\n"
"typedef float2x3 M2x3;\n"
"typedef float2x4 M2x4;\n"
"typedef float3x2 M3x2;\n"
"typedef float3x3 M3x3;\n"
"typedef float3x4 M3x4;\n"
"typedef float4x2 M4x2;\n"
"typedef float4x3 M4x3;\n"
"typedef float4x4 M4x4;\n"
"#else\n"
"typedef half M;\n"
"typedef half2 M2;\n"
"typedef half3 M3;\n"
"typedef half4 M4;\n"
"typedef half2x2 M2x2;\n"
"typedef half2x3 M2x3;\n"
"typedef half2x4 M2x4;\n"
"typedef half3x2 M3x2;\n"
"typedef half3x3 M3x3;\n"
"typedef half3x4 M3x4;\n"
"typedef half4x2 M4x2;\n"
"typedef half4x3 M4x3;\n"
"typedef half4x4 M4x4;\n"
"#endif\n"
"#if MNN_METAL_FLOAT32_COMPUTER\n"
"typedef float FLOAT;\n"
"typedef float2 FLOAT2;\n"
"typedef float3 FLOAT3;\n"
"typedef float4 FLOAT4;\n"
"typedef float2x2 FLOAT2x2;\n"
"typedef float2x3 FLOAT2x3;\n"
"typedef float2x4 FLOAT2x4;\n"
"typedef float3x2 FLOAT3x2;\n"
"typedef float3x3 FLOAT3x3;\n"
"typedef float3x4 FLOAT3x4;\n"
"typedef float4x2 FLOAT4x2;\n"
"typedef float4x3 FLOAT4x3;\n"
"typedef float4x4 FLOAT4x4;\n"
"#else\n"
"typedef half FLOAT;\n"
"typedef half2 FLOAT2;\n"
"typedef half3 FLOAT3;\n"
"typedef half4 FLOAT4;\n"
"typedef half2x2 FLOAT2x2;\n"
"typedef half2x3 FLOAT2x3;\n"
"typedef half2x4 FLOAT2x4;\n"
"typedef half3x2 FLOAT3x2;\n"
"typedef half3x3 FLOAT3x3;\n"
"typedef half3x4 FLOAT3x4;\n"
"typedef half4x2 FLOAT4x2;\n"
"typedef half4x3 FLOAT4x3;\n"
"typedef half4x4 FLOAT4x4;\n"
"#endif\n"
"namespace MNN {\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Number Limit\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
"#define INT8_MAX 127\n"
"#define INT8_MIN -128\n"
"#define INT16_MAX 32767\n"
"#define INT16_MIN -32768\n"
"#define INT32_MAX 2147483647\n"
"#define INT32_MIN -2147483648\n"
"#define UINT8_MAX 255\n"
"#define UINT16_MAX 65535\n"
"#define UINT32_MAX 4294967295U\n"
" \n"
" template<typename T> struct num_limits {\n"
" static int max() { return 0; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<char> {\n"
" static int max() { return INT8_MAX; };\n"
" static int min() { return INT8_MIN; };\n"
" };\n"
" template<> struct num_limits<uchar> {\n"
" static int max() { return UINT8_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<short> {\n"
" static int max() { return INT16_MAX; };\n"
" static int min() { return INT16_MIN; };\n"
" };\n"
" template<> struct num_limits<ushort> {\n"
" static int max() { return UINT16_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" template<> struct num_limits<int> {\n"
" static int max() { return INT32_MAX; };\n"
" static int min() { return INT32_MIN; };\n"
" };\n"
" template<> struct num_limits<uint> {\n"
" static int max() { return UINT32_MAX; };\n"
" static int min() { return 0; };\n"
" };\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Function\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" inline int dot(int4 i4,int4 w4) {\n"
" return i4[0]*w4[0]+i4[1]*w4[1]+i4[2]*w4[2]+i4[3]*w4[3];\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T saturate_round_x2_high_mul(T a,int b) {\n"
" return mulhi(a,b)*2;\n"
" }\n"
" \n"
" template <typename T>\n"
" inline T round_divide_by_pot(T x,int exponent) {\n"
" int mask=(1 << exponent)-1;\n"
" T remainder=x & mask;\n"
" T threshold=(mask >> 1)+T(x<0);\n"
" return (x >> exponent)+T(remainder>threshold);\n"
" }\n"
" \n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" // Typedef\n"
" // –––––––––––––––––––––––––––––––––––––––––––––––––––\n"
" \n"
" typedef struct short4x4 {\n"
" private:\n"
" short4 v[4];\n"
" public:\n"
" short4x4(short4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" short4x4(short4 a,short4 b,short4 c,short4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread short4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device short4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup short4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread short4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device short4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup short4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device{\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } short4x4;\n"
" \n"
" typedef struct char4x4 {\n"
" private:\n"
" char4 v[4];\n"
" public:\n"
" char4x4(char4 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" char4x4(char4 a,char4 b,char4 c,char4 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread char4& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device char4& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup char4& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread char4& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device char4& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup char4& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x4() const {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const device {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" inline explicit operator half4x4() const threadgroup {\n"
" return half4x4( half4(v[0]),half4(v[1]),half4(v[2]),half4(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x4() const {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const device {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" inline explicit operator float4x4() const threadgroup {\n"
" return float4x4( float4(v[0]),float4(v[1]),float4(v[2]),float4(v[3]) );\n"
" }\n"
" } char4x4;\n"
" typedef struct char4x2 {\n"
" private:\n"
" char2 v[4];\n"
" public:\n"
" char4x2(char2 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" char4x2(char2 a,char2 b,char2 c,char2 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread char2& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device char2& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup char2& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread char2& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device char2& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup char2& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x2() const {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const device {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const threadgroup {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x2() const {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const device {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const threadgroup {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" } char4x2;\n"
" typedef struct uchar4x2 {\n"
" private:\n"
" uchar2 v[4];\n"
" public:\n"
" uchar4x2(uchar2 a) {\n"
" v[0]=a; v[1]=a; v[2]=a; v[3]=a;\n"
" }\n"
" uchar4x2(uchar2 a,uchar2 b,uchar2 c,uchar2 d) {\n"
" v[0]=a; v[1]=b; v[2]=c; v[3]=d;\n"
" }\n"
" \n"
" inline thread uchar2& operator[] (const int index) {\n"
" return v[index];\n"
" }\n"
" inline device uchar2& operator[] (const int index) device {\n"
" return v[index];\n"
" }\n"
" inline threadgroup uchar2& operator[] (const int index) threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline const thread uchar2& operator[] (const int index) const {\n"
" return v[index];\n"
" }\n"
" inline const device uchar2& operator[] (const int index) const device {\n"
" return v[index];\n"
" }\n"
" inline const threadgroup uchar2& operator[] (const int index) const threadgroup {\n"
" return v[index];\n"
" }\n"
" \n"
" inline explicit operator half4x2() const {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const device {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" inline explicit operator half4x2() const threadgroup {\n"
" return half4x2( half2(v[0]),half2(v[1]),half2(v[2]),half2(v[3]) );\n"
" }\n"
" \n"
" inline explicit operator float4x2() const {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const device {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" inline explicit operator float4x2() const threadgroup {\n"
" return float4x2( float2(v[0]),float2(v[1]),float2(v[2]),float2(v[3]) );\n"
" }\n"
" } uchar4x2;\n"
"}\n"
;
const char* shader_MetalEltwise_metal =
"kernel void eltwise_prod(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]*in1[(int)gid];\n"
" }\n"
"}\n"
"kernel void eltwise_max(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=max(in0[(int)gid],in1[(int)gid]);\n"
" }\n"
"}\n"
"kernel void eltwise_add(device const M *in0 [[buffer(0)]],\n"
" device const M *in1 [[buffer(1)]],\n"
" device M *out [[buffer(2)]],\n"
" constant int4& shape [[buffer(3)]],\n"
" uint gid [[thread_position_in_grid]]) {\n"
" if ((int)gid<shape.x) {\n"
" out[(int)gid]=in0[(int)gid]+in1[(int)gid];\n"
" }\n"
"}\n"
;