in src/RefImplementations.cc [486:590]
FBGEMM_API void im2col_ref(
const conv_param_t<2>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
uint8_t* Ao) {
int IC = conv_p.IC;
int G = conv_p.G;
assert(IC % G == 0);
array<int, 2> IN_DIM = conv_p.IN_DIM;
array<int, 2> OUT_DIM = conv_p.OUT_DIM;
array<int, 2> K = conv_p.K;
if (conv_p.transposed) {
for (int n = 0; n < conv_p.MB; ++n) {
for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
for (int r = 0; r < K[0]; ++r) {
for (int s = 0; s < K[1]; ++s) {
int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
int h_in = h / conv_p.stride[0];
int w_in = w / conv_p.stride[1];
if (h_in * conv_p.stride[0] == h && h_in >= 0 &&
h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
w_in >= 0 && w_in < IN_DIM[1]) {
for (int g = 0; g < G; ++g) {
memcpy(
Ao +
(((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
g) *
K[0] +
r) *
K[1] +
s) *
(IC / G),
A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
g * (IC / G),
sizeof(uint8_t) * (IC / G));
}
} else {
for (int g = 0; g < G; ++g) {
memset(
Ao +
(((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
g) *
K[0] +
r) *
K[1] +
s) *
(IC / G),
A_zero_point,
sizeof(uint8_t) * (IC / G));
}
}
} // for each s
} // for each r
} // for each ow
} // for each oh
} // for each n
} else {
for (int n = 0; n < conv_p.MB; ++n) {
for (int h = 0; h < OUT_DIM[0]; ++h) {
for (int w = 0; w < OUT_DIM[1]; ++w) {
for (int r = 0; r < K[0]; ++r) {
int h_in =
-conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
for (int s = 0; s < K[1]; ++s) {
int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
s * conv_p.dilation[1];
if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
w_in >= IN_DIM[1]) {
for (int g = 0; g < G; ++g) {
memset(
Ao +
(((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
K[0] +
r) *
K[1] +
s) *
(IC / G),
A_zero_point,
sizeof(uint8_t) * (IC / G));
}
} else {
for (int g = 0; g < G; ++g) {
memcpy(
Ao +
(((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
K[0] +
r) *
K[1] +
s) *
(IC / G),
A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
g * (IC / G),
sizeof(uint8_t) * (IC / G));
}
}
} // for each s
} // for each r
} // for each w
} // for each h
} // for each n
}
}