odps/df/backends/odpssql/compiler.py (1,212 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from datetime import date, datetime
from decimal import Decimal
from ...expr.reduction import *
from ...expr.arithmetic import BinOp, Add, Substract, Power, Invert, Negate, Abs
from ...expr.merge import JoinCollectionExpr, UnionCollectionExpr
from ...expr.element import MappedExpr, Func
from ...expr.window import CumSum, NthValue, QCut
from ...expr.datetimes import DTScalar
from ...expr.collections import RowAppliedCollectionExpr
from ...expr import element, strings, datetimes, composites
from ...utils import is_source_collection, traverse_until_source
from ... import types as df_types
from . import types
from .models import MemCacheReference
from ..core import Backend
from .... import utils
from ....models import Function
from ..errors import CompileError
BINARY_OP_COMPILE_DIC = {
'Add': '+',
'Substract': '-',
'Multiply': '*',
'Divide': '/',
'Greater': '>',
'GreaterEqual': '>=',
'Less': '<',
'LessEqual': '<=',
'Equal': '==',
'NotEqual': '!=',
'And': 'and',
'Or': 'or'
}
UNARY_OP_COMPILE_DIC = {
'Negate': '-'
}
WINDOW_COMPILE_DIC = {
'CumSum': 'sum',
'CumMean': 'avg',
'CumMedian': 'median',
'CumStd': 'stddev',
'CumMax': 'max',
'CumMin': 'min',
'CumCount': 'count',
'NthValue': 'nth_value',
'Lag': 'lag',
'Lead': 'lead',
'Rank': 'rank',
'DenseRank': 'dense_rank',
'PercentRank': 'percent_rank',
'RowNumber': 'row_number',
'QCut': 'ntile',
'CumeDist': 'cume_dist',
}
MATH_COMPILE_DIC = {
'Abs': 'abs',
'Sqrt': 'sqrt',
'Sin': 'sin',
'Sinh': 'sinh',
'Cos': 'cos',
'Cosh': 'cosh',
'Tan': 'tan',
'Tanh': 'tanh',
'Exp': 'exp',
'Arccos': 'acos',
'Arcsin': 'asin',
'Arctan': 'atan',
'Ceil': 'ceil',
'Floor': 'floor'
}
DATE_PARTS_DIC = {
'Year': 'yyyy',
'Month': 'mm',
'Day': 'dd',
'Hour': 'hh',
'Minute': 'mi',
'Second': 'ss'
}
class OdpsSQLCompiler(Backend):
"""
OdpsSQLCompiler will compile an Expr into an ODPS SQL.
"""
def __init__(self, ctx, indent_size=2, beautify=False):
self._ctx = ctx
self._indent_size = indent_size
self._beautify = beautify
# use for `join` or `union` operations etc.
self._sub_compiles = defaultdict(lambda: list())
self._union_no_alias = dict()
# When encountering `join` or `union`, we will try to compile all child branches,
# for each nodes of these branches, we should not check the uniqueness,
# when compilation finishes, we substitute the children of `join` or `union` with None,
# so the upcoming compilation will not visit its children.
# When everything are done, we use the callbacks to substitute back the original children
# of the `join` or `union` node.
self._callbacks = list()
# store the expr ids to do mem cache
self._mem_ref_caches = set()
self._re_init()
def _re_init(self):
self._select_clause = None
self._from_clause = None
self._where_clause = None
self._group_by_clause = None
self._having_clause = None
self._order_by_clause = None
self._table_sample_clause = None
self._limit = None
def _cleanup(self):
self._sub_compiles = defaultdict(lambda: list())
for callback in self._callbacks:
callback()
self._callbacks = list()
@classmethod
def _need_recursive_handle_in_expr(cls, node):
return isinstance(node, (element.IsIn, element.NotIn)) and \
not all(n is None for n in node.args) and \
isinstance(node.values[0], SequenceExpr)
@classmethod
def _retrieve_until_find_root(cls, expr):
for node in traverse_until_source(expr, top_down=True, unique=True):
if isinstance(node, (JoinCollectionExpr, UnionCollectionExpr, LateralViewCollectionExpr)) and \
not all(n is None for n in node.args):
yield node
else:
ins = [n for n in node.children() if cls._need_recursive_handle_in_expr(n)]
if len(ins) > 0:
for n in ins:
yield n
def _compile_union_node(self, expr, traversed):
if isinstance(expr.lhs, UnionCollectionExpr):
self._union_no_alias[expr.lhs] = True
compiled = self._compile(expr.lhs)
self._sub_compiles[expr].append(compiled)
compiled = self._compile(expr.rhs)
self._sub_compiles[expr].append(compiled)
args = expr.args
self._ctx._expr_raw_args[id(expr)] = args
for arg_name in expr._args:
setattr(expr, arg_name, None)
def cb():
for arg_name, arg in zip(expr._args, args):
setattr(expr, arg_name, arg)
self._callbacks.append(cb)
def _compile_join_node(self, expr, traversed):
travs = set()
compiled, trav = self._compile(expr.lhs, return_traversed=True)
travs.update(trav)
if not is_source_collection(expr.lhs) and not isinstance(expr.lhs, JoinCollectionExpr):
self._sub_compiles[expr].append(
'(\n{0}\n) {1}'.format(utils.indent(compiled, self._indent_size),
self._ctx.get_collection_alias(expr.lhs, create=True)[0])
)
else:
self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.lhs))
compiled, trav = self._compile(expr.rhs, return_traversed=True)
travs.update(trav)
if not is_source_collection(expr.rhs):
self._sub_compiles[expr].append(
'(\n{0}\n) {1}'.format(utils.indent(compiled, self._indent_size),
self._ctx.get_collection_alias(expr.rhs, create=True)[0])
)
else:
self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.rhs))
if expr.predicate is None:
self._sub_compiles[expr].append(None)
traversed.update(travs)
else:
self._compile(expr.predicate, traversed)
self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.predicate))
if expr._mapjoin:
self._ctx._mapjoin_hints.append(self._ctx.get_collection_alias(expr.rhs)[0])
if expr._skewjoin:
skewjoin_expr = self._ctx.get_collection_alias(expr.rhs)[0]
if isinstance(expr._skewjoin, list):
skewjoin_expr += '({0})'.format(', '.join(expr._skewjoin))
if expr._skewjoin_values:
skewjoin_expr += '({0})'.format(
', '.join(
'({0})'.format(', '.join(repr(s) for s in values))
for values in expr._skewjoin_values
)
)
self._ctx._skewjoin_hints.append(skewjoin_expr)
args = expr.args
self._ctx._expr_raw_args[id(expr)] = args
for arg_name in expr._args:
setattr(expr, arg_name, None)
def cb():
for arg_name, arg in zip(expr._args, args):
setattr(expr, arg_name, arg)
self._callbacks.append(cb)
@classmethod
def _find_table(cls, expr):
return next(it for it in expr.traverse(top_down=True, unique=True)
if isinstance(it, CollectionExpr))
def _compile_in_node(self, expr, traversed):
self._compile(expr.input)
self._sub_compiles[expr].append(self._ctx.get_expr_compiled(expr.input))
to_sub = self._find_table(expr.values[0])[[expr.values[0], ]]
compiled = self._compile(to_sub)
self._sub_compiles[expr].append(compiled)
args = expr.args
self._ctx._expr_raw_args[id(expr)] = args
for arg_name in expr._args:
setattr(expr, arg_name, None)
def cb():
for arg_name, arg in zip(expr._args, args):
setattr(expr, arg_name, arg)
self._callbacks.append(cb)
def _compile_lateral_view(self, expr, traversed):
travs = set()
compiled_input, trav = self._compile(expr.input, return_traversed=True)
travs.update(trav)
if not is_source_collection(expr.input) and not isinstance(expr.input, LateralViewCollectionExpr):
compiled_input = '(\n{0}\n) {1}'.format(utils.indent(compiled_input, self._indent_size),
self._ctx.get_collection_alias(expr.input, create=True)[0])
else:
compiled_input = self._ctx.get_expr_compiled(expr.input)
from_lines = [compiled_input]
for lview in expr.lateral_views:
self._compile(lview)
from_lines.append(self._ctx.get_expr_compiled(lview))
from_clause = ' \n'.join(from_lines)
traversed.update(travs)
self._sub_compiles[expr].append(from_clause)
args = expr.args
self._ctx._expr_raw_args[id(expr)] = args
for arg_name in expr._args:
setattr(expr, arg_name, None)
def cb():
for arg_name, arg in zip(expr._args, args):
setattr(expr, arg_name, arg)
self._callbacks.append(cb)
def _compile(self, expr, traversed=None,
return_traversed=False, root_expr=None):
roots = self._retrieve_until_find_root(expr)
if traversed is None:
traversed = set()
for root in roots:
if root is not None:
if isinstance(root, JoinCollectionExpr):
self._compile_join_node(root, traversed)
elif isinstance(root, UnionCollectionExpr):
self._compile_union_node(root, traversed)
elif isinstance(root, LateralViewCollectionExpr):
self._compile_lateral_view(root, traversed)
elif isinstance(root, (element.IsIn, element.NotIn)):
self._compile_in_node(root, traversed)
root.accept(self)
traversed.add(id(root))
for node in traverse_until_source(expr):
if id(node) not in traversed:
node.accept(self)
traversed.add(id(node))
if (
expr is root_expr
and self._select_clause is None
and (self._ctx._mapjoin_hints or self._ctx._skewjoin_hints)
):
self.add_select_clause(expr, '* ')
sql = self.to_sql().strip()
if not return_traversed:
return sql
return sql, traversed
def compile(self, expr):
try:
sql = self._compile(expr, root_expr=expr)
symbol_columns = dict(self._ctx.get_all_need_alias_column_symbols())
sql = self._fill_back_columns(sql, symbol_columns)
sql = self._re_join_select_fields(sql, symbol_columns)
if self._mem_ref_caches:
dep_sqls = self._ctx.get_mem_cache_dep_sqls(*self._mem_ref_caches)
dep_sqls = [dep_sql if dep_sql.endswith(';') else dep_sql + ';'
for dep_sql in dep_sqls]
return dep_sqls + [sql,]
return sql
finally:
self._cleanup()
def to_sql(self):
if not self._from_clause.startswith('SELECT'):
lines = [
'SELECT {0} '.format(self._select_clause or '*'),
'FROM {0} '.format(self._from_clause),
]
else:
# special case due to `union`
lines = [self._from_clause]
# tablesample need to be placed right after from clause
if self._table_sample_clause:
lines.append('TABLESAMPLE {0}'.format(self._table_sample_clause))
if self._where_clause:
lines.append('WHERE {0} '.format(self._where_clause))
if self._group_by_clause:
lines.append(self._group_by_clause)
if self._having_clause:
lines.append('HAVING {0} '.format(self._having_clause))
if self._order_by_clause:
if self._order_by_clause.startswith('ORDER BY') and not self._limit and \
options.df.odps.sort.limit:
# for `order by`, limit is required
# for `sort by`, limit is unnecessary
self._limit = options.df.odps.sort.limit
lines.append(self._order_by_clause)
if self._limit is not None:
lines.append('LIMIT {0}'.format(self._limit))
self._re_init()
return '\n'.join(lines)
def _fill_back_columns(self, sql, symbols_to_columns):
symbol_compiled = dict()
for symbol, column in six.iteritems(symbols_to_columns):
try:
collection, name = self._retrieve_column_alias_collection(column)
except KeyError:
# sorted column relative, just ignore
continue
symbol_compiled[symbol] = '{0}.{1}'.format(
self._ctx.get_collection_alias(collection)[0],
self._quote(name))
sql = sql % symbol_compiled
reg = re.compile(r'###\[(col_\d+)\]##')
def repl(matched):
symbol = matched.group(1)
compiled = symbol_compiled[symbol]
compiled_name = self._unquote(compiled)
column = symbols_to_columns[symbol]
if compiled_name != column.name:
return ' AS {0}'.format(self._quote(column.name))
else:
return ''
return reg.sub(repl, sql)
def _re_join_select_fields(self, sql, symbols_to_columns):
if not self._beautify:
return sql
reg = re.compile(r'/{_i(\d+)}')
for select_idx in reg.findall(sql):
s_regex_str = '//{{_i{0}}}(.+?){{_i{0}}}//'.format(select_idx)
regex_str = '\n?( *)/{{_i{0}}}({1})+/'.format(select_idx, s_regex_str)
regex = re.compile(regex_str, (re.M | re.DOTALL))
s_regex = re.compile(s_regex_str, (re.M | re.DOTALL))
def repl(matched):
space = len(matched.group(1))
joined = matched.group()
fields = s_regex.findall(joined)
sub = self._join_compiled_fields(fields)
if not joined.startswith('\n'):
sub = sub.lstrip('\n')
if space > self._indent_size:
return utils.indent(sub, space-self._indent_size)
return sub
sql = regex.sub(repl, sql)
return sql
def _retrieve_column_alias_collection(self, expr):
column_name = expr.source_name
collection = expr.input
while True:
if isinstance(collection, JoinCollectionExpr):
idx, column_name = collection._column_origins[column_name]
args = self._ctx._expr_raw_args[id(collection)] # get the args which are substituted out
lhs, rhs = args[0], args[1]
collection = (lhs, rhs)[idx]
elif isinstance(collection, LateralViewCollectionExpr):
args = self._ctx._expr_raw_args[id(collection)] # get the args which are substituted out
while True:
for lv in args[2]:
if utils.to_lower_str(column_name) in lv.schema._name_indexes:
return lv, column_name
if isinstance(args[0], LateralViewCollectionExpr):
args = self._ctx._expr_raw_args[id(args[0])]
else:
return args[0], column_name
elif self._ctx.get_collection_alias(collection, silent=True):
return collection, column_name
elif isinstance(getattr(collection, 'input', None), CollectionExpr):
collection = collection.input
raise CompileError('Cannot find table alias for column: \n%s' % repr_obj(expr))
def sub_sql_to_from_clause(self, expr):
sql = self.to_sql()
alias, _ = self._ctx.get_collection_alias(expr, create=True)
from_clause = '(\n{0}\n) {1}'.format(
utils.indent(sql, self._indent_size), alias
)
self._re_init()
self._from_clause = from_clause
def add_select_clause(self, expr, select_clause):
if self._select_clause is not None:
self.sub_sql_to_from_clause(expr.input)
elif self._order_by_clause is not None:
self.sub_sql_to_from_clause(expr.input)
elif isinstance(expr, Summary) and self._limit is not None:
self.sub_sql_to_from_clause(expr.input)
elif isinstance(expr, (GroupByCollectionExpr, MutateCollectionExpr, RowAppliedCollectionExpr)) and \
self._limit is not None:
self.sub_sql_to_from_clause(expr.input)
elif isinstance(expr.input, ReshuffledCollectionExpr):
self.sub_sql_to_from_clause(expr.input)
self._select_clause = select_clause
join_hints = []
if self._ctx._mapjoin_hints:
join_hints.append(
'mapjoin({0})'.format(', '.join(self._ctx._mapjoin_hints))
)
self._ctx._mapjoin_hints = []
if self._ctx._skewjoin_hints:
join_hints.append(
'skewjoin({0})'.format(', '.join(self._ctx._skewjoin_hints))
)
self._ctx._skewjoin_hints = []
if join_hints:
self._select_clause = '/*+ {0} */ {1}'.format(
', '.join(join_hints), self._select_clause
)
def add_from_clause(self, expr, from_clause):
if self._from_clause is None:
self._from_clause = from_clause
def add_where_clause(self, expr, where_clause):
if any(clause is not None for clause in
(self._where_clause, self._select_clause, self._limit)):
self.sub_sql_to_from_clause(expr.input)
self._where_clause = where_clause
def add_group_by_clause(self, expr, group_by_clause):
if self._group_by_clause is not None:
self.sub_sql_to_from_clause(expr.input)
elif isinstance(expr, ReshuffledCollectionExpr) and \
self._select_clause is not None:
self.sub_sql_to_from_clause(expr.input)
self._group_by_clause = group_by_clause
def add_having_clause(self, expr, having_clause):
if self._having_clause is None:
self._having_clause = having_clause
assert having_clause == self._having_clause
def add_order_by_clause(self, expr, order_by_clause):
if self._order_by_clause is not None:
self.sub_sql_to_from_clause(expr.input)
self._order_by_clause = order_by_clause
def add_table_sample_clause(self, expr, table_sample_clause):
if any(clause is not None for clause in
(self._where_clause, self._table_sample_clause, self._group_by_clause,
self._having_clause, self._order_by_clause, self._limit)):
self.sub_sql_to_from_clause(expr.input)
self._table_sample_clause = table_sample_clause
def set_limit(self, expr, limit):
if self._limit is not None:
self.sub_sql_to_from_clause(expr.input)
self._limit = limit
def visit_source_collection(self, expr):
source_data = expr._source_data
alias = self._ctx.register_collection(expr)
if isinstance(source_data, MemCacheReference):
from_clause = '{0} {1}'.format(source_data.ref_name, alias)
self.add_from_clause(expr, from_clause)
self._ctx.add_expr_compiled(expr, from_clause)
self._mem_ref_caches.add(source_data.expr_id)
else:
table_parts = [source_data.project.name]
schema = source_data.get_schema()
if schema is not None:
table_parts.append(schema.name)
if options.df.quote:
table_parts.append("`%s`" % source_data.name)
else:
table_parts.append(source_data.name)
name = '.'.join(table_parts)
from_clause = '{0} {1}'.format(name, alias)
self.add_from_clause(expr, from_clause)
self._ctx.add_expr_compiled(expr, from_clause)
def _compile_select_field(self, field):
compiled = self._ctx.get_expr_compiled(field)
if not isinstance(field, Column) or field._source_data_type != field._data_type:
compiled = '{0} AS {1}'.format(compiled, self._quote(field.name))
else:
if compiled.startswith('%(') and compiled.endswith(')s'):
symbol = compiled[2:-2]
compiled = '{0}###[{1}]##'.format(compiled, symbol)
else:
compiled_name = self._unquote(compiled)
if field.name != compiled_name:
compiled = '{0} AS {1}'.format(compiled, self._quote(field.name))
return compiled
def _join_select_fields(self, fields):
if not self._beautify:
return ', '.join(fields)
else:
select_id = self._ctx.next_select_id()
def h(field):
return '//{{_i{0}}}{1}{{_i{0}}}//'.format(select_id, field)
return utils.indent('\n/{{_i{0}}}{1}/'.format(select_id, ''.join([h(f) for f in fields])),
self._indent_size)
def _join_compiled_fields(self, fields):
if not self._beautify:
return ', '.join(fields)
else:
buf = six.StringIO()
buf.write('\n')
split_fields = [field.rsplit(' AS ', 1) for field in fields]
get = lambda s: s if '\n' not in s else s.rsplit('\n', 1)[1]
max_length = max(len(get(f[0])) for f in split_fields)
for f in split_fields:
if len(f) > 1:
buf.write(f[0].ljust(max_length))
buf.write(' AS ')
buf.write(f[1])
else:
buf.write(f[0])
buf.write(',\n')
return utils.indent(buf.getvalue()[:-2], self._indent_size)
def visit_project_collection(self, expr):
fields = expr._fields
compiled_fields = [self._compile_select_field(field)
for field in fields]
compiled = self._join_select_fields(compiled_fields)
self._ctx.add_expr_compiled(expr, compiled)
self.add_select_clause(expr, compiled)
def visit_lateral_view(self, expr):
compiled = self._sub_compiles[expr][0]
self._ctx.add_expr_compiled(expr, compiled)
self.add_from_clause(expr, compiled)
def visit_filter_collection(self, expr):
predicate = expr.args[1]
compiled = self._ctx.get_expr_compiled(predicate)
self._ctx.add_expr_compiled(expr, compiled)
self.add_where_clause(expr, compiled)
def visit_filter_partition_collection(self, expr):
compiled = self._ctx.get_expr_compiled(expr._predicate)
self._ctx.add_expr_compiled(expr, compiled)
self.add_where_clause(expr, compiled)
compiled_fields = [self._compile_select_field(field)
for field in expr.fields]
compiled = self._join_select_fields(compiled_fields)
self._ctx.add_expr_compiled(expr, compiled)
self.add_select_clause(expr, compiled)
def visit_slice_collection(self, expr):
sliced = expr._indexes
if sliced[0] is not None:
raise NotImplementedError
if sliced[2] is not None:
raise NotImplementedError
if sliced[1].value <= 0:
raise CompileError('limit number must be greater than 0')
self.set_limit(expr, sliced[1].value)
def visit_element_op(self, expr):
if isinstance(expr, element.IsNull):
compiled = '{0} IS NULL'.format(self._parenthesis(expr.input))
elif isinstance(expr, element.NotNull):
compiled = '{0} IS NOT NULL'.format(self._parenthesis(expr.input))
elif isinstance(expr, element.FillNa):
compiled = 'IF(%(input)s IS NULL, %(value)s, %(else_value)s)' % {
'input': self._parenthesis(expr.input),
'value': self._ctx.get_expr_compiled(expr._fill_value),
'else_value': self._ctx.get_expr_compiled(expr.input),
}
elif isinstance(expr, element.IsIn):
if expr.values is not None:
compiled = '{0} IN ({1})'.format(
self._ctx.get_expr_compiled(expr.input),
', '.join(self._ctx.get_expr_compiled(it) for it in expr.values)
)
else:
subs = self._sub_compiles[expr]
compiled = '{0} IN ({1})'.format(
subs[0], subs[1].replace('\n', '')
)
elif isinstance(expr, element.NotIn):
if expr.values is not None:
compiled = '{0} NOT IN ({1})'.format(
self._ctx.get_expr_compiled(expr.input),
', '.join(self._ctx.get_expr_compiled(it) for it in expr.values)
)
else:
subs = self._sub_compiles[expr]
compiled = '{0} NOT IN ({1})'.format(
subs[0], subs[1].replace('\n', '')
)
elif isinstance(expr, element.IfElse):
compiled = 'IF({0}, {1}, {2})'.format(
self._ctx.get_expr_compiled(expr._input),
self._ctx.get_expr_compiled(expr._then),
self._ctx.get_expr_compiled(expr._else),
)
elif isinstance(expr, element.Switch):
case = self._ctx.get_expr_compiled(expr.case) + ' ' \
if expr.case is not None else ''
lines = ['CASE {0}'.format(case)]
for pair in zip(expr.conditions, expr.thens):
args = [self._ctx.get_expr_compiled(p) for p in pair]
lines.append('WHEN {0} THEN {1} '.format(*args))
if expr.default is not None:
lines.append('ELSE {0} '.format(self._ctx.get_expr_compiled(expr.default)))
lines.append('END')
if self._beautify:
for i in range(1, len(lines) - 1):
lines[i] = utils.indent(lines[i], self._indent_size)
compiled = '\n'.join(lines)
else:
compiled = ''.join(lines)
elif isinstance(expr, element.IntToDatetime):
compiled = 'FROM_UNIXTIME({0})'.format(
self._ctx.get_expr_compiled(expr._input),
)
else:
raise NotImplementedError
self._ctx.add_expr_compiled(expr, compiled)
def _parenthesis(self, child):
if isinstance(child, BinOp):
return '(%s)' % self._ctx.get_expr_compiled(child)
elif isinstance(child, (element.IsNull, element.NotNull, element.IsIn, element.NotIn,
element.Between, element.Switch, element.Cut)):
return '(%s)' % self._ctx.get_expr_compiled(child)
else:
return self._ctx.get_expr_compiled(child)
def visit_binary_op(self, expr):
if isinstance(expr, Add) and expr.dtype == df_types.string:
compiled = 'CONCAT({0}, {1})'.format(
self._ctx.get_expr_compiled(expr.lhs),
self._ctx.get_expr_compiled(expr.rhs),
)
elif isinstance(expr, (Add, Substract)) and expr.dtype == df_types.datetime:
if isinstance(expr, Add) and \
all(child.dtype == df_types.datetime for child in (expr.lhs, expr.rhs)):
raise CompileError('Cannot add two datetimes')
if isinstance(expr.rhs, DTScalar) or (isinstance(expr, Add) and expr.lhs, DTScalar):
if isinstance(expr.rhs, DTScalar):
dt, scalar = expr.lhs, expr.rhs
else:
dt, scalar = expr.rhs, expr.lhs
class_name = type(scalar).__name__[:-6]
date_part = DATE_PARTS_DIC[class_name]
val = scalar.value
if isinstance(expr, Substract):
val = -val
compiled = 'DATEADD({0}, {1}, {2})'.format(
self._ctx.get_expr_compiled(dt), repr(val), repr(date_part)
)
else:
compiled, op = None, None
try:
op = BINARY_OP_COMPILE_DIC[expr.node_name].upper()
except KeyError:
if isinstance(expr, Power):
compiled = 'POW({0}, {1})'.format(
self._ctx.get_expr_compiled(expr.lhs),
self._ctx.get_expr_compiled(expr.rhs)
)
if not isinstance(expr.dtype, df_types.Float):
compiled = self._cast(compiled, df_types.float64, expr.dtype)
else:
raise NotImplementedError
if compiled is None:
lhs, rhs = expr.args
if op:
compiled = '{0} {1} {2}'.format(
self._parenthesis(lhs), op, self._parenthesis(rhs)
)
else:
raise NotImplementedError
self._ctx.add_expr_compiled(expr, compiled)
def visit_unary_op(self, expr):
try:
if isinstance(expr, Negate) and expr.input.dtype == df_types.boolean:
compiled = 'NOT {0}'.format(self._parenthesis(expr.input))
else:
op = UNARY_OP_COMPILE_DIC[expr.node_name]
compiled = '{0}{1}'.format(
op, self._parenthesis(expr.input))
except KeyError:
if isinstance(expr, Abs):
compiled = 'ABS({0})'.format(
self._ctx.get_expr_compiled(expr.input))
elif isinstance(expr, (Invert, Negate)) and \
expr.input.dtype == df_types.boolean:
compiled = 'NOT {0}'.format(self._parenthesis(expr.input))
else:
raise NotImplementedError
self._ctx.add_expr_compiled(expr, compiled)
def visit_math(self, expr):
compiled = None
try:
op = MATH_COMPILE_DIC[expr.node_name]
except KeyError:
if expr.node_name == 'Log':
if expr._base is None:
op = 'ln'
else:
compiled = 'LOG({0}, {1})'.format(
self._ctx.get_expr_compiled(expr._base),
self._ctx.get_expr_compiled(expr.input)
)
elif expr.node_name == 'Log2':
compiled = 'LOG(2, {0})'.format(
self._ctx.get_expr_compiled(expr.input)
)
elif expr.node_name == 'Log10':
compiled = 'LOG(10, {0})'.format(
self._ctx.get_expr_compiled(expr.input)
)
elif expr.node_name == 'Log1p':
compiled = 'LN(1 + {0})'.format(
self._ctx.get_expr_compiled(expr.input)
)
elif expr.node_name == 'Expm1':
compiled = 'EXP({0}) - 1'.format(
self._ctx.get_expr_compiled(expr.input)
)
elif expr.node_name == 'Trunc':
if expr._decimals is None:
op = 'TRUNC'
else:
compiled = 'TRUNC({0}, {1})'.format(
self._ctx.get_expr_compiled(expr.input),
self._ctx.get_expr_compiled(expr._decimals)
)
elif expr.node_name == 'Round':
if expr._decimals is None:
op = 'ROUND'
else:
compiled = 'ROUND({0}, {1})'.format(
self._ctx.get_expr_compiled(expr.input),
self._ctx.get_expr_compiled(expr._decimals)
)
else:
raise NotImplementedError
if compiled is None:
compiled = '{0}({1})'.format(
op.upper(), self._ctx.get_expr_compiled(expr.input))
self._ctx.add_expr_compiled(expr, compiled)
def visit_string_op(self, expr):
# FIXME quite a few operations cannot support by internal function
compiled = None
input = self._ctx.get_expr_compiled(expr.input)
if isinstance(expr, strings.Capitalize):
compiled = 'CONCAT(TOUPPER(SUBSTR(%(input)s, 1, 1)), TOLOWER(SUBSTR(%(input)s, 2)))' % {
'input': input
}
elif isinstance(expr, strings.CatStr):
nodes = [expr._input]
if expr._others is not None:
others = (expr._others, ) if not isinstance(expr._others, Iterable) else expr._others
for other in others:
if expr._sep is not None:
nodes.extend([expr._sep, other])
else:
nodes.append(other)
compiled = 'CONCAT(%s)' % ', '.join(self._ctx.get_expr_compiled(e) for e in nodes)
elif isinstance(expr, strings.Contains):
if expr.regex:
raise NotImplementedError
compiled = 'INSTR(%s, %s) > 0' % (input, self._ctx.get_expr_compiled(expr._pat))
elif isinstance(expr, strings.Endswith):
# TODO: any better solution?
compiled = 'INSTR(REVERSE(%s), REVERSE(%s)) == 1' % (
input, self._ctx.get_expr_compiled(expr._pat))
elif isinstance(expr, strings.Startswith):
compiled = 'INSTR(%s, %s) == 1' % (input, self._ctx.get_expr_compiled(expr._pat))
elif isinstance(expr, strings.Find):
if isinstance(expr.start, six.integer_types):
start = expr.start + 1 if expr.start >= 0 else expr.start
else:
start = 'IF(%(start)s >= 0, %(start)s + 1, %(start)s)' % {
'start': self._ctx.get_expr_compiled(expr._start)
}
if expr.end is not None:
raise NotImplementedError
else:
compiled = 'INSTR(%s, %s, %s) - 1' % (
input, self._ctx.get_expr_compiled(expr._sub), start)
elif isinstance(expr, strings.Get):
compiled = 'SUBSTR(%s, %s, 1)' % (input, expr.index + 1)
elif isinstance(expr, strings.Len):
compiled = 'LENGTH(%s)' % input
elif isinstance(expr, strings.Lower):
compiled = 'TOLOWER(%s)' % input
elif isinstance(expr, strings.Upper):
compiled = 'TOUPPER(%s)' % input
elif isinstance(expr, (strings.Lstrip, strings.Rstrip, strings.Strip)):
if expr.to_strip != ' ':
raise NotImplementedError
func = {
'Lstrip': 'LTRIM',
'Rstrip': 'RTRIM',
'Strip': 'TRIM'
}
compiled = '%s(%s)' % (func[type(expr).__name__], input)
elif isinstance(expr, strings.Slice):
# internal function will be compiled in two cases:
# 1) start is not None
# 2) positive start and end
if expr.end is None and expr.step is None:
compiled = 'SUBSTR(%s, %s)' % (input, expr.start + 1)
else:
# expr.start and expr.end
length = expr.end - expr.start
compiled = 'SUBSTR(%s, %s, %s)' % (input, expr.start + 1, length)
elif isinstance(expr, strings.Repeat):
compiled = 'REPEAT(%s, %s)' % (
input, self._ctx.get_expr_compiled(expr._repeats))
elif isinstance(expr, strings.Split):
if expr.n != -1:
raise NotImplementedError
escape_pat = re.escape(expr.pat)
nre_compiled = 'SPLIT(%s, \'%s\')' % (input, utils.escape_odps_string(expr.pat))
re_compiled = 'SPLIT(%s, \'%s\')' % (input, utils.escape_odps_string(re.escape(expr.pat)))
try:
re.compile(expr.pat)
is_regex = True
except:
is_regex = False
if expr.pat == escape_pat or not is_regex:
compiled = nre_compiled
else:
compiled = 'IF(SIZE(%(re)s) = 0, %(nre)s, %(re)s)' % dict(re=re_compiled, nre=nre_compiled)
elif isinstance(expr, strings.StringToDict):
compiled = 'STR_TO_MAP(%s, \'%s\', \'%s\')' % (input, expr.item_delim, expr.kv_delim)
if compiled is not None:
self._ctx.add_expr_compiled(expr, compiled)
else:
raise NotImplementedError
def visit_datetime_op(self, expr):
# FIXME quite a few operations cannot support by internal function
class_name = type(expr).__name__
input = self._ctx.get_expr_compiled(expr.input)
compiled = None
if class_name in DATE_PARTS_DIC:
compiled = 'DATEPART(%s, %r)' % (input, DATE_PARTS_DIC[class_name])
elif isinstance(expr, datetimes.WeekOfYear):
compiled = 'WEEKOFYEAR(%s)' % input
elif isinstance(expr, datetimes.WeekDay):
compiled = 'WEEKDAY(%s)' % input
elif isinstance(expr, datetimes.Date):
compiled = 'DATETRUNC(%s, %r)' % (input, 'dd')
elif isinstance(expr, datetimes.UnixTimestamp):
compiled = 'UNIX_TIMESTAMP(%s)' % input
if compiled is not None:
self._ctx.add_expr_compiled(expr, compiled)
else:
raise NotImplementedError
def visit_composite_op(self, expr):
if getattr(expr, 'input', None) is not None:
input = self._ctx.get_expr_compiled(expr.input)
else:
input = None
compiled = None
if isinstance(expr, composites.ListDictLength):
compiled = 'SIZE(%s)' % input
elif isinstance(expr, composites.ListDictGetItem):
compiled = '%s[%s]' % (input, self._ctx.get_expr_compiled(expr._key))
elif isinstance(expr, composites.ListSort):
compiled = 'SORT_ARRAY(%s)' % input
elif isinstance(expr, composites.ListContains):
compiled = 'ARRAY_CONTAINS(%s, %s)' % (
input,
self._ctx.get_expr_compiled(expr._value)
)
elif isinstance(expr, composites.DictKeys):
compiled = 'MAP_KEYS(%s)' % input
elif isinstance(expr, composites.DictValues):
compiled = 'MAP_VALUES(%s)' % input
elif isinstance(expr, composites.ListBuilder):
compiled = 'ARRAY(%s)' % ', '.join(
self._ctx.get_expr_compiled(it) for it in expr._values
)
elif isinstance(expr, composites.DictBuilder):
compiled = 'MAP(%s)' % ', '.join(
'%s, %s' % (self._ctx.get_expr_compiled(it1), self._ctx.get_expr_compiled(it2))
for it1, it2 in zip(expr._keys, expr._values)
)
if compiled is not None:
self._ctx.add_expr_compiled(expr, compiled)
else:
raise NotImplementedError
def visit_groupby(self, expr):
bys, having, aggs, fields = tuple(expr.args[1:])
if fields is None:
fields = bys + aggs
by_fields = [self._ctx.get_expr_compiled(by) for by in bys]
group_by_clause = 'GROUP BY {0} '.format(self._join_compiled_fields(by_fields))
select_fields = [self._compile_select_field(field) for field in fields]
select_clause = self._join_select_fields(select_fields)
self.add_select_clause(expr, select_clause)
self.add_group_by_clause(expr, group_by_clause)
if having:
self.add_having_clause(expr, self._ctx.get_expr_compiled(having))
def visit_mutate(self, expr):
bys, mutates, fields = tuple(expr.args[1:])
if fields is None:
fields = bys + mutates
select_fields = [self._compile_select_field(field) for field in fields]
select_clause = self._join_select_fields(select_fields)
self.add_select_clause(expr, select_clause)
def visit_sort_column(self, expr):
def get_field(field):
if isinstance(field.input, CollectionExpr):
return field._source_name
elif isinstance(field.input, Column):
return field.input.source_name
else:
return self._ctx.get_expr_compiled(field.input)
compiled = '{0} DESC'.format(get_field(expr)) \
if not expr._ascending else get_field(expr)
self._ctx.add_expr_compiled(expr, compiled)
def visit_sort(self, expr):
keys_fields = expr.args[1]
order_by_clause = 'ORDER BY {0} '.format(self._join_compiled_fields(
[self._ctx.get_expr_compiled(field) for field in keys_fields]))
self.add_order_by_clause(expr, order_by_clause)
def visit_sample(self, expr):
assert expr._parts is None and not expr._replace.value
if expr._frac is not None:
self.add_table_sample_clause(expr, "({0} PERCENT)".format(int(expr._frac.value * 100)))
elif expr._n is not None:
self.add_table_sample_clause(expr, "({0} ROWS)".format(expr._n.value))
def visit_distinct(self, expr):
distinct_fields = expr.args[1]
fields_clause = self._join_select_fields(
[self._compile_select_field(field) for field in distinct_fields])
select_clause = 'DISTINCT {0}'.format(fields_clause)
self.add_select_clause(expr, select_clause)
def visit_reduction(self, expr):
is_unique = getattr(expr, '_unique', False)
if isinstance(expr, (Count, GroupedCount)) and isinstance(expr.input, CollectionExpr):
compiled = 'COUNT(1)'
self._ctx.add_expr_compiled(expr, compiled)
return
if isinstance(expr, (Std, GroupedStd)):
if expr._ddof not in (0, 1):
raise CompileError('Does not support %s with ddof=%s' % (
expr.node_name, expr._ddof))
compiled = None
if isinstance(expr, (Mean, GroupedMean)):
node_name = 'avg'
elif isinstance(expr, (Std, GroupedStd)):
node_name = 'stddev' if expr._ddof == 0 else 'stddev_samp'
elif isinstance(expr, (Sum, GroupedSum)) and expr.input.dtype == df_types.string:
if is_unique:
compiled = 'WM_CONCAT(DISTINCT \'\', %s)' % self._ctx.get_expr_compiled(expr.input)
else:
compiled = 'WM_CONCAT(\'\', %s)' % self._ctx.get_expr_compiled(expr.input)
elif isinstance(expr, (Sum, GroupedSum)) and expr.input.dtype == df_types.boolean:
if getattr(expr, '_unique', False):
compiled = 'SUM(DISTINCT IF(%s, 1, 0))' % self._ctx.get_expr_compiled(expr.input)
else:
compiled = 'SUM(IF(%s, 1, 0))' % self._ctx.get_expr_compiled(expr.input)
elif isinstance(expr, (Max, GroupedMax, Min, GroupedMin)) and \
expr.input.dtype == df_types.boolean:
compiled = '%s(IF(%s, 1, 0)) == 1' % (
expr.node_name, self._ctx.get_expr_compiled(expr.input))
elif isinstance(expr, (Any, GroupedAny)):
compiled = 'MAX(IF(%s, 1, 0)) == 1' % self._ctx.get_expr_compiled(expr.args[0])
elif isinstance(expr, (All, GroupedAll)):
compiled = 'MIN(IF(%s, 1, 0)) == 1' % self._ctx.get_expr_compiled(expr.args[0])
elif isinstance(expr, (NUnique, GroupedNUnique)):
compiled = 'COUNT(DISTINCT %s)' % ', '.join(
self._ctx.get_expr_compiled(c) for c in expr.inputs)
elif isinstance(expr, (Cat, GroupedCat)):
if is_unique:
compiled = 'WM_CONCAT(DISTINCT %s, %s)' % (self._ctx.get_expr_compiled(expr._sep),
self._ctx.get_expr_compiled(expr.input))
else:
compiled = 'WM_CONCAT(%s, %s)' % (self._ctx.get_expr_compiled(expr._sep),
self._ctx.get_expr_compiled(expr.input))
elif isinstance(expr, (Quantile, GroupedQuantile)):
if not isinstance(expr._prob, (list, set)):
probs_expr = expr._prob
else:
probs_expr = 'ARRAY(' + ', '.join(str(p) for p in expr._prob) + ')'
if expr.input.data_type in (df_types.float32, df_types.float64) and types.get_local_use_odps2_types():
func_name = 'PERCENTILE_APPROX'
else:
func_name = 'PERCENTILE'
if is_unique:
compiled = '%s(DISTINCT %s, %s)' % (func_name, self._ctx.get_expr_compiled(expr.input), probs_expr)
else:
compiled = '%s(%s, %s)' % (func_name, self._ctx.get_expr_compiled(expr.input), probs_expr)
elif isinstance(expr, (ToList, GroupedToList)):
func_name = 'COLLECT_SET' if expr._unique else 'COLLECT_LIST'
compiled = '%s(%s)' % (func_name, self._ctx.get_expr_compiled(expr.input))
else:
node_name = expr.node_name
if compiled is None:
if is_unique:
compiled = '{0}(DISTINCT {1})'.format(
node_name.upper(), self._ctx.get_expr_compiled(expr.args[0]))
else:
compiled = '{0}({1})'.format(
node_name.upper(), self._ctx.get_expr_compiled(expr.args[0]))
self._ctx.add_expr_compiled(expr, compiled)
def visit_user_defined_aggregator(self, expr):
is_func_created = False
if isinstance(expr._aggregator, six.string_types):
func_name = expr._aggregator
elif isinstance(expr._aggregator, Function):
func_name = expr._aggregator.name
else:
func_name = self._ctx.get_udf(expr._aggregator)
is_func_created = True
args = [self._ctx.get_expr_compiled(i) for i in expr.inputs]
if hasattr(expr, '_func_args') and expr._func_args is not None \
and not is_func_created:
func_args = [repr(arg) for arg in expr._func_args]
args.extend(func_args)
if getattr(expr, '_unique', False):
compiled = '{0}(DISTINCT {1})'.format(func_name, ', '.join(args))
else:
compiled = '{0}({1})'.format(func_name, ', '.join(args))
self._ctx.add_expr_compiled(expr, compiled)
def visit_column(self, expr):
alias = self._ctx.get_collection_alias(expr.input, silent=True)
if alias:
alias = alias[0]
compiled = '{0}.{1}'.format(alias, self._quote(expr.source_name))
else:
symbol = self._ctx.add_need_alias_column(expr)
compiled = '%({0})s'.format(symbol)
if expr._source_data_type != expr._data_type:
compiled = 'CAST({0} AS {1})'.format(
compiled, types.df_type_to_odps_type(expr.dtype))
self._ctx.add_expr_compiled(expr, compiled)
def visit_function(self, expr):
is_func_created = False
if isinstance(expr, Func):
func_name = expr._func_name
else:
if isinstance(expr._func, six.string_types):
func_name = expr._func
elif isinstance(expr._func, Function):
func_name = expr._func.name
else:
func_name = self._ctx.get_udf(expr._func)
is_func_created = True
if isinstance(expr, (MappedExpr, Func)):
args = [self._ctx.get_expr_compiled(f) for f in expr.inputs]
else:
raise NotImplementedError
if hasattr(expr, '_func_args') and expr._func_args is not None \
and not is_func_created:
func_args = [repr(arg) for arg in expr._func_args]
args.extend(func_args)
compiled = '{0}({1})'.format(func_name, ', '.join(args))
self._ctx.add_expr_compiled(expr, compiled)
def visit_builtin_function(self, expr):
compiled = '{0}({1})'.format(expr._func_name,
', '.join(repr(arg) for arg in expr._func_args))
self._ctx.add_expr_compiled(expr, compiled)
def _quote(self, compiled):
if options.df.quote:
return '`{0}`'.format(compiled)
else:
return compiled
def _unquote(self, compiled):
if options.df.quote:
reg = re.compile(r'`([^`]+)`')
else:
reg = re.compile(r'\.(\w+)')
matched = reg.search(compiled)
if matched:
return matched.group(1)
return compiled
def visit_reshuffle(self, expr):
bys, sorts = expr._by, expr._sort_fields
by_fields = [self._unquote(self._ctx.get_expr_compiled(by)) for by in bys]
distribute_by_clause = 'DISTRIBUTE BY {0} '.format(
self._join_compiled_fields(by_fields))
self.add_group_by_clause(expr, distribute_by_clause)
if sorts:
sort_fields = [self._ctx.get_expr_compiled(sort) for sort in sorts]
sort_by_clause = 'SORT BY {0}'.format(self._join_compiled_fields(sort_fields))
self.add_order_by_clause(expr, sort_by_clause)
def visit_apply_collection(self, expr):
is_func_created = False
if isinstance(expr._func, six.string_types):
func_name = expr._func
elif isinstance(expr._func, Function):
func_name = expr._func.name
else:
func_name = self._ctx.get_udf(expr._func)
is_func_created = True
args = [self._ctx.get_expr_compiled(f) for f in expr._fields]
if hasattr(expr, '_func_args') and expr._func_args is not None \
and not is_func_created:
func_args = [repr(arg) for arg in expr._func_args]
args.extend(func_args)
func_call = '{0}({1})'.format(func_name, ', '.join(args))
field_list = ', '.join(self._quote(n) for n in expr.schema.names)
if expr._lateral_view:
lv_prefix = 'LATERAL VIEW ' if not expr._keep_nulls else 'LATERAL VIEW OUTER '
compiled = lv_prefix + '{0} {1} AS {2}'.format(
func_call, self._ctx.get_collection_alias(expr, create=True)[0], field_list)
else:
compiled = '{0} AS ({1})'.format(func_call, field_list)
self.add_select_clause(expr, compiled)
self._ctx.add_expr_compiled(expr, compiled)
def _wrap_typed(self, expr, compiled):
if expr._source_data_type != expr._data_type:
compiled = 'cast({0} AS {1})'.format(
compiled, types.df_type_to_odps_type(expr._data_type))
return compiled
def visit_sequence(self, expr):
compiled = expr._source_name
compiled = self._wrap_typed(expr, compiled)
self._ctx.add_expr_compiled(expr, compiled)
def _compile_window_order_by(self, expr):
if isinstance(expr.input, SequenceExpr):
compiled = self._ctx.get_expr_compiled(expr.input)
return '%s DESC' % compiled if not expr._ascending else compiled
else:
return self._ctx.get_expr_compiled(expr)
def _compile_window_function(self, func, args, partition_by=None,
order_by=None, preceding=None, following=None):
partition_by = 'PARTITION BY {0}'.format(partition_by or '1')
order_by = 'ORDER BY {0}'.format(order_by) if order_by is not None else ''
if isinstance(preceding, tuple):
window_clause = 'ROWS BETWEEN {0} PRECEDING AND {1} PRECEDING' \
.format(*[self._ctx.get_expr_compiled(p) for p in preceding])
elif isinstance(following, tuple):
window_clause = 'ROWS BETWEEN {0} FOLLOWING AND {1} FOLLOWING' \
.format(*[self._ctx.get_expr_compiled(f) for f in following])
elif preceding is not None and following is not None:
window_clause = 'ROWS BETWEEN {0} PRECEDING AND {1} FOLLOWING' \
.format(self._ctx.get_expr_compiled(preceding), self._ctx.get_expr_compiled(following))
elif preceding is not None:
window_clause = 'ROWS {0} PRECEDING'.format(self._ctx.get_expr_compiled(preceding))
elif following is not None:
window_clause = 'ROWS {0} FOLLOWING'.format(self._ctx.get_expr_compiled(following))
else:
window_clause = ''
over = ' '.join(sub for sub in (partition_by, order_by, window_clause)
if len(sub) > 0)
return '{0}({1}) OVER ({2})'.format(func, args, over)
def visit_cum_window(self, expr):
col_compiled = self._ctx.get_expr_compiled(expr.input)
if isinstance(expr, CumSum) and expr.input.dtype == df_types.boolean:
col_compiled = 'IF({0}, 1, 0)'.format(col_compiled)
elif isinstance(expr, NthValue):
col_compiled = '{0}, {1}'.format(col_compiled, expr._nth + 1)
if expr._skip_nulls:
col_compiled += ', true'
if expr.distinct:
col_compiled = 'DISTINCT {0}'.format(col_compiled)
partition_by = ', '.join(self._ctx.get_expr_compiled(by)
for by in expr._partition_by) if expr._partition_by else None
order_by = ', '.join(self._compile_window_order_by(by)
for by in expr._order_by) if expr._order_by else None
func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()
compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
order_by=order_by, preceding=expr._preceding,
following=expr._following)
self._ctx.add_expr_compiled(expr, compiled)
def visit_rank_window(self, expr):
func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()
col_compiled = ''
if isinstance(expr, QCut):
col_compiled = str(expr._bins)
partition_by = ', '.join(self._ctx.get_expr_compiled(by)
for by in expr._partition_by) if expr._partition_by else None
order_by = ', '.join(self._compile_window_order_by(by)
for by in expr._order_by) if expr._order_by else None
compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
order_by=order_by)
self._ctx.add_expr_compiled(expr, compiled)
def visit_shift_window(self, expr):
func_name = WINDOW_COMPILE_DIC[expr.node_name].upper()
compiled_fields = [self._ctx.get_expr_compiled(expr.input), ]
if expr._offset:
compiled_fields.append(self._ctx.get_expr_compiled(expr._offset))
if expr._default:
compiled_fields.append(self._ctx.get_expr_compiled(expr._default))
col_compiled = self._join_compiled_fields(compiled_fields)
partition_by = ', '.join(self._ctx.get_expr_compiled(by)
for by in expr._partition_by) if expr._partition_by else None
order_by = ', '.join(self._compile_window_order_by(by)
for by in expr._order_by) if expr._order_by else None
compiled = self._compile_window_function(func_name, col_compiled, partition_by=partition_by,
order_by=order_by)
self._ctx.add_expr_compiled(expr, compiled)
def visit_scalar(self, expr):
compiled = None
if expr._value is not None:
if expr.dtype == df_types.string:
val = utils.to_str(expr.value) \
if isinstance(expr.value, six.text_type) else expr.value
compiled = "'{0}'".format(val.replace("'", "\\'"))
elif isinstance(expr.dtype, df_types.Integer) and types.get_local_use_odps2_types():
if expr.dtype == df_types.int8:
compiled = "{0!r}Y".format(expr._value)
elif expr.dtype == df_types.int16:
compiled = "{0!r}S".format(expr._value)
elif expr.dtype == df_types.int64:
compiled = "{0!r}L".format(expr._value)
elif isinstance(expr._value, bool):
compiled = 'true' if expr._value else 'false'
elif isinstance(expr._value, datetime):
# FIXME: just ignore shorter than second
compiled = 'FROM_UNIXTIME({0})'.format(utils.to_timestamp(expr._value))
elif isinstance(expr._value, date):
compiled = 'CAST({0!r} AS DATE)'.format(expr._value.strftime("%Y-%m-%d"))
elif isinstance(expr._value, Decimal):
compiled = 'CAST({0} AS DECIMAL)'.format(repr(str(expr._value)))
else:
compiled = 'CAST(NULL AS {0})'.format(types.df_type_to_odps_type(expr._value_type))
if compiled is None:
compiled = repr(expr._value)
self._ctx.add_expr_compiled(expr, compiled)
@classmethod
def _cast(cls, compiled, source_type, to_type):
source_odps_type = types.df_type_to_odps_type(source_type)
to_type = types.df_type_to_odps_type(to_type)
if not to_type.can_explicit_cast(source_odps_type):
raise CompileError(
'Cannot cast from %s to %s' % (source_odps_type, to_type))
return 'CAST({0} AS {1})'.format(compiled, to_type)
def visit_cast(self, expr):
compiled = self._ctx.get_expr_compiled(expr._input)
if isinstance(expr.source_type, df_types.Integer) and expr.dtype == df_types.datetime:
compiled = 'FROM_UNIXTIME({0})'.format(self._ctx.get_expr_compiled(expr.input))
elif expr.dtype != expr.source_type:
compiled = self._cast(compiled, expr.source_type, expr.dtype)
self._ctx.add_expr_compiled(expr, compiled)
def visit_join(self, expr):
left_compiled, right_compiled, predicate_compiled = tuple(self._sub_compiles[expr])
from_clause = '{0} \n{1} JOIN \n{2}'.format(
left_compiled, expr._how, utils.indent(right_compiled, self._indent_size)
)
if predicate_compiled:
from_clause += '\nON {0}'.format(predicate_compiled)
self.add_from_clause(expr, from_clause)
self._ctx.add_expr_compiled(expr, from_clause)
def visit_union(self, expr):
union_type = 'UNION ALL' if not expr._distinct else 'UNION'
left_compiled, right_compiled = tuple(self._sub_compiles[expr])
from_clause = '{0} \n{1}\n{2}'.format(left_compiled, union_type,
utils.indent(right_compiled, self._indent_size))
compiled = from_clause
if not self._union_no_alias.get(expr, False):
compiled = '(\n{0}\n) {1}'.format(utils.indent(from_clause, self._indent_size),
self._ctx.get_collection_alias(expr, create=True)[0])
self.add_from_clause(expr, compiled)
self._ctx.add_expr_compiled(expr, compiled)