void conv2d()

in tfjs-backend-wasm/src/cc/conv2d_impl.cc [114:288]


void conv2d(const size_t x_id, const size_t batch_size,
            const size_t input_height, const size_t input_width,
            const size_t filter_id, const size_t filter_height,
            const size_t filter_width, const size_t bias_id, size_t pad_top,
            size_t pad_right, size_t pad_bottom, size_t pad_left,
            const bool is_same_pad, const size_t dilation_height,
            const size_t dilation_width, const size_t stride_height,
            const size_t stride_width, const size_t input_channels,
            const size_t output_channels, const bool is_depthwise,
            const FusableActivation activation, const size_t prelu_weights_id,
            const float leakyrelu_alpha, const size_t out_id) {
  auto& x_info = backend::get_tensor_info(x_id);
  auto& filter_info = backend::get_tensor_info(filter_id);
  auto& out_info = backend::get_tensor_info_out(out_id);

  const float* x_buf = x_info.f32();
  const float* filter_buf = filter_info.f32();
  const float* bias_buf = nullptr;
  if (bias_id != 0) {
    bias_buf = backend::get_tensor_info_out(bias_id).f32();
  }

  float* out_buf = out_info.f32_write();
  std::vector<float> intermediate_output;

  if (prelu_weights_id != 0 || activation == FusableActivation::LEAKYRELU) {
    intermediate_output.resize(out_info.size);
    out_buf = intermediate_output.data();
  }

  xnn_operator_t conv2d_op = nullptr;

  size_t flags = 0;
  if (is_same_pad) {
    pad_top = 0, pad_right = 0, pad_bottom = 0, pad_left = 0;
    flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING;
  }

  size_t groups;
  size_t group_input_channels;
  size_t group_output_channels;
  const size_t input_pixel_stride = input_channels;
  const size_t output_pixel_stride = output_channels;
  if (is_depthwise) {
    groups = input_channels;
    group_input_channels = 1;
    group_output_channels = output_channels / input_channels;
    flags |= XNN_FLAG_DEPTHWISE_CONVOLUTION;
  } else {
    groups = 1;
    group_input_channels = input_channels;
    group_output_channels = output_channels;
  }

  FusableActivation clamp_method = activation;
  if (activation == FusableActivation::PRELU ||
      activation == FusableActivation::LEAKYRELU) {
    clamp_method = FusableActivation::LINEAR;
  }

  float output_min = -std::numeric_limits<float>::infinity();
  float output_max = std::numeric_limits<float>::infinity();

  if (activation == FusableActivation::RELU) {
    output_min = 0;
  } else if (activation == FusableActivation::RELU6) {
    output_min = 0;
    output_max = 6;
  }

  OperatorCacheKey cache_key = {pad_top,
                                pad_right,
                                pad_bottom,
                                pad_left,
                                filter_height,
                                filter_width,
                                stride_height,
                                stride_width,
                                dilation_height,
                                dilation_width,
                                groups,
                                group_input_channels,
                                group_output_channels,
                                input_pixel_stride,
                                output_pixel_stride,
                                clamp_method,
                                filter_id,
                                bias_id,
                                flags,
                                output_min,
                                output_max};

  auto operator_cache_idx = operator_cache.find(cache_key);
  if (operator_cache_idx == operator_cache.end()) {
    // This lives outside the if statement so the data survives the scope.
    std::vector<float> transposed_filter;

    const float* filter_xnn;
    if (is_depthwise) {
      // For depthwiseConv2d, xnn pack and TensorFlow expect the same weights
      // layout:
      //   [filter_height, filter_width, input_channels, channel_multiplier]
      filter_xnn = filter_buf;
    } else {
      // For regular conv2d, xnn pack expects weights layed out like:
      //   [output_channels, filter_height, filter_width, input_channels]
      // TensorFlow has weights layed out like:
      //   [filter_height, filter_width, input_channels, output_channels]
      // This can be transposed with a 2d transpose to move output_channels to
      // the outer most dimension.
      transposed_filter.resize(filter_info.size);
      std::vector<size_t> filter_shape = {
          filter_height * filter_width * input_channels, output_channels};
      std::vector<size_t> perm = {1, 0};

      transpose(filter_buf, filter_shape, perm, transposed_filter.data());

      filter_xnn = transposed_filter.data();
    }

    xnn_status status = xnn_create_convolution2d_nhwc_f32(
        pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,
        stride_height, stride_width, dilation_height, dilation_width, groups,
        group_input_channels, group_output_channels, input_pixel_stride,
        output_pixel_stride, filter_xnn, bias_buf, output_min, output_max,
        flags, &conv2d_op);
    if (status != xnn_status_success) {
      util::warn(
          "XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. "
          "Got status %d. Use -c dbg to see XNN logs.",
          status);
    }

    operator_cache.emplace(
        cache_key,
        // Move ownership of the transposed filter to the cache map.
        CachedInfo{conv2d_op, std::move(transposed_filter)});

    associate_tensor_with_key(filter_id, cache_key,
                              filter_operator_cache_key_map);
    if (bias_id != 0) {
      associate_tensor_with_key(bias_id, cache_key,
                                bias_operator_cache_key_map);
    }

    tfjs::backend::xnn_operator_count++;
  } else {
    conv2d_op = operator_cache_idx->second.op;
  }

  xnn_status status = xnn_setup_convolution2d_nhwc_f32(
      conv2d_op, batch_size, input_height, input_width, x_buf, out_buf,
      tfjs::backend::threadpool);
  if (status != xnn_status_success) {
    util::warn(
        "XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. "
        "Got status %d. Use -c dbg to see XNN logs.",
        status);
  }

  xnn_run_operator(conv2d_op, tfjs::backend::threadpool);

  if (activation == FusableActivation::PRELU) {
    prelu(out_buf, out_info.size, prelu_weights_id, out_id);
  }
  if (activation == FusableActivation::LEAKYRELU) {
    leakyrelu(out_buf, out_info.size, leakyrelu_alpha, out_id);
  }
  if (activation == FusableActivation::SIGMOID) {
    sigmoid(out_buf, out_info.size, out_id);
  }
  if (activation == FusableActivation::ELU) {
    elu(out_buf, out_info.size, out_id);
  }
}