at::ScalarType numpy_dtype_to_aten()

in congestion_control/Utils.cpp [43:66]


at::ScalarType numpy_dtype_to_aten(int dtype) {
  switch (dtype) {
  case NPY_DOUBLE:
    return at::kDouble;
  case NPY_FLOAT:
    return at::kFloat;
  case NPY_HALF:
    return at::kHalf;
  case NPY_LONG:
    return at::kLong;
  case NPY_INT:
    return at::kInt;
  case NPY_SHORT:
    return at::kShort;
  case NPY_BYTE:
    return at::kChar;
  case NPY_UBYTE:
    return at::kByte;
  case NPY_BOOL:
    return at::kBool;
  default:
    throw std::runtime_error(folly::sformat("Unsupported dtype: {}", dtype));
  }
}