FBGEMM_API void conv_ref()

in src/RefImplementations.cc [929:1058]


FBGEMM_API void conv_ref(
    const conv_param_t<3>& conv_p,
    const uint8_t* A,
    int32_t A_zero_point,
    const int8_t* B,
    int32_t* C) {
  // filters are assumed to be in G QRS C/G x K format
  int IC = conv_p.IC;
  int OC = conv_p.OC;
  int G = conv_p.G;
  assert(IC % G == 0);
  assert(OC % G == 0);
  array<int, 3> IN_DIM = conv_p.IN_DIM;
  array<int, 3> OUT_DIM = conv_p.OUT_DIM;
  array<int, 3> K = conv_p.K;

  if (conv_p.transposed) {
    // for ref implementation, there is no padding on the input buffer,
    // padding specifies how much we remove from the output buffers
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
        for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
          for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
            // stride on output is fractional stride on input
            // conv index is
            // int t_in =
            //     -conv_p.pad[0] + t * conv_p.stride[0] + q *
            //     conv_p.dilation[0];
            // int h_in =
            //     -conv_p.pad[1] + h * conv_p.stride[1] + r *
            //     conv_p.dilation[1];
            // int w_in =
            //     -conv_p.pad[2] + w * conv_p.stride[2] + s *
            //     conv_p.dilation[2];
            // so we reverse it
            for (int g = 0; g < G; ++g) {
              for (int oc = 0; oc < OC / G; ++oc) {
                int sum = 0;
                for (int q = 0; q < K[0]; ++q) {
                  for (int r = 0; r < K[1]; ++r) {
                    for (int s = 0; s < K[2]; ++s) {
                      int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
                      int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
                      int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
                      int t_in = t / conv_p.stride[0];
                      int h_in = h / conv_p.stride[1];
                      int w_in = w / conv_p.stride[2];
                      for (int ic = 0; ic < IC / G; ++ic) {
                        int a =
                            (t_in * conv_p.stride[0] == t && t_in >= 0 &&
                             t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
                             h_in >= 0 && h_in < IN_DIM[1] &&
                             w_in * conv_p.stride[2] == w && w_in >= 0 &&
                             w_in < IN_DIM[2])
                            ? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
                                  IN_DIM[2]) +
                                 w_in) *
                                    IC +
                                g * (IC / G) + ic]
                            : A_zero_point;
                        int b =
                            B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) *
                                    (IC / G) +
                                ic) *
                               (OC / G)) +
                              oc]; // G Q R S Cin/G Cout/G after transpose
                        sum += a * b;
                      } // for each ic
                    } // for each s
                  } // for each r
                } // for each q
                C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] +
                   ow) *
                      OC +
                  g * (OC / G) + oc] = sum;
              } // for each oc
            } // for each g
          } // for each ow
        } // for each oh
      } // for each ot
    } // for each n
  } else {
    for (int n = 0; n < conv_p.MB; ++n) {
      for (int t = 0; t < OUT_DIM[0]; ++t) {
        for (int h = 0; h < OUT_DIM[1]; ++h) {
          for (int w = 0; w < OUT_DIM[2]; ++w) {
            for (int g = 0; g < G; ++g) {
              for (int m = 0; m < OC / G; ++m) {
                int sum = 0;
                for (int q = 0; q < K[0]; ++q) {
                  int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
                      q * conv_p.dilation[0];
                  for (int r = 0; r < K[1]; ++r) {
                    int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
                        r * conv_p.dilation[1];
                    for (int s = 0; s < K[2]; ++s) {
                      int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
                          s * conv_p.dilation[2];
                      for (int c = 0; c < IC / G; ++c) {
                        int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
                                h_in >= IN_DIM[1] || w_in < 0 ||
                                w_in >= IN_DIM[2]
                            ? A_zero_point
                            : A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
                                     IN_DIM[2] +
                                 w_in) *
                                    IC +
                                g * (IC / G) + c];
                        int b =
                            B[((((g * K[0] + q) * K[1] + r) * K[2] + s) *
                                   (IC / G) +
                               c) *
                                  (OC / G) +
                              m];
                        sum += a * b;
                      } // for each c
                    } // for each s
                  } // for each r
                } // for each q
                C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) *
                      OC +
                  g * (OC / G) + m] = sum;
              } // for each m
            } // for each group
          } // for each w
        } // for each h
      } // for each t
    } // for each n
  }
}