odps/df/backends/sqlalchemy/compiler.py (555 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from datetime import timedelta
from ..core import Backend
from ..errors import CompileError
from ...expr.reduction import *
from ...expr.arithmetic import *
from ...expr.datetimes import *
from ...expr.window import *
from ...expr.merge import *
from ...expr.utils import highest_precedence_data_type
from ... import types as df_types
from ...utils import is_constant_scalar, traverse_until_source, is_source_collection
from ....compat import lzip
from .... import utils
from . import types
try:
from sqlalchemy import Table as SATable, select, func, case, \
extract, desc, distinct, literal, join, union, union_all
from sqlalchemy.sql.expression import Alias, text
# define the compile ext
from .ext import *
has_sqlalchemy = True
except ImportError:
has_sqlalchemy = False
BINARY_OP = {
'Add': operator.add,
'Substract': operator.sub,
'Multiply': operator.mul,
'Divide': operator.div if six.PY2 else operator.truediv,
'Mod': operator.mod,
'FloorDivide': operator.floordiv,
'Greater': operator.gt,
'GreaterEqual': operator.ge,
'Less': operator.lt,
'LessEqual': operator.le,
'Equal': operator.eq,
'NotEqual': operator.ne,
'And': operator.and_,
'Or': operator.or_
}
UNARY_OP = {
'Negate': operator.neg,
'Invert': operator.inv,
}
DATE_KEY_DIC = {
'Day': 'days',
'Hour': 'hours',
'Minute': 'minutes',
'Second': 'seconds',
'MilliSecond': 'milliseconds',
'MicroSecond': 'microseconds',
}
MATH_COMPILE_DIC = {
'Abs': 'abs',
'Sqrt': 'sqrt',
'Sin': 'sin',
'Cos': 'cos',
'Tan': 'tan',
'Exp': 'exp',
'Arccos': 'acos',
'Arcsin': 'asin',
'Arctan': 'atan',
'Ceil': 'ceil',
'Floor': 'floor'
}
WINDOW_COMPILE_DIC = {
'CumSum': 'sum',
'CumMean': 'avg',
'CumStd': 'stddev',
'CumMax': 'max',
'CumMin': 'min',
'CumCount': 'count',
'Lag': 'lag',
'Lead': 'lead',
'Rank': 'rank',
'DenseRank': 'dense_rank',
'PercentRank': 'percent_rank',
'RowNumber': 'row_number'
}
DATE_PARTS_DIC = {
'Year': 'year',
'Month': 'month',
'Day': 'day',
'Hour': 'hour',
'Minute': 'minute',
'Second': 'second',
'WeekOfYear': 'week',
'DayOfYear': 'doy',
'MicroSecond': 'microseconds',
'UnixTimestamp': 'epoch',
}
class SQLAlchemyCompiler(Backend):
def __init__(self, expr_dag):
self._expr_dag = expr_dag
self._expr_to_sqlalchemy = ExprDictionary()
self._id_gen = itertools.count(1)
self._sa_engine = None
def _new_alias(self):
return 't%s' % next(self._id_gen)
def compile(self, expr, traversed=None):
if traversed is None:
traversed = set()
for node in traverse_until_source(expr):
if id(node) not in traversed:
node.accept(self)
traversed.add(id(node))
sa_expr = self._expr_to_sqlalchemy[self._expr_dag.root]
if is_source_collection(expr):
sa_expr = select([sa_expr])
return sa_expr
def _add(self, expr, op):
self._expr_to_sqlalchemy[expr] = op
def _gen_select_columns(self, fields):
sa_exprs = []
for field in fields:
if not isinstance(field, Scalar) or field._value is None:
sa_exprs.append(self._expr_to_sqlalchemy[field].label(field.name))
else:
sa_exprs.append(
literal(self._expr_to_sqlalchemy[field]).label(field.name))
return sa_exprs
def visit_source_collection(self, expr):
table = next(expr.data_source())
if not isinstance(table, SATable):
raise ValueError('Source data must be a sqlalchemy table')
if table.bind and self._sa_engine is None:
self._sa_engine = table.bind
self._add(expr, table.alias(self._new_alias()))
def visit_project_collection(self, expr):
selects = select(self._gen_select_columns(expr._fields))\
.select_from(self._expr_to_sqlalchemy[expr.input])
if expr is not self._expr_dag.root:
selects = selects.alias(self._new_alias())
self._add(expr, selects)
def visit_apply_collection(self, expr):
raise NotImplementedError
def visit_filter_collection(self, expr):
input = self._expr_to_sqlalchemy[expr.input]
predicate = self._expr_to_sqlalchemy[expr._predicate]
filtered = input.select(predicate)
if expr is not self._expr_dag.root:
filtered = filtered.alias(self._new_alias())
self._add(expr, filtered)
def visit_slice_collection(self, expr):
input = self._expr_to_sqlalchemy[expr.input]
sliced = expr._indexes
if sliced[2] is not None:
raise NotImplementedError
if sliced[0] is not None and sliced[0].value < 0:
raise CompileError('start number must be greater than 0')
if sliced[1] is not None and sliced[1].value <= 0:
raise CompileError('end number must be greater than 0')
kw = dict()
if sliced[0] is not None and sliced[0].value > 0:
kw['offset'] = sliced[0].value
if sliced[1] is not None and sliced[1].value > 0:
kw['limit'] = sliced[1].value
input = input.select(**kw)
if expr is not self._expr_dag.root:
input = input.alias(self._new_alias())
self._add(expr, input)
def visit_element_op(self, expr):
input = self._expr_to_sqlalchemy.get(expr.input)
if isinstance(expr, element.IsNull):
sa_expr = input.is_(None)
elif isinstance(expr, element.NotNull):
sa_expr = input.isnot(None)
elif isinstance(expr, element.FillNa):
sa_expr = case([(input.is_(None), expr.fill_value)], else_=input)
elif isinstance(expr, (element.IsIn, element.NotIn)):
op = input.in_ if isinstance(expr, element.IsIn) else input.notin_
if expr._values is None:
sa_expr = op([None])
elif len(expr._values) == 1 and isinstance(expr._values[0], SequenceExpr):
right = select([self._expr_to_sqlalchemy[expr._values[0]]])
sa_expr = op(right)
else:
sa_expr = op(tuple(self._expr_to_sqlalchemy[it] for it in expr._values))
elif isinstance(expr, element.Between):
if not expr.inclusive:
raise NotImplementedError
sa_expr = input.between(
self._expr_to_sqlalchemy[expr._left],
self._expr_to_sqlalchemy[expr._right]
)
elif isinstance(expr, element.IfElse):
sa_expr = case([(input, self._expr_to_sqlalchemy[expr._then])],
else_=self._expr_to_sqlalchemy[expr._else])
elif isinstance(expr, element.Switch):
conditions = [self._expr_to_sqlalchemy[cond] for cond in expr._conditions]
thens = [self._expr_to_sqlalchemy[then] for then in expr._thens]
sa_else = self._expr_to_sqlalchemy[expr._default] \
if expr._default is not None else expr._default
if expr._input is None:
sa_expr = case(lzip(conditions, thens), else_=sa_else)
else:
sa_expr = case(dict(lzip(conditions, thens)),
value=input, else_=sa_else)
else:
raise NotImplementedError
self._add(expr, sa_expr)
def visit_binary_op(self, expr):
if isinstance(expr, Power):
op = func.pow
elif isinstance(expr, FloorDivide):
op = operator.div if six.PY2 else operator.truediv
elif isinstance(expr, (Add, Substract)) and expr.dtype == df_types.datetime:
if isinstance(expr, Add) and \
all(child.dtype == df_types.datetime for child in (expr.lhs, expr.rhs)):
raise CompileError('Cannot add two datetimes')
if isinstance(expr.rhs, DTScalar) or (isinstance(expr, Add) and expr.lhs, DTScalar):
if isinstance(expr.rhs, DTScalar):
dt, scalar = expr.lhs, expr.rhs
else:
dt, scalar = expr.rhs, expr.lhs
val = scalar.value
if isinstance(expr, Substract):
val = -val
dt_type = type(scalar).__name__[:-6]
sa_dt = self._expr_to_sqlalchemy[dt]
try:
key = DATE_KEY_DIC[dt_type]
except KeyError:
raise NotImplementedError
if self._sa_engine and self._sa_engine.name == 'mysql':
if dt_type == 'MilliSecond':
val, dt_type = val * 1000, 'MicroSecond'
sa_expr = func.date_add(sa_dt, text('interval %d %s' % (val, dt_type.lower())))
else:
sa_expr = sa_dt + timedelta(**{key: val})
self._add(expr, sa_expr)
return
else:
raise NotImplementedError
elif isinstance(expr, Substract) and expr._lhs.dtype == df_types.datetime and \
expr._rhs.dtype == df_types.datetime:
sa_expr = self._expr_to_sqlalchemy[expr._lhs] - self._expr_to_sqlalchemy[expr._rhs]
if self._sa_engine and self._sa_engine.name == 'mysql':
sa_expr = func.abs(func.microsecond(sa_expr)
.cast(types.df_type_to_sqlalchemy_type(expr.dtype))) / 1000
else:
sa_expr = func.abs(extract('MICROSECONDS', sa_expr)
.cast(types.df_type_to_sqlalchemy_type(expr.dtype))) / 1000
self._add(expr, sa_expr)
return
elif isinstance(expr, Mod):
lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs]
sa_expr = BINARY_OP[expr.node_name](lhs, rhs)
if not is_constant_scalar(expr._rhs):
sa_expr = case([(rhs > 0, func.abs(sa_expr))], else_=sa_expr)
elif expr._rhs.value > 0:
sa_expr = func.abs(sa_expr)
self._add(expr, sa_expr)
return
else:
op = BINARY_OP[expr.node_name]
lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs]
sa_expr = op(lhs, rhs)
self._add(expr, sa_expr)
def visit_unary_op(self, expr):
if isinstance(expr, Abs):
op = func.abs
else:
op = UNARY_OP[expr.node_name]
self._add(expr, op(self._expr_to_sqlalchemy[expr._input]))
def visit_math(self, expr):
try:
op = getattr(func, MATH_COMPILE_DIC[expr.node_name])
sa_expr = op(self._expr_to_sqlalchemy[expr._input])
except KeyError:
if expr.node_name == 'Log':
if expr._base is not None:
sa_expr = SALog('log', self._expr_to_sqlalchemy[expr._base],
self._expr_to_sqlalchemy[expr._base],
self._expr_to_sqlalchemy[expr._input])
else:
sa_expr = SALog('log', None, self._expr_to_sqlalchemy[expr._input])
elif expr.node_name == 'Log2':
sa_expr = SALog('log', 2, 2, self._expr_to_sqlalchemy[expr._input])
sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype))
elif expr.node_name == 'Log10':
sa_expr = SALog('log', 10, 10, self._expr_to_sqlalchemy[expr._input])
sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype))
elif expr.node_name == 'Trunc':
input = self._expr_to_sqlalchemy[expr._input]
decimals = 0 if expr._decimals is None else self._expr_to_sqlalchemy[expr._decimals]
sa_expr = SATruncate('trunc', input, decimals)
elif expr.node_name == 'Round':
decimals = 0 if expr._decimals is None else self._expr_to_sqlalchemy[expr._decimals]
sa_expr = func.round(self._expr_to_sqlalchemy[expr._input], decimals)
else:
raise NotImplementedError
self._add(expr, sa_expr)
def visit_string_op(self, expr):
if isinstance(expr, strings.Capitalize):
input = self._expr_to_sqlalchemy[expr._input]
tp = types.df_type_to_sqlalchemy_type(expr.dtype)
sa_expr = func.upper(func.substr(input, 1, 1)).cast(tp) + \
func.lower(func.substr(input, 2)).cast(tp)
elif isinstance(expr, strings.Contains) and not expr.regex:
sa_expr = self._expr_to_sqlalchemy[expr._input].contains(
self._expr_to_sqlalchemy[expr._pat])
elif isinstance(expr, strings.Endswith):
sa_expr = self._expr_to_sqlalchemy[expr._input].endswith(
self._expr_to_sqlalchemy[expr._pat])
elif isinstance(expr, strings.Startswith):
sa_expr = self._expr_to_sqlalchemy[expr._input].startswith(
self._expr_to_sqlalchemy[expr._pat])
elif isinstance(expr, strings.Replace) and not expr.regex:
sa_expr = func.replace(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._pat],
self._expr_to_sqlalchemy[expr._repl])
elif isinstance(expr, strings.Get):
sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._index] + 1, 1)
elif isinstance(expr, strings.Len):
sa_expr = func.length(self._expr_to_sqlalchemy[expr._input])
elif isinstance(expr, (strings.Ljust, strings.Rjust, strings.Pad)):
if isinstance(expr, strings.Pad):
if expr.side == 'both':
raise NotImplementedError
op = func.lpad if expr.side == 'left' else func.rpad
else:
op = func.lpad if isinstance(expr, strings.Ljust) else func.rpad
sa_expr = op(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._width],
self._expr_to_sqlalchemy[expr._fillchar])
elif isinstance(expr, (strings.Lower, strings.Upper)):
op = func.lower if isinstance(expr, strings.Lower) else func.upper
sa_expr = op(self._expr_to_sqlalchemy[expr._input])
elif isinstance(expr, (strings.Lstrip, strings.Rstrip, strings.Strip)):
if expr._to_strip is None:
raise NotImplementedError
op = func.ltrim if isinstance(expr, strings.Lstrip) else (
func.rtrim if isinstance(expr, strings.Rstrip) else func.btrim
)
sa_expr = op(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._to_strip])
elif isinstance(expr, strings.Repeat):
sa_expr = func.repeat(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._repeats])
elif isinstance(expr, strings.Slice):
if expr.end is None and expr.step is None:
sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input],
self._expr_to_sqlalchemy[expr._start] + 1)
elif isinstance(expr.start, six.integer_types) and \
isinstance(expr.end, six.integer_types) and \
expr.step is None and expr.start > 0 and expr.end > 0:
length = expr.end - expr.start
sa_expr = func.substr(self._expr_to_sqlalchemy[expr._input],
expr.start + 1, length)
else:
raise NotImplementedError
elif isinstance(expr, strings.Title):
sa_expr = func.initcap(self._expr_to_sqlalchemy[expr._input])
else:
raise NotImplementedError
self._add(expr, sa_expr)
def visit_datetime_op(self, expr):
class_name = type(expr).__name__
input = self._expr_to_sqlalchemy[expr._input]
if class_name in DATE_PARTS_DIC:
if self._sa_engine and self._sa_engine.name == 'mysql':
if class_name == 'UnixTimestamp':
fun = func.unix_timestamp
else:
fun = getattr(func, class_name.lower())
sa_expr = fun(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype))
else:
sa_expr = func.date_part(DATE_PARTS_DIC[class_name], input)\
.cast(types.df_type_to_sqlalchemy_type(expr.dtype))
elif isinstance(expr, Date):
if self._sa_engine and self._sa_engine.name == 'mysql':
sa_expr = func.date(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype))
else:
sa_expr = func.date_trunc('day', input)
elif isinstance(expr, WeekDay):
if self._sa_engine and self._sa_engine.name == 'mysql':
sa_expr = (func.dayofweek(input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) + 5) % 7
else:
sa_expr = (func.date_part('dow', input).cast(types.df_type_to_sqlalchemy_type(expr.dtype)) + 6) % 7
else:
raise NotImplementedError
self._add(expr, sa_expr)
def visit_groupby(self, expr):
bys, having, aggs, fields = tuple(expr.args[1:])
if fields is None:
fields = bys + aggs
selects = select(self._gen_select_columns(fields))
if len(fields) == 1 and isinstance(fields[0], (Count, GroupedCount)):
selects = selects.select_from(self._expr_to_sqlalchemy[fields[0].input])
grouped = selects.group_by(*self._gen_select_columns(bys))
if having:
grouped = grouped.having(self._expr_to_sqlalchemy[having])
if expr is not self._expr_dag.root:
grouped = grouped.alias(self._new_alias())
self._add(expr, grouped)
def visit_mutate(self, expr):
bys, mutates, fields = tuple(expr.args[1:])
if fields is None:
fields = bys + mutates
selects = select(self._gen_select_columns(fields))
if expr is not self._expr_dag.root:
selects = selects.alias(self._new_alias())
self._add(expr, selects)
def visit_sort_column(self, expr):
if isinstance(expr.input, CollectionExpr):
sa_expr = self._expr_to_sqlalchemy[expr.input].c[expr.source_name]
else:
sa_expr = self._expr_to_sqlalchemy[expr.input]
if not expr._ascending:
sa_expr = desc(sa_expr)
self._add(expr, sa_expr)
def visit_sort(self, expr):
input = self._expr_to_sqlalchemy[expr.input]
sa_expr = input.select(order_by=[self._expr_to_sqlalchemy[e]
for e in expr._sorted_fields])
if expr is not self._expr_dag.root:
sa_expr = sa_expr.alias(self._new_alias())
self._add(expr, sa_expr)
def visit_distinct(self, expr):
sa_expr = select(self._gen_select_columns(expr._unique_fields), distinct=True)
if expr is not self._expr_dag.root:
sa_expr = sa_expr.alias(self._new_alias())
self._add(expr, sa_expr)
def visit_column(self, expr):
table = self._expr_to_sqlalchemy[expr.input]
col = table.c[expr.source_name]
if expr._source_data_type != expr._data_type:
col = col.cast(types.df_type_to_sqlalchemy_type(expr._data_type))
self._add(expr, col)
def visit_reduction(self, expr):
if getattr(expr, '_unique', False):
raise NotImplementedError
input = self._expr_to_sqlalchemy[expr.input]
# TODO: MEDIAN does not support
if isinstance(expr, (Max, GroupedMax)):
f = func.max
elif isinstance(expr, (Min, GroupedMin)):
f = func.min
elif isinstance(expr, (Count, GroupedCount)):
f = func.count
elif isinstance(expr, (Sum, GroupedSum)):
f = func.sum
elif isinstance(expr, (Var, GroupedVar)) and expr._ddof in (0, 1):
f = func.var_pop if expr._ddof == 0 else func.var_samp
elif isinstance(expr, (Std, GroupedStd)) and expr._ddof in (0, 1):
f = func.stddev_pop if expr._ddof == 0 else func.stddev_samp
elif isinstance(expr, (Mean, GroupedMean)):
f = func.avg
elif isinstance(expr, (NUnique, GroupedNUnique)):
f = lambda *x: func.count(distinct(*x))
elif isinstance(expr, (Cat, GroupedCat)):
f = lambda x: func.array_to_string(func.array_agg(x),
self._expr_to_sqlalchemy[expr._sep])
else:
raise NotImplementedError
if isinstance(expr, (Count, GroupedCount)) and \
isinstance(expr.input, CollectionExpr):
reduced = f()
elif isinstance(expr, (NUnique, GroupedNUnique)):
if len(expr.inputs) > 1:
raise NotImplementedError
reduced = f(*(self._expr_to_sqlalchemy[i] for i in expr.inputs))
else:
reduced = f(input)
self._add(expr, reduced)
def visit_cum_window(self, expr):
input = self._expr_to_sqlalchemy[expr._input]
if expr._distinct.value is True:
raise NotImplementedError
try:
func_name = WINDOW_COMPILE_DIC[expr.node_name]
except KeyError:
raise NotImplementedError
f = getattr(func, func_name)
partition_by = self._gen_select_columns(expr._partition_by) \
if expr._partition_by else None
order_by = self._gen_select_columns(expr._order_by) \
if expr._order_by else None
rows = (self._expr_to_sqlalchemy[expr._preceding] if expr._preceding else None,
self._expr_to_sqlalchemy[expr._following] if expr._following else None)
rows = None if all(r is None for r in rows) else rows
sa_expr = f(input).over(partition_by=partition_by, order_by=order_by, rows=rows)
self._add(expr, sa_expr)
def visit_rank_window(self, expr):
try:
func_name = WINDOW_COMPILE_DIC[expr.node_name]
except KeyError:
raise NotImplementedError
f = getattr(func, func_name)
partition_by = self._gen_select_columns(expr._partition_by) \
if expr._partition_by else None
order_by = self._gen_select_columns(expr._order_by) \
if expr._order_by else None
sa_expr = f().over(partition_by=partition_by, order_by=order_by)
if isinstance(expr, PercentRank):
sa_expr = sa_expr.cast(types.df_type_to_sqlalchemy_type(expr.dtype))
self._add(expr, sa_expr)
def visit_shift_window(self, expr):
input = self._expr_to_sqlalchemy[expr._input]
try:
func_name = WINDOW_COMPILE_DIC[expr.node_name]
except KeyError:
raise NotImplementedError
f = getattr(func, func_name)
partition_by = self._gen_select_columns(expr._partition_by) \
if expr._partition_by else None
order_by = self._gen_select_columns(expr._order_by) \
if expr._order_by else None
args = (input, self._expr_to_sqlalchemy[expr._offset])
if expr._default:
args += (literal(self._expr_to_sqlalchemy[expr._default]).cast(
types.df_type_to_sqlalchemy_type(expr._input.dtype)),)
sa_expr = f(*args).over(partition_by=partition_by, order_by=order_by)
self._add(expr, sa_expr)
def visit_scalar(self, expr):
if expr._value is not None:
if expr.dtype == df_types.string:
val = utils.to_str(expr.value) \
if isinstance(expr.value, six.text_type) else expr.value
self._add(expr, val)
return
else:
self._add(expr, expr._value)
else:
self._add(expr, None)
def visit_cast(self, expr):
to_type = types.df_type_to_sqlalchemy_type(expr.dtype)
self._add(expr, self._expr_to_sqlalchemy[expr.input].cast(to_type))
def visit_join(self, expr):
lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs]
if isinstance(expr, RightJoin):
lhs, rhs = rhs, lhs
on = self._expr_to_sqlalchemy[expr._predicate]
kw = dict()
if isinstance(expr, OuterJoin):
kw['full'] = True
elif isinstance(expr, (LeftJoin, RightJoin)):
kw['isouter'] = True
joined = join(lhs, rhs, onclause=on, **kw)
self._add(expr, joined)
def visit_union(self, expr):
lhs, rhs = self._expr_to_sqlalchemy[expr._lhs], self._expr_to_sqlalchemy[expr._rhs]
if is_source_collection(expr._lhs):
lhs = select([lhs])
elif isinstance(lhs, Alias):
lhs = lhs.element
if is_source_collection(expr._rhs):
rhs = select([rhs])
elif isinstance(rhs, Alias):
rhs = rhs.element
method = union if expr._distinct else union_all
unioned = method(lhs, rhs)
if expr is not self._expr_dag.root:
unioned = unioned.alias(self._new_alias())
self._add(expr, unioned)