core/maxframe/io/odpsio/schema.py (357 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import string
from collections import defaultdict
from typing import Any, Dict, Tuple, Union
import numpy as np
import pandas as pd
import pyarrow as pa
from odps import types as odps_types
from pandas.api import types as pd_types
from ...core import TILEABLE_TYPE, OutputType
from ...lib.dtypes_extension import ArrowDtype
from ...protocol import DataFrameTableMeta
from ...tensor.core import TENSOR_TYPE
_TEMP_TABLE_PREFIX = "tmp_mf_"
DEFAULT_SINGLE_INDEX_NAME = "_idx_0"
_arrow_to_odps_types = {
pa.string(): odps_types.string,
pa.binary(): odps_types.binary,
pa.int8(): odps_types.tinyint,
pa.int16(): odps_types.smallint,
pa.int32(): odps_types.int_,
pa.int64(): odps_types.bigint,
pa.bool_(): odps_types.boolean,
pa.float32(): odps_types.float_,
pa.float64(): odps_types.double,
pa.date32(): odps_types.date,
pa.timestamp("ms"): odps_types.datetime,
pa.timestamp("us"): odps_types.timestamp,
pa.timestamp("ns"): odps_types.timestamp,
}
_odps_type_to_arrow = {
odps_types.string: pa.string(),
odps_types.binary: pa.binary(),
odps_types.tinyint: pa.int8(),
odps_types.smallint: pa.int16(),
odps_types.int_: pa.int32(),
odps_types.bigint: pa.int64(),
odps_types.boolean: pa.bool_(),
odps_types.float_: pa.float32(),
odps_types.double: pa.float64(),
odps_types.date: pa.date32(),
odps_types.datetime: pa.timestamp("ms"),
odps_types.json: pa.string(),
odps_types.timestamp: pa.timestamp("ns"),
odps_types.timestamp_ntz: pa.timestamp("ns"),
}
_based_for_pandas_pa_dtypes = Union[pa.MapType]
def is_based_for_pandas_dtype(dtype: pa.DataType) -> bool:
"""
Check whether the arrow type is based for one pandas data type.
If true, we should make sure the environment support ArrowDtype.
"""
if not isinstance(dtype, _based_for_pandas_pa_dtypes):
return False
if ArrowDtype is None:
raise ImportError("ArrowDtype is not supported in current environment")
return True
def pandas_types_to_arrow_schema(df_obj: pd.DataFrame) -> pa.Schema:
"""
This one is only called when a pandas DataFrame is written to ODPS. So we can check
whether the ArrowDtype is supported.
"""
schema = pa.Schema.from_pandas(df_obj, preserve_index=False)
for idx, col_dtype in enumerate(df_obj.dtypes.items()):
if ArrowDtype is not None and isinstance(col_dtype[1], ArrowDtype):
schema.set(idx, pa.field(col_dtype[0], col_dtype[1].pyarrow_dtype))
return schema
def arrow_type_to_odps_type(
arrow_type: pa.DataType, col_name: str, unknown_as_string: bool = False
) -> odps_types.DataType:
if arrow_type in _arrow_to_odps_types:
return _arrow_to_odps_types[arrow_type]
elif isinstance(arrow_type, pa.ListType):
return odps_types.Array(
arrow_type_to_odps_type(arrow_type.value_type, col_name, unknown_as_string)
)
elif isinstance(arrow_type, pa.MapType):
return odps_types.Map(
arrow_type_to_odps_type(arrow_type.key_type, col_name, unknown_as_string),
arrow_type_to_odps_type(arrow_type.item_type, col_name, unknown_as_string),
)
elif isinstance(arrow_type, pa.StructType):
type_dict = {}
for idx in range(arrow_type.num_fields):
field = arrow_type[idx]
type_dict[field.name] = arrow_type_to_odps_type(
field.type, col_name, unknown_as_string
)
return odps_types.Struct(type_dict)
elif isinstance(arrow_type, pa.Decimal128Type):
return odps_types.Decimal(arrow_type.precision, arrow_type.scale)
if unknown_as_string:
return odps_types.string
else:
raise TypeError(
"Unknown type {}, column name is {},"
"specify `unknown_as_string=True` "
"or `as_type` to set column dtype".format(arrow_type, col_name)
)
def arrow_schema_to_odps_schema(
arrow_schema: pa.Schema, unknown_as_string: bool = False
) -> odps_types.OdpsSchema:
odps_cols = []
for col_name, col_type in zip(arrow_schema.names, arrow_schema.types):
col_odps_type = arrow_type_to_odps_type(
col_type, col_name, unknown_as_string=unknown_as_string
)
odps_cols.append(odps_types.Column(col_name, col_odps_type))
return odps_types.OdpsSchema(odps_cols)
def odps_type_to_arrow_type(
odps_type: odps_types.DataType, col_name: str
) -> pa.DataType:
if odps_type in _odps_type_to_arrow:
col_type = _odps_type_to_arrow[odps_type]
else:
if isinstance(odps_type, odps_types.Array):
col_type = pa.list_(odps_type_to_arrow_type(odps_type.value_type, col_name))
elif isinstance(odps_type, odps_types.Map):
col_type = pa.map_(
odps_type_to_arrow_type(odps_type.key_type, col_name),
odps_type_to_arrow_type(odps_type.value_type, col_name),
)
elif isinstance(odps_type, odps_types.Struct):
fields = [
(k, odps_type_to_arrow_type(v, col_name))
for k, v in odps_type.field_types.items()
]
col_type = pa.struct(fields)
elif isinstance(odps_type, odps_types.Decimal):
if odps_type.name == "decimal":
# legacy decimal data without precision or scale
# precision data from internal compat mode
col_type = pa.decimal128(38, 18)
else:
col_type = pa.decimal128(
odps_type.precision or odps_types.Decimal._max_precision,
odps_type.scale or odps_types.Decimal._max_scale,
)
elif isinstance(odps_type, (odps_types.Varchar, odps_types.Char)):
col_type = pa.string()
else:
raise TypeError(
"Unsupported type {}, column name is {}".format(odps_type, col_name)
)
return col_type
def odps_schema_to_arrow_schema(
odps_schema: odps_types.OdpsSchema, with_partitions: bool = False
) -> pa.Schema:
arrow_schema = []
cols = odps_schema.columns if with_partitions else odps_schema.simple_columns
for col in cols:
col_name = col.name
col_type = odps_type_to_arrow_type(col.type, col_name)
arrow_schema.append(pa.field(col_name, col_type))
return pa.schema(arrow_schema)
def odps_schema_to_pandas_dtypes(
odps_schema: odps_types.OdpsSchema, with_partitions: bool = False
) -> pd.Series:
arrow_schema = odps_schema_to_arrow_schema(
odps_schema, with_partitions=with_partitions
)
return arrow_table_to_pandas_dataframe(arrow_schema.empty_table()).dtypes
def arrow_table_to_pandas_dataframe(
table: pa.Table, meta: DataFrameTableMeta = None
) -> pd.DataFrame:
df = table.to_pandas(
types_mapper=lambda x: (
ArrowDtype(x) if is_based_for_pandas_dtype(x) else None
),
ignore_metadata=True,
)
if not meta:
return df
# If meta is passed, we should convert the dtypes to match the ones in the meta
converted_column_dtypes = dict()
for source_dtype, target_col, target_dtype in zip(
df.dtypes.values,
df.columns,
list(meta.pd_index_dtypes.values) + list(meta.pd_column_dtypes.values),
):
if source_dtype != target_dtype:
# Converting tz-aware dtype to tz-native dtype is a special case.
# In numpy1.19, we can't use numpy.dtype.DateTime64Dtype
if (
isinstance(source_dtype, pd.DatetimeTZDtype)
and isinstance(target_dtype, np.dtype)
and target_dtype.name.startswith("datetime64")
):
df[target_col] = df[target_col].dt.tz_localize(None)
else:
converted_column_dtypes[target_col] = target_dtype
if converted_column_dtypes:
df = df.astype(converted_column_dtypes)
return df
def pandas_dataframe_to_arrow_table(df: pd.DataFrame, nthreads=1) -> pa.Table:
schema = pandas_types_to_arrow_schema(df)
return pa.Table.from_pandas(
df, schema=schema, nthreads=nthreads, preserve_index=False
)
def is_scalar_object(df_obj: Any) -> bool:
return (
isinstance(df_obj, TENSOR_TYPE) and df_obj.shape == ()
) or pd_types.is_scalar(df_obj)
def _scalar_as_index(df_obj: Any) -> pd.Index:
if isinstance(df_obj, TILEABLE_TYPE):
return pd.Index([], dtype=df_obj.dtype)
else:
return pd.Index([df_obj])[:0]
def pandas_to_odps_schema(
df_obj: Any,
unknown_as_string: bool = False,
ignore_index=False,
) -> Tuple[odps_types.OdpsSchema, DataFrameTableMeta]:
from ... import dataframe as md
from .arrow import pandas_to_arrow
if is_scalar_object(df_obj):
empty_index = None
elif hasattr(df_obj, "index_value"):
empty_index = df_obj.index_value.to_pandas()[:0]
elif not isinstance(df_obj, pd.Index):
empty_index = df_obj.index[:0]
else:
empty_index = df_obj[:0]
if hasattr(df_obj, "columns_value"):
empty_columns = df_obj.dtypes.index
elif hasattr(df_obj, "columns"):
empty_columns = df_obj.columns
else:
empty_columns = None
ms_cols = None
if isinstance(df_obj, (md.DataFrame, pd.DataFrame)):
empty_df_obj = pd.DataFrame(
[], columns=empty_columns, index=empty_index
).astype(df_obj.dtypes)
ms_cols = [
col for col, dt in df_obj.dtypes.items() if dt == np.dtype("datetime64[ms]")
]
elif isinstance(df_obj, (md.Series, pd.Series)):
empty_df_obj = pd.Series([], name=df_obj.name, index=empty_index).astype(
df_obj.dtype
)
ms_cols = df_obj.dtype == np.dtype("datetime64[ms]")
elif isinstance(df_obj, (md.Index, pd.Index)):
empty_df_obj = empty_index
if isinstance(empty_index, pd.MultiIndex):
ms_cols = [
idx
for idx, dt in enumerate(empty_index.dtypes.values)
if dt == np.dtype("datetime64[ms]")
]
else:
ms_cols = df_obj.dtype == np.dtype("datetime64[ms]")
else:
empty_df_obj = df_obj
arrow_data, table_meta = pandas_to_arrow(
empty_df_obj, ignore_index=ignore_index, ms_cols=ms_cols
)
return (
arrow_schema_to_odps_schema(
arrow_data.schema, unknown_as_string=unknown_as_string
),
table_meta,
)
def build_table_column_name(
col_idx: int, pd_col_name: Any, records: Dict[str, str]
) -> str:
"""
Convert column name to MaxCompute acceptable names
Parameters
----------
col_idx:
index of the column
pd_col_name:
column name in pandas
records:
record for existing columns
Returns
-------
converted column name
"""
def _is_col_name_legal(name: str):
if len(name) < 1 or len(name) > 128:
return False
if name[0] not in string.ascii_letters and name[0] != "_":
return False
for ch in name:
if ch not in string.digits and ch not in string.ascii_letters and ch != "_":
return False
return True
try:
return records[pd_col_name]
except KeyError:
pass
if isinstance(pd_col_name, str):
col_name = pd_col_name
elif isinstance(pd_col_name, tuple):
col_name = "_".join(str(x) for x in pd_col_name)
else:
col_name = str(pd_col_name)
col_name = col_name.lower()
if not _is_col_name_legal(col_name):
col_name = f"_column_{col_idx}"
records[pd_col_name] = col_name
return col_name
def build_dataframe_table_meta(
df_obj: Any, ignore_index: bool = False
) -> DataFrameTableMeta:
from ... import dataframe as md
col_to_count = defaultdict(lambda: 0)
col_to_idx = defaultdict(lambda: 0)
pd_col_to_col_name = dict()
if isinstance(df_obj, (md.DataFrame, pd.DataFrame)):
obj_type = OutputType.dataframe
elif isinstance(df_obj, (md.Series, pd.Series)):
obj_type = OutputType.series
elif isinstance(df_obj, (md.Index, pd.Index)):
obj_type = OutputType.index
elif is_scalar_object(df_obj):
obj_type = OutputType.scalar
else: # pragma: no cover
raise TypeError(f"Cannot accept type {type(df_obj)}")
if obj_type == OutputType.scalar:
pd_dtypes = pd.Series([])
column_index_names = []
index_obj = _scalar_as_index(df_obj)
elif obj_type == OutputType.index:
pd_dtypes = pd.Series([])
column_index_names = []
index_obj = df_obj
elif obj_type == OutputType.series:
pd_dtypes = pd.Series([df_obj.dtype], index=[df_obj.name])
column_index_names = [None]
index_obj = df_obj.index
else:
pd_dtypes = df_obj.dtypes
column_index_names = list(pd_dtypes.index.names)
index_obj = df_obj.index
if isinstance(df_obj, TILEABLE_TYPE):
table_name = _TEMP_TABLE_PREFIX + str(df_obj.key)
else:
table_name = None
sql_columns = [None] * len(pd_dtypes)
pd_col_names = pd_dtypes.index
if obj_type == OutputType.series and df_obj.name is None:
# use special table column name for series
pd_col_names = ["_data"]
for idx, col in enumerate(pd_col_names):
sql_columns[idx] = col_name = build_table_column_name(
idx, col, pd_col_to_col_name
)
col_to_count[col_name] += 1
final_sql_columns = []
for col in sql_columns:
if col_to_count[col] > 1:
col_name = f"{col}_{col_to_idx[col]}"
col_to_idx[col] += 1
while col_name in col_to_count:
col_name = f"{col}_{col_to_idx[col]}"
col_to_idx[col] += 1
final_sql_columns.append(col_name)
else:
final_sql_columns.append(col)
if hasattr(index_obj, "index_value"):
pd_index_val = index_obj.index_value.to_pandas()
else:
pd_index_val = index_obj
level_dtypes = [
pd_index_val.get_level_values(level).dtype
for level in range(pd_index_val.nlevels)
]
index_dtypes = pd.Series(level_dtypes, index=pd_index_val.names)
if ignore_index and obj_type != OutputType.index:
table_index_column_names = []
pd_index_dtypes = pd.Series([], index=[])
else:
table_index_column_names = [f"_idx_{i}" for i in range(len(index_obj.names))]
pd_index_dtypes = index_dtypes
return DataFrameTableMeta(
table_name=table_name,
type=obj_type,
table_column_names=final_sql_columns,
table_index_column_names=table_index_column_names,
pd_column_dtypes=pd_dtypes,
pd_column_level_names=column_index_names,
pd_index_dtypes=pd_index_dtypes,
)