mkldnn_memory_format_t GetDefaultFormat()

in src/operator/nn/mkldnn/mkldnn_base.cc [295:360]


mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
  if (desc.data.ndims == 1) {
    return desc.data.format;
  } else if (desc.data.ndims == 2) {
    if (desc.data.format == mkldnn_io)
      return mkldnn_oi;
    else
      return desc.data.format;
  } else if (desc.data.ndims == 4) {
    switch (desc.data.format) {
      case mkldnn_nchw:
      case mkldnn_nhwc:
      case mkldnn_chwn:
      case mkldnn_nChw8c:
      case mkldnn_nChw16c:
        return mkldnn_nchw;
      case mkldnn_oihw:
      case mkldnn_ihwo:
      case mkldnn_hwio:
      case mkldnn_OIhw8i8o:
      case mkldnn_OIhw16i16o:
      case mkldnn_OIhw4i16o4i:
      case mkldnn_OIhw8i16o2i:
      case mkldnn_OIhw8o16i2o:
      case mkldnn_OIhw8o8i:
      case mkldnn_OIhw16o16i:
      case mkldnn_IOhw16o16i:
      case mkldnn_Oihw8o:
      case mkldnn_Oihw16o:
      case mkldnn_Ohwi8o:
      case mkldnn_Ohwi16o:
      case mkldnn_OhIw16o4i:
        return mkldnn_oihw;
      default:
        LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
        return mkldnn_format_undef;
    }
  } else if (desc.data.ndims == 5) {
    switch (desc.data.format) {
      case mkldnn_goihw:
      case mkldnn_hwigo:
      case mkldnn_gOIhw8i8o:
      case mkldnn_gOIhw16i16o:
      case mkldnn_gOIhw4i16o4i:
      case mkldnn_gOIhw8i16o2i:
      case mkldnn_gOIhw8o16i2o:
      case mkldnn_gOIhw8o8i:
      case mkldnn_gOIhw16o16i:
      case mkldnn_gIOhw16o16i:
      case mkldnn_gOihw8o:
      case mkldnn_Goihw8g:
      case mkldnn_gOihw16o:
      case mkldnn_Goihw16g:
      case mkldnn_gOhwi8o:
      case mkldnn_gOhwi16o:
      case mkldnn_gOhIw16o4i:
        return mkldnn_goihw;
      default:
        LOG(FATAL) << "Unknown MKLDNN format for 5 dimensions: " << desc.data.format;
        return mkldnn_format_undef;
    }
  } else {
    LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
    return mkldnn_format_undef;
  }
}