in lib/torch_util.py [0:0]
def parse_dtype(x):
if isinstance(x, th.dtype):
return x
elif isinstance(x, str):
if x == "float32" or x == "float":
return th.float32
elif x == "float64" or x == "double":
return th.float64
elif x == "float16" or x == "half":
return th.float16
elif x == "uint8":
return th.uint8
elif x == "int8":
return th.int8
elif x == "int16" or x == "short":
return th.int16
elif x == "int32" or x == "int":
return th.int32
elif x == "int64" or x == "long":
return th.int64
elif x == "bool":
return th.bool
else:
raise ValueError(f"cannot parse {x} as a dtype")
else:
raise TypeError(f"cannot parse {type(x)} as dtype")