in src/runtime/hexagon/ops/conv2d_fp16_hvx.cc [181:401]
void conv_layer_fp16_hvx(DLTensor& cr_out, const DLTensor& cr_act, // NOLINT(*)
const DLTensor& cr_filt, const DLTensor& out_shape,
const DLTensor& act_shape, const DLTensor& bias_flat,
const DLTensor& filt_shape, const DLTensor& pad_shape, bool relu,
int stride_h, int stride_w, uintptr_t zero_block) {
int64_t filt_height = filt_shape.shape[0];
int64_t filt_width = filt_shape.shape[1];
int64_t filt_idepth = filt_shape.shape[2];
int pad_top = pad_shape.shape[0];
int pad_left = pad_shape.shape[1];
LOG_INFO << "filt_height=" << filt_height << ", filt_width=" << filt_width
<< ", filt_idepth=" << filt_idepth << ", pad_top=" << pad_top
<< ", pad_left=" << pad_left << "\n";
ICHECK_LT(pad_top, 8) << "pad_top offset cannot be >= 8";
ICHECK_LT(pad_left, 4) << "pad_left offset cannot be >= 4";
int a_height = cr_act.shape[1];
int a_width = cr_act.shape[2];
int a_depth = cr_act.shape[3];
int w_height = cr_filt.shape[0];
int w_width = cr_filt.shape[1];
int o_depth = cr_out.shape[3];
int b_depth = bias_flat.shape[0];
int o_height = cr_out.shape[1];
int o_width = cr_out.shape[2];
int out_height = out_shape.shape[1];
int out_width = out_shape.shape[2];
LOG_INFO << "a: 1x" << a_height << "x" << a_width << "x" << a_depth << ", w: " << w_height << "x"
<< w_width << "x" << static_cast<int>(cr_filt.shape[2]) << "x"
<< static_cast<int>(cr_filt.shape[3]) << ", o: 1x" << o_height << "x" << o_width << "x"
<< o_depth << ", b: " << b_depth << ", out_shape: " << out_height << "x" << out_width
<< "\n";
ICHECK_EQ(a_depth, cr_filt.shape[2]) << "input depth should match weights input channels";
ICHECK_EQ(o_depth, cr_filt.shape[3]) << "output depth should match the weights output channel";
int rd = round_down(filt_width, 4);
int wgt_chunk_thin_width = filt_width - rd;
/*
* Compute the output vector of either 1 or 2 elements along the width and max 32 elements along
* the depth to constitue a maximum of 64 elements
*
* The weights are loaded directly in the order they're stored, which results
* in 2 input channels and 32 output channels
*
* Weights vector illustration:
* ------- ------ ------------
* weights_vec = [0-0,0-1,1-0,1-1,2-0,2-1,3-0,3-1,4-0,4-1,...,31-0,31-1] -> This is the
* vector representation of weights, where the elements are represented as
* "out_channel-input_channel"
*
*
* Same 2 input channels have to be multiplied across all output channels in the weights.
*
* Activations vector would thus be:
* ----------- ------ ----- ---- --
* act_vec = [i0,i1,i0,i1,i0,i1,...,i0,i1] - 2 elements of the input channels broadcasted 32 times
* to fill 64 elements of the vector
*
*
* Thus the computation is just a vmpy(act_vec,weights_vec) followed by a some rearrangement to
* add every pair of 16b lanes in the vector to reduce along the input channels
*
* This result is added to the result of the next pair of input channels all the way until we
* have reduced across the entire input channels.
*
* Then the same vector is added to the results of the following elements along the width and
* height to finally get 32 elements representing 32 output channels.
*
* Since the output block also has the 8h2w32c2w format, the 32 elements of the next element
* along the width is also added into the same vector such that the first 32 channel elements
* occupy the even lanes and the next 32 occupy the odd lanes to form a single 64-element vector
* which is then stored
*/
auto computeConv = [filt_height, filt_width, wgt_chunk_thin_width, filt_idepth, stride_h,
stride_w, &cr_out, &cr_act, &cr_filt](int out_act_y, int out_act_x, int out_c,
int h, int wo, bool skip_wi_1 = false) {
auto out_element_ptr = getElementPtr(out_act_y, out_act_x, out_c, h, wo, 0, 0, cr_out);
LOG_INFO << "out_act_y: " << out_act_y << ", out_act_x: " << out_act_x << ", out_c: " << out_c
<< ", h: " << h << ", wo: " << wo << " out_element_ptr: " << out_element_ptr;
HVX_Vector* out_vector = reinterpret_cast<HVX_Vector*>(out_element_ptr);
HVX_Vector existing_out_vec = Q6_V_vzero();
for (int fh = 0; fh < filt_height; ++fh) {
for (int fw = 0; fw < filt_width; ++fw) {
int fch = fh / 8;
int fcw = 0;
if (fw >= wgt_chunk_thin_width) {
fcw = (fw - wgt_chunk_thin_width) / 4 + 1;
}
int fx = (fw < wgt_chunk_thin_width) ? fw : ((fw - wgt_chunk_thin_width) % 4);
int fy = fh % 8;
for (int c = 0; c < conv_utils::round_up(filt_idepth, 2); c += 2) {
int out_act_cc = c / 32;
int ci = c % 32;
auto wgt_chunk = conv_utils::hwio_at(cr_filt, fch, fcw, out_act_cc, out_c);
// Find weight chunk offset ptr
int max_x = (fcw == 0) ? wgt_chunk_thin_width : 4;
int wi = 0;
int out_width_idx = out_act_x * 4 + wo * 2 + wi;
int act_width_access_idx = out_width_idx * stride_w + fw;
int true_out_act_x = act_width_access_idx / 4;
int true_wo = (act_width_access_idx % 4) / 2;
int true_wi = act_width_access_idx % 2;
int out_height_idx = out_act_y * 8 + h;
int act_height_access_idx = out_height_idx * stride_h + fh;
int true_out_act_y = act_height_access_idx / 8;
int true_h = act_height_access_idx % 8;
int act_channel_idx = out_act_cc * 32 + ci;
auto act_element_ptr = getElementPtr(true_out_act_y, true_out_act_x, out_act_cc, true_h,
true_wo, ci, true_wi, cr_act);
HVX_Vector act_vec = getInputVector(act_element_ptr);
auto wgt_chunk_offset = conv_utils::hwio_to_sm_16b(max_x, fy, fx, ci, 0);
auto base_chunk_ptr = reinterpret_cast<uint16_t*>(wgt_chunk);
auto chunk_ptr = base_chunk_ptr + wgt_chunk_offset;
LOG_INFO << "act: 0x" << act_height_access_idx << "x" << act_width_access_idx << "x"
<< act_channel_idx << ", wgt: " << fh << "x" << fw << "x" << act_channel_idx
<< "x" << out_c * 32 << ", out: 0x" << out_height_idx << "x" << out_width_idx
<< "x" << out_c * 32 << ", wgt_chunk_offset: " << wgt_chunk_offset;
const HVX_Vector* weights_vec_ptr = reinterpret_cast<const HVX_Vector*>(chunk_ptr);
HVX_Vector weights_vec = *weights_vec_ptr;
HVX_Vector reduced_vec_even_elements = computeOuputVector(act_vec, weights_vec);
if (!skip_wi_1) {
wi = 1;
out_width_idx = out_act_x * 4 + wo * 2 + wi;
act_width_access_idx = out_width_idx * stride_w + fw;
true_out_act_x = act_width_access_idx / 4;
true_wo = (act_width_access_idx % 4) / 2;
true_wi = act_width_access_idx % 2;
act_element_ptr = getElementPtr(true_out_act_y, true_out_act_x, out_act_cc, true_h,
true_wo, ci, true_wi, cr_act);
act_vec = getInputVector(act_element_ptr);
LOG_INFO << "act: 0x" << act_height_access_idx << "x" << act_width_access_idx << "x"
<< act_channel_idx << ", wgt: " << fh << "x" << fw << "x" << act_channel_idx
<< "x" << out_c * 32 << ", out: 0x" << out_height_idx << "x" << out_width_idx
<< "x" << out_c * 32 << ", wgt_chunk_offset: " << wgt_chunk_offset;
HVX_Vector reduced_vec_odd_elements = computeOuputVector(act_vec, weights_vec);
reduced_vec_odd_elements = Q6_V_vror_VR(reduced_vec_odd_elements, -2);
HVX_Vector out_final = Q6_V_vor_VV(reduced_vec_even_elements, reduced_vec_odd_elements);
HVX_Vector out_vec_qf16 = Q6_Vqf16_vadd_VhfVhf(out_final, existing_out_vec);
existing_out_vec = Q6_Vhf_equals_Vqf16(out_vec_qf16);
} else {
HVX_Vector out_vec_qf16 =
Q6_Vqf16_vadd_VhfVhf(reduced_vec_even_elements, existing_out_vec);
existing_out_vec = Q6_Vhf_equals_Vqf16(out_vec_qf16);
}
}
}
}
*out_vector = existing_out_vec;
};
auto computeFullWidth = [&computeConv](int out_y, int out_x, int out_c, int h) {
for (int wo = 0; wo < 2; ++wo) {
computeConv(out_y, out_x, out_c, h, wo);
}
};
auto computePartialWidth = [out_width, o_width, &computeConv](int out_y, int out_c, int h) {
int out_x = o_width - 1;
int wo = 0;
for (; wo < (out_width % 4) / 2; ++wo) {
computeConv(out_y, out_x, out_c, h, wo);
}
if (out_width % 2) {
computeConv(out_y, out_x, out_c, h, wo, true /* skip_wi_1 */);
}
};
for (int out_c = 0; out_c < cr_filt.shape[3]; ++out_c) {
for (int out_act_y = 0; out_act_y < out_height / 8; ++out_act_y) {
int out_y = out_act_y;
for (int out_act_x = 0; out_act_x < out_width / 4; ++out_act_x) {
int out_x = out_act_x;
for (int h = 0; h < 8; ++h) {
computeFullWidth(out_y, out_x, out_c, h);
}
}
for (int h = 0; h < 8; ++h) {
computePartialWidth(out_y, out_c, h);
}
}
int out_y = o_height - 1;
for (int h = 0; h < out_height % 8; ++h) {
for (int out_act_x = 0; out_act_x < out_width / 4; ++out_act_x) {
int out_x = out_act_x;
computeFullWidth(out_y, out_x, out_c, h);
}
computePartialWidth(out_y, out_c, h);
}
}
}