odps/df/expr/core.py (448 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import itertools
import weakref
from collections import deque, defaultdict, OrderedDict
from ...compat import six, Iterable
from ...dag import DAG, Queue
from . import utils
class NodeMetaclass(type):
def __new__(mcs, name, bases, kv):
mixin_slots = []
for b in bases:
mixin_slots.extend(list(getattr(b, '_mixin_slots', [])))
if kv.pop('_add_args_slots', True):
slots = mixin_slots + list(kv.get('__slots__', [])) + list(kv.get('_mixin_slots', [])) + \
list(kv.get('_slots', []))
args = kv.get('_args', [])
slots.extend(args)
slots = OrderedDict.fromkeys(slots)
kv['__slots__'] = tuple(slot for slot in slots if slot not in kv)
if '__slots__' not in kv:
kv['__slots__'] = ()
if mixin_slots:
kv['_mixin_slots'] = []
return type.__new__(mcs, name, bases, kv)
def __instancecheck__(self, instance):
from .expressions import CollectionExpr
if issubclass(type(instance), CollectionExpr) and \
hasattr(instance, '_proxy') and instance._proxy is not None:
return super(NodeMetaclass, self).__instancecheck__(instance._proxy)
return super(NodeMetaclass, self).__instancecheck__(instance)
class Node(six.with_metaclass(NodeMetaclass)):
__slots__ = '_args_indexes', '__weakref__'
_args = () # the child(ren) of the current node
_extra_args = ()
_non_table = False
def __init__(self, *args, **kwargs):
self._args_indexes = dict((name, i) for i, name in enumerate(self._args))
self._init(*args, **kwargs)
def _init(self, *args, **kwargs):
for arg, value in zip(self._args, args):
setattr(self, arg, value)
for key, value in six.iteritems(kwargs):
setattr(self, key, value)
@property
def _node_id(self):
return id(self)
def _init_attr(self, attr, val):
if not hasattr(self, attr):
setattr(self, attr, val)
@property
def args(self):
return tuple(getattr(self, arg, None) for arg in self._args)
def iter_args(self):
for name, arg in zip(self._args, self.args):
yield name, arg
@property
def extra_args(self):
return tuple(getattr(self, extra_arg, None)
for extra_arg in self._extra_args)
def arg_name_values(self, extra=False):
arg_names = self._args if not extra else self._args + self._extra_args
arg_values = self.args if not extra else self.args + self.extra_args
for arg_name, arg in zip(arg_names, arg_values):
yield arg_name, arg
def _data_source(self):
yield None
def data_source(self):
stop_cond = lambda x: hasattr(x, '_source_data') and x._source_data is not None
for n in self.traverse(top_down=True, unique=True, stop_cond=stop_cond):
for ds in n._data_source():
if ds is not None:
yield ds
def substitute(self, old_arg, new_arg, dag=None):
if dag is not None:
dag.substitute(old_arg, new_arg, parents=[self])
return
if hasattr(old_arg, '_name') and old_arg._name is not None and \
new_arg._name is None:
new_arg._name = old_arg._name
for arg_name, arg in self.arg_name_values(extra=True):
if not isinstance(arg, (list, tuple)):
if arg is old_arg:
setattr(self, arg_name, new_arg)
else:
subs = list(arg)
for i in range(len(subs)):
if subs[i] is old_arg:
subs[i] = new_arg
setattr(self, arg_name, type(arg)(subs))
def children(self, extra=False):
args = []
all_args = self.args if not extra else self.args + self.extra_args
for arg in all_args:
if isinstance(arg, (list, tuple)):
args.extend(arg)
else:
args.append(arg)
return [arg for arg in args if arg is not None]
def leaves(self):
for n in self.traverse(unique=True):
if len(n.children()) == 0:
yield n
def traverse(self, top_down=False, unique=False, traversed=None,
extra=False, stop_cond=None):
traversed = traversed if traversed is not None else set()
def is_trav(n):
if not unique:
return False
if n._node_id in traversed:
return True
traversed.add(n._node_id)
return False
q = deque()
q.append(self)
if is_trav(self):
return
checked = set()
yields = set()
while len(q) > 0:
curr = q.popleft()
if top_down:
yield curr
if stop_cond is None or not stop_cond(curr):
children = [c for c in curr.children(extra=extra)
if not is_trav(c)]
q.extendleft(children[::-1])
else:
if curr._node_id not in checked:
children = curr.children(extra=extra)
if len(children) == 0:
yield curr
else:
q.appendleft(curr)
if stop_cond is None or not stop_cond(curr):
q.extendleft([c for c in children
if not is_trav(c) or c._node_id not in checked][::-1])
checked.add(curr._node_id)
else:
if curr._node_id not in yields:
yield curr
if unique:
yields.add(curr._node_id)
def __eq__(self, other):
return self.equals(other)
def equals(self, other):
if other is None:
return False
if not isinstance(other, type(self)):
return False
def slot_values(obj):
return [getattr(obj, slot, None) for slot in utils.get_attrs(obj)]
def cmp(x, y):
if isinstance(x, Node):
res = x.equals(y)
elif isinstance(y, Node):
res = y.equals(y)
elif isinstance(x, (tuple, list)) and \
isinstance(y, (tuple, list)):
res = all(map(cmp, x, y))
else:
res = x == y
return res
return all(map(cmp, slot_values(self), slot_values(other)))
def __hash__(self):
return hash((type(self), tuple(self.children())))
def is_ancestor(self, other):
other = utils.get_proxied_expr(other)
for n in self.traverse(top_down=True, unique=True):
if n is other:
return True
return False
def path(self, other, strict=False):
all_apaths = self.all_path(other, strict=strict)
try:
return next(all_apaths)
except StopIteration:
return
def _all_path(self, other):
if self is other:
yield [self, ]
node_poses = defaultdict(lambda: 0)
q = deque()
q.append(self)
while len(q) > 0:
curr = q[-1]
children = curr.children()
pos = node_poses[curr._node_id]
if len(children) == 0 or pos >= len(children):
q.pop()
# TODO: add this in the future
# if pos >= len(children):
# node_poses[curr._node_id] = 0
continue
n = children[pos]
q.append(n)
if n is other:
yield list(q)
q.pop()
node_poses[curr._node_id] += 1
def all_path(self, other, strict=False):
# remember, if the node has been changed into another one during traversing
# the modification may not be applied to the paths
i = 0
for i, path in zip(itertools.count(1), self._all_path(other)):
yield path
if i == 0 and not strict:
for path in other._all_path(self):
yield path
def _attr_dict(self):
slots = utils.get_attrs(self)
return dict((attr, getattr(self, attr, None)) for attr in slots)
def _copy_type(self):
return type(self)
def copy(self, clear_keys=None, **kw):
proxied = utils.get_proxied_expr(self)
attr_dict = proxied._attr_dict()
if clear_keys is not None:
for k in clear_keys:
attr_dict.pop(k, None)
attr_dict.update(kw)
copied = type(proxied)(**attr_dict)
return copied
def copy_to(self, target):
slots = utils.get_attrs(self)
for attr in slots:
if hasattr(self, attr):
setattr(target, attr, getattr(self, attr, None))
def copy_tree(self, on_copy=None, extra=False, stop_cond=None):
if on_copy is not None and not isinstance(on_copy, Iterable):
on_copy = [on_copy, ]
expr_id_to_copied = dict()
def get(n):
if n is None:
return n
try:
return expr_id_to_copied[n._node_id]
except KeyError:
if stop_cond is not None:
return n
raise
for node in self.traverse(unique=True, extra=extra, stop_cond=stop_cond):
node = utils.get_proxied_expr(node)
attr_dict = node._attr_dict()
for arg_name, arg in node.arg_name_values(extra=extra):
if isinstance(arg, (tuple, list)):
attr_dict[arg_name] = type(arg)(get(it) for it in arg)
else:
attr_dict[arg_name] = get(arg)
copied_node = type(node)(**attr_dict)
expr_id_to_copied[node._node_id] = copied_node
if on_copy is not None:
[func(node, copied_node) for func in on_copy]
return expr_id_to_copied[self._node_id]
def to_dag(self, copy=True, on_copy=None, dag=None, validate=True, stop_cond=None):
if copy:
expr = self.copy_tree(on_copy=on_copy, extra=True)
else:
expr = self
dag = dag or ExprDAG(expr)
queue = Queue()
if not dag.contains_node(expr):
dag.add_node(expr)
queue.put(expr)
traversed = set()
traversed.add(expr._node_id)
while not queue.empty():
node = queue.get()
for child in node.children(extra=True):
if not dag.contains_node(child):
dag.add_node(child)
if not dag.contains_edge(child, node):
dag.add_edge(child, node, validate=False)
if child._node_id in traversed:
continue
traversed.add(child._node_id)
if stop_cond is None or not stop_cond(child):
queue.put(child)
if validate:
dag._validate() # validate the DAG
return dag
def __getstate__(self):
slots = utils.get_attrs(self)
return tuple((slot, object.__getattribute__(self, slot)) for slot in slots
if not slot.startswith('__'))
def __setstate__(self, state):
self.__init__(**dict(state))
def _extract_df_inputs(o):
if isinstance(o, Node):
yield o
elif isinstance(o, dict):
for v in itertools.chain(*(_extract_df_inputs(dv) for dv in six.itervalues(o))):
if v is not None:
yield v
elif isinstance(o, (list, set, tuple)):
for v in itertools.chain(*(_extract_df_inputs(dv) for dv in o)):
if v is not None:
yield v
else:
yield None
class ExprProxy(object):
def __init__(self, expr, d=None, compare=False):
if d is not None:
def callback(_):
if self in d:
del d[self]
else:
callback = None
self._ref = weakref.ref(expr, callback)
self._cmp = compare
self._hash = hash(expr)
self._expr_id = expr._node_id
def __call__(self):
return self._ref()
def __hash__(self):
return self._hash
def __eq__(self, other):
if isinstance(other, ExprProxy):
if self._ref() is not None and other() is not None:
return self._ref() is other()
return self._expr_id == other._expr_id
obj = self._ref()
if obj is not None and self._cmp:
return obj.equals(other)
return self._expr_id == other._node_id
class ExprDAG(DAG):
def __init__(self, root, dag=None):
self._root = weakref.ref(root)
super(ExprDAG, self).__init__()
if dag is not None:
self._graph = dag._graph
self._map = dag._map
@property
def root(self):
return self._root()
@root.setter
def root(self, root):
self._root = weakref.ref(root)
def ensure_all_nodes_in_dag(self):
self.root.to_dag(copy=False, dag=self, validate=False)
def traverse(self, root=None, top_down=False, traversed=None, stop_cond=None):
root = root or self.root
return root.traverse(top_down=top_down, unique=True,
traversed=traversed, stop_cond=stop_cond)
def substitute(self, expr, new_expr, parents=None):
if expr is self.root:
self.root = new_expr
parents = self.successors(expr) if parents is None else parents
if expr._need_cache:
new_expr.cache()
q = Queue()
q.put(new_expr)
while not q.empty():
node = q.get()
if not self.contains_node(node):
self.add_node(node)
for child in node.children():
if not self.contains_node(child):
q.put(child)
self.add_node(child)
if not self.contains_edge(child, node):
self.add_edge(child, node, validate=False)
for parent in parents:
parent.substitute(expr, new_expr)
self.add_edge(new_expr, parent, validate=False)
try:
self.remove_edge(expr, parent)
except KeyError:
pass
def prune(self):
while True:
nodes = [n for n, succ in six.iteritems(self._graph) if len(succ) == 0]
if len(nodes) == 1 and nodes[0] is self.root:
break
for node in nodes:
if node is not self.root:
self.remove_node(node)
def closest_ancestors(self, node, cond):
collected = set()
stop_cond = lambda n: not collected.intersection(self.successors(n))
for n in self.bfs(node, self.predecessors, stop_cond):
if cond(n):
collected.add(n)
yield n
class ExprDictionary(dict):
def _ref(self, obj, ref_self=False):
r = self if ref_self else None
return obj if isinstance(obj, ExprProxy) else ExprProxy(obj, d=r)
def __getitem__(self, item):
if item is None:
raise KeyError
return dict.__getitem__(self, self._ref(item))
def __setitem__(self, key, value):
if key is None:
raise KeyError
return dict.__setitem__(self, self._ref(key, True), value)
def __iter__(self):
for k in dict.__iter__(self):
yield k()
def __delitem__(self, key):
if key is None:
raise KeyError
return dict.__delitem__(self, self._ref(key))
def __contains__(self, item):
if item is None:
return False
return dict.__contains__(self, self._ref(item))
def get(self, k, d=None):
if k is None:
return d
return dict.get(self, self._ref(k), d)
def has_key(self, k):
if k is None:
return False
return dict.has_key(self, self._ref(k))
def pop(self, k, d=None):
if k is None:
return False
return dict.pop(self, self._ref(k), d)
def popitem(self):
k, v = dict.popitem(self)
return k(), v
def setdefault(self, k, d=None):
if k is None:
raise KeyError
return dict.setdefault(self, self._ref(k, True), d)
def update(self, E=None, **F):
if hasattr(E, 'keys'):
for k in E.keys():
self[k] = E[k]
elif E is not None:
for k, v in E:
self[k] = v
else:
for k, v in six.iteritems(F):
self[k] = v
if six.PY2:
def items(self):
return [(k(), v) for k, v in dict.items(self)]
def keys(self):
return [k() for k in dict.keys(self)]
def iteritems(self):
for k, v in dict.iteritems(self):
yield k(), v
def iterkeys(self):
for k in dict.iterkeys(self):
yield k()
else:
def items(self):
for k, v in dict.items(self):
yield k(), v
def keys(self):
for k in dict.keys(self):
yield k()