odps/df/backends/odpssql/analyzer.py (594 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sys
import math as pymath
from ..analyzer import BaseAnalyzer
from ...expr.arithmetic import *
from ...expr.composites import ListDictGetItem
from ...expr.math import *
from ...expr.datetimes import *
from ...expr.strings import *
from ...expr.strings import Count as StrCount
from ...expr.element import *
from ...expr.reduction import *
from ...expr.collections import *
from ...expr.merge import *
from ...expr.window import QCut
from ...utils import output
from ..errors import CompileError
from ..utils import refresh_dynamic
from ... import types
from .... import compat
from ....utils import to_text
_NAN = float('nan')
class Analyzer(BaseAnalyzer):
def _parents(self, expr):
return self._dag.successors(expr)
def visit_composite_op(self, expr):
if isinstance(expr, ListDictGetItem) and isinstance(expr.input.dtype, types.List):
if is_constant_scalar(expr._negative_handled) and expr._negative_handled.value:
return
key_expr = expr._key
if is_constant_scalar(key_expr):
if key_expr.value >= 0:
return
sub = expr.input[expr.input.len() - (-key_expr.value)]
sub._negative_handled = Scalar(True)
else:
expr._negative_handled = Scalar(True)
neg_expr = expr.input[expr.input.len() + key_expr]
neg_expr._negative_handled = Scalar(True)
sub = (key_expr >= 0).ifelse(expr, neg_expr).rename(expr.name)
self._sub(expr, sub)
else:
raise NotImplementedError
@staticmethod
def _make_isna_expr(expr, invert=False):
from ... import func
input_expr = expr.input
has_float_col = (
isinstance(input_expr, CollectionExpr)
and any(isinstance(col.type, types.Float) for col in input_expr.columns)
) or (
isinstance(input_expr, (SequenceExpr, Scalar))
and isinstance(input_expr.dtype, types.Float)
)
if options.df.odps.nan_handler is None or not has_float_col:
return input_expr.isnull() if not invert else input_expr.notnull()
handler = options.df.odps.nan_handler.lower()
if handler == "py":
func = lambda x: pymath.isnan(x)
func._identifier = "isnan"
handle_fun = lambda x: x.map(func, rtype=bool)
else:
handle_fun = functools.partial(func.ISNAN, rtype=bool)
if not invert:
return input_expr.isnull() | handle_fun(input_expr)
else:
return input_expr.notnull() & ~handle_fun(input_expr)
def visit_element_op(self, expr):
if isinstance(expr, Between):
if expr.inclusive:
sub = ((expr.left <= expr.input) & (expr.input.copy() <= expr.right))
else:
sub = ((expr.left < expr.input) & (expr.input.copy() < expr.right))
self._sub(expr, sub.rename(expr.name))
elif isinstance(expr, Cut):
sub = self._get_cut_sub_expr(expr)
self._sub(expr, sub)
elif isinstance(expr, IsNa):
sub = self._make_isna_expr(expr)
self._sub(expr, sub)
elif isinstance(expr, NotNa):
sub = self._make_isna_expr(expr, invert=True)
self._sub(expr, sub)
elif isinstance(expr, FillNa):
cond = self._make_isna_expr(expr)
sub = cond.ifelse(expr._fill_value, expr.input)
if isinstance(expr, (SequenceExpr, Scalar)):
sub = sub.rename(expr.name)
self._sub(expr, sub)
else:
raise NotImplementedError
def visit_sample(self, expr):
if expr._replace.value or expr._weights is not None or expr._strata is not None:
raise CompileError('ODPS SQL does not support specified sample method')
if not expr._parts:
raise NotImplementedError
idxes = [None, ] if expr._i is None else expr._i
condition = None
for idx in idxes:
inputs = [expr._parts]
if idx is not None:
new_val = idx.value + 1
inputs.append(Scalar(_value=new_val, _value_type=idx.value_type))
if expr._sampled_fields:
inputs.extend(expr._sampled_fields)
cond = MappedExpr(_inputs=inputs, _func='SAMPLE', _data_type=types.boolean)
if condition is None:
condition = cond
else:
condition |= cond
sub = FilterCollectionExpr(_input=expr.input, _predicate=condition,
_schema=expr.schema)
expr.input.optimize_banned = True
self._sub(expr, sub)
def _visit_pivot(self, expr):
sub = self._get_pivot_sub_expr(expr)
self._sub(expr, sub)
def _visit_pivot_table(self, expr):
sub = self._get_pivot_table_sub_expr(expr)
self._sub(expr, sub)
def visit_pivot(self, expr):
if isinstance(expr, PivotCollectionExpr):
self._visit_pivot(expr)
else:
self._visit_pivot_table(expr)
def visit_extract_kv(self, expr):
kv_delimiter = expr._kv_delimiter.value
item_delimiter = expr._item_delimiter.value
default = expr._default.value if expr._default else None
class KeyAgg(object):
def buffer(self):
return set()
def __call__(self, buf, val):
if not val:
return
def validate_kv(v):
parts = v.split(kv_delimiter)
if len(parts) != 2:
raise ValueError('Malformed KV pair: %s' % v)
return parts[0]
buf.update([validate_kv(item) for item in val.split(item_delimiter)])
def merge(self, buf, pbuffer):
buf.update(pbuffer)
def getvalue(self, buf):
return item_delimiter.join(sorted(buf))
columns_expr = expr.input.exclude(expr._intact).apply(KeyAgg, names=[c.name for c in expr._columns])
intact_names = [g.name for g in expr._intact]
intact_types = [g.dtype for g in expr._intact]
exprs = [expr]
def callback(result, new_expr):
expr = exprs[0]
names = list(intact_names)
tps = list(intact_types)
kv_slot_map = dict()
for col, key_str in compat.izip(result.columns, result[0]):
kv_slot_map[col.name] = dict()
for k in key_str.split(item_delimiter):
names.append('%s_%s' % (col.name, k))
tps.append(expr._column_type)
kv_slot_map[col.name][k] = len(names) - 1
kv_slot_names = list(kv_slot_map.keys())
type_adapter = None
if isinstance(expr._column_type, types.Float):
type_adapter = float
elif isinstance(expr._column_type, types.Integer):
type_adapter = int
@output(names, tps)
def mapper(row):
ret = [default, ] * len(names)
ret[:len(intact_names)] = [getattr(row, col) for col in intact_names]
for col in kv_slot_names:
kv_val = getattr(row, col)
if not kv_val:
continue
for kv_item in kv_val.split(item_delimiter):
k, v = kv_item.split(kv_delimiter)
if type_adapter:
v = type_adapter(v)
ret[kv_slot_map[col][k]] = v
return tuple(ret)
new_expr._schema = TableSchema.from_lists(names, tps)
extracted = expr.input.map_reduce(mapper)
self._sub(new_expr, extracted)
# trigger refresh of dynamic operations
refresh_dynamic(extracted, self._dag)
sub = CollectionExpr(_schema=DynamicSchema.from_lists(intact_names, intact_types),
_deps=[(columns_expr, callback)])
self._sub(expr, sub)
def visit_value_counts(self, expr):
self._sub(expr, self._get_value_counts_sub_expr(expr))
def _gen_mapped_expr(self, expr, inputs, func, name,
args=None, kwargs=None, multiple=False):
kwargs = dict(_inputs=inputs, _func=func, _name=name,
_func_args=args, _func_kwargs=kwargs,
_multiple=multiple)
if isinstance(expr, SequenceExpr):
kwargs['_data_type'] = expr.dtype
else:
kwargs['_value_type'] = expr.dtype
return MappedExpr(**kwargs)
def visit_binary_op(self, expr):
if not options.df.analyze:
raise NotImplementedError
if isinstance(expr, FloorDivide):
func = lambda l, r: l // r
# multiple False will pass *args instead of namedtuple
sub = self._gen_mapped_expr(expr, (expr.lhs, expr.rhs),
func, expr.name, multiple=False)
self._sub(expr, sub)
return
if isinstance(expr, Mod):
func = lambda l, r: l % r
sub = self._gen_mapped_expr(expr, (expr.lhs, expr.rhs),
func, expr.name, multiple=False)
self._sub(expr, sub)
return
if isinstance(expr, Add) and \
all(child.dtype == types.datetime for child in (expr.lhs, expr.rhs)):
return
elif isinstance(expr, (Add, Substract)):
if expr.lhs.dtype == types.datetime and expr.rhs.dtype == types.datetime:
pass
elif any(isinstance(child, MilliSecondScalar) for child in (expr.lhs, expr.rhs)):
pass
else:
return
def func(l, r, method):
from datetime import datetime, timedelta
if not isinstance(l, datetime):
l = timedelta(milliseconds=l)
if not isinstance(r, datetime):
r = timedelta(milliseconds=r)
if method == '+':
res = l + r
else:
res = l - r
if isinstance(res, timedelta):
return int(res.total_seconds() * 1000)
return res
func._identifier = "dt_add_sub"
inputs = expr.lhs, expr.rhs, Scalar('+') if isinstance(expr, Add) else Scalar('-')
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
raise NotImplementedError
def visit_unary_op(self, expr):
if not options.df.analyze:
raise NotImplementedError
if isinstance(expr, Invert) and isinstance(expr.input.dtype, types.Integer):
sub = expr.input.map(lambda x: ~x)
self._sub(expr, sub)
return
raise NotImplementedError
def visit_math(self, expr):
if not options.df.analyze:
raise NotImplementedError
if expr.dtype != types.decimal:
if isinstance(expr, Arccosh):
def func(x):
return pymath.acosh(x)
elif isinstance(expr, Arcsinh):
def func(x):
return pymath.asinh(x)
elif isinstance(expr, Arctanh):
def func(x):
try:
return pymath.atanh(x)
except ValueError:
return _NAN
elif isinstance(expr, Radians):
def func(x):
return pymath.radians(x)
elif isinstance(expr, Degrees):
def func(x):
return pymath.degrees(x)
else:
raise NotImplementedError
func._identifier = "m_" + type(expr).__name__.lower()
sub = expr.input.map(func, expr.dtype)
self._sub(expr, sub)
return
raise NotImplementedError
def visit_datetime_op(self, expr):
if isinstance(expr, Strftime):
if not options.df.analyze:
raise NotImplementedError
def func(x, fmt):
return x.strftime(fmt)
func._identifier = "strftime"
date_fmt = expr._date_format \
if not isinstance(expr._date_format, StringScalar) or expr._date_format._value is None \
else Scalar(to_text(expr.date_format).replace("%", "%%"))
inputs = expr.input, date_fmt
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
return
raise NotImplementedError
def visit_string_op(self, expr):
if isinstance(expr, Ljust):
rest = expr.width - expr.input.len()
sub = expr.input + (rest >= 0).ifelse(expr._fillchar.repeat(rest), '')
self._sub(expr, sub.rename(expr.name))
return
elif isinstance(expr, Rjust):
rest = expr.width - expr.input.len()
sub = (rest >= 0).ifelse(expr._fillchar.repeat(rest), '') + expr.input
self._sub(expr, sub.rename(expr.name))
return
elif isinstance(expr, Zfill):
fillchar = Scalar('0')
rest = expr.width - expr.input.len()
sub = (rest >= 0).ifelse(fillchar.repeat(rest), '') + expr.input
self._sub(expr, sub.rename(expr.name))
return
elif isinstance(expr, CatStr):
input = expr.input
others = expr._others if isinstance(expr._others, Iterable) else (expr._others, )
for other in others:
if expr.na_rep is not None:
for e in (input, ) + tuple(others):
self._sub(e, e.fillna(expr.na_rep), parents=(expr, ))
return
else:
if expr._sep is not None:
input = other.isnull().ifelse(input, input + expr._sep + other)
else:
input = other.isnull().ifelse(input, input + other)
self._sub(expr, input.rename(expr.name))
return
if not options.df.analyze:
raise NotImplementedError
func = None
if isinstance(expr, Contains) and expr.regex:
def func(x, pat, case, flags):
if x is None:
return None
flgs = 0
if not case:
flgs = re.I
if flags > 0:
flgs = flgs | flags
r = re.compile(pat, flgs)
return r.search(x) is not None
func._identifier = "str_contains"
pat = expr._pat if not isinstance(expr._pat, StringScalar) or expr._pat._value is None \
else Scalar(re.escape(to_text(expr.pat)))
inputs = expr.input, pat, expr._case, expr._flags
sub = self._gen_mapped_expr(expr, inputs, func,
expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, StrCount):
def func(x, pat, flags):
if x is None:
return None
regex = re.compile(pat, flags=flags)
return len(regex.findall(x))
func._identifier = "str_count"
pat = expr._pat if not isinstance(expr._pat, StringScalar) or expr._pat._value is None \
else Scalar(re.escape(to_text(expr.pat)))
inputs = expr.input, pat, expr._flags
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, Find) and expr.end is not None:
start = expr.start
end = expr.end
substr = expr.sub
def func(x):
if x is None:
return None
return x.find(substr, start, end)
func._identifier = "str_find"
elif isinstance(expr, RFind):
start = expr.start
end = expr.end
substr = expr.sub
def func(x):
if x is None:
return None
return x.rfind(substr, start, end)
func._identifier = "str_rfind"
elif isinstance(expr, Extract):
def func(x, pat, flags, group):
if x is None:
return None
regex = re.compile(pat, flags=flags)
m = regex.search(x)
if m:
if group is None:
return m.group()
return m.group(group)
func._identifier = "str_extract"
pat = expr._pat if not isinstance(expr._pat, StringScalar) or expr._pat._value is None \
else Scalar(re.escape(to_text(expr.pat)))
inputs = expr.input, pat, expr._flags, expr._group
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, Replace):
def func(x, pat, repl, n, case, flags, use_regex):
if x is None:
return None
use_re = use_regex and (not case or len(pat) > 1 or flags)
if use_re:
if not case:
flags |= re.IGNORECASE
regex = re.compile(pat, flags=flags)
n = n if n >= 0 else 0
return regex.sub(repl, x, count=n)
else:
return x.replace(pat, repl, n)
func._identifier = "str_replace"
pat = expr._pat if not isinstance(expr._pat, StringScalar) or expr._pat._value is None \
else Scalar(re.escape(to_text(expr.pat)))
inputs = expr.input, pat, expr._repl, expr._n, \
expr._case, expr._flags, expr._regex
sub = self._gen_mapped_expr(expr, inputs, func,
expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, (Lstrip, Strip, Rstrip)) and expr.to_strip != ' ':
to_strip = expr.to_strip
if isinstance(expr, Lstrip):
def func(x):
if x is None:
return None
return x.lstrip(to_strip)
func._identifier = "str_lstrip"
elif isinstance(expr, Strip):
def func(x):
if x is None:
return None
return x.strip(to_strip)
func._identifier = "str_strip"
elif isinstance(expr, Rstrip):
def func(x):
if x is None:
return None
return x.rstrip(to_strip)
func._identifier = "str_rstrip"
elif isinstance(expr, Pad):
side = expr.side
fillchar = expr.fillchar
width = expr.width
def func(x, width, fillchar, side):
if x is None:
return None
if side == 'left':
return x.rjust(width, fillchar)
elif side == 'right':
return x.ljust(width, fillchar)
else:
return x.center(width, fillchar)
func._identifier = "str_pad"
if side not in ('left', 'right', 'both'):
raise NotImplementedError
inputs = expr.input, Scalar(width), Scalar(fillchar), Scalar(side)
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, Slice):
start, end, step = expr.start, expr.end, expr.step
if end is None and step is None:
raise NotImplementedError
if isinstance(start, six.integer_types) and \
isinstance(end, six.integer_types) and step is None:
if start >= 0 and end >= 0:
raise NotImplementedError
flag = 0x4 if start is not None else 0
flag |= 0x2 if end is not None else 0
flag |= 0x1 if step is not None else 0
def func(x, flag, *args):
if x is None:
return None
idx = 0
s, e, t = None, None, None
for i in range(3):
if i == 0 and (flag & 0x4):
s = args[idx]
idx += 1
if i == 1 and (flag & 0x2):
e = args[idx]
idx += 1
if i == 2 and (flag & 0x1):
t = args[idx]
idx += 1
return x[s: e: t]
func._identifier = "str_slice"
inputs = expr.input, Scalar(flag), expr._start, expr._end, expr._step
sub = self._gen_mapped_expr(expr, tuple(i for i in inputs if i is not None),
func, expr.name, multiple=False)
self._sub(expr, sub)
return
elif isinstance(expr, Swapcase):
func = lambda x: x.swapcase() if x is not None else None
func._identifier = "str_swapcase"
elif isinstance(expr, Title):
func = lambda x: x.title() if x is not None else None
func._identifier = "str_title"
elif isinstance(expr, Strptime):
def func(x, date_fmt):
from datetime import datetime
return datetime.strptime(x, date_fmt) if x is not None else None
func._identifier = "strptime"
date_fmt = expr._date_format \
if not isinstance(expr._date_format, StringScalar) or expr._date_format._value is None \
else Scalar(to_text(expr.date_format).replace("%", "%%"))
inputs = expr.input, date_fmt
sub = self._gen_mapped_expr(expr, inputs, func, expr.name, multiple=False)
self._sub(expr, sub)
return
else:
if isinstance(expr, Isalnum):
func = lambda x: x.isalnum() if x is not None else None
func._identifier = "str_isalnum"
elif isinstance(expr, Isalpha):
func = lambda x: x.isalpha() if x is not None else None
func._identifier = "str_isalpha"
elif isinstance(expr, Isdigit):
func = lambda x: x.isdigit() if x is not None else None
func._identifier = "str_isdigit"
elif isinstance(expr, Isspace):
func = lambda x: x.isspace() if x is not None else None
func._identifier = "str_isspace"
elif isinstance(expr, Islower):
func = lambda x: x.islower() if x is not None else None
func._identifier = "str_islower"
elif isinstance(expr, Isupper):
func = lambda x: x.isupper() if x is not None else None
func._identifier = "str_isupper"
elif isinstance(expr, Istitle):
func = lambda x: x.istitle() if x is not None else None
func._identifier = "str_istitle"
elif isinstance(expr, (Isnumeric, Isdecimal)):
def u_safe(s):
try:
return unicode(s, "unicode_escape")
except:
return s
if isinstance(expr, Isnumeric):
func = lambda x: u_safe(x).isnumeric() if x is not None else None
func._identifier = "str_isnumeric"
else:
func = lambda x: u_safe(x).isdecimal() if x is not None else None
func._identifier = "str_isdecimal"
if func is not None:
sub = expr.input.map(func, expr.dtype)
self._sub(expr, sub)
return
raise NotImplementedError
def visit_rank_window(self, expr):
if isinstance(expr, QCut):
self._sub(expr, expr - 1)
return
raise NotImplementedError
def visit_reduction(self, expr):
expr_input = expr.input
if getattr(expr, '_unique', False):
expr_input = expr_input.unique()
if isinstance(expr, (Var, GroupedVar)):
std = expr_input.std(ddof=expr._ddof)
if isinstance(expr, GroupedVar):
std = std.to_grouped_reduction(expr._grouped)
sub = (std ** 2).rename(expr.name)
self._sub(expr, sub)
return
elif isinstance(expr, (Moment, GroupedMoment)):
order = expr._order
center = expr._center
sub = self._get_moment_sub_expr(expr, expr_input, order, center)
sub = sub.rename(expr.name)
self._sub(expr, sub)
return
elif isinstance(expr, (Skewness, GroupedSkewness)):
std = expr_input.std(ddof=1)
if isinstance(expr, GroupedSequenceReduction):
std = std.to_grouped_reduction(expr._grouped)
cnt = expr_input.count()
if isinstance(expr, GroupedSequenceReduction):
cnt = cnt.to_grouped_reduction(expr._grouped)
sub = self._get_moment_sub_expr(expr, expr_input, 3, True) / (std ** 3)
sub *= (cnt ** 2) / (cnt - 1) / (cnt - 2)
sub = sub.rename(expr.name)
self._sub(expr, sub)
elif isinstance(expr, (Kurtosis, GroupedKurtosis)):
std = expr_input.std(ddof=0)
if isinstance(expr, GroupedSequenceReduction):
std = std.to_grouped_reduction(expr._grouped)
cnt = expr_input.count()
if isinstance(expr, GroupedSequenceReduction):
cnt = cnt.to_grouped_reduction(expr._grouped)
m4 = self._get_moment_sub_expr(expr, expr_input, 4, True)
sub = 1.0 / (cnt - 2) / (cnt - 3) * ((cnt * cnt - 1) * m4 / (std ** 4) - 3 * (cnt - 1) ** 2)
sub = sub.rename(expr.name)
self._sub(expr, sub)
raise NotImplementedError