FBGEMM_API void im2col_ref()

in src/RefImplementations.cc [599:747]


FBGEMM_API void im2col_ref(
    const conv_param_t<3>& conv_p,
    const uint8_t* A,
    int32_t A_zero_point,
    uint8_t* Ao) {
  int IC = conv_p.IC;
  int G = conv_p.G;
  assert(IC % G == 0);
  array<int, 3> IN_DIM = conv_p.IN_DIM;
  array<int, 3> OUT_DIM = conv_p.OUT_DIM;
  array<int, 3> K = conv_p.K;

  if (conv_p.transposed) {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
        for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
          for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
            for (int q = 0; q < K[0]; ++q) {
              for (int r = 0; r < K[1]; ++r) {
                for (int s = 0; s < K[2]; ++s) {
                  int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
                  int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
                  int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
                  int t_in = t / conv_p.stride[0];
                  int h_in = h / conv_p.stride[1];
                  int w_in = w / conv_p.stride[2];
                  if (t_in * conv_p.stride[0] == t && t_in >= 0 &&
                      t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
                      h_in >= 0 && h_in < IN_DIM[1] &&
                      w_in * conv_p.stride[2] == w && w_in >= 0 &&
                      w_in < IN_DIM[2]) {
                    for (int g = 0; g < G; ++g) {
                      memcpy(
                          Ao +
                              (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
                                       OUT_DIM[2] +
                                   ow) *
                                      G +
                                  g) *
                                     K[0] +
                                 q) *
                                    K[1] +
                                r) *
                                   K[2] +
                               s) *
                                  (IC / G),
                          A +
                              (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
                                   IN_DIM[2] +
                               w_in) *
                                  IC +
                              g * (IC / G),
                          sizeof(uint8_t) * (IC / G));
                    }
                  } else {
                    for (int g = 0; g < G; ++g) {
                      memset(
                          Ao +
                              (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
                                       OUT_DIM[2] +
                                   ow) *
                                      G +
                                  g) *
                                     K[0] +
                                 q) *
                                    K[1] +
                                r) *
                                   K[2] +
                               s) *
                                  (IC / G),
                          A_zero_point,
                          sizeof(uint8_t) * (IC / G));
                    }
                  }
                } // for each s
              } // for each r
            } // for each q
          } // for each ow
        } // for each oh
      } // for each ot
    } // for each n
  } else {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int t = 0; t < OUT_DIM[0]; ++t) {
        for (int h = 0; h < OUT_DIM[1]; ++h) {
          for (int w = 0; w < OUT_DIM[2]; ++w) {
            for (int q = 0; q < K[0]; ++q) {
              int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
                  q * conv_p.dilation[0];
              for (int r = 0; r < K[1]; ++r) {
                int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
                    r * conv_p.dilation[1];
                for (int s = 0; s < K[2]; ++s) {
                  int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
                      s * conv_p.dilation[2];
                  if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
                      h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
                    for (int g = 0; g < G; ++g) {
                      memset(
                          Ao +
                              (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
                                       OUT_DIM[2] +
                                   w) *
                                      G +
                                  g) *
                                     K[0] +
                                 q) *
                                    K[1] +
                                r) *
                                   K[2] +
                               s) *
                                  (IC / G),
                          A_zero_point,
                          sizeof(uint8_t) * (IC / G));
                    }
                  } else {
                    for (int g = 0; g < G; ++g) {
                      memcpy(
                          Ao +
                              (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
                                       OUT_DIM[2] +
                                   w) *
                                      G +
                                  g) *
                                     K[0] +
                                 q) *
                                    K[1] +
                                r) *
                                   K[2] +
                               s) *
                                  (IC / G),
                          A +
                              (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
                                   IN_DIM[2] +
                               w_in) *
                                  IC +
                              g * (IC / G),
                          sizeof(uint8_t) * (IC / G));
                    }
                  }
                } // for each s
              } // for each r
            } // for each q
          } // for each w
        } // for each h
      } // for each t
    } // for each n
  }
}