std::string GetKernelForConv1x1()

in tensorflow/tensorflow/lite/delegates/gpu/metal/kernels/conv.cc [315:1216]


std::string GetKernelForConv1x1(const Convolution2DAttributes& params,
                                int z_out) {
  std::string code;
  code.reserve(16 * 1024);  // Reserve large enough buffer.
  std::string channels[4] = {"x", "y", "z", "w"};
  code += R"(
#include <metal_stdlib>
using namespace metal;

struct uniforms {
  int4 src_size;
  int4 dst_size;
  int4 stride_padding;
  int4 kernel_dilation;
  uint4 work_group_size;
};
$0

kernel void ComputeFunction(
                            $1
                            uint3 group_id[[threadgroup_position_in_grid]],
                            uint3 tid3d[[thread_position_in_threadgroup]])
{
  int gid_x = group_id.y * params.work_group_size.x + tid3d.x;
  int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) << 1u;
  )";
  code += "  int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " +
          std::to_string(z_out) + "u;\n";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
    code += "  ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
  }
  code += R"(
  device FLT4* tmp = filters + gid_z * 4 * params.src_size.w;

  int y0 = clamp(gid_y, 0, params.src_size.y - 1);
  int y1 = clamp(gid_y + 1, 0, params.src_size.y - 1);
  int x0 = clamp(gid_x, 0, params.src_size.x - 1);

  int s = 0;

  device FLT4* src_loc_0 = src_buffer + y0 * params.src_size.x + x0;
  device FLT4* src_loc_1 = src_buffer + y1 * params.src_size.x + x0;
  do {
    FLT4 src_0 = *src_loc_0;
    FLT4 src_1 = *src_loc_1;
    src_loc_0 += params.src_size.z;
    src_loc_1 += params.src_size.z;
    )";
  for (int i = 0; i < z_out * 4; ++i) {
    const std::string s_i = std::to_string(i);
    code += "    r" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_0);\n";
    code += "    l" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_1);\n";
  }

  code += "    tmp += " + std::to_string(z_out * 4) + ";\n";
  code += R"(
    s += 1;
  } while (s < params.src_size.w);
  const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x;
  const int offset_1 = offset_0 + params.dst_size.x;
  bool y0_in = gid_y < params.dst_size.y;
  bool y1_in = gid_y + 1 < params.dst_size.y;

  device FLT4* bias_loc = biases + gid_z;
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
    code += "  l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
  }
  code += R"(
  if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
      return;
  }
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  if (gid_z + " + s_i + "< params.dst_size.w) {\n";
    code += "    if (y0_in) {\n";
    code += "      FLT4 value = FLT4(r" + s_i + ");\n";
    code += "      int linear_index = offset_0 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "    if (y1_in) {\n";
    code += "      FLT4 value = FLT4(l" + s_i + ");\n";
    code += "      int linear_index = offset_1 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "  }\n";
  }
  code += "  }\n";
  return code;
}

