odps/df/expr/merge.py (740 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from ...compat import six, reduce
from ...models import TableSchema
from ...utils import to_lower_str
from ..utils import to_collection
from .arithmetic import Equal
from .core import Node, ExprDictionary
from .expressions import CollectionExpr, TypedExpr, ProjectCollectionExpr, \
Column, BooleanSequenceExpr, SequenceExpr, Expr, Scalar, repr_obj, CallableColumn
from .errors import ExpressionError
class JoinCollectionExpr(CollectionExpr):
__slots__ = '_how', '_left_suffix', '_right_suffix', '_column_origins', \
'_renamed_columns', '_column_conflict', '_mapjoin', '_skewjoin', \
'_skewjoin_values'
_args = '_lhs', '_rhs', '_predicate'
def _init(self, *args, **kwargs):
self._init_attr('_left_suffix', None)
self._init_attr('_right_suffix', None)
self._init_attr('_column_origins', dict())
self._init_attr('_renamed_columns', dict())
self._init_attr('_column_conflict', False)
super(JoinCollectionExpr, self)._init(*args, **kwargs)
if not isinstance(self._lhs, CollectionExpr):
raise TypeError(
'Can only join collection expressions, got %s for left expr.' % type(self._lhs)
)
if not isinstance(self._rhs, CollectionExpr):
raise TypeError(
'Can only join collection expressions, got %s for right expr.' % type(self._rhs)
)
if self._rhs is self._lhs:
self._rhs = self._rhs.view()
if isinstance(self._lhs, JoinCollectionExpr) and \
(self._rhs is self._lhs._lhs or self._rhs is self._lhs._rhs):
self._rhs = self._rhs.view()
if self._left_suffix is None and self._right_suffix is None:
overlaps = set(self._lhs.schema.names).intersection(self._rhs.schema.names)
if len(overlaps) > 0:
raise ValueError(
'Column conflict exists in join, overlap columns: %s' % ','.join(overlaps)
)
self._set_schema()
self._validate_predicates(self._predicate)
self._how = self._how.upper()
def _defunc(self, field):
if inspect.isfunction(field):
if six.get_function_code(field).co_argcount == 1:
return field(self)
else:
return field(self._lhs, self._rhs)
return field
def _get_child(self, expr):
while isinstance(expr, JoinProjectCollectionExpr):
expr = expr.input
return expr
def _is_column(self, col, expr):
return col.input is expr or col.input is self._get_child(expr)
def _get_fields(self, fields, ret_raw_fields=False):
selects = []
raw_selects = []
for field in fields:
field = self._defunc(field)
if isinstance(field, CollectionExpr):
if any(c is self for c in field.children()) or \
any(c is self._get_child(self._lhs) for c in field.children()) or \
any(c is self._get_child(self._rhs) for c in field.children()):
selects.extend(self._get_fields(field._project_fields))
elif field is self:
selects.extend(self._get_fields(self._schema.names))
elif field is self._get_child(self._lhs):
fields = [self._renamed_columns.get(n, [n])[0]
for n in field.schema.names]
selects.extend(self._get_fields(fields))
elif field is self._get_child(self._rhs):
fields = [self._renamed_columns.get(n, [None, n])[1]
for n in field.schema.names]
selects.extend(self._get_fields(fields))
else:
selects.extend(self._get_fields(field._fetch_fields()))
raw_selects.append(field)
else:
select = self._get_field(field)
selects.append(select)
raw_selects.append(select)
if ret_raw_fields:
return selects, raw_selects
return selects
def _get_field(self, field):
field = self._defunc(field)
if isinstance(field, six.string_types):
if field not in self._schema:
raise ValueError('Field(%s) does not exist' % field)
cls = Column
if callable(getattr(type(self), field, None)):
cls = CallableColumn
return cls(self, _name=field, _data_type=self._schema[field].type)
root = field
has_path = False
for expr in root.traverse(top_down=True, unique=True,
stop_cond=lambda x: isinstance(x, Column) or x is self):
if isinstance(expr, Column):
if self._is_column(expr, self):
has_path = True
continue
if self._is_column(expr, self._lhs):
has_path = True
idx = 0
elif self._is_column(expr, self._rhs):
has_path = True
idx = 1
elif expr.input._id == self._get_child(self._rhs)._id:
# In case that the right collection has been copied
expr.substitute(expr.input, self._get_child(self._rhs))
has_path = True
idx = 1
elif isinstance(self._lhs, JoinCollectionExpr):
try:
expr = self._lhs._get_field(expr)
except ExpressionError:
continue
has_path = True
idx = 0
elif isinstance(self._rhs, JoinCollectionExpr):
try:
expr = self._rhs._get_field(expr)
except ExpressionError:
continue
has_path = True
idx = 1
else:
continue
name = expr.source_name
if name in self._renamed_columns:
name = self._renamed_columns[name][idx]
to_sub = self._get_field(name)
if expr.is_renamed():
to_sub = to_sub.rename(expr.name)
to_sub.copy_to(expr)
if isinstance(field, SequenceExpr) and not has_path:
raise ExpressionError('field must come from Join collection '
'or its left and right child collection: %s'
% repr_obj(field))
return root
def origin_collection(self, column_name):
idx, name = self._column_origins[column_name]
return [self._lhs, self._rhs][idx], name
@property
def node_name(self):
return self.__class__.__name__
def iter_args(self):
for it in zip(['collection(left)', 'collection(right)', 'on'], self.args):
yield it
@property
def column_conflict(self):
return self._column_conflict
def accept(self, visitor):
visitor.visit_join(self)
def _get_non_suffixes_fields(self):
return set()
def _get_predicate_fields(self):
predicate_fields = set()
if not self._predicate:
return predicate_fields
for p in self._predicate:
if isinstance(p, six.string_types):
predicate_fields.add(p)
elif isinstance(p, (tuple, Equal)):
if isinstance(p, Equal):
p = p.lhs, p.rhs
if isinstance(p, tuple) and len(p) != 2:
continue
left_name = None
if isinstance(p[0], six.string_types):
left_name = p[0]
elif isinstance(p[0], Column) and not p[0].is_renamed():
left_name = p[0].name
if left_name is None:
continue
right_name = None
if isinstance(p[1], six.string_types):
right_name = p[1]
elif isinstance(p[1], Column) and not p[1].is_renamed():
right_name = p[1].name
if left_name == right_name:
predicate_fields.add(left_name)
return predicate_fields
def _set_schema(self):
names, typos = [], []
non_suffixes_fields = self._get_non_suffixes_fields()
for col in self._lhs.schema.columns:
name = col.name
if to_lower_str(col.name) in self._rhs.schema._name_indexes:
self._column_conflict = True
if to_lower_str(col.name) in self._rhs.schema._name_indexes and \
col.name not in non_suffixes_fields:
name = '%s%s' % (col.name, self._left_suffix)
self._renamed_columns[col.name] = (name,)
names.append(name)
typos.append(col.type)
self._column_origins[name] = 0, col.name
for col in self._rhs.schema.columns:
name = col.name
if to_lower_str(col.name) in self._lhs.schema._name_indexes:
self._column_conflict = True
if to_lower_str(col.name) in self._lhs.schema._name_indexes and \
col.name not in non_suffixes_fields:
name = '%s%s' % (col.name, self._right_suffix)
self._renamed_columns[col.name] = \
self._renamed_columns[col.name][0], name
if name in non_suffixes_fields:
continue
names.append(name)
typos.append(col.type)
self._column_origins[name] = 1, col.name
if issubclass(type(self._lhs.schema), TableSchema):
schema_type = type(self._lhs.schema)
else:
schema_type = type(self._rhs.schema)
self._schema = schema_type.from_lists(names, typos)
def _validate_equal(self, equal_expr):
# FIXME: sometimes may be wrong, e.g. t3 = t1.join(t2, 'name')); t4.join(t3, t4.id == t3.id)
return (equal_expr.lhs.is_ancestor(self._get_child(self._lhs)) and
equal_expr.rhs.is_ancestor(self._get_child(self._rhs))) or \
(equal_expr.lhs.is_ancestor(self._get_child(self._rhs)) and
equal_expr.rhs.is_ancestor(self._get_child(self._lhs)))
def _reverse_equal(self, equal_expr):
if equal_expr.lhs.is_ancestor(self._get_child(self._rhs)) and \
equal_expr.rhs.is_ancestor(self._get_child(self._lhs)):
# the equal's left side is on the right collection and vise versa
equal_expr._rhs, equal_expr._lhs = equal_expr._lhs, equal_expr._rhs
def _validate_predicates(self, predicates):
if predicates is None:
return
is_validate = False
subs = []
if self._mapjoin:
is_validate = True
for p in predicates:
if (isinstance(p, tuple) and len(p) == 2) or isinstance(p, six.string_types):
if isinstance(p, six.string_types):
left_name, right_name = p, p
else:
left_name, right_name = p
left_col = self._lhs._get_field(left_name)
right_col = self._rhs._get_field(right_name)
if not left_col.is_ancestor(self._lhs) or not right_col.is_ancestor(self._rhs):
raise ExpressionError('Invalid predicate: {0!s}'.format(repr_obj(p)))
subs.append(left_col == right_col)
is_validate = True
elif isinstance(p, BooleanSequenceExpr):
if not is_validate:
it = (expr for expr in p.traverse(top_down=True, unique=True)
if isinstance(expr, Equal))
while not is_validate:
try:
equal_expr = next(it)
validate = self._validate_equal(equal_expr)
if validate:
is_validate = True
break
except StopIteration:
break
else:
if not is_validate:
raise ExpressionError('Invalid predicate: {0!s}'.format(repr_obj(p)))
if not is_validate:
raise ExpressionError('Invalid predicate: no validate predicate assigned')
for p in predicates:
if isinstance(p, BooleanSequenceExpr):
subs.append(p)
if len(subs) == 0:
self._predicate = None
else:
self._predicate = subs
def _merge_joined_fields(self, merge_columns):
if not merge_columns:
return self
predicate_fields = self._get_predicate_fields()
if not predicate_fields:
raise ValueError('No fields in predicate. Cannot merge columns.')
src_map = self._column_origins
rename_map = dict()
for name, src in six.iteritems(src_map):
src_id, src_name = src
if src_name not in predicate_fields:
continue
if src_name not in rename_map:
rename_map[src_name] = [None, None]
rename_map[src_name][src_id] = name
if merge_columns in ('auto', 'left', 'right') or (isinstance(merge_columns, bool) and merge_columns):
merge_columns = dict((k, merge_columns) for k in six.iterkeys(rename_map))
if isinstance(merge_columns, six.string_types):
merge_columns = {merge_columns: 'auto'}
if isinstance(merge_columns, list):
merge_columns = dict((k, 'auto') for k in merge_columns)
excludes = set()
for col, action in six.iteritems(merge_columns):
if col not in rename_map:
raise ValueError('Column {0} not exists in join predicate.'.format(col))
if isinstance(action, bool) and action:
merge_columns[col] = 'auto'
else:
merge_columns[col] = action.lower()
excludes.update(rename_map[col])
selects = []
merged = set()
for col in self.schema.names:
if col not in excludes:
selects.append(self[col])
else:
src_name = src_map[col][1]
if src_name in merged:
continue
merged.add(src_name)
left_name, right_name = rename_map[src_name]
left_col = self[left_name]
right_col = self[right_name]
if merge_columns[src_name] == 'auto':
selects.append(left_col.isnull().ifelse(right_col, left_col).rename(src_name))
elif merge_columns[src_name] == 'left':
selects.append(left_col.rename(src_name))
elif merge_columns[src_name] == 'right':
selects.append(right_col.rename(src_name))
selected = self.select(*selects)
return JoinFieldMergedCollectionExpr(_input=self, _fields=selected._fields,
_schema=selected._schema, _rename_map=rename_map)
class InnerJoin(JoinCollectionExpr):
def _init(self, *args, **kwargs):
self._how = 'INNER'
super(InnerJoin, self)._init(*args, **kwargs)
def _get_non_suffixes_fields(self):
return self._get_predicate_fields()
class LeftJoin(JoinCollectionExpr):
def _init(self, *args, **kwargs):
self._how = 'LEFT OUTER'
super(LeftJoin, self)._init(*args, **kwargs)
class RightJoin(JoinCollectionExpr):
def _init(self, *args, **kwargs):
self._how = 'RIGHT OUTER'
super(RightJoin, self)._init(*args, **kwargs)
class OuterJoin(JoinCollectionExpr):
def _init(self, *args, **kwargs):
self._how = 'FULL OUTER'
super(OuterJoin, self)._init(*args, **kwargs)
class JoinFieldMergedCollectionExpr(ProjectCollectionExpr):
__slots__ = '_rename_map',
def _get_fields(self, fields, ret_raw_fields=False):
selects = []
raw_selects = []
joined_expr = self._input
for field in fields:
field = self._defunc(field)
if isinstance(field, CollectionExpr):
if any(c is self for c in field.children()) or \
any(c is joined_expr._lhs for c in field.children()) or \
any(c is joined_expr._rhs for c in field.children()):
selects.extend(self._get_fields(field._project_fields))
elif field is self:
selects.extend(self._get_fields(self._schema.names))
elif field is joined_expr._lhs:
fields = [joined_expr._renamed_columns.get(n, [n])[0]
if n not in self._rename_map else n
for n in field.schema.names]
selects.extend(self._get_fields(fields))
elif field is joined_expr._rhs:
fields = [joined_expr._renamed_columns.get(n, [None, n])[1]
if n not in self._rename_map else n
for n in field.schema.names]
selects.extend(self._get_fields(fields))
else:
selects.extend(super(JoinFieldMergedCollectionExpr, self)._get_fields(field))
raw_selects.append(field)
else:
select = self._get_field(field)
selects.append(select)
raw_selects.append(select)
if ret_raw_fields:
return selects, raw_selects
return selects
def _get_field(self, field):
field = self._defunc(field)
if isinstance(field, six.string_types):
if field not in self._schema:
raise ValueError('Field(%s) does not exist' % field)
cls = Column
if callable(getattr(type(self), field, None)):
cls = CallableColumn
return cls(self, _name=field, _data_type=self._schema[field].type)
joined_expr = self._input
root = field
has_path = False
for expr in root.traverse(top_down=True, unique=True,
stop_cond=lambda x: isinstance(x, Column) or x is self):
if isinstance(expr, Column):
if expr.input is self or expr.input is joined_expr:
has_path = True
continue
if expr.input is joined_expr._lhs:
has_path = True
idx = 0
elif expr.input is joined_expr._rhs:
has_path = True
idx = 1
elif isinstance(joined_expr._lhs, JoinCollectionExpr):
try:
expr = joined_expr._lhs._get_field(expr)
except ExpressionError:
continue
has_path = True
idx = 0
elif isinstance(joined_expr._rhs, JoinCollectionExpr):
try:
expr = joined_expr._rhs._get_field(expr)
except ExpressionError:
continue
has_path = True
idx = 1
else:
continue
name = expr.source_name
if name not in self._rename_map and name in joined_expr._renamed_columns:
name = joined_expr._renamed_columns[name][idx]
to_sub = self._get_field(name)
if expr.is_renamed():
to_sub = to_sub.rename(expr.name)
to_sub.copy_to(expr)
if isinstance(field, SequenceExpr) and not has_path:
raise ExpressionError('field must come from Join collection '
'or its left and right child collection: %s'
% repr_obj(field))
return root
class JoinProjectCollectionExpr(ProjectCollectionExpr):
"""
Only for analyzer, project join should generate normal `ProjectCollectionExpr`.
"""
__slots__ = ()
_join_dict = {
'INNER': InnerJoin,
'LEFT': LeftJoin,
'RIGHT': RightJoin,
'OUTER': OuterJoin
}
def _make_different_sources(left, right, predicate=None):
# TODO: move to analyzer, do it before analyze and optimize
exprs = ExprDictionary()
for n in left.traverse(unique=True):
exprs[n] = True
subs = ExprDictionary()
if getattr(right, '_proxy', None) is not None:
right = right._proxy
dag = right.to_dag(copy=False, validate=False)
for n in dag.traverse():
if n in exprs:
copied = subs.get(n, n.copy())
for p in dag.successors(n):
if p in exprs and p not in subs:
subs[p] = p.copy()
subs.get(p, p).substitute(n, copied)
subs[n] = copied
if predicate and n is right:
for p in predicate:
if not isinstance(p, Expr):
continue
p_dag = p.to_dag(copy=False, validate=False)
for p_n in p_dag.traverse(top_down=True):
if p_n is right:
succs = list(s for s in p_dag.successors(p_n)
if s not in exprs)
p_dag.substitute(right, copied, parents=succs)
return left, subs.get(right, right)
def join(left, right, on=None, how='inner', suffixes=('_x', '_y'), mapjoin=False, skewjoin=False):
"""
Join two collections.
If `on` is not specified, we will find the common fields of the left and right collection.
`suffixes` means that if column names conflict, the suffixes will be added automatically.
For example, both left and right has a field named `col`,
there will be col_x, and col_y in the joined collection.
:param left: left collection
:param right: right collection
:param on: fields to join on
:param how: 'inner', 'left', 'right', or 'outer'
:param suffixes: when name conflict, the suffix will be added to both columns.
:param mapjoin: set use mapjoin or not, default value False.
:param skewjoin: set use of skewjoin or not, default value False. Can specify True if
the collection is skew, or a list specifying columns with skew values, or a list of
dicts specifying skew combinations.
:return: collection
:Example:
>>> df.dtypes.names
['name', 'id']
>>> df2.dtypes.names
['name', 'id1']
>>> df.join(df2)
>>> df.join(df2, on='name')
>>> df.join(df2, on=('id', 'id1'))
>>> df.join(df2, on=['name', ('id', 'id1')])
>>> df.join(df2, on=[df.name == df2.name, df.id == df2.id1])
>>> df.join(df2, mapjoin=False)
>>> df.join(df2, skewjoin=True)
>>> df.join(df2, skewjoin=["c0", "c1"])
>>> df.join(df2, skewjoin=[{"c0": 1, "c1": "2"}, {"c0": 3, "c1": "4"}])
"""
if mapjoin and skewjoin:
raise TypeError("Cannot specify mapjoin and skewjoin at the same time")
if isinstance(left, TypedExpr):
left = to_collection(left)
if isinstance(right, TypedExpr):
right = to_collection(right)
if on is None:
if not mapjoin:
on = [name for name in left.schema.names if to_lower_str(name) in right.schema._name_indexes]
if not on and len(left.schema) == 1 and len(right.schema) == 1:
on = [(left.schema.names[0], right.schema.names[0])]
skewjoin_values = None
if isinstance(skewjoin, (dict, six.string_types)):
skewjoin = [skewjoin]
if isinstance(skewjoin, list):
if (
all(isinstance(c, six.string_types) for c in skewjoin)
and any(c not in right.schema.names for c in skewjoin)
):
raise ValueError(
"All columns specified in `skewjoin` need to exist in the right collection"
)
elif (
all(isinstance(c, dict) for c in skewjoin)
):
cols = sorted(skewjoin[0].keys())
cols_set = set(cols)
if any(c not in right.schema.names for c in cols):
raise ValueError(
"All columns specified in `skewjoin` need to exist in the right collection"
)
if any(cols_set != set(c.keys()) for c in skewjoin):
raise ValueError("All values in `skewjoin` need to have same columns")
skewjoin_values = [[d[c] for c in cols] for d in skewjoin]
skewjoin = cols
elif skewjoin and not isinstance(skewjoin, bool):
raise TypeError("Cannot accept skewjoin type %s" % type(skewjoin))
if isinstance(suffixes, (tuple, list)) and len(suffixes) == 2:
left_suffix, right_suffix = suffixes
else:
raise ValueError('suffixes must be a tuple or list with two elements, got %s' % suffixes)
if not isinstance(on, list):
on = [on, ]
for i in range(len(on)):
it = on[i]
if inspect.isfunction(it):
on[i] = it(left, right)
left, right = _make_different_sources(left, right, on)
try:
return _join_dict[how.upper()](
_lhs=left, _rhs=right, _predicate=on, _left_suffix=left_suffix, _right_suffix=right_suffix,
_mapjoin=mapjoin, _skewjoin=skewjoin, _skewjoin_values=skewjoin_values
)
except KeyError:
return JoinCollectionExpr(
_lhs=left, _rhs=right, _predicate=on, _how=how, _left_suffix=left_suffix,
_right_suffix=right_suffix, _mapjoin=mapjoin, _skewjoin=skewjoin,
_skewjoin_values=skewjoin_values
)
def inner_join(left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, skewjoin=False):
"""
Inner join two collections.
If `on` is not specified, we will find the common fields of the left and right collection.
`suffixes` means that if column names conflict, the suffixes will be added automatically.
For example, both left and right has a field named `col`,
there will be col_x, and col_y in the joined collection.
:param left: left collection
:param right: right collection
:param on: fields to join on
:param suffixes: when name conflict, the suffixes will be added to both columns.
:return: collection
:Example:
>>> df.dtypes.names
['name', 'id']
>>> df2.dtypes.names
['name', 'id1']
>>> df.inner_join(df2)
>>> df.inner_join(df2, on='name')
>>> df.inner_join(df2, on=('id', 'id1'))
>>> df.inner_join(df2, on=['name', ('id', 'id1')])
>>> df.inner_join(df2, on=[df.name == df2.name, df.id == df2.id1])
"""
return join(left, right, on, suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin)
def left_join(
left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False
):
"""
Left join two collections.
If `on` is not specified, we will find the common fields of the left and right collection.
`suffixes` means that if column names conflict, the suffixes will be added automatically.
For example, both left and right has a field named `col`,
there will be col_x, and col_y in the joined collection.
:param left: left collection
:param right: right collection
:param on: fields to join on
:param suffixes: when name conflict, the suffixes will be added to both columns.
:param mapjoin: set use mapjoin or not, default value False.
:param merge_columns: whether to merge columns with the same name into one column without suffix.
If the value is True, columns in the predicate with same names will be merged,
with non-null value. If the value is 'left' or 'right', the values of predicates
on the left / right collection will be taken. You can also pass a dictionary to
describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }.
:return: collection
:Example:
>>> df.dtypes.names
['name', 'id']
>>> df2.dtypes.names
['name', 'id1']
>>> df.left_join(df2)
>>> df.left_join(df2, on='name')
>>> df.left_join(df2, on=('id', 'id1'))
>>> df.left_join(df2, on=['name', ('id', 'id1')])
>>> df.left_join(df2, on=[df.name == df2.name, df.id == df2.id1])
"""
joined = join(left, right, on, how='left', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin)
return joined._merge_joined_fields(merge_columns)
def right_join(
left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False
):
"""
Right join two collections.
If `on` is not specified, we will find the common fields of the left and right collection.
`suffixes` means that if column names conflict, the suffixes will be added automatically.
For example, both left and right has a field named `col`,
there will be col_x, and col_y in the joined collection.
:param left: left collection
:param right: right collection
:param on: fields to join on
:param suffixes: when name conflict, the suffixes will be added to both columns.
:param mapjoin: set use mapjoin or not, default value False.
:param merge_columns: whether to merge columns with the same name into one column without suffix.
If the value is True, columns in the predicate with same names will be merged,
with non-null value. If the value is 'left' or 'right', the values of predicates
on the left / right collection will be taken. You can also pass a dictionary to
describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }.
:return: collection
:Example:
>>> df.dtypes.names
['name', 'id']
>>> df2.dtypes.names
['name', 'id1']
>>> df.right_join(df2)
>>> df.right_join(df2, on='name')
>>> df.right_join(df2, on=('id', 'id1'))
>>> df.right_join(df2, on=['name', ('id', 'id1')])
>>> df.right_join(df2, on=[df.name == df2.name, df.id == df2.id1])
"""
joined = join(left, right, on, how='right', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin)
return joined._merge_joined_fields(merge_columns)
def outer_join(
left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False
):
"""
Outer join two collections.
If `on` is not specified, we will find the common fields of the left and right collection.
`suffixes` means that if column names conflict, the suffixes will be added automatically.
For example, both left and right has a field named `col`,
there will be col_x, and col_y in the joined collection.
:param left: left collection
:param right: right collection
:param on: fields to join on
:param suffixes: when name conflict, the suffixes will be added to both columns.
:param mapjoin: set use mapjoin or not, default value False.
:param merge_columns: whether to merge columns with the same name into one column without suffix.
If the value is True, columns in the predicate with same names will be merged,
with non-null value. If the value is 'left' or 'right', the values of predicates
on the left / right collection will be taken. You can also pass a dictionary to
describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }.
:return: collection
:Example:
>>> df.dtypes.names
['name', 'id']
>>> df2.dtypes.names
['name', 'id1']
>>> df.outer_join(df2)
>>> df.outer_join(df2, on='name')
>>> df.outer_join(df2, on=('id', 'id1'))
>>> df.outer_join(df2, on=['name', ('id', 'id1')])
>>> df.outer_join(df2, on=[df.name == df2.name, df.id == df2.id1])
"""
joined = join(left, right, on, how='outer', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin)
return joined._merge_joined_fields(merge_columns)
CollectionExpr.join = join
CollectionExpr.inner_join = inner_join
CollectionExpr.left_join = left_join
CollectionExpr.right_join = right_join
CollectionExpr.outer_join = outer_join
TypedExpr.join = join
TypedExpr.inner_join = inner_join
TypedExpr.left_join = left_join
TypedExpr.right_join = right_join
TypedExpr.outer_join = outer_join
def _get_sequence_source_collection(expr):
return next(it for it in expr.traverse(top_down=True, unique=True) if isinstance(it, CollectionExpr))
class UnionCollectionExpr(CollectionExpr):
__slots__ = '_distinct',
_args = '_lhs', '_rhs',
node_name = 'Union'
def _init(self, *args, **kwargs):
super(UnionCollectionExpr, self)._init(*args, **kwargs)
self._validate()
self._schema = self._clean_schema()
def _validate_collection_child(self):
if self._lhs.schema.names == self._rhs.schema.names and self._lhs.schema.types == self._rhs.schema.types:
return True
elif set(self._lhs.schema.names) == set(self._rhs.schema.names) and set(self._lhs.schema.types) == set(
self._rhs.schema.types):
self._rhs = self._rhs[self._lhs.schema.names]
return self._lhs.schema.types == self._rhs.schema.types
else:
return False
def _validate(self):
if isinstance(self._lhs, SequenceExpr):
source_collection = _get_sequence_source_collection(self._lhs)
self._lhs = source_collection[[self._lhs]]
if isinstance(self._rhs, SequenceExpr):
source_collection = _get_sequence_source_collection(self._rhs)
self._rhs = source_collection[[self._rhs]]
if isinstance(self._lhs, CollectionExpr) and isinstance(self._rhs, CollectionExpr):
if not self._validate_collection_child():
raise ExpressionError('Table schemas must be equal to form union')
else:
raise ExpressionError('Both inputs should be collections or sequences.')
def _clean_schema(self):
return TableSchema.from_lists(self._lhs.schema.names, self._lhs.schema.types)
def accept(self, visitor):
return visitor.visit_union(self)
def iter_args(self):
for it in zip(['collection(left)', 'collection(right)'], self.args):
yield it
class ConcatCollectionExpr(CollectionExpr):
_args = '_lhs', '_rhs',
node_name = 'Concat'
def _init(self, *args, **kwargs):
super(ConcatCollectionExpr, self)._init(*args, **kwargs)
self._schema = self._clean_schema()
def iter_args(self):
for it in zip(['collection(left)', 'collection(right)'], self.args):
yield it
def _clean_schema(self):
return TableSchema.from_lists(
self._lhs.schema.names + self._rhs.schema.names,
self._lhs.schema.types + self._rhs.schema.types,
)
@staticmethod
def _get_sequence_source_collection(expr):
return next(it for it in expr.traverse(top_down=True, unique=True) if isinstance(it, CollectionExpr))
@classmethod
def validate_input(cls, *inputs):
new_inputs = []
for i in inputs:
if isinstance(i, SequenceExpr):
source_collection = cls._get_sequence_source_collection(i)
new_inputs.append(source_collection[[i]])
else:
new_inputs.append(i)
if any(not isinstance(i, CollectionExpr) for i in new_inputs):
raise ExpressionError('Inputs should be collections or sequences.')
unioned = reduce(lambda a, b: a | b, (set(i.schema.names) for i in inputs))
total_fields = sum((len(i.schema.names) for i in inputs))
if total_fields != len(unioned):
raise ExpressionError('Column names in inputs should not collides with each other.')
def accept(self, visitor):
return visitor.visit_concat(self)
def union(left, right, distinct=False):
"""
Union two collections.
:param left: left collection
:param right: right collection
:param distinct:
:return: collection
:Example:
>>> df['name', 'id'].union(df2['id', 'name'])
"""
left, right = _make_different_sources(left, right)
return UnionCollectionExpr(_lhs=left, _rhs=right, _distinct=distinct)
def __horz_concat(left, rights):
from ..utils import to_collection
left = to_collection(left)
for right in rights:
right = to_collection(right)
left, right = _make_different_sources(left, right)
left = ConcatCollectionExpr(_lhs=left, _rhs=right)
return left
def concat(left, rights, distinct=False, axis=0):
"""
Concat collections.
:param left: left collection
:param rights: right collections, can be a DataFrame object or a list of DataFrames
:param distinct: whether to remove duplicate entries. only available when axis == 0
:param axis: when axis == 0, the DataFrames are merged vertically, otherwise horizontally.
:return: collection
Note that axis==1 can only be used under Pandas DataFrames or XFlow.
:Example:
>>> df['name', 'id'].concat(df2['score'], axis=1)
"""
from ..utils import to_collection
if isinstance(rights, Node):
rights = [rights, ]
if not rights:
raise ValueError('At least one DataFrame should be provided.')
if axis == 0:
for right in rights:
left = union(left, right, distinct=distinct)
return left
else:
rights = [to_collection(r) for r in rights]
ConcatCollectionExpr.validate_input(left, *rights)
if hasattr(left, '_xflow_concat'):
return left._xflow_concat(rights)
else:
return __horz_concat(left, rights)
def _drop(expr, data, axis=0, columns=None):
"""
Drop data from a DataFrame.
:param expr: collection to drop data from
:param data: data to be removed
:param axis: 0 for deleting rows, 1 for columns.
:param columns: columns of data to select, only useful when axis == 0
:return: collection
:Example:
>>> import pandas as pd
>>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]}))
>>> df2 = DataFrame(pd.DataFrame({'a': [2, 3], 'b': [5, 7]}))
>>> df1.drop(df2)
a b c
0 1 4 7
1 3 6 9
>>> df1.drop(df2, columns='a')
a b c
0 1 4 7
>>> df1.drop(['a'], axis=1)
b c
0 4 7
1 5 8
2 6 9
>>> df1.drop(df2, axis=1)
c
0 7
1 8
2 9
"""
from ..utils import to_collection
expr = to_collection(expr)
if axis == 0:
if not isinstance(data, (CollectionExpr, SequenceExpr)):
raise ExpressionError('data should be a collection or sequence when axis == 1.')
data = to_collection(data)
if columns is None:
columns = [n for n in data.schema.names]
if isinstance(columns, six.string_types):
columns = [columns, ]
data = data.select(*columns).distinct()
drop_predicates = [data[n].isnull() for n in data.schema.names]
return expr.left_join(data, on=columns, suffixes=('', '_dp')).filter(*drop_predicates) \
.select(*expr.schema.names)
else:
if isinstance(data, (CollectionExpr, SequenceExpr)):
data = to_collection(data).schema.names
return expr.exclude(data)
def setdiff(left, *rights, **kwargs):
"""
Exclude data from a collection, like `except` clause in SQL. All collections involved should
have same schema.
:param left: collection to drop data from
:param rights: collection or list of collections
:param distinct: whether to preserve duplicate entries
:return: collection
:Examples:
>>> import pandas as pd
>>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3, 3, 3], 'b': [1, 2, 3, 3, 3]}))
>>> df2 = DataFrame(pd.DataFrame({'a': [1, 3], 'b': [1, 3]}))
>>> df1.setdiff(df2)
a b
0 2 2
1 3 3
2 3 3
>>> df1.setdiff(df2, distinct=True)
a b
0 2 2
"""
import time
from ..utils import output
distinct = kwargs.get('distinct', False)
if isinstance(rights[0], list):
rights = rights[0]
cols = [n for n in left.schema.names]
types = [n for n in left.schema.types]
counter_col_name = 'exc_counter_%d' % int(time.time())
left = left[left, Scalar(1).rename(counter_col_name)]
rights = [r[r, Scalar(-1).rename(counter_col_name)] for r in rights]
unioned = left
for r in rights:
unioned = unioned.union(r)
if distinct:
aggregated = unioned.groupby(*cols).agg(**{counter_col_name: unioned[counter_col_name].min()})
return aggregated.filter(aggregated[counter_col_name] == 1).select(*cols)
else:
aggregated = unioned.groupby(*cols).agg(**{counter_col_name: unioned[counter_col_name].sum()})
@output(cols, types)
def exploder(row):
import sys
irange = xrange if sys.version_info[0] < 3 else range
for _ in irange(getattr(row, counter_col_name)):
yield row[:-1]
return aggregated.map_reduce(mapper=exploder).select(*cols)
def intersect(left, *rights, **kwargs):
"""
Calc intersection among datasets,
:param left: collection
:param rights: collection or list of collections
:param distinct: whether to preserve duolicate entries
:return: collection
:Examples:
>>> import pandas as pd
>>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3, 3, 3], 'b': [1, 2, 3, 3, 3]}))
>>> df2 = DataFrame(pd.DataFrame({'a': [1, 3, 3], 'b': [1, 3, 3]}))
>>> df1.intersect(df2)
a b
0 1 1
1 3 3
2 3 3
>>> df1.intersect(df2, distinct=True)
a b
0 1 1
1 3 3
"""
import time
from ..utils import output
distinct = kwargs.get('distinct', False)
if isinstance(rights[0], list):
rights = rights[0]
cols = [n for n in left.schema.names]
types = [n for n in left.schema.types]
collections = (left, ) + rights
idx_col_name = 'idx_%d' % int(time.time())
counter_col_name = 'exc_counter_%d' % int(time.time())
collections = [c[c, Scalar(idx).rename(idx_col_name)] for idx, c in enumerate(collections)]
unioned = reduce(lambda a, b: a.union(b), collections)
src_agg = unioned.groupby(*(cols + [idx_col_name])) \
.agg(**{counter_col_name: unioned.count()})
aggregators = {
idx_col_name: src_agg[idx_col_name].nunique(),
counter_col_name: src_agg[counter_col_name].min(),
}
final_agg = src_agg.groupby(*cols).agg(**aggregators)
final_agg = final_agg.filter(final_agg[idx_col_name] == len(collections))
if distinct:
return final_agg.filter(final_agg[counter_col_name] > 0).select(*cols)
else:
@output(cols, types)
def exploder(row):
import sys
irange = xrange if sys.version_info[0] < 3 else range
for _ in irange(getattr(row, counter_col_name)):
yield row[:-2]
return final_agg.map_reduce(mapper=exploder).select(*cols)
CollectionExpr.union = union
SequenceExpr.union = union
CollectionExpr.__horz_concat = __horz_concat
SequenceExpr.__horz_concat = __horz_concat
CollectionExpr.concat = concat
SequenceExpr.concat = concat
CollectionExpr.drop = _drop
SequenceExpr.drop = _drop
CollectionExpr.setdiff = setdiff
SequenceExpr.setdiff = setdiff
CollectionExpr.except_ = setdiff
SequenceExpr.except_ = setdiff
CollectionExpr.intersect = intersect
SequenceExpr.intersect = intersect