def dtype_of_type()

in torcharrow/dtypes.py [0:0]


def dtype_of_type(typ: ty.Union[ty.Type, DType]) -> DType:
    assert typ is not None

    if isinstance(typ, DType):
        return typ

    if typing_inspect.is_tuple_type(typ):
        return Tuple([dtype_of_type(a) for a in typing_inspect.get_args(typ)])
    if inspect.isclass(typ) and issubclass(typ, tuple) and hasattr(typ, "_fields"):
        fields = typ._fields
        field_types = getattr(typ, "__annotations__", None)
        if field_types is None or any(n not in field_types for n in fields):
            raise TypeError(
                f"Can't infer type from namedtuple without type hints: {typ}"
            )
        return Struct([Field(n, dtype_of_type(field_types[n])) for n in fields])
    if is_dataclass(typ):
        return Struct(
            [Field(f.name, dtype_of_type(f.type)) for f in dataclasses.fields(typ)]
        )
    if get_origin(typ) in (List, list):
        args = get_args(typ)
        assert len(args) == 1
        elem_type = dtype_of_type(args[0])
        return List(elem_type)
    if get_origin(typ) in (ty.Dict, dict):
        args = get_args(typ)
        assert len(args) == 2
        key = dtype_of_type(args[0])
        value = dtype_of_type(args[1])
        return Map(key, value)
    if typing_inspect.is_optional_type(typ):
        args = get_args(typ)
        assert len(args) == 2
        if issubclass(args[1], type(None)):
            contained = args[0]
        else:
            contained = args[1]
        return dtype_of_type(contained).with_null()
    # same inference rules as for values above
    if typ is float:
        # PyTorch defaults to use Single-precision floating-point format (float32) for Python float type
        return float32
    if typ is int:
        return int64
    if typ is str:
        return string
    if typ is bool:
        return boolean
    raise TypeError(f"Can't infer dtype from {typ}")