std::string GetKernelForConvGeneric(const Convolution2DAttributes& params,
                                    int z_out) {
  std::string code;
  code.reserve(16 * 1024);  // Reserve large enough buffer.
  std::string channels[4] = {"x", "y", "z", "w"};
  code += R"(
#include <metal_stdlib>
using namespace metal;

struct uniforms {
  int4 src_size;
  int4 dst_size;
  int4 stride_padding;
  int4 kernel_dilation;
  uint4 work_group_size;
};
$0

kernel void ComputeFunction(
                            $1
                            uint3 group_id[[threadgroup_position_in_grid]],
                            uint3 tid3d[[thread_position_in_threadgroup]])
{
  int gid_x = group_id.y * params.work_group_size.x + tid3d.x;
  int gid_y = (group_id.z * params.work_group_size.y + tid3d.y) * 2;
  )";
  code += "  int gid_z = (group_id.x * params.work_group_size.z + tid3d.z) * " +
          std::to_string(z_out) + "u;\n";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
    code += "  ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
  }
  code += R"(
  device FLT4* tmp = filters + gid_z * 4 * params.src_size.w * params.kernel_dilation.x * params.kernel_dilation.y;

  int y0 = gid_y * params.stride_padding.y + params.stride_padding.w;
  int y1 = (gid_y + 1) * params.stride_padding.y + params.stride_padding.w;
  int x0 = gid_x * params.stride_padding.x + params.stride_padding.z;

  int y = 0;
  do {
    int coord_y0 = y * params.kernel_dilation.w + y0;
    int coord_y1 = y * params.kernel_dilation.w + y1;
    bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y;
    bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y;
    coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1);
    coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1);
    int x = 0;
    do {
      int coord_x0 = x * params.kernel_dilation.z + x0;
      bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x;
      coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1);
      FLT m0 = !(y0_out || x0_out);
      FLT m1 = !(y1_out || x0_out);
      int s = 0;
      device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0;
      device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x0;
      do {
        FLT4 src_0 = *src_loc_0 * m0;
        FLT4 src_1 = *src_loc_1 * m1;
        src_loc_0 += params.src_size.z;
        src_loc_1 += params.src_size.z;
    )";
  for (int i = 0; i < z_out * 4; ++i) {
    const std::string s_i = std::to_string(i);
    code += "        r" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_0);\n";
    code += "        l" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_1);\n";
  }

  code += "        tmp += " + std::to_string(z_out * 4) + ";\n";
  code += R"(
        s += 1;
      } while (s < params.src_size.w);
      x++;
    } while (x < params.kernel_dilation.x);
    y++;
  } while (y < params.kernel_dilation.y);
  const int offset_0 = gid_z * params.dst_size.z + gid_y * params.dst_size.x + gid_x;
  const int offset_1 = offset_0 + params.dst_size.x;
  bool p0_in = gid_x < params.dst_size.x && gid_y < params.dst_size.y;
  bool p1_in = gid_x < params.dst_size.x && gid_y + 1 < params.dst_size.y;

  device FLT4* bias_loc = biases + gid_z;
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
    code += "  l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
  }
  code += R"(
  if (gid_x >= params.dst_size.x || gid_y >= params.dst_size.y) {
      return;
  }
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  if (gid_z + " + s_i + "< params.dst_size.w) {\n";
    code += "    if (p0_in) {\n";
    code += "      FLT4 value = FLT4(r" + s_i + ");\n";
    code += "      int linear_index = offset_0 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x, gid_y, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "    if (p1_in) {\n";
    code += "      FLT4 value = FLT4(l" + s_i + ");\n";
    code += "      int linear_index = offset_1 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x, gid_y + 1, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "  }\n";
  }
  code += "  }\n";
  return code;
}

