in src/RefImplementations.cc [830:925]
FBGEMM_API void conv_ref(
const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
const int8_t* B,
int32_t* C) {
// filters are assumed to be in G RS C/G x K format
int IC = conv_p.IC;
int OC = conv_p.OC;
int G = conv_p.G;
assert(IC % G == 0);
assert(OC % G == 0);
array<int, 2> IN_DIM = conv_p.IN_DIM;
array<int, 2> OUT_DIM = conv_p.OUT_DIM;
array<int, 2> K = conv_p.K;
if (conv_p.transposed) {
// for ref implementation, there is no padding on the input buffer,
// padding specifies how much we remove from the output buffers
for (int n = 0; n < conv_p.MB; ++n) {
for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
// stride on output is fractional stride on input
// conv index is
// int h_in =
// -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
// int w_in =
// -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
// so we reverse it
for (int g = 0; g < G; ++g) {
for (int oc = 0; oc < OC / G; ++oc) {
int sum = 0;
for (int r = 0; r < K[0]; ++r) {
for (int s = 0; s < K[1]; ++s) {
int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
int h_in = h / conv_p.stride[0];
int w_in = w / conv_p.stride[1];
for (int ic = 0; ic < IC / G; ++ic) {
int a = (h_in * conv_p.stride[0] == h && h_in >= 0 &&
h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
w_in >= 0 && w_in < IN_DIM[1])
? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
g * (IC / G) + ic]
: A_zero_point;
int b =
B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC /
G) +
oc]; // G R S IC OC after transpose
sum += a * b;
} // for each ic
} // for each s
} // for each r
C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) +
oc] = sum;
} // for each oc
} // for each g
} // for each w
} // for each h
} // for each n
} else {
for (int n = 0; n < conv_p.MB; ++n) {
for (int h = 0; h < OUT_DIM[0]; ++h) {
for (int w = 0; w < OUT_DIM[1]; ++w) {
for (int g = 0; g < G; ++g) {
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int r = 0; r < K[0]; ++r) {
int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
s * conv_p.dilation[1];
for (int c = 0; c < IC / G; ++c) {
int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]
? A_zero_point
: A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
g * (IC / G) + c];
int b =
B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) *
(OC / G) +
m];
sum += a * b;
} // for each c
} // for each s
} // for each r
C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) +
m] = sum;
} // for each m
} // for each group
} // for each w
} // for each h
} // for each n
}
}