odps/df/backends/sqlalchemy/engine.py (311 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import logging import types as tps from decimal import Decimal from .compiler import SQLAlchemyCompiler from . import analyzer as ana from . import rewriter as rwr from .types import df_schema_to_sqlalchemy_columns from ..core import ExecuteNode, Engine from ..errors import CompileError from ..utils import write_table, refresh_dynamic from ..frame import ResultFrame from ..context import context from ... import DataFrame from ...utils import is_source_collection, is_constant_scalar from ...expr.expressions import * from ...expr.core import ExprDAG from ...expr.dynamic import DynamicMixin from ...types import DynamicSchema, Unknown from ...backends.odpssql.types import df_schema_to_odps_schema, df_type_to_odps_type from ....types import PartitionSpec from ....utils import gen_temp_table from ....models import TableSchema, Partition logger = logging.getLogger(__name__) class SQLExecuteNode(ExecuteNode): def _sql(self): raise NotImplementedError def __repr__(self): buf = six.StringIO() sql = self._sql() buf.write('MPP SQL compiled: \n\n') buf.write(sql) return buf.getvalue() def _repr_html_(self): buf = six.StringIO() sql = self._sql() buf.write('<h4>MPP SQL compiled</h4>') buf.write('<code>%s</code>' % sql) return buf.getvalue() _engine_to_connections = {} class SQLAlchemyEngine(Engine): def __init__(self, odps=None): self._odps = odps def _new_execute_node(self, expr_dag): node = SQLExecuteNode(expr_dag) def _sql(*_): return self._compile_sql(expr_dag) def verify(*_): try: _sql(*_) return True except NotImplementedError: return False node._sql = tps.MethodType(_sql, node) node.verify = tps.MethodType(verify, node) return node def _new_analyzer(self, expr_dag, on_sub=None): return ana.Analyzer(expr_dag, on_sub=on_sub) def _new_rewriter(self, expr_dag): return rwr.Rewriter(expr_dag) def _compile_sql(self, expr_dag): self._rewrite(expr_dag) sa = self._compile(expr_dag.root, expr_dag) return self._sa_to_sql(sa) def _sa_to_sql(self, sa): try: return sa.compile(compile_kwargs={"literal_binds": True}) except NotImplementedError: return sa.compile() def _status_ui(self, ui): ui.status('Try to execute by sqlalchemy...', clear_keys=True) @classmethod def _get_or_create_conn(cls, engine): if engine in _engine_to_connections: return _engine_to_connections[engine] conn = engine.connect() _engine_to_connections[engine] = conn return conn def _get_cast_func(self, tp): df_types_to_builtin_types = { types.Integer: int, types.Float: float } for df_type, builtin_type in six.iteritems(df_types_to_builtin_types): if isinstance(tp, df_type): return lambda value: builtin_type(value) return lambda value: value def _get_sa_table(self, table_name, engine, schema): import sqlalchemy metadata = sqlalchemy.MetaData(bind=engine) columns = df_schema_to_sqlalchemy_columns(schema, engine=engine) table = sqlalchemy.Table(table_name, metadata, *columns, extend_existing=True) return table def _get_result(self, table_name, engine, schema, head, tail): import sqlalchemy table = self._get_sa_table(table_name, engine, schema) sa = sqlalchemy.select([table.alias('t1')]) if head: sa = sa.limit(head) elif tail: count = sa.alias('t2').count().execute().scalar() skip = max(count - tail, 0) if skip: sa = sa.offset(skip) conn = self._get_or_create_conn(engine) return conn.execute(sa) @classmethod def _create_table(cls, table_name, sa, expr): from .ext import SACreateTempTableAs return SACreateTempTableAs(table_name, sa) def _convert_result(self, result, schema): res = [list(r) for r in result] if len(res) > 0: record = res[0] for i, r in enumerate(record): if schema[i].type != types.decimal and isinstance(r, Decimal): cast = self._get_cast_func(schema[i].type) for r in res: r[i] = cast(r[i]) return res def _run(self, sa, ui, expr_dag, src_expr, progress_proportion=1, head=None, tail=None, fetch=True, tmp_table_name=None, execution_options=None): self._status_ui(ui) schema = expr_dag.root.schema execution_options = dict() if execution_options is None else execution_options try: if isinstance(src_expr, (CollectionExpr, SequenceExpr)): if tmp_table_name is None: tmp_table_name = gen_temp_table() to_execute = self._create_table(tmp_table_name, sa, expr_dag.root) logger.info('Sql compiled:\n' + self._sa_to_sql(to_execute)) conn = self._get_or_create_conn(sa.bind) conn.execution_options(**execution_options).execute(to_execute) if fetch: result = self._get_result(tmp_table_name, sa.bind, schema, head, tail) res = self._convert_result(result, schema) return tmp_table_name, res else: return tmp_table_name, None else: logger.info('Sql compiled:\n' + self._sa_to_sql(sa)) conn = self._get_or_create_conn(sa.bind) res = conn.execution_options(**execution_options).execute(sa).scalar() if src_expr.dtype != types.decimal and isinstance(res, Decimal): return self._get_cast_func(src_expr.dtype)(res) return res finally: ui.inc(progress_proportion) def _cache(self, expr_dag, dag, expr, **kwargs): if is_source_collection(expr_dag.root) or \ is_constant_scalar(expr_dag.root): return execute_dag = ExprDAG(expr_dag.root, dag=expr_dag) if isinstance(expr, CollectionExpr): table_name = gen_temp_table() table = self._get_table(table_name, expr_dag) root = expr_dag.root sub = CollectionExpr(_source_data=table, _schema=expr.schema) sub.add_deps(root) expr_dag.substitute(root, sub) kw = dict(kwargs) execute_node = self._execute(execute_dag, dag, expr, execute_kw={'fetch': False, 'temp_table_name': table_name}, **kw) def callback(res): if isinstance(expr, DynamicMixin): sub._schema = res.schema refresh_dynamic(sub, expr_dag) execute_node.callback = callback else: assert isinstance(expr, Scalar) # sequence is not cache-able class ValueHolder(object): pass sub = Scalar(_value_type=expr.dtype) sub._value = ValueHolder() root = expr_dag.root sub.add_deps(root) expr_dag.substitute(root, sub) execute_node = self._execute(execute_dag, dag, expr, **kwargs) def callback(res): sub._value = res execute_node.callback = callback return sub, execute_node def _compile(self, expr, dag): compiler = SQLAlchemyCompiler(dag) return compiler.compile(expr) def _get_table(self, table_name, expr_dag, bind=None): if bind is None: bind = next(e for e in expr_dag.traverse() if is_source_collection(e) and e._source_data.bind)._source_data.bind return self._get_sa_table(table_name, bind, expr_dag.root.schema) def _do_execute(self, expr_dag, expr, ui=None, progress_proportion=1, head=None, tail=None, **kwargs): expr_dag = self._convert_table(expr_dag) self._rewrite(expr_dag) src_expr = expr expr = expr_dag.root if isinstance(expr, Scalar) and expr.value is not None: ui.inc(progress_proportion) return expr.value sqlalchemy_expr = self._compile(expr, expr_dag) fetch = kwargs.pop('fetch', True) temp_table_name = kwargs.pop('temp_table_name', None) execution_options = kwargs.pop('execution_options', options.df.sqlalchemy.execution_options) result = self._run(sqlalchemy_expr, ui, expr_dag, src_expr, progress_proportion=progress_proportion, head=head, tail=tail, fetch=fetch, tmp_table_name=temp_table_name, execution_options=execution_options) if not isinstance(src_expr, Scalar): # reset schema if isinstance(src_expr, CollectionExpr) and \ (isinstance(src_expr._schema, DynamicSchema) or any(isinstance(col.type, Unknown) for col in src_expr._schema.columns)): src_expr._schema = expr_dag.root.schema table_name, result = result table = self._get_table(table_name, expr_dag, sqlalchemy_expr.bind) context.cache(src_expr, table) if fetch: return ResultFrame(result, schema=expr_dag.root.schema) else: return table else: context.cache(src_expr, result) return result def _do_persist(self, expr_dag, expr, name, partitions=None, partition=None, odps=None, project=None, ui=None, progress_proportion=1, execute_percent=0.5, lifecycle=None, overwrite=True, drop_table=False, create_table=True, drop_partition=False, create_partition=False, cast=False, **kwargs): expr_dag = self._convert_table(expr_dag) self._rewrite(expr_dag) src_expr = expr expr = expr_dag.root odps = odps or self._odps try: import pandas except ImportError: raise DependencyNotInstalledError('persist requires for pandas') df = self._do_execute(expr_dag, src_expr, ui=ui, progress_proportion=progress_proportion * execute_percent, **kwargs) schema = TableSchema(columns=df.columns) if partitions is not None: if drop_partition: raise ValueError('Cannot drop partitions when specify `partitions`') if create_partition: raise ValueError('Cannot create partitions when specify `partitions`') if isinstance(partitions, tuple): partitions = list(partitions) if not isinstance(partitions, list): partitions = [partitions, ] for p in partitions: if p not in schema: raise ValueError( 'Partition field(%s) does not exist in DataFrame schema' % p) schema = df_schema_to_odps_schema(schema) columns = [c for c in schema.columns if c.name not in partitions] ps = [Partition(name=t, type=schema.get_type(t)) for t in partitions] schema = TableSchema(columns=columns, partitions=ps) if odps.exist_table(name, project=project) or not create_table: t = odps.get_table(name, project=project) else: t = odps.create_table(name, schema, project=project) elif partition is not None: if odps.exist_table(name, project=project) or not create_table: t = odps.get_table(name, project=project) partition = self._get_partition(partition, t) if drop_partition: t.delete_partition(partition, if_exists=True) if create_partition: t.create_partition(partition, if_not_exists=True) else: partition = self._get_partition(partition) project_obj = odps.get_project(project) column_names = [n for n in expr.schema.names if n not in partition] column_types = [ df_type_to_odps_type(expr.schema[n].type, project=project_obj) for n in column_names ] partition_names = [n for n in partition.keys] partition_types = ['string'] * len(partition_names) t = odps.create_table( name, TableSchema.from_lists( column_names, column_types, partition_names, partition_types ), project=project, ) if create_partition is None or create_partition is True: t.create_partition(partition) else: if drop_partition: raise ValueError('Cannot drop partition for non-partition table') if create_partition: raise ValueError('Cannot create partition for non-partition table') if odps.exist_table(name, project=project) or not create_table: t = odps.get_table(name, project=project) if t.table_schema.partitions: raise CompileError('Cannot insert into partition table %s without specifying ' '`partition` or `partitions`.') else: t = odps.create_table(name, df_schema_to_odps_schema(schema), project=project) write_table(df, t, ui=ui, cast=cast, overwrite=overwrite, partitions=partitions, partition=partition, progress_proportion=progress_proportion * (1 - execute_percent)) if partition: filters = [] df = DataFrame(t) for k in partition.keys: filters.append(df[k] == partition[k]) return df.filter(*filters) return DataFrame(t)