std::string GetKernelForConvPrecise(int z_out) {
  std::string channels[4] = {"x", "y", "z", "w"};
  std::string code;
  code.reserve(16 * 1024);  // Reserve large enough buffer.
  code += R"(
#include <metal_stdlib>
using namespace metal;

struct uniforms {
    int4 src_size;
    int4 dst_size;
    int4 stride_padding;
    int4 kernel_dilation;
    int4 slices;
};
$0

kernel void ComputeFunction(
                            $1
                            uint3 ugid[[thread_position_in_grid]])
{
    int linear_id = ugid.x;
    int gid_z = linear_id / params.slices.y;
    int linear_xy = (linear_id - gid_z * params.slices.y) << 1;
    )";
  code += "    gid_z *= " + std::to_string(z_out) + ";\n";
  code += R"(
    int gid_y0 = linear_xy / params.slices.x;
    int gid_x0 = linear_xy - gid_y0 * params.slices.x;
    linear_xy += 1;
    int gid_y1 = linear_xy / params.slices.x;
    int gid_x1 = linear_xy - gid_y1 * params.slices.x;

    if (gid_z >= params.dst_size.w) return;
    )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  ACCUM_FLT4 r" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
    code += "  ACCUM_FLT4 l" + s_i + " = ACCUM_FLT4(0.0f, 0.0f, 0.0f, 0.0f);\n";
  }
  code += R"(
    device FLT4* tmp = filters + gid_z * 4 * params.src_size.w *
                       params.kernel_dilation.x * params.kernel_dilation.y;

    int y0 = gid_y0 * params.stride_padding.y + params.stride_padding.w;
    int y1 = gid_y1 * params.stride_padding.y + params.stride_padding.w;
    int x0 = gid_x0 * params.stride_padding.x + params.stride_padding.z;
    int x1 = gid_x1 * params.stride_padding.x + params.stride_padding.z;
)";
  code += R"(
    int y = 0;
    do {
    int coord_y0 = y * params.kernel_dilation.w + y0;
    int coord_y1 = y * params.kernel_dilation.w + y1;
    bool y0_out = coord_y0 < 0 || coord_y0 >= params.src_size.y;
    bool y1_out = coord_y1 < 0 || coord_y1 >= params.src_size.y;
    coord_y0 = clamp(coord_y0, 0, params.src_size.y - 1);
    coord_y1 = clamp(coord_y1, 0, params.src_size.y - 1);
    int x = 0;
    do {
    int coord_x0 = x * params.kernel_dilation.z + x0;
    int coord_x1 = x * params.kernel_dilation.z + x1;
    bool x0_out = coord_x0 < 0 || coord_x0 >= params.src_size.x;
    bool x1_out = coord_x1 < 0 || coord_x1 >= params.src_size.x;
    coord_x0 = clamp(coord_x0, 0, params.src_size.x - 1);
    coord_x1 = clamp(coord_x1, 0, params.src_size.x - 1);
    FLT m0 = !(y0_out || x0_out);
    FLT m1 = !(y1_out || x1_out);
    device FLT4* src_loc_0 = src_buffer + coord_y0 * params.src_size.x + coord_x0;
    device FLT4* src_loc_1 = src_buffer + coord_y1 * params.src_size.x + coord_x1;
    int s = 0;
    do {
        FLT4 src_0 = *src_loc_0 * m0;
        FLT4 src_1 = *src_loc_1 * m1;
        src_loc_0 += params.src_size.z;
        src_loc_1 += params.src_size.z;
)";
  for (int i = 0; i < z_out * 4; ++i) {
    const std::string s_i = std::to_string(i);
    code += "        r" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_0);\n";
    code += "        l" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_1);\n";
  }

  code += "        tmp += " + std::to_string(z_out * 4) + ";\n";
  code += R"(
        s += 1;
      } while (s < params.src_size.w);
      x++;
    } while (x < params.kernel_dilation.x);
    y++;
  } while (y < params.kernel_dilation.y);
  const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0;
  const int offset_1 = gid_z * params.dst_size.z + gid_y1 * params.dst_size.x + gid_x1;
  bool p0_in = gid_x0 < params.dst_size.x && gid_y0 < params.dst_size.y;
  bool p1_in = gid_x1 < params.dst_size.x && gid_y1 < params.dst_size.y;

  device FLT4* bias_loc = biases + gid_z;
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  r" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
    code += "  l" + s_i + " += TO_ACCUM4_TYPE(bias_loc[" + s_i + "]);\n";
  }
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  if (gid_z + " + s_i + "< params.dst_size.w) {\n";
    code += "    if (p0_in) {\n";
    code += "      FLT4 value = FLT4(r" + s_i + ");\n";
    code += "      int linear_index = offset_0 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "    if (p1_in) {\n";
    code += "      FLT4 value = FLT4(l" + s_i + ");\n";
    code += "      int linear_index = offset_1 + params.dst_size.z * " + s_i +
            ";\n";
    code += "      uint3 gid = uint3(gid_x1, gid_y1, gid_z + " + s_i + ");\n";
    code += "      $2\n";
    code += "      dst_buffer[linear_index] = value;\n";
    code += "    }\n";
    code += "  }\n";
  }
  code += "  }\n";
  return code;
}

