in tfjs-backend-wasm/src/cc/conv2d_impl.cc [114:288]
void conv2d(const size_t x_id, const size_t batch_size,
const size_t input_height, const size_t input_width,
const size_t filter_id, const size_t filter_height,
const size_t filter_width, const size_t bias_id, size_t pad_top,
size_t pad_right, size_t pad_bottom, size_t pad_left,
const bool is_same_pad, const size_t dilation_height,
const size_t dilation_width, const size_t stride_height,
const size_t stride_width, const size_t input_channels,
const size_t output_channels, const bool is_depthwise,
const FusableActivation activation, const size_t prelu_weights_id,
const float leakyrelu_alpha, const size_t out_id) {
auto& x_info = backend::get_tensor_info(x_id);
auto& filter_info = backend::get_tensor_info(filter_id);
auto& out_info = backend::get_tensor_info_out(out_id);
const float* x_buf = x_info.f32();
const float* filter_buf = filter_info.f32();
const float* bias_buf = nullptr;
if (bias_id != 0) {
bias_buf = backend::get_tensor_info_out(bias_id).f32();
}
float* out_buf = out_info.f32_write();
std::vector<float> intermediate_output;
if (prelu_weights_id != 0 || activation == FusableActivation::LEAKYRELU) {
intermediate_output.resize(out_info.size);
out_buf = intermediate_output.data();
}
xnn_operator_t conv2d_op = nullptr;
size_t flags = 0;
if (is_same_pad) {
pad_top = 0, pad_right = 0, pad_bottom = 0, pad_left = 0;
flags |= XNN_FLAG_TENSORFLOW_SAME_PADDING;
}
size_t groups;
size_t group_input_channels;
size_t group_output_channels;
const size_t input_pixel_stride = input_channels;
const size_t output_pixel_stride = output_channels;
if (is_depthwise) {
groups = input_channels;
group_input_channels = 1;
group_output_channels = output_channels / input_channels;
flags |= XNN_FLAG_DEPTHWISE_CONVOLUTION;
} else {
groups = 1;
group_input_channels = input_channels;
group_output_channels = output_channels;
}
FusableActivation clamp_method = activation;
if (activation == FusableActivation::PRELU ||
activation == FusableActivation::LEAKYRELU) {
clamp_method = FusableActivation::LINEAR;
}
float output_min = -std::numeric_limits<float>::infinity();
float output_max = std::numeric_limits<float>::infinity();
if (activation == FusableActivation::RELU) {
output_min = 0;
} else if (activation == FusableActivation::RELU6) {
output_min = 0;
output_max = 6;
}
OperatorCacheKey cache_key = {pad_top,
pad_right,
pad_bottom,
pad_left,
filter_height,
filter_width,
stride_height,
stride_width,
dilation_height,
dilation_width,
groups,
group_input_channels,
group_output_channels,
input_pixel_stride,
output_pixel_stride,
clamp_method,
filter_id,
bias_id,
flags,
output_min,
output_max};
auto operator_cache_idx = operator_cache.find(cache_key);
if (operator_cache_idx == operator_cache.end()) {
// This lives outside the if statement so the data survives the scope.
std::vector<float> transposed_filter;
const float* filter_xnn;
if (is_depthwise) {
// For depthwiseConv2d, xnn pack and TensorFlow expect the same weights
// layout:
// [filter_height, filter_width, input_channels, channel_multiplier]
filter_xnn = filter_buf;
} else {
// For regular conv2d, xnn pack expects weights layed out like:
// [output_channels, filter_height, filter_width, input_channels]
// TensorFlow has weights layed out like:
// [filter_height, filter_width, input_channels, output_channels]
// This can be transposed with a 2d transpose to move output_channels to
// the outer most dimension.
transposed_filter.resize(filter_info.size);
std::vector<size_t> filter_shape = {
filter_height * filter_width * input_channels, output_channels};
std::vector<size_t> perm = {1, 0};
transpose(filter_buf, filter_shape, perm, transposed_filter.data());
filter_xnn = transposed_filter.data();
}
xnn_status status = xnn_create_convolution2d_nhwc_f32(
pad_top, pad_right, pad_bottom, pad_left, filter_height, filter_width,
stride_height, stride_width, dilation_height, dilation_width, groups,
group_input_channels, group_output_channels, input_pixel_stride,
output_pixel_stride, filter_xnn, bias_buf, output_min, output_max,
flags, &conv2d_op);
if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_create_convolution2d_nhwc_f32 is not successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
}
operator_cache.emplace(
cache_key,
// Move ownership of the transposed filter to the cache map.
CachedInfo{conv2d_op, std::move(transposed_filter)});
associate_tensor_with_key(filter_id, cache_key,
filter_operator_cache_key_map);
if (bias_id != 0) {
associate_tensor_with_key(bias_id, cache_key,
bias_operator_cache_key_map);
}
tfjs::backend::xnn_operator_count++;
} else {
conv2d_op = operator_cache_idx->second.op;
}
xnn_status status = xnn_setup_convolution2d_nhwc_f32(
conv2d_op, batch_size, input_height, input_width, x_buf, out_buf,
tfjs::backend::threadpool);
if (status != xnn_status_success) {
util::warn(
"XNN status for xnn_setup_convolution2d_nhwc_f32 is not successful. "
"Got status %d. Use -c dbg to see XNN logs.",
status);
}
xnn_run_operator(conv2d_op, tfjs::backend::threadpool);
if (activation == FusableActivation::PRELU) {
prelu(out_buf, out_info.size, prelu_weights_id, out_id);
}
if (activation == FusableActivation::LEAKYRELU) {
leakyrelu(out_buf, out_info.size, leakyrelu_alpha, out_id);
}
if (activation == FusableActivation::SIGMOID) {
sigmoid(out_buf, out_info.size, out_id);
}
if (activation == FusableActivation::ELU) {
elu(out_buf, out_info.size, out_id);
}
}