FBGEMM_API void conv_ref()

in src/RefImplementations.cc [830:925]


FBGEMM_API void conv_ref(
    const conv_param_t<2>& conv_p,
    const uint8_t* A,
    int32_t A_zero_point,
    const int8_t* B,
    int32_t* C) {
  // filters are assumed to be in G RS C/G x K format
  int IC = conv_p.IC;
  int OC = conv_p.OC;
  int G = conv_p.G;
  assert(IC % G == 0);
  assert(OC % G == 0);
  array<int, 2> IN_DIM = conv_p.IN_DIM;
  array<int, 2> OUT_DIM = conv_p.OUT_DIM;
  array<int, 2> K = conv_p.K;

  if (conv_p.transposed) {
    // for ref implementation, there is no padding on the input buffer,
    // padding specifies how much we remove from the output buffers
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
        for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
          // stride on output is fractional stride on input
          // conv index is
          // int h_in =
          //     -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
          // int w_in =
          //     -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
          // so we reverse it
          for (int g = 0; g < G; ++g) {
            for (int oc = 0; oc < OC / G; ++oc) {
              int sum = 0;
              for (int r = 0; r < K[0]; ++r) {
                for (int s = 0; s < K[1]; ++s) {
                  int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
                  int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
                  int h_in = h / conv_p.stride[0];
                  int w_in = w / conv_p.stride[1];
                  for (int ic = 0; ic < IC / G; ++ic) {
                    int a = (h_in * conv_p.stride[0] == h && h_in >= 0 &&
                             h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
                             w_in >= 0 && w_in < IN_DIM[1])
                        ? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
                            g * (IC / G) + ic]
                        : A_zero_point;
                    int b =
                        B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC /
                           G) +
                          oc]; // G R S IC OC after  transpose
                    sum += a * b;
                  } // for each ic
                } // for each s
              } // for each r
              C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) +
                oc] = sum;
            } // for each oc
          } // for each g
        } // for each w
      } // for each h
    } // for each n
  } else {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int h = 0; h < OUT_DIM[0]; ++h) {
        for (int w = 0; w < OUT_DIM[1]; ++w) {
          for (int g = 0; g < G; ++g) {
            for (int m = 0; m < OC / G; ++m) {
              int sum = 0;
              for (int r = 0; r < K[0]; ++r) {
                int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
                    r * conv_p.dilation[0];
                for (int s = 0; s < K[1]; ++s) {
                  int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
                      s * conv_p.dilation[1];
                  for (int c = 0; c < IC / G; ++c) {
                    int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
                            w_in >= IN_DIM[1]
                        ? A_zero_point
                        : A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
                            g * (IC / G) + c];
                    int b =
                        B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) *
                              (OC / G) +
                          m];
                    sum += a * b;
                  } // for each c
                } // for each s
              } // for each r
              C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) +
                m] = sum;
            } // for each m
          } // for each group
        } // for each w
      } // for each h
    } // for each n
  }
}