std::string GetKernelForConvPrecise1x1PowerVR(int z_out) {
  std::string channels[4] = {"x", "y", "z", "w"};
  std::string code;
  code.reserve(16 * 1024);  // Reserve large enough buffer.
  code += R"(
#include <metal_stdlib>
using namespace metal;

struct uniforms {
    int4 src_size;
    int4 dst_size;
    int4 slices;
    int4 dummy0;
};
$0

kernel void ComputeFunction(
                            $1
                            uint3 ugid[[thread_position_in_grid]])
{
  int linear_id = ugid.x;
  int gid_z = linear_id / params.slices.y;
  int linear_xy = linear_id - gid_z * params.slices.y;
)";
  code += "    gid_z *= " + std::to_string(z_out) + ";\n";
  code += R"(
  int gid_y0 = linear_xy / params.slices.x;
  int gid_x0 = linear_xy - gid_y0 * params.slices.x;

  if (gid_z >= params.dst_size.w) return;
)";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  float4 r" + s_i + " = float4(0.0f, 0.0f, 0.0f, 0.0f);\n";
  }
  code += R"(
  device FLT4* tmp = filters + gid_z * 4 * params.src_size.w;

  device FLT4* src_loc_0 = src_buffer + gid_y0 * params.src_size.x + gid_x0;
  int s = 0;
  do {
    FLT4 src_0 = *src_loc_0;
    src_loc_0 += params.src_size.z;
)";
  for (int i = 0; i < z_out * 4; ++i) {
    const std::string s_i = std::to_string(i);
    code += "    r" + std::to_string(i / 4) + "." + channels[i % 4] +
            " += dot(tmp[" + s_i + "], src_0);\n";
  }

  code += "    tmp += " + std::to_string(z_out * 4) + ";\n";
  code += R"(
    s += 1;
  } while (s < params.src_size.w);
  const int offset_0 = gid_z * params.dst_size.z + gid_y0 * params.dst_size.x + gid_x0;

  device FLT4* bias_loc = biases + gid_z;
  )";
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  r" + s_i + " += float4(bias_loc[" + s_i + "]);\n";
  }
  for (int i = 0; i < z_out; ++i) {
    const std::string s_i = std::to_string(i);
    code += "  if (gid_z + " + s_i + "< params.dst_size.w) {\n";
    code += "    FLT4 value = FLT4(r" + s_i + ");\n";
    code +=
        "    int linear_index = offset_0 + params.dst_size.z * " + s_i + ";\n";
    code += "    uint3 gid = uint3(gid_x0, gid_y0, gid_z + " + s_i + ");\n";
    code += "    $2\n";
    code += "    dst_buffer[linear_index] = value;\n";
    code += "  }\n";
  }
  code += "  }\n";
  return code;
}

// Reorder weights to make the weights memory access pattern cache friendly for
// Convolution1x1/ConvolutionGeneric
std::vector<float> ReorderWeightsForConv(const Convolution2DAttributes& params,
                                         int z_out) {
  const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4);
  const int src_depth = IntegralDivideRoundUp(params.weights.shape.i, 4);
  std::vector<float> weights_reordered(params.weights.shape.w *
                                       params.weights.shape.h * dst_depth * 4 *
                                       src_depth * 4);
  int counter = 0;
  for (int d = 0; d < IntegralDivideRoundUp(dst_depth, z_out); ++d) {
    for (int y = 0; y < params.weights.shape.h; ++y) {
      for (int x = 0; x < params.weights.shape.w; ++x) {
        for (int s = 0; s < src_depth; ++s) {
          for (int k = 0; k < z_out; ++k) {
            for (int j = 0; j < 4; ++j) {
              for (int i = 0; i < 4; ++i) {
                int src_ch = s * 4 + i;
                int dst_ch = (d * z_out + k) * 4 + j;
                if (src_ch >= params.weights.shape.i ||
                    dst_ch >= params.weights.shape.o) {
                  weights_reordered[counter++] = 0.0f;
                } else {
                  const int f_index =
                      params.weights.shape.LinearIndex({dst_ch, y, x, src_ch});
                  weights_reordered[counter++] = params.weights.data[f_index];
                }
              }
            }
          }
        }
      }
    }
  }
  return weights_reordered;
}

