in torcharrow/dtypes.py [0:0]
def dtype_of_type(typ: ty.Union[ty.Type, DType]) -> DType:
assert typ is not None
if isinstance(typ, DType):
return typ
if typing_inspect.is_tuple_type(typ):
return Tuple([dtype_of_type(a) for a in typing_inspect.get_args(typ)])
if inspect.isclass(typ) and issubclass(typ, tuple) and hasattr(typ, "_fields"):
fields = typ._fields
field_types = getattr(typ, "__annotations__", None)
if field_types is None or any(n not in field_types for n in fields):
raise TypeError(
f"Can't infer type from namedtuple without type hints: {typ}"
)
return Struct([Field(n, dtype_of_type(field_types[n])) for n in fields])
if is_dataclass(typ):
return Struct(
[Field(f.name, dtype_of_type(f.type)) for f in dataclasses.fields(typ)]
)
if get_origin(typ) in (List, list):
args = get_args(typ)
assert len(args) == 1
elem_type = dtype_of_type(args[0])
return List(elem_type)
if get_origin(typ) in (ty.Dict, dict):
args = get_args(typ)
assert len(args) == 2
key = dtype_of_type(args[0])
value = dtype_of_type(args[1])
return Map(key, value)
if typing_inspect.is_optional_type(typ):
args = get_args(typ)
assert len(args) == 2
if issubclass(args[1], type(None)):
contained = args[0]
else:
contained = args[1]
return dtype_of_type(contained).with_null()
# same inference rules as for values above
if typ is float:
# PyTorch defaults to use Single-precision floating-point format (float32) for Python float type
return float32
if typ is int:
return int64
if typ is str:
return string
if typ is bool:
return boolean
raise TypeError(f"Can't infer dtype from {typ}")