std::vector genDimFlags()

in functorch/csrc/CompileCache.cpp [79:103]


std::vector<int> genDimFlags(c10::IntArrayRef sizes, c10::IntArrayRef strides) {
  // Pack all the properties for each dimension into a uint8.
  int nDims = sizes.size();
  std::vector<int> dimflags(nDims);
  for (int64_t dim = 0; dim < nDims; ++dim) {
    uint8_t flag =
        (sizes[dim] == 0 ? SIZE_MISSING
                         : (sizes[dim] == 1 ? SIZE_ONE : SIZE_OTHER));
    if (strides[dim] == 0) {
      flag |= STRIDE_ZERO;
    } else if (strides[dim] == 1) {
      flag |= STRIDE_ONE;
    } else if (dim + 1 < (int64_t)sizes.size() &&
               strides[dim] == strides[dim + 1] * sizes[dim + 1]) {
      flag |= STRIDE_CONTIGUOUS;
    } else if (dim > 0 && strides[dim] == strides[dim - 1] * sizes[dim - 1] &&
               (dimflags[dim - 1] & STRIDE_CONTIGUOUS) == 0) {
      flag |= STRIDE_TRANSPOSED_CONTIGUOUS;
    } else {
      flag |= STRIDE_AS_ARG;
    }
    dimflags[dim] = flag;
  }
  return dimflags;
}