uint3 GetWorkGroupForConv() { return {8, 4, 1}; }
uint3 GetWorkGroupForConvPrecise() { return {32, 1, 1}; }

std::vector<uint8_t> GetUniformBufferForConv(
    const BHWC& src_size, const BHWC& dst_size,
    const Convolution2DAttributes& params) {
  const int3 group_size = GetWorkGroupForConv();
  std::vector<int> uniform_params = {
      src_size.w,
      src_size.h,
      src_size.w * src_size.h,
      IntegralDivideRoundUp(src_size.c, 4),
      dst_size.w,
      dst_size.h,
      dst_size.w * dst_size.h,
      IntegralDivideRoundUp(dst_size.c, 4),
      params.strides.w,
      params.strides.h,
      -params.padding.prepended.w,
      -params.padding.prepended.h,
      params.weights.shape.w,
      params.weights.shape.h,
      params.dilations.w,
      params.dilations.h,
      group_size.x,
      group_size.y,
      group_size.z,
      1u,  // dummy, for alignment
  };
  return VectorToUint8Vector(uniform_params);
}

std::vector<uint8_t> GetUniformBufferForConvPrecise(
    const BHWC& src_size, const BHWC& dst_size,
    const Convolution2DAttributes& params) {
  std::vector<int> uniform_params = {
      src_size.w,
      src_size.h,
      src_size.w * src_size.h,
      IntegralDivideRoundUp(src_size.c, 4),
      dst_size.w,
      dst_size.h,
      dst_size.w * dst_size.h,
      IntegralDivideRoundUp(dst_size.c, 4),
      params.strides.w,
      params.strides.h,
      -params.padding.prepended.w,
      -params.padding.prepended.h,
      params.weights.shape.w,
      params.weights.shape.h,
      params.dilations.w,
      params.dilations.h,
      dst_size.w,
      IntegralDivideRoundUp(dst_size.w * dst_size.h, 2),
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
  };
  return VectorToUint8Vector(uniform_params);
}

std::vector<uint8_t> GetUniformBufferForConvPrecise1x1(
    const BHWC& src_size, const BHWC& dst_size,
    const Convolution2DAttributes& params) {
  std::vector<int> uniform_params = {
      src_size.w,
      src_size.h,
      src_size.w * src_size.h,
      IntegralDivideRoundUp(src_size.c, 4),
      dst_size.w,
      dst_size.h,
      dst_size.w * dst_size.h,
      IntegralDivideRoundUp(dst_size.c, 4),
      dst_size.w,
      IntegralDivideRoundUp(dst_size.w * dst_size.h, 1),
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
      0u,  // dummy, for alignment
  };
  return VectorToUint8Vector(uniform_params);
}

uint3 GetGroupsCountForConv(const uint3& group_size, const BHWC& dst_shape) {
  const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4);
  int groups_x = IntegralDivideRoundUp(dst_shape.w, group_size.x);
  int groups_y = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_shape.h, 2),
                                       group_size.y);
  const int z_out = GetNumOutputSlices(dst_shape.c);
  int groups_z = IntegralDivideRoundUp(IntegralDivideRoundUp(dst_depth, z_out),
                                       group_size.z);
  return {groups_x, groups_y, groups_z};
}

uint3 GetGroupsCountForConvPrecise(const uint3& group_size,
                                   const BHWC& dst_shape, int xy_pixels) {
  const int z_out = GetNumOutputSlices(dst_shape.c);
  const int dst_depth = IntegralDivideRoundUp(dst_shape.c, 4);
  int xy_size = IntegralDivideRoundUp(dst_shape.w * dst_shape.h, xy_pixels);
  int z_size = IntegralDivideRoundUp(dst_depth, z_out);
  int task_size = xy_size * z_size;
  return {IntegralDivideRoundUp(task_size, group_size.x), 1, 1};
}

