odps/df/backends/sqlalchemy/compiler.py (555 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from datetime import timedelta from ..core import Backend from ..errors import CompileError from ...expr.reduction import * from ...expr.arithmetic import * from ...expr.datetimes import * from ...expr.window import * from ...expr.merge import * from ...expr.utils import highest_precedence_data_type from ... import types as df_types from ...utils import is_constant_scalar, traverse_until_source, is_source_collection from ....compat import lzip from .... import utils from . import types try: from sqlalchemy import Table as SATable, select, func, case, \ extract, desc, distinct, literal, join, union, union_all from sqlalchemy.sql.expression import Alias, text # define the compile ext from .ext import * has_sqlalchemy = True except ImportError: has_sqlalchemy = False BINARY_OP = { 'Add': operator.add, 'Substract': operator.sub, 'Multiply': operator.mul, 'Divide': operator.div if six.PY2 else operator.truediv, 'Mod': operator.mod, 'FloorDivide': operator.floordiv, 'Greater': operator.gt, 'GreaterEqual': operator.ge, 'Less': operator.lt, 'LessEqual': operator.le, 'Equal': operator.eq, 'NotEqual': operator.ne, 'And': operator.and_, 'Or': operator.or_ } UNARY_OP = { 'Negate': operator.neg, 'Invert': operator.inv, } DATE_KEY_DIC = { 'Day': 'days', 'Hour': 'hours', 'Minute': 'minutes', 'Second': 'seconds', 'MilliSecond': 'milliseconds', 'MicroSecond': 'microseconds', } MATH_COMPILE_DIC = { 'Abs': 'abs', 'Sqrt': 'sqrt', 'Sin': 'sin', 'Cos': 'cos', 'Tan': 'tan', 'Exp': 'exp', 'Arccos': 'acos', 'Arcsin': 'asin', 'Arctan': 'atan', 'Ceil': 'ceil', 'Floor': 'floor' } WINDOW_COMPILE_DIC = { 'CumSum': 'sum', 'CumMean': 'avg', 'CumStd': 'stddev', 'CumMax': 'max', 'CumMin': 'min', 'CumCount': 'count', 'Lag': 'lag', 'Lead': 'lead', 'Rank': 'rank', 'DenseRank': 'dense_rank', 'PercentRank': 'percent_rank', 'RowNumber': 'row_number' } DATE_PARTS_DIC = { 'Year': 'year', 'Month': 'month', 'Day': 'day', 'Hour': 'hour', 'Minute': 'minute', 'Second': 'second', 'WeekOfYear': 'week', 'DayOfYear': 'doy', 'MicroSecond': 'microseconds', 'UnixTimestamp': 'epoch', } class SQLAlchemyCompiler(Backend): def __init__(self, expr_dag): self._expr_dag = expr_dag self._expr_to_sqlalchemy = ExprDictionary() self._id_gen = itertools.count(1) self._sa_engine = None def _new_alias(self): return 't%s' % next(self._id_gen) def compile(self, expr, traversed=None): if traversed is None: traversed = set() for node in traverse_until_source(expr): if id(node) not in traversed: node.accept(self) traversed.add(id(node)) sa_expr = self._expr_to_sqlalchemy[self._expr_dag.root] if is_source_collection(expr): sa_expr = select([sa_expr]) return sa_expr def _add(self, expr, op): self._expr_to_sqlalchemy[expr] = op def _gen_select_columns(self, fields): sa_exprs = [] for field in fields: if not isinstance(field, Scalar) or field._value is None: sa_exprs.append(self._expr_to_sqlalchemy[field].label(field.name)) else: sa_exprs.append( literal(self._expr_to_sqlalchemy[field]).label(field.name)) return sa_exprs def visit_source_collection(self, expr): table = next(expr.data_source()) if not isinstance(table, SATable): raise ValueError('Source data must be a sqlalchemy table') if table.bind and self._sa_engine is None: self._sa_engine = table.bind self._add(expr, table.alias(self._new_alias())) def visit_project_collection(self, expr): selects = select(self._gen_select_columns(expr._fields))\ .select_from(self._expr_to_sqlalchemy[expr.input]) if expr is not self._expr_dag.root: selects = selects.alias(self._new_alias()) self._add(expr, selects) def visit_apply_collection(self, expr): raise NotImplementedError def visit_filter_collection(self, expr): input = self._expr_to_sqlalchemy[expr.input] predicate = self._expr_to_sqlalchemy[expr._predicate] filtered = input.select(predicate) if expr is not self._expr_dag.root: filtered = filtered.alias(self._new_alias()) self._add(expr, filtered) def visit_slice_collection(self, expr): input = self._expr_to_sqlalchemy[expr.input] sliced = expr._indexes if sliced[2] is not None: raise NotImplementedError if sliced[0] is not None and sliced[0].value < 0: raise CompileError('start number must be greater than 0') if sliced[1] is not None and sliced[1].value <= 0: raise CompileError('end number must be greater than 0') kw = dict() if sliced[0] is not None and sliced[0].value > 0: kw['offset'] = sliced[0].value if sliced[1] is not None and sliced[1].value > 0: kw['limit'] = sliced[1].value input = input.select(**kw) if expr is not self._expr_dag.root: input = input.alias(self._new_alias()) self._add(expr, input) def visit_element_op(self, expr): input = self._expr_to_sqlalchemy.get(expr.input) if isinstance(expr, element.IsNull): sa_expr = input.is_(None) elif isinstance(expr, element.NotNull): sa_expr = input.isnot(None) elif isinstance(expr, element.FillNa): sa_expr = case([(input.is_(None), expr.fill_value)], else_=input) elif isinstance(expr, (element.IsIn, element.NotIn)): op = input.in_ if isinstance(expr, element.IsIn) else input.notin_ if expr._values is None: sa_expr = op([None]) elif len(expr._values) == 1 and isinstance(expr._values[0], SequenceExpr): right = select([self._expr_to_sqlalchemy[expr._values[0]]]) sa_expr = op(right) else: sa_expr = op(tuple(self._expr_to_sqlalchemy[it] for it in expr._values)) elif isinstance(expr, element.Between): if not expr.inclusive: raise NotImplementedError sa_expr = input.between( self._expr_to_sqlalchemy[expr._left], self._expr_to_sqlalchemy[expr._right] ) elif isinstance(expr, element.IfElse): sa_expr = case([(input, self._expr_to_sqlalchemy[expr._then])], else_=self._expr_to_sqlalchemy[expr._else]) elif isinstance(expr, element.Switch): conditions = [self._expr_to_sqlalchemy[cond] for cond in expr._conditions] thens = [self._expr_to_sqlalchemy[then] for then in expr._thens] sa_else = self._expr_to_sqlalchemy[expr._default] \ if expr._default is not None else expr._default if expr._input is None: sa_expr = case(lzip(conditions, thens), else_=sa_else) else: sa_expr = case(dict(lzip(conditions, thens)), value=input, else_=sa_else) else: raise NotImplementedError self._add(expr, sa_expr) def visit_binary_op(self, expr): if isinstance(expr, Power): op = func.pow elif isinstance(expr, FloorDivide): op = operator.div if six.PY2 else operator.truediv elif isinstance(expr, (Add, Substract)) and expr.dtype == df_types.datetime: if isinstance(expr, Add) and \ all(child.dtype == df_types.datetime for child in (expr.lhs, expr.rhs)): raise CompileError('Cannot add two datetimes') if isinstance(expr.rhs, DTScalar) or (isinstance(expr, Add) and expr.lhs, DTScalar): if isinstance(expr.rhs, DTScalar): dt, scalar = expr.lhs, expr.rhs else: dt, scalar = expr.rhs, expr.lhs val = scalar.value if isinstance(expr, Substract): val = -val dt_type = type(scalar).__name__[:-6] sa_dt = self._expr_to_sqlalchemy[dt] try: key = DATE_KEY_DIC[dt_type] except KeyError: raise NotImplementedError if self._sa_engine and self._sa_engine.name == 'mysql': if dt_type == 'MilliSecond': val, dt_type = val * 1000, 'MicroSecond' sa_expr = func.date_add(sa_dt, text('interval %d %s' % (val, dt_type.lower()))) else: sa_expr = sa_dt + timedelta(**{key: val}) self._add(expr, sa_expr) return else: raise NotImplementedError elif isinstance(expr, Substract) and expr._lhs.dtype == df_types.datetime and \ expr._rhs.dtype == df_types.datetime: sa_expr = self._expr_to_sqlalchemy[expr._lhs] - self._expr_to_sqlalchemy[expr._rhs] if self._sa_engine and self._sa_engine.name == 'mysql': sa_expr = func.abs(func.microsecond(sa_expr) .cast(types.df_type_to_sqlalchemy_type(expr.dtype))) / 1000 else: sa_expr = func.abs(extract('MICROSECONDS', sa_expr) .cast(types.df_type_to_sqlalchemy_type(expr.dtype))) / 1000 self._add(expr, sa_expr) return elif isinstance(expr, Mod): lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs] sa_expr = BINARY_OP[expr.node_name](lhs, rhs) if not is_constant_scalar(expr._rhs): sa_expr = case([(rhs > 0, func.abs(sa_expr))], else_=sa_expr) elif expr._rhs.value > 0: sa_expr = func.abs(sa_expr) self._add(expr, sa_expr) return else: op = BINARY_OP[expr.node_name] lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs] sa_expr = op(lhs, rhs) self._add(expr, sa_expr) def visit_unary_op(self, expr): if isinstance(expr, Abs): op = func.abs else: op = UNARY_OP[expr.node_name] self._add(expr, op(self._expr_to_sqlalchemy[expr._input])) def visit_math(self, expr): try: op = getattr(func, MATH_COMPILE_DIC[expr.node_name]) sa_expr = op(self._expr_to_sqlalchemy[expr._input]) except KeyError: if expr.node_name == 'Log': if expr._base is not None: sa_expr = SALog('log', self._expr_to_sqlalchemy[expr._base], self._expr_to_sqlalchemy[expr._base], self._expr_to_sqlalchemy[expr._input]) else: sa_expr = SALog('log', None, self._expr_to_sqlalchemy[expr._input]) elif expr.node_name == 'Log2': sa_expr = SALog('log', 2, 2, self._expr_to_sqlalchemy[expr._input]) sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype)) elif expr.node_name == 'Log10': sa_expr = SALog('log', 10, 10, self._expr_to_sqlalchemy[expr._input]) sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype)) elif expr.node_name == 'Trunc': input = self._expr_to_sqlalchemy[expr._input] decimals = 0 if expr._decimals is None else self._expr_to_sqlalchemy[expr._decimals] sa_expr = SATruncate('trunc', input, decimals) elif expr.node_name == 'Round': decimals = 0 if expr._decimals is None else self._expr_to_sqlalchemy[expr._decimals] sa_expr = func.round(self._expr_to_sqlalchemy[expr._input], decimals) else: raise NotImplementedError self._add(expr, sa_expr) def visit_string_op(self, expr): if isinstance(expr, strings.Capitalize): input = self._expr_to_sqlalchemy[expr._input] tp = types.df_type_to_sqlalchemy_type(expr.dtype) sa_expr = func.upper(func.substr(input, 1, 1)).cast(tp) + \ func.lower(func.substr(input, 2)).cast(tp) elif isinstance(expr, strings.Contains) and not expr.regex: sa_expr = self._expr_to_sqlalchemy[expr._input].contains( self._expr_to_sqlalchemy[expr._pat]) elif isinstance(expr, strings.Endswith): sa_expr = self._expr_to_sqlalchemy[expr._input].endswith( self._expr_to_sqlalchemy[expr._pat]) elif isinstance(expr, strings.Startswith): sa_expr = self._expr_to_sqlalchemy[expr._input].startswith( self._expr_to_sqlalchemy[expr._pat]) elif isinstance(expr, strings.Replace) and not expr.regex: sa_expr = func.replace(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._pat], self._expr_to_sqlalchemy[expr._repl]) elif isinstance(expr, strings.Get): sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._index] + 1, 1) elif isinstance(expr, strings.Len): sa_expr = func.length(self._expr_to_sqlalchemy[expr._input]) elif isinstance(expr, (strings.Ljust, strings.Rjust, strings.Pad)): if isinstance(expr, strings.Pad): if expr.side == 'both': raise NotImplementedError op = func.lpad if expr.side == 'left' else func.rpad else: op = func.lpad if isinstance(expr, strings.Ljust) else func.rpad sa_expr = op(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._width], self._expr_to_sqlalchemy[expr._fillchar]) elif isinstance(expr, (strings.Lower, strings.Upper)): op = func.lower if isinstance(expr, strings.Lower) else func.upper sa_expr = op(self._expr_to_sqlalchemy[expr._input]) elif isinstance(expr, (strings.Lstrip, strings.Rstrip, strings.Strip)): if expr._to_strip is None: raise NotImplementedError op = func.ltrim if isinstance(expr, strings.Lstrip) else ( func.rtrim if isinstance(expr, strings.Rstrip) else func.btrim ) sa_expr = op(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._to_strip]) elif isinstance(expr, strings.Repeat): sa_expr = func.repeat(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._repeats]) elif isinstance(expr, strings.Slice): if expr.end is None and expr.step is None: sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input], self._expr_to_sqlalchemy[expr._start] + 1) elif isinstance(expr.start, six.integer_types) and \ isinstance(expr.end, six.integer_types) and \ expr.step is None and expr.start > 0 and expr.end > 0: length = expr.end - expr.start sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input], expr.start + 1, length) else: raise NotImplementedError elif isinstance(expr, strings.Title): sa_expr = func.initcap(self._expr_to_sqlalchemy[expr._input]) else: raise NotImplementedError self._add(expr, sa_expr) def visit_datetime_op(self, expr): class_name = type(expr).__name__ input = self._expr_to_sqlalchemy[expr._input] if class_name in DATE_PARTS_DIC: if self._sa_engine and self._sa_engine.name == 'mysql': if class_name == 'UnixTimestamp': fun = func.unix_timestamp else: fun = getattr(func, class_name.lower()) sa_expr = fun(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) else: sa_expr = func.date_part(DATE_PARTS_DIC[class_name], input)\ .cast(types.df_type_to_sqlalchemy_type(expr.dtype)) elif isinstance(expr, Date): if self._sa_engine and self._sa_engine.name == 'mysql': sa_expr = func.date(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) else: sa_expr = func.date_trunc('day', input) elif isinstance(expr, WeekDay): if self._sa_engine and self._sa_engine.name == 'mysql': sa_expr = (func.dayofweek(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) + 5) % 7 else: sa_expr = (func.date_part('dow', input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) + 6) % 7 else: raise NotImplementedError self._add(expr, sa_expr) def visit_groupby(self, expr): bys, having, aggs, fields = tuple(expr.args[1:]) if fields is None: fields = bys + aggs selects = select(self._gen_select_columns(fields)) if len(fields) == 1 and isinstance(fields[0], (Count, GroupedCount)): selects = selects.select_from(self._expr_to_sqlalchemy[fields[0].input]) grouped = selects.group_by(*self._gen_select_columns(bys)) if having: grouped = grouped.having(self._expr_to_sqlalchemy[having]) if expr is not self._expr_dag.root: grouped = grouped.alias(self._new_alias()) self._add(expr, grouped) def visit_mutate(self, expr): bys, mutates, fields = tuple(expr.args[1:]) if fields is None: fields = bys + mutates selects = select(self._gen_select_columns(fields)) if expr is not self._expr_dag.root: selects = selects.alias(self._new_alias()) self._add(expr, selects) def visit_sort_column(self, expr): if isinstance(expr.input, CollectionExpr): sa_expr = self._expr_to_sqlalchemy[expr.input].c[expr.source_name] else: sa_expr = self._expr_to_sqlalchemy[expr.input] if not expr._ascending: sa_expr = desc(sa_expr) self._add(expr, sa_expr) def visit_sort(self, expr): input = self._expr_to_sqlalchemy[expr.input] sa_expr = input.select(order_by=[self._expr_to_sqlalchemy[e] for e in expr._sorted_fields]) if expr is not self._expr_dag.root: sa_expr = sa_expr.alias(self._new_alias()) self._add(expr, sa_expr) def visit_distinct(self, expr): sa_expr = select(self._gen_select_columns(expr._unique_fields), distinct=True) if expr is not self._expr_dag.root: sa_expr = sa_expr.alias(self._new_alias()) self._add(expr, sa_expr) def visit_column(self, expr): table = self._expr_to_sqlalchemy[expr.input] col = table.c[expr.source_name] if expr._source_data_type != expr._data_type: col = col.cast(types.df_type_to_sqlalchemy_type(expr._data_type)) self._add(expr, col) def visit_reduction(self, expr): if getattr(expr, '_unique', False): raise NotImplementedError input = self._expr_to_sqlalchemy[expr.input] # TODO: MEDIAN does not support if isinstance(expr, (Max, GroupedMax)): f = func.max elif isinstance(expr, (Min, GroupedMin)): f = func.min elif isinstance(expr, (Count, GroupedCount)): f = func.count elif isinstance(expr, (Sum, GroupedSum)): f = func.sum elif isinstance(expr, (Var, GroupedVar)) and expr._ddof in (0, 1): f = func.var_pop if expr._ddof == 0 else func.var_samp elif isinstance(expr, (Std, GroupedStd)) and expr._ddof in (0, 1): f = func.stddev_pop if expr._ddof == 0 else func.stddev_samp elif isinstance(expr, (Mean, GroupedMean)): f = func.avg elif isinstance(expr, (NUnique, GroupedNUnique)): f = lambda *x: func.count(distinct(*x)) elif isinstance(expr, (Cat, GroupedCat)): f = lambda x: func.array_to_string(func.array_agg(x), self._expr_to_sqlalchemy[expr._sep]) else: raise NotImplementedError if isinstance(expr, (Count, GroupedCount)) and \ isinstance(expr.input, CollectionExpr): reduced = f() elif isinstance(expr, (NUnique, GroupedNUnique)): if len(expr.inputs) > 1: raise NotImplementedError reduced = f(*(self._expr_to_sqlalchemy[i] for i in expr.inputs)) else: reduced = f(input) self._add(expr, reduced) def visit_cum_window(self, expr): input = self._expr_to_sqlalchemy[expr._input] if expr._distinct.value is True: raise NotImplementedError try: func_name = WINDOW_COMPILE_DIC[expr.node_name] except KeyError: raise NotImplementedError f = getattr(func, func_name) partition_by = self._gen_select_columns(expr._partition_by) \ if expr._partition_by else None order_by = self._gen_select_columns(expr._order_by) \ if expr._order_by else None rows = (self._expr_to_sqlalchemy[expr._preceding] if expr._preceding else None, self._expr_to_sqlalchemy[expr._following] if expr._following else None) rows = None if all(r is None for r in rows) else rows sa_expr = f(input).over(partition_by=partition_by, order_by=order_by, rows=rows) self._add(expr, sa_expr) def visit_rank_window(self, expr): try: func_name = WINDOW_COMPILE_DIC[expr.node_name] except KeyError: raise NotImplementedError f = getattr(func, func_name) partition_by = self._gen_select_columns(expr._partition_by) \ if expr._partition_by else None order_by = self._gen_select_columns(expr._order_by) \ if expr._order_by else None sa_expr = f().over(partition_by=partition_by, order_by=order_by) if isinstance(expr, PercentRank): sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype)) self._add(expr, sa_expr) def visit_shift_window(self, expr): input = self._expr_to_sqlalchemy[expr._input] try: func_name = WINDOW_COMPILE_DIC[expr.node_name] except KeyError: raise NotImplementedError f = getattr(func, func_name) partition_by = self._gen_select_columns(expr._partition_by) \ if expr._partition_by else None order_by = self._gen_select_columns(expr._order_by) \ if expr._order_by else None args = (input, self._expr_to_sqlalchemy[expr._offset]) if expr._default: args += (literal(self._expr_to_sqlalchemy[expr._default]).cast( types.df_type_to_sqlalchemy_type(expr._input.dtype)),) sa_expr = f(*args).over(partition_by=partition_by, order_by=order_by) self._add(expr, sa_expr) def visit_scalar(self, expr): if expr._value is not None: if expr.dtype == df_types.string: val = utils.to_str(expr.value) \ if isinstance(expr.value, six.text_type) else expr.value self._add(expr, val) return else: self._add(expr, expr._value) else: self._add(expr, None) def visit_cast(self, expr): to_type = types.df_type_to_sqlalchemy_type(expr.dtype) self._add(expr, self._expr_to_sqlalchemy[expr.input].cast(to_type)) def visit_join(self, expr): lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs] if isinstance(expr, RightJoin): lhs, rhs = rhs, lhs on = self._expr_to_sqlalchemy[expr._predicate] kw = dict() if isinstance(expr, OuterJoin): kw['full'] = True elif isinstance(expr, (LeftJoin, RightJoin)): kw['isouter'] = True joined = join(lhs, rhs, onclause=on, **kw) self._add(expr, joined) def visit_union(self, expr): lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs] if is_source_collection(expr._lhs): lhs = select([lhs]) elif isinstance(lhs, Alias): lhs = lhs.element if is_source_collection(expr._rhs): rhs = select([rhs]) elif isinstance(rhs, Alias): rhs = rhs.element method = union if expr._distinct else union_all unioned = method(lhs, rhs) if expr is not self._expr_dag.root: unioned = unioned.alias(self._new_alias()) self._add(expr, unioned)