in tensorflow/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc [315:1216]
std::string GetKernelForConv1x1(const Convolution2DAttributes& params,
int z_out) {
std::string code;
code.reserve(16 * 1024); // Reserve large enough buffer.
std::string channels[4] = {"x", "y", "z", "w"};
code += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int4 stride_padding;
int4 kernel_dilation;
uint4 work_group_size;
};
$0
kernel void ComputeFunction(
$1
uint3 group_id[[threadgroup_position_in_grid]],
uint3 tid3d[[thread_position_in_threadgroup]])
{
int gid_x = group_id.y * params.work_group_size.x + tid3d.x;
int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) << 1u;
)";
code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " +
std::to_string(z_out) + "u;\n";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
}
code += R"(
device FLT4* tmp = filters + gid_z * 4 * params.src_size.w;
int y0 = clamp(gid_y, 0, params.src_size.y - 1);
int y1 = clamp(gid_y + 1, 0, params.src_size.y - 1);
int x0 = clamp(gid_x, 0, params.src_size.x - 1);
int s = 0;
device FLT4* src_loc_0 = src_buffer + y0 * params.src_size.x + x0;
device FLT4* src_loc_1 = src_buffer + y1 * params.src_size.x + x0;
do {
FLT4 src_0 = *src_loc_0;
FLT4 src_1 = *src_loc_1;
src_loc_0 += params.src_size.z;
src_loc_1 += params.src_size.z;
)";
for (int i = 0; i < z_out * 4; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_0);\n";
code += " l" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_1);\n";
}
code += " tmp += " + std::to_string(z_out * 4) + ";\n";
code += R"(
s += 1;
} while (s < params.src_size.w);
const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x;
const int offset_1 = offset_0 + params.dst_size.x;
bool y0_in = gid_y < params.dst_size.y;
bool y1_in = gid_y + 1 < params.dst_size.y;
device FLT4* bias_loc = biases + gid_z;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
}
code += R"(
if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
return;
}
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n";
code += " if (y0_in) {\n";
code += " FLT4 value = FLT4(r" + s_i + ");\n";
code += " int linear_index = offset_0 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " if (y1_in) {\n";
code += " FLT4 value = FLT4(l" + s_i + ");\n";
code += " int linear_index = offset_1 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " }\n";
}
code += " }\n";
return code;
}
std::string GetKernelForConvGeneric(const Convolution2DAttributes& params,
int z_out) {
std::string code;
code.reserve(16 * 1024); // Reserve large enough buffer.
std::string channels[4] = {"x", "y", "z", "w"};
code += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int4 stride_padding;
int4 kernel_dilation;
uint4 work_group_size;
};
$0
kernel void ComputeFunction(
$1
uint3 group_id[[threadgroup_position_in_grid]],
uint3 tid3d[[thread_position_in_threadgroup]])
{
int gid_x = group_id.y * params.work_group_size.x + tid3d.x;
int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) * 2;
)";
code += " int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " +
std::to_string(z_out) + "u;\n";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
}
code += R"(
device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * params.kernel_dilation.x * params.kernel_dilation.y;
int y0 = gid_y * params.stride_padding.y + params.stride_padding.w;
int y1 = (gid_y + 1) * params.stride_padding.y + params.stride_padding.w;
int x0 = gid_x * params.stride_padding.x + params.stride_padding.z;
int y = 0;
do {
int coord_y0 = y * params.kernel_dilation.w + y0;
int coord_y1 = y * params.kernel_dilation.w + y1;
bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y;
bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y;
coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1);
coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1);
int x = 0;
do {
int coord_x0 = x * params.kernel_dilation.z + x0;
bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x;
coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1);
FLT m0 = !(y0_out || x0_out);
FLT m1 = !(y1_out || x0_out);
int s = 0;
device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0;
device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x0;
do {
FLT4 src_0 = *src_loc_0 * m0;
FLT4 src_1 = *src_loc_1 * m1;
src_loc_0 += params.src_size.z;
src_loc_1 += params.src_size.z;
)";
for (int i = 0; i < z_out * 4; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_0);\n";
code += " l" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_1);\n";
}
code += " tmp += " + std::to_string(z_out * 4) + ";\n";
code += R"(
s += 1;
} while (s < params.src_size.w);
x++;
} while (x < params.kernel_dilation.x);
y++;
} while (y < params.kernel_dilation.y);
const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x;
const int offset_1 = offset_0 + params.dst_size.x;
bool p0_in = gid_x < params.dst_size.x && gid_y < params.dst_size.y;
bool p1_in = gid_x < params.dst_size.x && gid_y + 1 < params.dst_size.y;
device FLT4* bias_loc = biases + gid_z;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
}
code += R"(
if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
return;
}
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n";
code += " if (p0_in) {\n";
code += " FLT4 value = FLT4(r" + s_i + ");\n";
code += " int linear_index = offset_0 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " if (p1_in) {\n";
code += " FLT4 value = FLT4(l" + s_i + ");\n";
code += " int linear_index = offset_1 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " }\n";
}
code += " }\n";
return code;
}
std::string GetKernelForConvPrecise(int z_out) {
std::string channels[4] = {"x", "y", "z", "w"};
std::string code;
code.reserve(16 * 1024); // Reserve large enough buffer.
code += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int4 stride_padding;
int4 kernel_dilation;
int4 slices;
};
$0
kernel void ComputeFunction(
$1
uint3 ugid[[thread_position_in_grid]])
{
int linear_id = ugid.x;
int gid_z = linear_id / params.slices.y;
int linear_xy = (linear_id - gid_z * params.slices.y) << 1;
)";
code += " gid_z *= " + std::to_string(z_out) + ";\n";
code += R"(
int gid_y0 = linear_xy / params.slices.x;
int gid_x0 = linear_xy - gid_y0 * params.slices.x;
linear_xy += 1;
int gid_y1 = linear_xy / params.slices.x;
int gid_x1 = linear_xy - gid_y1 * params.slices.x;
if (gid_z >= params.dst_size.w) return;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
code += " ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
}
code += R"(
device FLT4* tmp = filters + gid_z * 4 * params.src_size.w *
params.kernel_dilation.x * params.kernel_dilation.y;
int y0 = gid_y0 * params.stride_padding.y + params.stride_padding.w;
int y1 = gid_y1 * params.stride_padding.y + params.stride_padding.w;
int x0 = gid_x0 * params.stride_padding.x + params.stride_padding.z;
int x1 = gid_x1 * params.stride_padding.x + params.stride_padding.z;
)";
code += R"(
int y = 0;
do {
int coord_y0 = y * params.kernel_dilation.w + y0;
int coord_y1 = y * params.kernel_dilation.w + y1;
bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y;
bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y;
coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1);
coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1);
int x = 0;
do {
int coord_x0 = x * params.kernel_dilation.z + x0;
int coord_x1 = x * params.kernel_dilation.z + x1;
bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x;
bool x1_out = coord_x1 < 0 || coord_x1 >= params.src_size.x;
coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1);
coord_x1 = clamp(coord_x1, 0, params.src_size.x - 1);
FLT m0 = !(y0_out || x0_out);
FLT m1 = !(y1_out || x1_out);
device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0;
device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x1;
int s = 0;
do {
FLT4 src_0 = *src_loc_0 * m0;
FLT4 src_1 = *src_loc_1 * m1;
src_loc_0 += params.src_size.z;
src_loc_1 += params.src_size.z;
)";
for (int i = 0; i < z_out * 4; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_0);\n";
code += " l" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_1);\n";
}
code += " tmp += " + std::to_string(z_out * 4) + ";\n";
code += R"(
s += 1;
} while (s < params.src_size.w);
x++;
} while (x < params.kernel_dilation.x);
y++;
} while (y < params.kernel_dilation.y);
const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0;
const int offset_1 = gid_z * params.dst_size.z + gid_y1 * params.dst_size.x + gid_x1;
bool p0_in = gid_x0 < params.dst_size.x && gid_y0 < params.dst_size.y;
bool p1_in = gid_x1 < params.dst_size.x && gid_y1 < params.dst_size.y;
device FLT4* bias_loc = biases + gid_z;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
code += " l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
}
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n";
code += " if (p0_in) {\n";
code += " FLT4 value = FLT4(r" + s_i + ");\n";
code += " int linear_index = offset_0 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " if (p1_in) {\n";
code += " FLT4 value = FLT4(l" + s_i + ");\n";
code += " int linear_index = offset_1 + params.dst_size.z * " + s_i +
";\n";
code += " uint3 gid = uint3(gid_x1, gid_y1, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
code += " }\n";
}
code += " }\n";
return code;
}
std::string GetKernelForConvPrecise1x1PowerVR(int z_out) {
std::string channels[4] = {"x", "y", "z", "w"};
std::string code;
code.reserve(16 * 1024); // Reserve large enough buffer.
code += R"(
#include <metal_stdlib>
using namespace metal;
struct uniforms {
int4 src_size;
int4 dst_size;
int4 slices;
int4 dummy0;
};
$0
kernel void ComputeFunction(
$1
uint3 ugid[[thread_position_in_grid]])
{
int linear_id = ugid.x;
int gid_z = linear_id / params.slices.y;
int linear_xy = linear_id - gid_z * params.slices.y;
)";
code += " gid_z *= " + std::to_string(z_out) + ";\n";
code += R"(
int gid_y0 = linear_xy / params.slices.x;
int gid_x0 = linear_xy - gid_y0 * params.slices.x;
if (gid_z >= params.dst_size.w) return;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " float4 r" + s_i + " = float4(0.0f, 0.0f, 0.0f, 0.0f);\n";
}
code += R"(
device FLT4* tmp = filters + gid_z * 4 * params.src_size.w;
device FLT4* src_loc_0 = src_buffer + gid_y0 * params.src_size.x + gid_x0;
int s = 0;
do {
FLT4 src_0 = *src_loc_0;
src_loc_0 += params.src_size.z;
)";
for (int i = 0; i < z_out * 4; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + std::to_string(i / 4) + "." + channels[i % 4] +
" += dot(tmp[" + s_i + "], src_0);\n";
}
code += " tmp += " + std::to_string(z_out * 4) + ";\n";
code += R"(
s += 1;
} while (s < params.src_size.w);
const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0;
device FLT4* bias_loc = biases + gid_z;
)";
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " r" + s_i + " += float4(bias_loc[" + s_i + "]);\n";
}
for (int i = 0; i < z_out; ++i) {
const std::string s_i = std::to_string(i);
code += " if (gid_z + " + s_i + "< params.dst_size.w) {\n";
code += " FLT4 value = FLT4(r" + s_i + ");\n";
code +=
" int linear_index = offset_0 + params.dst_size.z * " + s_i + ";\n";
code += " uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n";
code += " $2\n";
code += " dst_buffer[linear_index] = value;\n";
code += " }\n";
}
code += " }\n";
return code;
}
// Reorder weights to make the weights memory access pattern cache friendly for
// Convolution1x1/ConvolutionGeneric
std::vector<float> ReorderWeightsForConv(const Convolution2DAttributes& params,
int z_out) {
const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4);
const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4);
std::vector<float> weights_reordered(params.weights.shape.w *
params.weights.shape.h * dst_depth * 4 *
src_depth * 4);
int counter = 0;
for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) {
for (int y = 0; y < params.weights.shape.h; ++y) {
for (int x = 0; x < params.weights.shape.w; ++x) {
for (int s = 0; s < src_depth; ++s) {
for (int k = 0; k < z_out; ++k) {
for (int j = 0; j < 4; ++j) {
for (int i = 0; i < 4; ++i) {
int src_ch = s * 4 + i;
int dst_ch = (d * z_out + k) * 4 + j;
if (src_ch >= params.weights.shape.i ||
dst_ch >= params.weights.shape.o) {
weights_reordered[counter++] = 0.0f;
} else {
const int f_index =
params.weights.shape.LinearIndex({dst_ch, y, x, src_ch});
weights_reordered[counter++] = params.weights.data[f_index];
}
}
}
}
}
}
}
}
return weights_reordered;
}
uint3 GetWorkGroupForConv() { return {8, 4, 1}; }
uint3 GetWorkGroupForConvPrecise() { return {32, 1, 1}; }
std::vector<uint8_t> GetUniformBufferForConv(
const BHWC& src_size, const BHWC& dst_size,
const Convolution2DAttributes& params) {
const int3 group_size = GetWorkGroupForConv();
std::vector<int> uniform_params = {
src_size.w,
src_size.h,
src_size.w * src_size.h,
IntegralDivideRoundUp(src_size.c, 4),
dst_size.w,
dst_size.h,
dst_size.w * dst_size.h,
IntegralDivideRoundUp(dst_size.c, 4),
params.strides.w,
params.strides.h,
-params.padding.prepended.w,
-params.padding.prepended.h,
params.weights.shape.w,
params.weights.shape.h,
params.dilations.w,
params.dilations.h,
group_size.x,
group_size.y,
group_size.z,
1u, // dummy, for alignment
};
return VectorToUint8Vector(uniform_params);
}
std::vector<uint8_t> GetUniformBufferForConvPrecise(
const BHWC& src_size, const BHWC& dst_size,
const Convolution2DAttributes& params) {
std::vector<int> uniform_params = {
src_size.w,
src_size.h,
src_size.w * src_size.h,
IntegralDivideRoundUp(src_size.c, 4),
dst_size.w,
dst_size.h,
dst_size.w * dst_size.h,
IntegralDivideRoundUp(dst_size.c, 4),
params.strides.w,
params.strides.h,
-params.padding.prepended.w,
-params.padding.prepended.h,
params.weights.shape.w,
params.weights.shape.h,
params.dilations.w,
params.dilations.h,
dst_size.w,
IntegralDivideRoundUp(dst_size.w * dst_size.h, 2),
0u, // dummy, for alignment
0u, // dummy, for alignment
};
return VectorToUint8Vector(uniform_params);
}
std::vector<uint8_t> GetUniformBufferForConvPrecise1x1(
const BHWC& src_size, const BHWC& dst_size,
const Convolution2DAttributes& params) {
std::vector<int> uniform_params = {
src_size.w,
src_size.h,
src_size.w * src_size.h,
IntegralDivideRoundUp(src_size.c, 4),
dst_size.w,
dst_size.h,
dst_size.w * dst_size.h,
IntegralDivideRoundUp(dst_size.c, 4),
dst_size.w,
IntegralDivideRoundUp(dst_size.w * dst_size.h, 1),
0u, // dummy, for alignment
0u, // dummy, for alignment
0u, // dummy, for alignment
0u, // dummy, for alignment
0u, // dummy, for alignment
0u, // dummy, for alignment
};
return VectorToUint8Vector(uniform_params);
}
uint3 GetGroupsCountForConv(const uint3& group_size, const BHWC& dst_shape) {
const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4);
int groups_x = IntegralDivideRoundUp(dst_shape.w, group_size.x);
int groups_y = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_shape.h, 2),
group_size.y);
const int z_out = GetNumOutputSlices(dst_shape.c);
int groups_z = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_depth, z_out),
group_size.z);
return {groups_x, groups_y, groups_z};
}
uint3 GetGroupsCountForConvPrecise(const uint3& group_size,
const BHWC& dst_shape, int xy_pixels) {
const int z_out = GetNumOutputSlices(dst_shape.c);
const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4);
int xy_size = IntegralDivideRoundUp(dst_shape.w * dst_shape.h, xy_pixels);
int z_size = IntegralDivideRoundUp(dst_depth, z_out);
int task_size = xy_size * z_size;
return {IntegralDivideRoundUp(task_size, group_size.x), 1, 1};
}
int GetConvolutionThreadsCount(const BHWC& dst_shape) {
const uint3 group_size = GetWorkGroupForConv();
const uint3 groups_count = GetGroupsCountForConv(group_size, dst_shape);
return groups_count.x * groups_count.y * groups_count.z * group_size.x *
group_size.y * group_size.z;
}
int GetConvolutionPreciseThreadsCount(const BHWC& dst_shape, int xy_pixels) {
const uint3 group_size = GetWorkGroupForConvPrecise();
const uint3 groups_count =
GetGroupsCountForConvPrecise(group_size, dst_shape, xy_pixels);
return groups_count.x * groups_count.y * groups_count.z * group_size.x *
group_size.y * group_size.z;
}
bool IsConv1x1(const Convolution2DAttributes& attr) {
return attr.weights.shape.h == 1 && attr.weights.shape.w == 1 &&
attr.strides.h == 1 && attr.strides.w == 1 && attr.dilations.h == 1 &&
attr.dilations.w == 1 && attr.padding.prepended.h == 0 &&
attr.padding.prepended.w == 0 && attr.padding.appended.h == 0 &&
attr.padding.appended.w == 0;
}
} // namespace
std::vector<ComputeTaskDescriptorPtr> Convolution(
int id, ValueId input_id, ValueId output_id,
const Convolution2DAttributes& params, const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
desc->shader_source = GetKernelForConv(params);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
return CalculateOutputShape(buffers.find(input_id)->second, params);
}};
auto weights_reordered = ReorderWeightsForConvShared(params);
auto weights = options.storage_precision == RuntimeOptions::Precision::FP32
? VectorToUint8Vector(weights_reordered)
: VectorFloatToHalf(weights_reordered);
auto biases = options.storage_precision == RuntimeOptions::Precision::FP32
? VectorToUint8Vector(params.bias.data)
: VectorFloatToHalf(params.bias.data);
desc->immutable_buffers = {
{"device FLT4* const weights", weights},
{"device FLT4* const biases", biases},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferForConvShared(input_dimensions,
output_dimensions, params);
}},
};
desc->resize_function = [output_id,
params](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const int num_output_slices = GetNumOutputSlices(params.weights.shape.o);
const uint3 group_size{8, 4, 1};
int groups_x = IntegralDivideRoundUp(output_dims.w, group_size.x);
int groups_y = IntegralDivideRoundUp(output_dims.h, group_size.y);
const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4);
int groups_z = IntegralDivideRoundUp(dst_depth, num_output_slices);
return std::make_pair(group_size, uint3{groups_x, groups_y, groups_z});
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> Convolution1x1(
int id, ValueId input_id, ValueId output_id,
const Convolution2DAttributes& params,
const metal::RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
const int z_out = GetNumOutputSlices(params.weights.shape.o);
desc->shader_source = GetKernelForConv1x1(params, z_out);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, params);
return out_shape;
}};
auto weights_reordered = ReorderWeightsForConv(params, z_out);
auto weights =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(weights_reordered)
: VectorFloatToHalf(weights_reordered);
auto biases =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(params.bias.data)
: VectorFloatToHalf(params.bias.data);
desc->immutable_buffers = {
{"device FLT4* const filters", weights},
{"device FLT4* const biases", biases},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferForConv(input_dimensions, output_dimensions,
params);
}},
};
desc->resize_function = [output_id,
params](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 group_size = GetWorkGroupForConv();
const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims);
return std::make_pair(
group_size, uint3{groups_count.z, groups_count.x, groups_count.y});
};
return {desc};
}
bool CheckConvolution1x1Support(const Convolution2DAttributes& attr) {
return IsConv1x1(attr);
}
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
int id, ValueId input_id, ValueId output_id,
const Convolution2DAttributes& params,
const metal::RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
const int z_out = GetNumOutputSlices(params.weights.shape.o);
desc->shader_source = GetKernelForConvGeneric(params, z_out);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, params);
return out_shape;
}};
auto weights_reordered = ReorderWeightsForConv(params, z_out);
auto weights =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(weights_reordered)
: VectorFloatToHalf(weights_reordered);
auto biases =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(params.bias.data)
: VectorFloatToHalf(params.bias.data);
desc->immutable_buffers = {
{"device FLT4* const filters", weights},
{"device FLT4* const biases", biases},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferForConv(input_dimensions, output_dimensions,
params);
}},
};
desc->resize_function = [output_id,
params](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 group_size = GetWorkGroupForConv();
const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims);
return std::make_pair(
group_size, uint3{groups_count.z, groups_count.x, groups_count.y});
};
return {desc};
}
std::vector<ComputeTaskDescriptorPtr> ConvolutionPrecise(
int id, ValueId input_id, ValueId output_id,
const Convolution2DAttributes& params,
const metal::RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
const int z_out = GetNumOutputSlices(params.weights.shape.o);
desc->shader_source = GetKernelForConvPrecise(z_out);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, params);
return out_shape;
}};
auto weights_reordered = ReorderWeightsForConv(params, z_out);
auto weights =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(weights_reordered)
: VectorFloatToHalf(weights_reordered);
auto biases =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(params.bias.data)
: VectorFloatToHalf(params.bias.data);
desc->immutable_buffers = {
{"device FLT4* const filters", weights},
{"device FLT4* const biases", biases},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferForConvPrecise(input_dimensions,
output_dimensions, params);
}},
};
desc->resize_function = [output_id,
params](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 group_size = GetWorkGroupForConvPrecise();
const uint3 groups_count =
GetGroupsCountForConvPrecise(group_size, output_dims, 2);
return std::make_pair(group_size, groups_count);
};
return {desc};
}
float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape) {
return static_cast<float>(GetConvolutionThreadsCount(dst_shape)) /
static_cast<float>(GetConvolutionPreciseThreadsCount(dst_shape, 2));
}
std::vector<ComputeTaskDescriptorPtr> ConvolutionPrecise1x1PowerVR(
int id, ValueId input_id, ValueId output_id,
const Convolution2DAttributes& params, const RuntimeOptions& options) {
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
desc->is_linkable = false;
const int z_out = GetNumOutputSlices(params.weights.shape.o);
desc->shader_source = GetKernelForConvPrecise1x1PowerVR(z_out);
desc->input_buffers = {
{input_id, "device FLT4* const src_buffer"},
};
desc->output_buffer = {
output_id, "device FLT4* dst_buffer",
[input_id, params](const std::map<ValueId, BHWC>& buffers) {
auto out_shape =
CalculateOutputShape(buffers.find(input_id)->second, params);
return out_shape;
}};
auto weights_reordered = ReorderWeightsForConv(params, z_out);
auto weights =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(weights_reordered)
: VectorFloatToHalf(weights_reordered);
auto biases =
options.storage_precision == metal::RuntimeOptions::Precision::FP32
? VectorToUint8Vector(params.bias.data)
: VectorFloatToHalf(params.bias.data);
desc->immutable_buffers = {
{"device FLT4* const filters", weights},
{"device FLT4* const biases", biases},
};
desc->uniform_buffers = {
{"constant uniforms& params",
[input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
const auto& input_dimensions = buffers.find(input_id)->second;
const auto& output_dimensions = buffers.find(output_id)->second;
return GetUniformBufferForConvPrecise1x1(input_dimensions,
output_dimensions, params);
}},
};
desc->resize_function = [output_id,
params](const std::map<ValueId, BHWC>& buffers) {
const auto& output_dims = buffers.find(output_id)->second;
const uint3 group_size = GetWorkGroupForConvPrecise();
const uint3 groups_count =
GetGroupsCountForConvPrecise(group_size, output_dims, 1);
return std::make_pair(group_size, groups_count);
};
return {desc};
}
bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr) {
return IsConv1x1(attr);
}
} // namespace metal
} // namespace gpu
} // namespace tflite