int GetConvolutionThreadsCount(const BHWC& dst_shape) {
  const uint3 group_size = GetWorkGroupForConv();
  const uint3 groups_count = GetGroupsCountForConv(group_size, dst_shape);
  return groups_count.x * groups_count.y * groups_count.z * group_size.x *
         group_size.y * group_size.z;
}

int GetConvolutionPreciseThreadsCount(const BHWC& dst_shape, int xy_pixels) {
  const uint3 group_size = GetWorkGroupForConvPrecise();
  const uint3 groups_count =
      GetGroupsCountForConvPrecise(group_size, dst_shape, xy_pixels);
  return groups_count.x * groups_count.y * groups_count.z * group_size.x *
         group_size.y * group_size.z;
}

bool IsConv1x1(const Convolution2DAttributes& attr) {
  return attr.weights.shape.h == 1 && attr.weights.shape.w == 1 &&
         attr.strides.h == 1 && attr.strides.w == 1 && attr.dilations.h == 1 &&
         attr.dilations.w == 1 && attr.padding.prepended.h == 0 &&
         attr.padding.prepended.w == 0 && attr.padding.appended.h == 0 &&
         attr.padding.appended.w == 0;
}

}  // namespace

std::vector<ComputeTaskDescriptorPtr> Convolution(
    int id, ValueId input_id, ValueId output_id,
    const Convolution2DAttributes& params, const RuntimeOptions& options) {
  auto desc = std::make_shared<ComputeTaskDescriptor>();
  desc->id = id;
  desc->is_linkable = false;
  desc->shader_source = GetKernelForConv(params);

  desc->input_buffers = {
      {input_id, "device FLT4* const src_buffer"},
  };

  desc->output_buffer = {
      output_id, "device FLT4* dst_buffer",
      [input_id, params](const std::map<ValueId, BHWC>& buffers) {
        return CalculateOutputShape(buffers.find(input_id)->second, params);
      }};

  auto weights_reordered = ReorderWeightsForConvShared(params);
  auto weights = options.storage_precision == RuntimeOptions::Precision::FP32
                     ? VectorToUint8Vector(weights_reordered)
                     : VectorFloatToHalf(weights_reordered);
  auto biases = options.storage_precision == RuntimeOptions::Precision::FP32
                    ? VectorToUint8Vector(params.bias.data)
                    : VectorFloatToHalf(params.bias.data);
  desc->immutable_buffers = {
      {"device FLT4* const weights", weights},
      {"device FLT4* const biases", biases},
  };

  desc->uniform_buffers = {
      {"constant uniforms& params",
       [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
         const auto& input_dimensions = buffers.find(input_id)->second;
         const auto& output_dimensions = buffers.find(output_id)->second;
         return GetUniformBufferForConvShared(input_dimensions,
                                              output_dimensions, params);
       }},
  };

  desc->resize_function = [output_id,
                           params](const std::map<ValueId, BHWC>& buffers) {
    const auto& output_dims = buffers.find(output_id)->second;
    const int num_output_slices = GetNumOutputSlices(params.weights.shape.o);
    const uint3 group_size{8, 4, 1};
    int groups_x = IntegralDivideRoundUp(output_dims.w, group_size.x);
    int groups_y = IntegralDivideRoundUp(output_dims.h, group_size.y);
    const int dst_depth = IntegralDivideRoundUp(params.weights.shape.o, 4);
    int groups_z = IntegralDivideRoundUp(dst_depth, num_output_slices);
    return std::make_pair(group_size, uint3{groups_x, groups_y, groups_z});
  };

  return {desc};
}

