in src/GroupwiseConv.cc [976:1272]
void fbgemmGroupwiseConv(
const conv_param_t<SPATIAL_DIM>& conv_param,
const uint8_t* activations,
int32_t a_zero_point,
int32_t* rowOffsetBuf,
packed_W& packed_weights,
outType* out,
int32_t* outBuffer,
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
int thread_id,
int num_threads) {
using processOutputType = ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>;
if (!cpuinfo_initialize()) {
throw runtime_error("Failed to initialize cpuinfo!");
}
int MB = conv_param.MB;
int OT = SPATIAL_DIM <= 2 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 3];
int OH = SPATIAL_DIM == 1 ? 1 : conv_param.OUT_DIM[SPATIAL_DIM - 2];
int OW = conv_param.OUT_DIM[SPATIAL_DIM - 1];
int T = SPATIAL_DIM <= 2 ? 1 : conv_param.K[SPATIAL_DIM - 3];
int R = SPATIAL_DIM == 1 ? 1 : conv_param.K[SPATIAL_DIM - 2];
int S = conv_param.K[SPATIAL_DIM - 1];
int G = conv_param.G;
int OC = conv_param.OC;
int IC = conv_param.IC;
int K_per_G = conv_param.OC / G;
int C_per_G = conv_param.IC / G;
int OH_OW = OH * OW;
int OT_OH_OW = OT * OH * OW;
int IT = SPATIAL_DIM <= 2 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 3];
int IH = SPATIAL_DIM == 1 ? 1 : conv_param.IN_DIM[SPATIAL_DIM - 2];
int IW = conv_param.IN_DIM[SPATIAL_DIM - 1];
int IH_IW = IH * IW;
int IT_IH_IW = IT * IH * IW;
int paddedCPerG = (C_per_G + 3) / 4 * 4;
bool b_symmetric = (Q_GRAN == QuantizationGranularity::TENSOR &&
outProcess.getBZeroPoint()[0] == 0) ||
rowOffsetBuf == nullptr;
int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
numOfGroupsTogether(conv_param);
if (SPATIAL_DIM == 1) {
throw std::runtime_error("Groupwise 1D not implemented!");
}
if (SPATIAL_DIM == 2) {
// Parallelization:
int64_t batch_start = 0;
int64_t batch_end = MB;
int64_t oh_start = 0;
int64_t oh_end = OH;
if (MB >= num_threads) {
fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
} else {
fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
}
if (batch_start >= batch_end || oh_start >= oh_end) {
// There is no work for this thread
return;
}
// generate convolution + rowOffset kernel
bool calculateRowOffset = !b_symmetric;
bool isTopEdgeIncluded = oh_start == 0;
bool isBottomEdgeIncluded = oh_end == OH;
bool isTopBottomEdgeSame =
isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
jit_conv_kernel_fp fpConv = getOrCreateConvKernel<SPATIAL_DIM>(
conv_param,
a_zero_point,
calculateRowOffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
false);
int ih_start = 0;
if (oh_start > 0) {
ih_start = -conv_param.pad[SPATIAL_DIM - 2] +
oh_start * conv_param.stride[SPATIAL_DIM - 2];
}
int32_t* out_start = outBuffer + oh_start * OW * OC;
const uint8_t* in_start = activations + ih_start * IW * IC;
int32_t* rowOffsetBuf_start =
rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
for (int i = batch_start; i < batch_end; ++i) {
const uint8_t* in_start_batch = in_start + i * IH_IW * conv_param.IC;
int32_t* out_start_batch = out_start + i * OH_OW * OC;
int32_t* rowOffsetBuf_start_batch =
rowOffsetBuf ? rowOffsetBuf_start + i * OH_OW * G_together : nullptr;
for (int g = 0; g < G; g += G_together) {
const uint8_t* in_start_group = in_start_batch + g * C_per_G;
int8_t* weight_start =
packed_weights.getBuf() + g * R * S * K_per_G * paddedCPerG;
int32_t* out_start_group = out_start_batch;
int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
// Uncomment the following two lines to stop
// reuse of output and rowoffset buffer
// out_start_group = out_start_batch + g * K_per_G;
// rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB * OH_OW;
// exactly the same compute as the JIT'ed below
// kernel_compute(
// conv_param,
// in_start_group,
// weight_start,
// out_start_group,
// a_zero_point,
// oh_start,
// oh_end,
// OW,
// rowOffsetBuf_start_group);
fpConv(
in_start_group,
weight_start,
out_start_group,
a_zero_point,
oh_start,
oh_end,
OW,
rowOffsetBuf_start_group);
const int32_t* inp = out_start_group;
block_type_t block{
static_cast<int>(i * OT_OH_OW + oh_start * OW),
static_cast<int>((oh_end - oh_start) * OW),
g * K_per_G,
G_together * K_per_G};
int ld_out = G * K_per_G;
int ld_in = G * K_per_G;
dispatchOutputProcessing(
outProcess,
rowOffsetBuf_start_group,
out,
inp,
block,
ld_out,
ld_in,
G,
C_per_G,
is_requantization<processOutputType>());
} // for each g
} // for each i
} else {
assert(SPATIAL_DIM == 3 && "Unsupported SPATIAL_DIM");
conv_param_t<> conv_p_2d(
conv_param.MB,
conv_param.IC,
conv_param.OC,
{conv_param.IN_DIM[SPATIAL_DIM - 2],
conv_param.IN_DIM[SPATIAL_DIM - 1]},
conv_param.G,
{conv_param.K[SPATIAL_DIM - 2], conv_param.K[SPATIAL_DIM - 1]},
{conv_param.stride[SPATIAL_DIM - 2],
conv_param.stride[SPATIAL_DIM - 1]},
{conv_param.pad[1],
conv_param.pad[2],
conv_param.pad[4],
conv_param.pad[5]});
// Parallelization:
int64_t batch_start = 0;
int64_t batch_end = MB;
int64_t oh_start = 0;
int64_t oh_end = OH;
if (MB >= num_threads) {
fbgemmPartition1D(thread_id, num_threads, MB, batch_start, batch_end);
} else {
fbgemmPartition1D(thread_id, num_threads, OH, oh_start, oh_end);
}
if (batch_start >= batch_end || oh_start >= oh_end) {
// There is no work for this thread
return;
}
// generate convolution + rowOffset kernel
bool calculateRowOffset = !b_symmetric;
bool isTopEdgeIncluded = oh_start == 0;
bool isBottomEdgeIncluded = oh_end == OH;
bool isTopBottomEdgeSame =
isTopEdgeIncluded && isBottomEdgeIncluded && oh_end == oh_start + 1;
jit_conv_kernel_fp fpConvNoAccum = getOrCreateConvKernel<2>(
conv_p_2d,
a_zero_point,
calculateRowOffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
false);
jit_conv_kernel_fp fpConvAccum = getOrCreateConvKernel<2>(
conv_p_2d,
a_zero_point,
calculateRowOffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
true);
jit_conv_kernel_fp fpConv;
int ih_start = 0;
if (oh_start > 0) {
ih_start = -conv_p_2d.pad[0] + oh_start * conv_p_2d.stride[0];
}
vector<uint8_t> zero_points(IH * IW * IC, a_zero_point);
int32_t* out_start = outBuffer + oh_start * OW * OC;
const uint8_t* in_start = activations + ih_start * IW * IC;
int32_t* rowOffsetBuf_start =
rowOffsetBuf ? rowOffsetBuf + oh_start * OW * G_together : nullptr;
for (int i = batch_start; i < batch_end; ++i) {
const uint8_t* in_start_batch = in_start + i * IT_IH_IW * IC;
int32_t* out_start_batch = out_start + i * OT_OH_OW * OC;
int32_t* rowOffsetBuf_start_batch = rowOffsetBuf
? rowOffsetBuf_start + i * OT_OH_OW * G_together
: nullptr;
for (int g = 0; g < G; g += G_together) {
const uint8_t* in_start_group = in_start_batch + g * C_per_G;
int8_t* weight_start =
packed_weights.getBuf() + g * T * R * S * K_per_G * paddedCPerG;
int32_t* out_start_group = out_start_batch;
int32_t* rowOffsetBuf_start_group = rowOffsetBuf_start_batch;
// Uncomment the following two lines to stop
// reuse of output and rowoffset buffer
// out_start_group = out_start_batch + g * K_per_G;
// rowOffsetBuf_start_group = rowOffsetBuf_start_batch + g * MB *
// OT_OH_OW;
for (int ot = 0; ot < OT; ++ot) {
int32_t* out_start_t = out_start_group + ot * OH_OW * OC;
int32_t* rowOffsetBuf_start_t = rowOffsetBuf
? rowOffsetBuf_start_group + ot * OH_OW * G_together
: nullptr;
for (int t = 0; t < T; ++t) {
int t_in = -conv_param.pad[0] + ot * conv_param.stride[0] + t;
const uint8_t* in_start_t = in_start_group + t_in * IH_IW * IC;
int8_t* weight_start_t =
weight_start + t * R * S * K_per_G * G_together * paddedCPerG;
if (t_in < 0 || t_in >= IT) {
in_start_t = zero_points.data();
}
// exactly the same compute as the JIT'ed below
// kernel_compute(
// conv_p_2d,
// in_start_t,
// weight_start_t,
// out_start_t,
// a_zero_point,
// oh_start,
// oh_end,
// OW,
// rowOffsetBuf_start_t,
// t > 0);
fpConv = t > 0 ? fpConvAccum : fpConvNoAccum;
fpConv(
in_start_t,
weight_start_t,
out_start_t,
a_zero_point,
oh_start,
oh_end,
OW,
rowOffsetBuf_start_t);
}
const int32_t* inp = out_start_t;
block_type_t block{
static_cast<int>(i * OT_OH_OW + oh_start * OW),
static_cast<int>((oh_end - oh_start) * OW),
g * K_per_G,
G_together * K_per_G};
int ld_out = G * K_per_G;
int ld_in = G * K_per_G;
dispatchOutputProcessing(
outProcess,
rowOffsetBuf_start_t,
out + ot * OH_OW * OC,
inp,
block,
ld_out,
ld_in,
G,
C_per_G,
is_requantization<processOutputType>());
} // for each ot
} // for each g
} // for each i
} // SPATIAL_DIM == 3
}