odps/tunnel/io/types.py (110 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ... import types as odps_types
try:
import pyarrow as pa
except (AttributeError, ImportError):
pa = None
if pa is not None:
_ARROW_TO_ODPS_TYPE = {
pa.string(): odps_types.string,
pa.binary(): odps_types.binary,
pa.int8(): odps_types.tinyint,
pa.int16(): odps_types.smallint,
pa.int32(): odps_types.int_,
pa.int64(): odps_types.bigint,
pa.bool_(): odps_types.boolean,
pa.float32(): odps_types.float_,
pa.float64(): odps_types.double,
pa.date32(): odps_types.date,
pa.timestamp("ms"): odps_types.datetime,
pa.timestamp("ns"): odps_types.timestamp,
}
_ODPS_TO_ARROW_TYPE = {
odps_types.string: pa.string(),
odps_types.binary: pa.binary(),
odps_types.tinyint: pa.int8(),
odps_types.smallint: pa.int16(),
odps_types.int_: pa.int32(),
odps_types.bigint: pa.int64(),
odps_types.boolean: pa.bool_(),
odps_types.float_: pa.float32(),
odps_types.double: pa.float64(),
odps_types.date: pa.date32(),
odps_types.datetime: pa.timestamp("ms"),
odps_types.timestamp: pa.timestamp("ns"),
odps_types.timestamp_ntz: pa.timestamp("ns"),
}
else:
_ARROW_TO_ODPS_TYPE = {}
_ODPS_TO_ARROW_TYPE = {}
def odps_type_to_arrow_type(odps_type):
from ... import types
if odps_type in _ODPS_TO_ARROW_TYPE:
col_type = _ODPS_TO_ARROW_TYPE[odps_type]
else:
if isinstance(odps_type, types.Array):
col_type = pa.list_(odps_type_to_arrow_type(odps_type.value_type))
elif isinstance(odps_type, types.Map):
col_type = pa.map_(
odps_type_to_arrow_type(odps_type.key_type),
odps_type_to_arrow_type(odps_type.value_type),
)
elif isinstance(odps_type, types.Decimal):
precision = odps_type.precision or types.Decimal._default_precision
scale = odps_type.scale or types.Decimal._default_scale
if odps_type.precision is None and not hasattr(pa, "decimal256"):
# need to be less than minimal allowed digits of pa.decimal128
precision = min(precision, 38)
decimal_cls = getattr(pa, "decimal256") if precision > 38 else pa.decimal128
col_type = decimal_cls(precision, scale)
elif isinstance(odps_type, types.Struct):
fields = [
(k, odps_type_to_arrow_type(v))
for k, v in odps_type.field_types.items()
]
col_type = pa.struct(fields)
elif isinstance(odps_type, odps_types.IntervalDayTime):
col_type = pa.struct([("sec", pa.int64()), ("nano", pa.int32())])
else:
raise TypeError("Unsupported type: {}".format(odps_type))
return col_type
def odps_schema_to_arrow_schema(odps_schema):
arrow_schema = []
for col in odps_schema.simple_columns:
col_name = col.name
col_type = odps_type_to_arrow_type(col.type)
arrow_schema.append(pa.field(col_name, col_type))
return pa.schema(arrow_schema)
def arrow_type_to_odps_type(arrow_type):
from ... import types
arrow_decimal_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
arrow_decimal_types += (pa.Decimal256Type,)
if arrow_type in _ARROW_TO_ODPS_TYPE:
col_type = _ARROW_TO_ODPS_TYPE[arrow_type]
else:
if isinstance(arrow_type, pa.ListType):
col_type = types.Array(arrow_type_to_odps_type(arrow_type.value_type))
elif isinstance(arrow_type, pa.MapType):
col_type = types.Map(
arrow_type_to_odps_type(arrow_type.key_type),
arrow_type_to_odps_type(arrow_type.item_type),
)
elif isinstance(arrow_type, arrow_decimal_types):
precision = arrow_type.precision or types.Decimal._default_precision
scale = arrow_type.scale or types.Decimal._default_scale
col_type = types.Decimal(precision, scale)
elif isinstance(arrow_type, pa.StructType):
fields = [
(arrow_type[idx].name, arrow_type_to_odps_type(arrow_type[idx].type))
for idx in range(arrow_type.num_fields)
]
col_type = types.Struct(fields)
else:
raise TypeError("Unsupported type: {}".format(arrow_type))
return col_type
def arrow_schema_to_odps_schema(arrow_schema):
from ... import types
odps_cols = []
for col_name, pa_type in zip(arrow_schema.names, arrow_schema.types):
col_type = arrow_type_to_odps_type(pa_type)
odps_cols.append(types.Column(col_name, col_type))
return types.OdpsSchema(odps_cols)