std::vector<ComputeTaskDescriptorPtr> Convolution1x1(
    int id, ValueId input_id, ValueId output_id,
    const Convolution2DAttributes& params,
    const metal::RuntimeOptions& options) {
  auto desc = std::make_shared<ComputeTaskDescriptor>();
  desc->id = id;
  desc->is_linkable = false;
  const int z_out = GetNumOutputSlices(params.weights.shape.o);
  desc->shader_source = GetKernelForConv1x1(params, z_out);

  desc->input_buffers = {
      {input_id, "device FLT4* const src_buffer"},
  };

  desc->output_buffer = {
      output_id, "device FLT4* dst_buffer",
      [input_id, params](const std::map<ValueId, BHWC>& buffers) {
        auto out_shape =
            CalculateOutputShape(buffers.find(input_id)->second, params);
        return out_shape;
      }};

  auto weights_reordered = ReorderWeightsForConv(params, z_out);
  auto weights =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(weights_reordered)
          : VectorFloatToHalf(weights_reordered);
  auto biases =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(params.bias.data)
          : VectorFloatToHalf(params.bias.data);
  desc->immutable_buffers = {
      {"device FLT4* const filters", weights},
      {"device FLT4* const biases", biases},
  };

  desc->uniform_buffers = {
      {"constant uniforms& params",
       [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
         const auto& input_dimensions = buffers.find(input_id)->second;
         const auto& output_dimensions = buffers.find(output_id)->second;
         return GetUniformBufferForConv(input_dimensions, output_dimensions,
                                        params);
       }},
  };

  desc->resize_function = [output_id,
                           params](const std::map<ValueId, BHWC>& buffers) {
    const auto& output_dims = buffers.find(output_id)->second;
    const uint3 group_size = GetWorkGroupForConv();
    const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims);
    return std::make_pair(
        group_size, uint3{groups_count.z, groups_count.x, groups_count.y});
  };

  return {desc};
}

bool CheckConvolution1x1Support(const Convolution2DAttributes& attr) {
  return IsConv1x1(attr);
}

std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
    int id, ValueId input_id, ValueId output_id,
    const Convolution2DAttributes& params,
    const metal::RuntimeOptions& options) {
  auto desc = std::make_shared<ComputeTaskDescriptor>();
  desc->id = id;
  desc->is_linkable = false;
  const int z_out = GetNumOutputSlices(params.weights.shape.o);
  desc->shader_source = GetKernelForConvGeneric(params, z_out);

  desc->input_buffers = {
      {input_id, "device FLT4* const src_buffer"},
  };

  desc->output_buffer = {
      output_id, "device FLT4* dst_buffer",
      [input_id, params](const std::map<ValueId, BHWC>& buffers) {
        auto out_shape =
            CalculateOutputShape(buffers.find(input_id)->second, params);
        return out_shape;
      }};

  auto weights_reordered = ReorderWeightsForConv(params, z_out);
  auto weights =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(weights_reordered)
          : VectorFloatToHalf(weights_reordered);
  auto biases =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(params.bias.data)
          : VectorFloatToHalf(params.bias.data);
  desc->immutable_buffers = {
      {"device FLT4* const filters", weights},
      {"device FLT4* const biases", biases},
  };

  desc->uniform_buffers = {
      {"constant uniforms& params",
       [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
         const auto& input_dimensions = buffers.find(input_id)->second;
         const auto& output_dimensions = buffers.find(output_id)->second;
         return GetUniformBufferForConv(input_dimensions, output_dimensions,
                                        params);
       }},
  };

  desc->resize_function = [output_id,
                           params](const std::map<ValueId, BHWC>& buffers) {
    const auto& output_dims = buffers.find(output_id)->second;
    const uint3 group_size = GetWorkGroupForConv();
    const uint3 groups_count = GetGroupsCountForConv(group_size, output_dims);
    return std::make_pair(
        group_size, uint3{groups_count.z, groups_count.x, groups_count.y});
  };

  return {desc};
}

