def athena2pyarrow()

in awswrangler/_data_types.py [0:0]


def athena2pyarrow(dtype: str, df_type: str | None = None) -> pa.DataType:  # noqa: PLR0911,PLR0912
    """Athena to PyArrow data types conversion."""
    dtype = dtype.strip().lower()
    if dtype.startswith(("array", "struct", "map")):
        orig_dtype: str = dtype
    dtype = dtype.replace(" ", "")
    if dtype == "tinyint":
        return pa.int8()
    if dtype == "smallint":
        return pa.int16()
    if dtype in ("int", "integer"):
        return pa.int32()
    if dtype == "bigint":
        return pa.int64()
    if dtype in ("float", "real"):
        return pa.float32()
    if dtype == "double":
        return pa.float64()
    if dtype == "boolean":
        return pa.bool_()
    if (dtype in ("string", "uuid")) or dtype.startswith("char") or dtype.startswith("varchar"):
        return pa.string()
    if dtype == "timestamp":
        if df_type == "datetime64[ns]":
            return pa.timestamp(unit="ns")
        elif df_type == "datetime64[us]":
            return pa.timestamp(unit="us")
        elif df_type == "datetime64[ms]":
            return pa.timestamp(unit="ms")
        elif df_type == "datetime64[s]":
            return pa.timestamp(unit="s")
        else:
            return pa.timestamp(unit="ns")
    if dtype == "date":
        return pa.date32()
    if dtype in ("binary" or "varbinary"):
        return pa.binary()
    if dtype.startswith("decimal") is True:
        precision, scale = dtype.replace("decimal(", "").replace(")", "").split(sep=",")
        return pa.decimal128(precision=int(precision), scale=int(scale))
    if dtype.startswith("array") is True:
        return pa.list_(value_type=athena2pyarrow(dtype=orig_dtype[6:-1]), list_size=-1)
    if dtype.startswith("struct") is True:
        return pa.struct(
            [(f.split(":", 1)[0].strip(), athena2pyarrow(f.split(":", 1)[1])) for f in _split_struct(orig_dtype[7:-1])]
        )
    if dtype.startswith("map") is True:
        parts: list[str] = _split_map(s=orig_dtype[4:-1])
        return pa.map_(athena2pyarrow(parts[0]), athena2pyarrow(parts[1]))
    raise exceptions.UnsupportedType(f"Unsupported Athena type: {dtype}")