core/maxframe/dataframe/datasource/read_odps_query.py (313 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import dataclasses import logging import re from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd from odps import ODPS from odps.types import Column, OdpsSchema, validate_data_type from odps.utils import split_sql_by_semicolon from ... import opcodes from ...config import options from ...core import OutputType from ...core.graph import DAG from ...io.odpsio import odps_schema_to_pandas_dtypes from ...serialization.serializables import ( AnyField, BoolField, DictField, FieldTypes, Int64Field, ListField, SeriesField, StringField, ) from ...utils import is_empty from ..utils import parse_index from .core import ColumnPruneSupportedDataSourceMixin, IncrementalIndexDatasource logger = logging.getLogger(__name__) _DEFAULT_ANONYMOUS_COL_PREFIX = "_anon_col_" _EXPLAIN_DEPENDS_REGEX = re.compile(r"([^\s]+) depends on: ([^\n]+)") _EXPLAIN_JOB_REGEX = re.compile(r"(\S+) is root job") _EXPLAIN_TASKS_HEADER_REGEX = re.compile(r"In Job ([^:]+):") _EXPLAIN_ROOT_TASKS_REGEX = re.compile(r"root Tasks: (.+)") _EXPLAIN_TASK_REGEX = re.compile(r"In Task ([^:]+)") _EXPLAIN_TASK_SCHEMA_REGEX = re.compile( r"In Task ([^:]+)[\S\s]+FS: output: ([^\n #]+)[\s\S]+schema:\s+([\S\s]+)$", re.MULTILINE, ) _EXPLAIN_COLUMN_REGEX = re.compile(r"([^\(]+) \(([^\n]+)\)(?:| AS ([^ ]+))(?:\n|$)") _ANONYMOUS_COL_REGEX = re.compile(r"^_c(\d+)$") _SIMPLE_SCHEMA_COLS_REGEX = re.compile(r"SELECT (([^:]+:[^, ]+[, ]*)+)FROM") _SIMPLE_SCHEMA_COL_REGEX = re.compile(r"([^ \.\)]+):([^ ]+)") @dataclasses.dataclass class DependencySector: roots: List[str] dependencies: List[Tuple[str, str]] def build_dag(self) -> DAG: dag = DAG() for r in self.roots: dag.add_node(r) for v_from, v_to in self.dependencies: dag.add_node(v_from) dag.add_node(v_to) dag.add_edge(v_from, v_to) return dag @dataclasses.dataclass class JobsSector(DependencySector): jobs: Dict[str, "TasksSector"] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class TasksSector(DependencySector): job_name: str tasks: Dict[str, "TaskSector"] = dataclasses.field(default_factory=dict) @dataclasses.dataclass class ColumnSchema: column_name: str column_type: str column_alias: Optional[str] @dataclasses.dataclass class TaskSector: job_name: str task_name: str output_target: Optional[str] schema: List[ColumnSchema] def _split_explain_string(explain_string: str) -> List[str]: parts = explain_string.split("\n\n") final_parts = [] grouped = [] for part in parts: part = part.strip("\n") if grouped and not part.startswith(" "): final_parts.append("\n\n".join(grouped).strip()) grouped = [] grouped.append(part) if grouped: final_parts.append("\n\n".join(grouped).strip()) return final_parts def _find_all_deps(sector: str) -> List[Tuple[str, str]]: deps = [] for match in _EXPLAIN_DEPENDS_REGEX.findall(sector): descendant = match[0] for r in match[1].split(","): deps.append((r.strip(), descendant)) return deps def _resolve_jobs_sector(sector: str) -> JobsSector: match = _EXPLAIN_JOB_REGEX.search(sector) roots = [r.strip() for r in match.group(1).split(",")] deps = _find_all_deps(sector) return JobsSector(roots, deps) def _resolve_tasks_sector(sector: str) -> TasksSector: match = _EXPLAIN_ROOT_TASKS_REGEX.search(sector) roots = [r.strip() for r in match.group(1).split(",")] match = _EXPLAIN_TASKS_HEADER_REGEX.search(sector) job_name = match.group(1) deps = _find_all_deps(sector) return TasksSector(roots, deps, job_name) def _resolve_task_sector(job_name: str, sector: str) -> TaskSector: match = _EXPLAIN_TASK_REGEX.match(sector) task_name = match.group(1) match = _EXPLAIN_TASK_SCHEMA_REGEX.match(sector) if match is None: return TaskSector(job_name, task_name, None, []) out_target = match.group(2) out_schema = match.group(3) schemas = [] for match in _EXPLAIN_COLUMN_REGEX.findall(out_schema): col_name, data_type, alias = match schemas.append(ColumnSchema(col_name.strip(), data_type.strip(), alias.strip())) return TaskSector(job_name, task_name, out_target, schemas) def _select_task_prefix(sector: TasksSector, prefix: str) -> List[TaskSector]: if prefix in sector.tasks: return [sector.tasks[prefix]] return [v for k, v in sector.tasks.items() if k.startswith(prefix + "_")] def _parse_full_explain(explain_string: str) -> OdpsSchema: sectors = _split_explain_string(explain_string) jobs_sector = tasks_sector = None for sector in sectors: if _EXPLAIN_JOB_REGEX.search(sector): jobs_sector = _resolve_jobs_sector(sector) elif _EXPLAIN_TASKS_HEADER_REGEX.search(sector): tasks_sector = _resolve_tasks_sector(sector) assert jobs_sector is not None jobs_sector.jobs[tasks_sector.job_name] = tasks_sector elif _EXPLAIN_TASK_REGEX.search(sector): assert tasks_sector is not None task_sector = _resolve_task_sector(tasks_sector.job_name, sector) tasks_sector.tasks[task_sector.task_name] = task_sector job_dag = jobs_sector.build_dag() indep_job_names = list(job_dag.iter_indep(reverse=True)) schema_signatures = dict() for job_name in indep_job_names: tasks_sector = jobs_sector.jobs[job_name] task_dag = tasks_sector.build_dag() indep_task_names = list(task_dag.iter_indep(reverse=True)) for task_name in indep_task_names: for task_sector in _select_task_prefix(tasks_sector, task_name): if not task_sector.schema: # pragma: no cover raise ValueError("Cannot detect output schema") if task_sector.output_target != "Screen": raise ValueError("The SQL statement should be an instant query") sig_tuples = sorted( [ (c.column_alias or c.column_name, c.column_type) for c in task_sector.schema ] ) schema_signatures[hash(tuple(sig_tuples))] = task_sector.schema if len(schema_signatures) != 1: raise ValueError("Only one final task is allowed in SQL statement") schema = list(schema_signatures.values())[0] cols = [ Column(c.column_alias or c.column_name, validate_data_type(c.column_type)) for c in schema ] return OdpsSchema(cols) def _parse_simple_explain(explain_string: str) -> OdpsSchema: fields_match = _SIMPLE_SCHEMA_COLS_REGEX.search(explain_string) if not fields_match: raise ValueError("Cannot detect output table schema") fields_str = fields_match.group(1) cols = [] for field, type_name in _SIMPLE_SCHEMA_COL_REGEX.findall(fields_str): cols.append(Column(field, validate_data_type(type_name.rstrip(",")))) return OdpsSchema(cols) def _parse_explained_schema(explain_string: str) -> OdpsSchema: if explain_string.startswith("AdhocSink"): return _parse_simple_explain(explain_string) else: return _parse_full_explain(explain_string) def _build_explain_sql(sql_stmt: str, no_split: bool = False) -> str: if no_split: return "EXPLAIN " + sql_stmt sql_parts = split_sql_by_semicolon(sql_stmt) if not sql_parts: raise ValueError(f"Cannot explain SQL statement {sql_stmt}") sql_parts[-1] = "EXPLAIN " + sql_parts[-1] return "\n".join(sql_parts) class DataFrameReadODPSQuery( IncrementalIndexDatasource, ColumnPruneSupportedDataSourceMixin, ): _op_type_ = opcodes.READ_ODPS_QUERY query = StringField("query") dtypes = SeriesField("dtypes", default=None) columns = AnyField("columns", default=None) nrows = Int64Field("nrows", default=None) use_arrow_dtype = BoolField("use_arrow_dtype", default=None) string_as_binary = BoolField("string_as_binary", default=None) index_columns = ListField("index_columns", FieldTypes.string, default=None) index_dtypes = SeriesField("index_dtypes", default=None) column_renames = DictField("column_renames", default=None) def get_columns(self): return self.columns or list(self.dtypes.index) def set_pruned_columns(self, columns, *, keep_order=None): # pragma: no cover self.columns = columns def __call__(self, chunk_bytes=None, chunk_size=None): if is_empty(self.index_columns): index_value = parse_index(pd.RangeIndex(0)) elif len(self.index_columns) == 1: index_value = parse_index( pd.Index([], name=self.index_columns[0]).astype( self.index_dtypes.iloc[0] ) ) else: idx = pd.MultiIndex.from_frame( pd.DataFrame([], columns=self.index_columns).astype(self.index_dtypes) ) index_value = parse_index(idx) if self.dtypes is not None: columns_value = parse_index(self.dtypes.index, store_data=True) shape = (np.nan, len(self.dtypes)) else: columns_value = None shape = (np.nan, np.nan) self.output_types = [OutputType.dataframe] return self.new_tileable( [], None, shape=shape, dtypes=self.dtypes, index_value=index_value, columns_value=columns_value, chunk_bytes=chunk_bytes, chunk_size=chunk_size, ) def read_odps_query( query: str, odps_entry: ODPS = None, index_col: Union[None, str, List[str]] = None, string_as_binary: bool = None, sql_hints: Dict[str, str] = None, anonymous_col_prefix: str = _DEFAULT_ANONYMOUS_COL_PREFIX, skip_schema: bool = False, **kw, ): """ Read data from a MaxCompute (ODPS) query into DataFrame. Supports specifying some columns as indexes. If not specified, RangeIndex will be generated. Parameters ---------- query: str MaxCompute SQL statement. index_col: Union[None, str, List[str]] Columns to be specified as indexes. string_as_binary: bool, optional Whether to convert string columns to binary. sql_hints: Dict[str, str], optional User specified SQL hints. anonymous_col_prefix: str, optional Prefix for anonymous columns, '_anon_col_' by default. skip_schema: bool, optional Skip resolving output schema before execution. Once this is configured, the output DataFrame cannot be inputs of other DataFrame operators before execution. Returns ------- result: DataFrame DataFrame read from MaxCompute (ODPS) table """ no_split_sql = kw.pop("no_split_sql", False) hints = options.sql.settings.copy() or {} if sql_hints: hints.update(sql_hints) odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments() if odps_entry is None: raise ValueError( "Need to provide an odps_entry argument or hold a default ODPS entry." ) if options.session.enable_schema or odps_entry.is_schema_namespace_enabled(): hints["odps.namespace.schema"] = "true" hints["odps.sql.allow.namespace.schema"] = "true" hints["odps.sql.submit.mode"] = "script" # fixme workaround for multi-stage split process hints["odps.sql.object.table.split.by.object.size.enabled"] = "false" if odps_entry is None: raise ValueError("Missing odps_entry parameter") col_renames = {} if not skip_schema: explain_stmt = _build_explain_sql(query, no_split=no_split_sql) inst = odps_entry.execute_sql(explain_stmt, hints=hints) logger.debug("Explain instance ID: %s", inst.id) explain_str = list(inst.get_task_results().values())[0] try: odps_schema = _parse_explained_schema(explain_str) except BaseException as ex: exc = ValueError( f"Failed to obtain schema from SQL explain: {ex!r}" f"\nExplain instance ID: {inst.id}" ) raise exc.with_traceback(ex.__traceback__) from None new_columns = [] for col in odps_schema.columns: anon_match = _ANONYMOUS_COL_REGEX.match(col.name) if anon_match and col.name not in query: new_name = anonymous_col_prefix + anon_match.group(1) col_renames[col.name] = new_name new_columns.append(Column(new_name, col.type)) else: new_columns.append(col) dtypes = odps_schema_to_pandas_dtypes(OdpsSchema(new_columns)) else: dtypes = None if not index_col: index_dtypes = None else: if dtypes is None: raise ValueError("Cannot configure index_col when skip_schema is True") if isinstance(index_col, str): index_col = [index_col] index_col_set = set(index_col) data_cols = [c for c in dtypes.index if c not in index_col_set] idx_dtype_vals = [dtypes[c] for c in index_col] col_dtype_vals = [dtypes[c] for c in data_cols] index_dtypes = pd.Series(idx_dtype_vals, index=index_col) dtypes = pd.Series(col_dtype_vals, index=data_cols) chunk_bytes = kw.pop("chunk_bytes", None) chunk_size = kw.pop("chunk_size", None) op = DataFrameReadODPSQuery( query=query, dtypes=dtypes, use_arrow_dtype=kw.pop("use_arrow_dtype", True), string_as_binary=string_as_binary, index_columns=index_col, index_dtypes=index_dtypes, column_renames=col_renames, no_split_sql=no_split_sql, ) return op(chunk_bytes=chunk_bytes, chunk_size=chunk_size)