in src/RefImplementations.cc [599:747]
FBGEMM_API void im2col_ref(
const conv_param_t<3>& conv_p,
const uint8_t* A,
int32_t A_zero_point,
uint8_t* Ao) {
int IC = conv_p.IC;
int G = conv_p.G;
assert(IC % G == 0);
array<int, 3> IN_DIM = conv_p.IN_DIM;
array<int, 3> OUT_DIM = conv_p.OUT_DIM;
array<int, 3> K = conv_p.K;
if (conv_p.transposed) {
for (int n = 0; n < conv_p.MB; ++n) {
for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
for (int q = 0; q < K[0]; ++q) {
for (int r = 0; r < K[1]; ++r) {
for (int s = 0; s < K[2]; ++s) {
int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
int t_in = t / conv_p.stride[0];
int h_in = h / conv_p.stride[1];
int w_in = w / conv_p.stride[2];
if (t_in * conv_p.stride[0] == t && t_in >= 0 &&
t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
h_in >= 0 && h_in < IN_DIM[1] &&
w_in * conv_p.stride[2] == w && w_in >= 0 &&
w_in < IN_DIM[2]) {
for (int g = 0; g < G; ++g) {
memcpy(
Ao +
(((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
OUT_DIM[2] +
ow) *
G +
g) *
K[0] +
q) *
K[1] +
r) *
K[2] +
s) *
(IC / G),
A +
(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
IN_DIM[2] +
w_in) *
IC +
g * (IC / G),
sizeof(uint8_t) * (IC / G));
}
} else {
for (int g = 0; g < G; ++g) {
memset(
Ao +
(((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
OUT_DIM[2] +
ow) *
G +
g) *
K[0] +
q) *
K[1] +
r) *
K[2] +
s) *
(IC / G),
A_zero_point,
sizeof(uint8_t) * (IC / G));
}
}
} // for each s
} // for each r
} // for each q
} // for each ow
} // for each oh
} // for each ot
} // for each n
} else {
for (int n = 0; n < conv_p.MB; ++n) {
for (int t = 0; t < OUT_DIM[0]; ++t) {
for (int h = 0; h < OUT_DIM[1]; ++h) {
for (int w = 0; w < OUT_DIM[2]; ++w) {
for (int q = 0; q < K[0]; ++q) {
int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
q * conv_p.dilation[0];
for (int r = 0; r < K[1]; ++r) {
int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
r * conv_p.dilation[1];
for (int s = 0; s < K[2]; ++s) {
int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
s * conv_p.dilation[2];
if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
for (int g = 0; g < G; ++g) {
memset(
Ao +
(((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
OUT_DIM[2] +
w) *
G +
g) *
K[0] +
q) *
K[1] +
r) *
K[2] +
s) *
(IC / G),
A_zero_point,
sizeof(uint8_t) * (IC / G));
}
} else {
for (int g = 0; g < G; ++g) {
memcpy(
Ao +
(((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
OUT_DIM[2] +
w) *
G +
g) *
K[0] +
q) *
K[1] +
r) *
K[2] +
s) *
(IC / G),
A +
(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
IN_DIM[2] +
w_in) *
IC +
g * (IC / G),
sizeof(uint8_t) * (IC / G));
}
}
} // for each s
} // for each r
} // for each q
} // for each w
} // for each h
} // for each t
} // for each n
}
}