in src/RefImplementations.cc [929:1058]
FBGEMM_API void conv_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
const int8_t* B,
int32_t* C) {
// filters are assumed to be in G QRS C/G x K format
int IC = conv_p.IC;
int OC = conv_p.OC;
int G = conv_p.G;
assert(IC % G == 0);
assert(OC % G == 0);
array<int, 3> IN_DIM = conv_p.IN_DIM;
array<int, 3> OUT_DIM = conv_p.OUT_DIM;
array<int, 3> K = conv_p.K;
if (conv_p.transposed) {
// for ref implementation, there is no padding on the input buffer,
// padding specifies how much we remove from the output buffers
for (int n = 0; n < conv_p.MB; ++n) {
for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
// stride on output is fractional stride on input
// conv index is
// int t_in =
// -conv_p.pad[0] + t * conv_p.stride[0] + q *
// conv_p.dilation[0];
// int h_in =
// -conv_p.pad[1] + h * conv_p.stride[1] + r *
// conv_p.dilation[1];
// int w_in =
// -conv_p.pad[2] + w * conv_p.stride[2] + s *
// conv_p.dilation[2];
// so we reverse it
for (int g = 0; g < G; ++g) {
for (int oc = 0; oc < OC / G; ++oc) {
int sum = 0;
for (int q = 0; q < K[0]; ++q) {
for (int r = 0; r < K[1]; ++r) {
for (int s = 0; s < K[2]; ++s) {
int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
int t_in = t / conv_p.stride[0];
int h_in = h / conv_p.stride[1];
int w_in = w / conv_p.stride[2];
for (int ic = 0; ic < IC / G; ++ic) {
int a =
(t_in * conv_p.stride[0] == t && t_in >= 0 &&
t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
h_in >= 0 && h_in < IN_DIM[1] &&
w_in * conv_p.stride[2] == w && w_in >= 0 &&
w_in < IN_DIM[2])
? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
IN_DIM[2]) +
w_in) *
IC +
g * (IC / G) + ic]
: A_zero_point;
int b =
B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) *
(IC / G) +
ic) *
(OC / G)) +
oc]; // G Q R S Cin/G Cout/G after transpose
sum += a * b;
} // for each ic
} // for each s
} // for each r
} // for each q
C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] +
ow) *
OC +
g * (OC / G) + oc] = sum;
} // for each oc
} // for each g
} // for each ow
} // for each oh
} // for each ot
} // for each n
} else {
for (int n = 0; n < conv_p.MB; ++n) {
for (int t = 0; t < OUT_DIM[0]; ++t) {
for (int h = 0; h < OUT_DIM[1]; ++h) {
for (int w = 0; w < OUT_DIM[2]; ++w) {
for (int g = 0; g < G; ++g) {
for (int m = 0; m < OC / G; ++m) {
int sum = 0;
for (int q = 0; q < K[0]; ++q) {
int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
s * conv_p.dilation[2];
for (int c = 0; c < IC / G; ++c) {
int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 ||
w_in >= IN_DIM[2]
? A_zero_point
: A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
IN_DIM[2] +
w_in) *
IC +
g * (IC / G) + c];
int b =
B[((((g * K[0] + q) * K[1] + r) * K[2] + s) *
(IC / G) +
c) *
(OC / G) +
m];
sum += a * b;
} // for each c
} // for each s
} // for each r
} // for each q
C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) *
OC +
g * (OC / G) + m] = sum;
} // for each m
} // for each group
} // for each w
} // for each h
} // for each t
} // for each n
}
}