in src/PackWeightMatrixForGConv.cc [159:242]
void PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>::pack_unpack_(
const T* src,
T* dst,
bool ispack) {
// Can't use T as varname because T is a template parameter.
int F = SPATIAL_DIM <= 2 ? 1 : conv_param_.K[SPATIAL_DIM - 3];
int R = SPATIAL_DIM == 1 ? 1 : conv_param_.K[SPATIAL_DIM - 2];
int S = conv_param_.K[SPATIAL_DIM - 1];
int G = conv_param_.G;
int IC_per_G = conv_param_.IC / G;
int OC_per_G = conv_param_.OC / G;
int paddedICPerG = (IC_per_G + 3) / 4 * 4;
// If transpose option is set, the weight matrix is in layout G K/G (T R S
// C/G) instead of G (T R S C/G) K/G
bool tr = (trans_ == matrix_op_t::Transpose);
if (fbgemmOptimizedGConv(conv_param_)) {
// currently only this case is supported
for (int t = 0; t < F; ++t) {
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
for (int g = 0; g < G; ++g) {
for (int c = 0; c < IC_per_G; ++c) {
int p_idx = packed_index_(t, r, s, k, g, c);
int up_idx = unpacked_index_(t, r, s, k, g, c, tr);
// Pack: src (unpacked) -> dst (packed)
if (ispack) {
dst[p_idx] = src[up_idx];
} else {
dst[up_idx] = src[p_idx];
}
}
if (ispack) {
for (int c = IC_per_G; c < paddedICPerG; ++c) {
int p_idx = packed_index_(t, r, s, k, g, c);
dst[p_idx] = 0;
}
}
}
}
}
}
}
} else {
// For pack & transposed, call transposeConvWeights()
// G K/G (T R S C/G) => G (T R S C/G) K/G
if (tr) {
if (ispack) {
transposeConvWeights(conv_param_, src, dst);
} else {
// TODO: Wrap this as a inverseTransposeConvWeights()?
// For unpack & transposed, call transposeConvWeights()
// G (T R S C/G) K/G => G K/G (T R S C/G)
for (int t = 0; t < F; ++t) {
for (int r = 0; r < R; ++r) {
for (int s = 0; s < S; ++s) {
for (int k = 0; k < OC_per_G; ++k) {
for (int g = 0; g < G; ++g) {
for (int c = 0; c < IC_per_G; ++c) {
dst[((((g * OC_per_G + k) * F + t) * R + r) * S + s) *
IC_per_G +
c] =
src[((((g * F + t) * R + r) * S + s) * IC_per_G + c) *
OC_per_G +
k];
}
}
}
}
}
}
} // end if(ispack)
} else {
// just copy the data for not supported cases
int kernel_prod = std::accumulate(
conv_param_.K.begin(),
conv_param_.K.end(),
1,
std::multiplies<int>());
memcpy(dst, src, G * kernel_prod * OC_per_G * IC_per_G * sizeof(inpType));
} // end if(tr)
} // end if(fbgemmOptimizedGConv(conv_param_)
}