FBGEMM_API void im2col_ref()

in src/RefImplementations.cc [486:590]


FBGEMM_API void im2col_ref(
    const conv_param_t<2>& conv_p,
    const uint8_t* A,
    int32_t A_zero_point,
    uint8_t* Ao) {
  int IC = conv_p.IC;
  int G = conv_p.G;
  assert(IC % G == 0);
  array<int, 2> IN_DIM = conv_p.IN_DIM;
  array<int, 2> OUT_DIM = conv_p.OUT_DIM;
  array<int, 2> K = conv_p.K;

  if (conv_p.transposed) {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
        for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
          for (int r = 0; r < K[0]; ++r) {
            for (int s = 0; s < K[1]; ++s) {
              int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
              int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
              int h_in = h / conv_p.stride[0];
              int w_in = w / conv_p.stride[1];
              if (h_in * conv_p.stride[0] == h && h_in >= 0 &&
                  h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
                  w_in >= 0 && w_in < IN_DIM[1]) {
                for (int g = 0; g < G; ++g) {
                  memcpy(
                      Ao +
                          (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
                             g) *
                                K[0] +
                            r) *
                               K[1] +
                           s) *
                              (IC / G),
                      A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
                          g * (IC / G),
                      sizeof(uint8_t) * (IC / G));
                }
              } else {
                for (int g = 0; g < G; ++g) {
                  memset(
                      Ao +
                          (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
                             g) *
                                K[0] +
                            r) *
                               K[1] +
                           s) *
                              (IC / G),
                      A_zero_point,
                      sizeof(uint8_t) * (IC / G));
                }
              }
            } // for each s
          } // for each r
        } // for each ow
      } // for each oh
    } // for each n
  } else {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int h = 0; h < OUT_DIM[0]; ++h) {
        for (int w = 0; w < OUT_DIM[1]; ++w) {
          for (int r = 0; r < K[0]; ++r) {
            int h_in =
                -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
            for (int s = 0; s < K[1]; ++s) {
              int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
                  s * conv_p.dilation[1];
              if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
                  w_in >= IN_DIM[1]) {
                for (int g = 0; g < G; ++g) {
                  memset(
                      Ao +
                          (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
                                K[0] +
                            r) *
                               K[1] +
                           s) *
                              (IC / G),
                      A_zero_point,
                      sizeof(uint8_t) * (IC / G));
                }
              } else {
                for (int g = 0; g < G; ++g) {
                  memcpy(
                      Ao +
                          (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
                                K[0] +
                            r) *
                               K[1] +
                           s) *
                              (IC / G),
                      A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
                          g * (IC / G),
                      sizeof(uint8_t) * (IC / G));
                }
              }
            } // for each s
          } // for each r
        } // for each w
      } // for each h
    } // for each n
  }
}