std::vector<ComputeTaskDescriptorPtr> ConvolutionPrecise(
    int id, ValueId input_id, ValueId output_id,
    const Convolution2DAttributes& params,
    const metal::RuntimeOptions& options) {
  auto desc = std::make_shared<ComputeTaskDescriptor>();
  desc->id = id;
  desc->is_linkable = false;
  const int z_out = GetNumOutputSlices(params.weights.shape.o);
  desc->shader_source = GetKernelForConvPrecise(z_out);

  desc->input_buffers = {
      {input_id, "device FLT4* const src_buffer"},
  };

  desc->output_buffer = {
      output_id, "device FLT4* dst_buffer",
      [input_id, params](const std::map<ValueId, BHWC>& buffers) {
        auto out_shape =
            CalculateOutputShape(buffers.find(input_id)->second, params);
        return out_shape;
      }};

  auto weights_reordered = ReorderWeightsForConv(params, z_out);
  auto weights =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(weights_reordered)
          : VectorFloatToHalf(weights_reordered);
  auto biases =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(params.bias.data)
          : VectorFloatToHalf(params.bias.data);
  desc->immutable_buffers = {
      {"device FLT4* const filters", weights},
      {"device FLT4* const biases", biases},
  };

  desc->uniform_buffers = {
      {"constant uniforms& params",
       [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
         const auto& input_dimensions = buffers.find(input_id)->second;
         const auto& output_dimensions = buffers.find(output_id)->second;
         return GetUniformBufferForConvPrecise(input_dimensions,
                                               output_dimensions, params);
       }},
  };

  desc->resize_function = [output_id,
                           params](const std::map<ValueId, BHWC>& buffers) {
    const auto& output_dims = buffers.find(output_id)->second;
    const uint3 group_size = GetWorkGroupForConvPrecise();
    const uint3 groups_count =
        GetGroupsCountForConvPrecise(group_size, output_dims, 2);
    return std::make_pair(group_size, groups_count);
  };

  return {desc};
}

float GetThreadsRatioUsualToPreciseConvolution(const BHWC& dst_shape) {
  return static_cast<float>(GetConvolutionThreadsCount(dst_shape)) /
         static_cast<float>(GetConvolutionPreciseThreadsCount(dst_shape, 2));
}

std::vector<ComputeTaskDescriptorPtr> ConvolutionPrecise1x1PowerVR(
    int id, ValueId input_id, ValueId output_id,
    const Convolution2DAttributes& params, const RuntimeOptions& options) {
  auto desc = std::make_shared<ComputeTaskDescriptor>();
  desc->id = id;
  desc->is_linkable = false;
  const int z_out = GetNumOutputSlices(params.weights.shape.o);
  desc->shader_source = GetKernelForConvPrecise1x1PowerVR(z_out);

  desc->input_buffers = {
      {input_id, "device FLT4* const src_buffer"},
  };

  desc->output_buffer = {
      output_id, "device FLT4* dst_buffer",
      [input_id, params](const std::map<ValueId, BHWC>& buffers) {
        auto out_shape =
            CalculateOutputShape(buffers.find(input_id)->second, params);
        return out_shape;
      }};

  auto weights_reordered = ReorderWeightsForConv(params, z_out);
  auto weights =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(weights_reordered)
          : VectorFloatToHalf(weights_reordered);
  auto biases =
      options.storage_precision == metal::RuntimeOptions::Precision::FP32
          ? VectorToUint8Vector(params.bias.data)
          : VectorFloatToHalf(params.bias.data);
  desc->immutable_buffers = {
      {"device FLT4* const filters", weights},
      {"device FLT4* const biases", biases},
  };

  desc->uniform_buffers = {
      {"constant uniforms& params",
       [input_id, output_id, params](const std::map<ValueId, BHWC>& buffers) {
         const auto& input_dimensions = buffers.find(input_id)->second;
         const auto& output_dimensions = buffers.find(output_id)->second;
         return GetUniformBufferForConvPrecise1x1(input_dimensions,
                                                  output_dimensions, params);
       }},
  };

  desc->resize_function = [output_id,
                           params](const std::map<ValueId, BHWC>& buffers) {
    const auto& output_dims = buffers.find(output_id)->second;
    const uint3 group_size = GetWorkGroupForConvPrecise();
    const uint3 groups_count =
        GetGroupsCountForConvPrecise(group_size, output_dims, 1);
    return std::make_pair(group_size, groups_count);
  };

  return {desc};
}

bool CheckConvolutionPrecise1x1Support(const Convolution2DAttributes& attr) {
  return IsConv1x1(attr);
}

}  // namespace metal
}  // namespace gpu
}  // namespace tflite