odps/df/expr/merge.py (740 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from ...compat import six, reduce from ...models import TableSchema from ...utils import to_lower_str from ..utils import to_collection from .arithmetic import Equal from .core import Node, ExprDictionary from .expressions import CollectionExpr, TypedExpr, ProjectCollectionExpr, \ Column, BooleanSequenceExpr, SequenceExpr, Expr, Scalar, repr_obj, CallableColumn from .errors import ExpressionError class JoinCollectionExpr(CollectionExpr): __slots__ = '_how', '_left_suffix', '_right_suffix', '_column_origins', \ '_renamed_columns', '_column_conflict', '_mapjoin', '_skewjoin', \ '_skewjoin_values' _args = '_lhs', '_rhs', '_predicate' def _init(self, *args, **kwargs): self._init_attr('_left_suffix', None) self._init_attr('_right_suffix', None) self._init_attr('_column_origins', dict()) self._init_attr('_renamed_columns', dict()) self._init_attr('_column_conflict', False) super(JoinCollectionExpr, self)._init(*args, **kwargs) if not isinstance(self._lhs, CollectionExpr): raise TypeError( 'Can only join collection expressions, got %s for left expr.' % type(self._lhs) ) if not isinstance(self._rhs, CollectionExpr): raise TypeError( 'Can only join collection expressions, got %s for right expr.' % type(self._rhs) ) if self._rhs is self._lhs: self._rhs = self._rhs.view() if isinstance(self._lhs, JoinCollectionExpr) and \ (self._rhs is self._lhs._lhs or self._rhs is self._lhs._rhs): self._rhs = self._rhs.view() if self._left_suffix is None and self._right_suffix is None: overlaps = set(self._lhs.schema.names).intersection(self._rhs.schema.names) if len(overlaps) > 0: raise ValueError( 'Column conflict exists in join, overlap columns: %s' % ','.join(overlaps) ) self._set_schema() self._validate_predicates(self._predicate) self._how = self._how.upper() def _defunc(self, field): if inspect.isfunction(field): if six.get_function_code(field).co_argcount == 1: return field(self) else: return field(self._lhs, self._rhs) return field def _get_child(self, expr): while isinstance(expr, JoinProjectCollectionExpr): expr = expr.input return expr def _is_column(self, col, expr): return col.input is expr or col.input is self._get_child(expr) def _get_fields(self, fields, ret_raw_fields=False): selects = [] raw_selects = [] for field in fields: field = self._defunc(field) if isinstance(field, CollectionExpr): if any(c is self for c in field.children()) or \ any(c is self._get_child(self._lhs) for c in field.children()) or \ any(c is self._get_child(self._rhs) for c in field.children()): selects.extend(self._get_fields(field._project_fields)) elif field is self: selects.extend(self._get_fields(self._schema.names)) elif field is self._get_child(self._lhs): fields = [self._renamed_columns.get(n, [n])[0] for n in field.schema.names] selects.extend(self._get_fields(fields)) elif field is self._get_child(self._rhs): fields = [self._renamed_columns.get(n, [None, n])[1] for n in field.schema.names] selects.extend(self._get_fields(fields)) else: selects.extend(self._get_fields(field._fetch_fields())) raw_selects.append(field) else: select = self._get_field(field) selects.append(select) raw_selects.append(select) if ret_raw_fields: return selects, raw_selects return selects def _get_field(self, field): field = self._defunc(field) if isinstance(field, six.string_types): if field not in self._schema: raise ValueError('Field(%s) does not exist' % field) cls = Column if callable(getattr(type(self), field, None)): cls = CallableColumn return cls(self, _name=field, _data_type=self._schema[field].type) root = field has_path = False for expr in root.traverse(top_down=True, unique=True, stop_cond=lambda x: isinstance(x, Column) or x is self): if isinstance(expr, Column): if self._is_column(expr, self): has_path = True continue if self._is_column(expr, self._lhs): has_path = True idx = 0 elif self._is_column(expr, self._rhs): has_path = True idx = 1 elif expr.input._id == self._get_child(self._rhs)._id: # In case that the right collection has been copied expr.substitute(expr.input, self._get_child(self._rhs)) has_path = True idx = 1 elif isinstance(self._lhs, JoinCollectionExpr): try: expr = self._lhs._get_field(expr) except ExpressionError: continue has_path = True idx = 0 elif isinstance(self._rhs, JoinCollectionExpr): try: expr = self._rhs._get_field(expr) except ExpressionError: continue has_path = True idx = 1 else: continue name = expr.source_name if name in self._renamed_columns: name = self._renamed_columns[name][idx] to_sub = self._get_field(name) if expr.is_renamed(): to_sub = to_sub.rename(expr.name) to_sub.copy_to(expr) if isinstance(field, SequenceExpr) and not has_path: raise ExpressionError('field must come from Join collection ' 'or its left and right child collection: %s' % repr_obj(field)) return root def origin_collection(self, column_name): idx, name = self._column_origins[column_name] return [self._lhs, self._rhs][idx], name @property def node_name(self): return self.__class__.__name__ def iter_args(self): for it in zip(['collection(left)', 'collection(right)', 'on'], self.args): yield it @property def column_conflict(self): return self._column_conflict def accept(self, visitor): visitor.visit_join(self) def _get_non_suffixes_fields(self): return set() def _get_predicate_fields(self): predicate_fields = set() if not self._predicate: return predicate_fields for p in self._predicate: if isinstance(p, six.string_types): predicate_fields.add(p) elif isinstance(p, (tuple, Equal)): if isinstance(p, Equal): p = p.lhs, p.rhs if isinstance(p, tuple) and len(p) != 2: continue left_name = None if isinstance(p[0], six.string_types): left_name = p[0] elif isinstance(p[0], Column) and not p[0].is_renamed(): left_name = p[0].name if left_name is None: continue right_name = None if isinstance(p[1], six.string_types): right_name = p[1] elif isinstance(p[1], Column) and not p[1].is_renamed(): right_name = p[1].name if left_name == right_name: predicate_fields.add(left_name) return predicate_fields def _set_schema(self): names, typos = [], [] non_suffixes_fields = self._get_non_suffixes_fields() for col in self._lhs.schema.columns: name = col.name if to_lower_str(col.name) in self._rhs.schema._name_indexes: self._column_conflict = True if to_lower_str(col.name) in self._rhs.schema._name_indexes and \ col.name not in non_suffixes_fields: name = '%s%s' % (col.name, self._left_suffix) self._renamed_columns[col.name] = (name,) names.append(name) typos.append(col.type) self._column_origins[name] = 0, col.name for col in self._rhs.schema.columns: name = col.name if to_lower_str(col.name) in self._lhs.schema._name_indexes: self._column_conflict = True if to_lower_str(col.name) in self._lhs.schema._name_indexes and \ col.name not in non_suffixes_fields: name = '%s%s' % (col.name, self._right_suffix) self._renamed_columns[col.name] = \ self._renamed_columns[col.name][0], name if name in non_suffixes_fields: continue names.append(name) typos.append(col.type) self._column_origins[name] = 1, col.name if issubclass(type(self._lhs.schema), TableSchema): schema_type = type(self._lhs.schema) else: schema_type = type(self._rhs.schema) self._schema = schema_type.from_lists(names, typos) def _validate_equal(self, equal_expr): # FIXME: sometimes may be wrong, e.g. t3 = t1.join(t2, 'name')); t4.join(t3, t4.id == t3.id) return (equal_expr.lhs.is_ancestor(self._get_child(self._lhs)) and equal_expr.rhs.is_ancestor(self._get_child(self._rhs))) or \ (equal_expr.lhs.is_ancestor(self._get_child(self._rhs)) and equal_expr.rhs.is_ancestor(self._get_child(self._lhs))) def _reverse_equal(self, equal_expr): if equal_expr.lhs.is_ancestor(self._get_child(self._rhs)) and \ equal_expr.rhs.is_ancestor(self._get_child(self._lhs)): # the equal's left side is on the right collection and vise versa equal_expr._rhs, equal_expr._lhs = equal_expr._lhs, equal_expr._rhs def _validate_predicates(self, predicates): if predicates is None: return is_validate = False subs = [] if self._mapjoin: is_validate = True for p in predicates: if (isinstance(p, tuple) and len(p) == 2) or isinstance(p, six.string_types): if isinstance(p, six.string_types): left_name, right_name = p, p else: left_name, right_name = p left_col = self._lhs._get_field(left_name) right_col = self._rhs._get_field(right_name) if not left_col.is_ancestor(self._lhs) or not right_col.is_ancestor(self._rhs): raise ExpressionError('Invalid predicate: {0!s}'.format(repr_obj(p))) subs.append(left_col == right_col) is_validate = True elif isinstance(p, BooleanSequenceExpr): if not is_validate: it = (expr for expr in p.traverse(top_down=True, unique=True) if isinstance(expr, Equal)) while not is_validate: try: equal_expr = next(it) validate = self._validate_equal(equal_expr) if validate: is_validate = True break except StopIteration: break else: if not is_validate: raise ExpressionError('Invalid predicate: {0!s}'.format(repr_obj(p))) if not is_validate: raise ExpressionError('Invalid predicate: no validate predicate assigned') for p in predicates: if isinstance(p, BooleanSequenceExpr): subs.append(p) if len(subs) == 0: self._predicate = None else: self._predicate = subs def _merge_joined_fields(self, merge_columns): if not merge_columns: return self predicate_fields = self._get_predicate_fields() if not predicate_fields: raise ValueError('No fields in predicate. Cannot merge columns.') src_map = self._column_origins rename_map = dict() for name, src in six.iteritems(src_map): src_id, src_name = src if src_name not in predicate_fields: continue if src_name not in rename_map: rename_map[src_name] = [None, None] rename_map[src_name][src_id] = name if merge_columns in ('auto', 'left', 'right') or (isinstance(merge_columns, bool) and merge_columns): merge_columns = dict((k, merge_columns) for k in six.iterkeys(rename_map)) if isinstance(merge_columns, six.string_types): merge_columns = {merge_columns: 'auto'} if isinstance(merge_columns, list): merge_columns = dict((k, 'auto') for k in merge_columns) excludes = set() for col, action in six.iteritems(merge_columns): if col not in rename_map: raise ValueError('Column {0} not exists in join predicate.'.format(col)) if isinstance(action, bool) and action: merge_columns[col] = 'auto' else: merge_columns[col] = action.lower() excludes.update(rename_map[col]) selects = [] merged = set() for col in self.schema.names: if col not in excludes: selects.append(self[col]) else: src_name = src_map[col][1] if src_name in merged: continue merged.add(src_name) left_name, right_name = rename_map[src_name] left_col = self[left_name] right_col = self[right_name] if merge_columns[src_name] == 'auto': selects.append(left_col.isnull().ifelse(right_col, left_col).rename(src_name)) elif merge_columns[src_name] == 'left': selects.append(left_col.rename(src_name)) elif merge_columns[src_name] == 'right': selects.append(right_col.rename(src_name)) selected = self.select(*selects) return JoinFieldMergedCollectionExpr(_input=self, _fields=selected._fields, _schema=selected._schema, _rename_map=rename_map) class InnerJoin(JoinCollectionExpr): def _init(self, *args, **kwargs): self._how = 'INNER' super(InnerJoin, self)._init(*args, **kwargs) def _get_non_suffixes_fields(self): return self._get_predicate_fields() class LeftJoin(JoinCollectionExpr): def _init(self, *args, **kwargs): self._how = 'LEFT OUTER' super(LeftJoin, self)._init(*args, **kwargs) class RightJoin(JoinCollectionExpr): def _init(self, *args, **kwargs): self._how = 'RIGHT OUTER' super(RightJoin, self)._init(*args, **kwargs) class OuterJoin(JoinCollectionExpr): def _init(self, *args, **kwargs): self._how = 'FULL OUTER' super(OuterJoin, self)._init(*args, **kwargs) class JoinFieldMergedCollectionExpr(ProjectCollectionExpr): __slots__ = '_rename_map', def _get_fields(self, fields, ret_raw_fields=False): selects = [] raw_selects = [] joined_expr = self._input for field in fields: field = self._defunc(field) if isinstance(field, CollectionExpr): if any(c is self for c in field.children()) or \ any(c is joined_expr._lhs for c in field.children()) or \ any(c is joined_expr._rhs for c in field.children()): selects.extend(self._get_fields(field._project_fields)) elif field is self: selects.extend(self._get_fields(self._schema.names)) elif field is joined_expr._lhs: fields = [joined_expr._renamed_columns.get(n, [n])[0] if n not in self._rename_map else n for n in field.schema.names] selects.extend(self._get_fields(fields)) elif field is joined_expr._rhs: fields = [joined_expr._renamed_columns.get(n, [None, n])[1] if n not in self._rename_map else n for n in field.schema.names] selects.extend(self._get_fields(fields)) else: selects.extend(super(JoinFieldMergedCollectionExpr, self)._get_fields(field)) raw_selects.append(field) else: select = self._get_field(field) selects.append(select) raw_selects.append(select) if ret_raw_fields: return selects, raw_selects return selects def _get_field(self, field): field = self._defunc(field) if isinstance(field, six.string_types): if field not in self._schema: raise ValueError('Field(%s) does not exist' % field) cls = Column if callable(getattr(type(self), field, None)): cls = CallableColumn return cls(self, _name=field, _data_type=self._schema[field].type) joined_expr = self._input root = field has_path = False for expr in root.traverse(top_down=True, unique=True, stop_cond=lambda x: isinstance(x, Column) or x is self): if isinstance(expr, Column): if expr.input is self or expr.input is joined_expr: has_path = True continue if expr.input is joined_expr._lhs: has_path = True idx = 0 elif expr.input is joined_expr._rhs: has_path = True idx = 1 elif isinstance(joined_expr._lhs, JoinCollectionExpr): try: expr = joined_expr._lhs._get_field(expr) except ExpressionError: continue has_path = True idx = 0 elif isinstance(joined_expr._rhs, JoinCollectionExpr): try: expr = joined_expr._rhs._get_field(expr) except ExpressionError: continue has_path = True idx = 1 else: continue name = expr.source_name if name not in self._rename_map and name in joined_expr._renamed_columns: name = joined_expr._renamed_columns[name][idx] to_sub = self._get_field(name) if expr.is_renamed(): to_sub = to_sub.rename(expr.name) to_sub.copy_to(expr) if isinstance(field, SequenceExpr) and not has_path: raise ExpressionError('field must come from Join collection ' 'or its left and right child collection: %s' % repr_obj(field)) return root class JoinProjectCollectionExpr(ProjectCollectionExpr): """ Only for analyzer, project join should generate normal `ProjectCollectionExpr`. """ __slots__ = () _join_dict = { 'INNER': InnerJoin, 'LEFT': LeftJoin, 'RIGHT': RightJoin, 'OUTER': OuterJoin } def _make_different_sources(left, right, predicate=None): # TODO: move to analyzer, do it before analyze and optimize exprs = ExprDictionary() for n in left.traverse(unique=True): exprs[n] = True subs = ExprDictionary() if getattr(right, '_proxy', None) is not None: right = right._proxy dag = right.to_dag(copy=False, validate=False) for n in dag.traverse(): if n in exprs: copied = subs.get(n, n.copy()) for p in dag.successors(n): if p in exprs and p not in subs: subs[p] = p.copy() subs.get(p, p).substitute(n, copied) subs[n] = copied if predicate and n is right: for p in predicate: if not isinstance(p, Expr): continue p_dag = p.to_dag(copy=False, validate=False) for p_n in p_dag.traverse(top_down=True): if p_n is right: succs = list(s for s in p_dag.successors(p_n) if s not in exprs) p_dag.substitute(right, copied, parents=succs) return left, subs.get(right, right) def join(left, right, on=None, how='inner', suffixes=('_x', '_y'), mapjoin=False, skewjoin=False): """ Join two collections. If `on` is not specified, we will find the common fields of the left and right collection. `suffixes` means that if column names conflict, the suffixes will be added automatically. For example, both left and right has a field named `col`, there will be col_x, and col_y in the joined collection. :param left: left collection :param right: right collection :param on: fields to join on :param how: 'inner', 'left', 'right', or 'outer' :param suffixes: when name conflict, the suffix will be added to both columns. :param mapjoin: set use mapjoin or not, default value False. :param skewjoin: set use of skewjoin or not, default value False. Can specify True if the collection is skew, or a list specifying columns with skew values, or a list of dicts specifying skew combinations. :return: collection :Example: >>> df.dtypes.names ['name', 'id'] >>> df2.dtypes.names ['name', 'id1'] >>> df.join(df2) >>> df.join(df2, on='name') >>> df.join(df2, on=('id', 'id1')) >>> df.join(df2, on=['name', ('id', 'id1')]) >>> df.join(df2, on=[df.name == df2.name, df.id == df2.id1]) >>> df.join(df2, mapjoin=False) >>> df.join(df2, skewjoin=True) >>> df.join(df2, skewjoin=["c0", "c1"]) >>> df.join(df2, skewjoin=[{"c0": 1, "c1": "2"}, {"c0": 3, "c1": "4"}]) """ if mapjoin and skewjoin: raise TypeError("Cannot specify mapjoin and skewjoin at the same time") if isinstance(left, TypedExpr): left = to_collection(left) if isinstance(right, TypedExpr): right = to_collection(right) if on is None: if not mapjoin: on = [name for name in left.schema.names if to_lower_str(name) in right.schema._name_indexes] if not on and len(left.schema) == 1 and len(right.schema) == 1: on = [(left.schema.names[0], right.schema.names[0])] skewjoin_values = None if isinstance(skewjoin, (dict, six.string_types)): skewjoin = [skewjoin] if isinstance(skewjoin, list): if ( all(isinstance(c, six.string_types) for c in skewjoin) and any(c not in right.schema.names for c in skewjoin) ): raise ValueError( "All columns specified in `skewjoin` need to exist in the right collection" ) elif ( all(isinstance(c, dict) for c in skewjoin) ): cols = sorted(skewjoin[0].keys()) cols_set = set(cols) if any(c not in right.schema.names for c in cols): raise ValueError( "All columns specified in `skewjoin` need to exist in the right collection" ) if any(cols_set != set(c.keys()) for c in skewjoin): raise ValueError("All values in `skewjoin` need to have same columns") skewjoin_values = [[d[c] for c in cols] for d in skewjoin] skewjoin = cols elif skewjoin and not isinstance(skewjoin, bool): raise TypeError("Cannot accept skewjoin type %s" % type(skewjoin)) if isinstance(suffixes, (tuple, list)) and len(suffixes) == 2: left_suffix, right_suffix = suffixes else: raise ValueError('suffixes must be a tuple or list with two elements, got %s' % suffixes) if not isinstance(on, list): on = [on, ] for i in range(len(on)): it = on[i] if inspect.isfunction(it): on[i] = it(left, right) left, right = _make_different_sources(left, right, on) try: return _join_dict[how.upper()]( _lhs=left, _rhs=right, _predicate=on, _left_suffix=left_suffix, _right_suffix=right_suffix, _mapjoin=mapjoin, _skewjoin=skewjoin, _skewjoin_values=skewjoin_values ) except KeyError: return JoinCollectionExpr( _lhs=left, _rhs=right, _predicate=on, _how=how, _left_suffix=left_suffix, _right_suffix=right_suffix, _mapjoin=mapjoin, _skewjoin=skewjoin, _skewjoin_values=skewjoin_values ) def inner_join(left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, skewjoin=False): """ Inner join two collections. If `on` is not specified, we will find the common fields of the left and right collection. `suffixes` means that if column names conflict, the suffixes will be added automatically. For example, both left and right has a field named `col`, there will be col_x, and col_y in the joined collection. :param left: left collection :param right: right collection :param on: fields to join on :param suffixes: when name conflict, the suffixes will be added to both columns. :return: collection :Example: >>> df.dtypes.names ['name', 'id'] >>> df2.dtypes.names ['name', 'id1'] >>> df.inner_join(df2) >>> df.inner_join(df2, on='name') >>> df.inner_join(df2, on=('id', 'id1')) >>> df.inner_join(df2, on=['name', ('id', 'id1')]) >>> df.inner_join(df2, on=[df.name == df2.name, df.id == df2.id1]) """ return join(left, right, on, suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin) def left_join( left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False ): """ Left join two collections. If `on` is not specified, we will find the common fields of the left and right collection. `suffixes` means that if column names conflict, the suffixes will be added automatically. For example, both left and right has a field named `col`, there will be col_x, and col_y in the joined collection. :param left: left collection :param right: right collection :param on: fields to join on :param suffixes: when name conflict, the suffixes will be added to both columns. :param mapjoin: set use mapjoin or not, default value False. :param merge_columns: whether to merge columns with the same name into one column without suffix. If the value is True, columns in the predicate with same names will be merged, with non-null value. If the value is 'left' or 'right', the values of predicates on the left / right collection will be taken. You can also pass a dictionary to describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }. :return: collection :Example: >>> df.dtypes.names ['name', 'id'] >>> df2.dtypes.names ['name', 'id1'] >>> df.left_join(df2) >>> df.left_join(df2, on='name') >>> df.left_join(df2, on=('id', 'id1')) >>> df.left_join(df2, on=['name', ('id', 'id1')]) >>> df.left_join(df2, on=[df.name == df2.name, df.id == df2.id1]) """ joined = join(left, right, on, how='left', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin) return joined._merge_joined_fields(merge_columns) def right_join( left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False ): """ Right join two collections. If `on` is not specified, we will find the common fields of the left and right collection. `suffixes` means that if column names conflict, the suffixes will be added automatically. For example, both left and right has a field named `col`, there will be col_x, and col_y in the joined collection. :param left: left collection :param right: right collection :param on: fields to join on :param suffixes: when name conflict, the suffixes will be added to both columns. :param mapjoin: set use mapjoin or not, default value False. :param merge_columns: whether to merge columns with the same name into one column without suffix. If the value is True, columns in the predicate with same names will be merged, with non-null value. If the value is 'left' or 'right', the values of predicates on the left / right collection will be taken. You can also pass a dictionary to describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }. :return: collection :Example: >>> df.dtypes.names ['name', 'id'] >>> df2.dtypes.names ['name', 'id1'] >>> df.right_join(df2) >>> df.right_join(df2, on='name') >>> df.right_join(df2, on=('id', 'id1')) >>> df.right_join(df2, on=['name', ('id', 'id1')]) >>> df.right_join(df2, on=[df.name == df2.name, df.id == df2.id1]) """ joined = join(left, right, on, how='right', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin) return joined._merge_joined_fields(merge_columns) def outer_join( left, right, on=None, suffixes=('_x', '_y'), mapjoin=False, merge_columns=None, skewjoin=False ): """ Outer join two collections. If `on` is not specified, we will find the common fields of the left and right collection. `suffixes` means that if column names conflict, the suffixes will be added automatically. For example, both left and right has a field named `col`, there will be col_x, and col_y in the joined collection. :param left: left collection :param right: right collection :param on: fields to join on :param suffixes: when name conflict, the suffixes will be added to both columns. :param mapjoin: set use mapjoin or not, default value False. :param merge_columns: whether to merge columns with the same name into one column without suffix. If the value is True, columns in the predicate with same names will be merged, with non-null value. If the value is 'left' or 'right', the values of predicates on the left / right collection will be taken. You can also pass a dictionary to describe the behavior of each column, such as { 'a': 'auto', 'b': 'left' }. :return: collection :Example: >>> df.dtypes.names ['name', 'id'] >>> df2.dtypes.names ['name', 'id1'] >>> df.outer_join(df2) >>> df.outer_join(df2, on='name') >>> df.outer_join(df2, on=('id', 'id1')) >>> df.outer_join(df2, on=['name', ('id', 'id1')]) >>> df.outer_join(df2, on=[df.name == df2.name, df.id == df2.id1]) """ joined = join(left, right, on, how='outer', suffixes=suffixes, mapjoin=mapjoin, skewjoin=skewjoin) return joined._merge_joined_fields(merge_columns) CollectionExpr.join = join CollectionExpr.inner_join = inner_join CollectionExpr.left_join = left_join CollectionExpr.right_join = right_join CollectionExpr.outer_join = outer_join TypedExpr.join = join TypedExpr.inner_join = inner_join TypedExpr.left_join = left_join TypedExpr.right_join = right_join TypedExpr.outer_join = outer_join def _get_sequence_source_collection(expr): return next(it for it in expr.traverse(top_down=True, unique=True) if isinstance(it, CollectionExpr)) class UnionCollectionExpr(CollectionExpr): __slots__ = '_distinct', _args = '_lhs', '_rhs', node_name = 'Union' def _init(self, *args, **kwargs): super(UnionCollectionExpr, self)._init(*args, **kwargs) self._validate() self._schema = self._clean_schema() def _validate_collection_child(self): if self._lhs.schema.names == self._rhs.schema.names and self._lhs.schema.types == self._rhs.schema.types: return True elif set(self._lhs.schema.names) == set(self._rhs.schema.names) and set(self._lhs.schema.types) == set( self._rhs.schema.types): self._rhs = self._rhs[self._lhs.schema.names] return self._lhs.schema.types == self._rhs.schema.types else: return False def _validate(self): if isinstance(self._lhs, SequenceExpr): source_collection = _get_sequence_source_collection(self._lhs) self._lhs = source_collection[[self._lhs]] if isinstance(self._rhs, SequenceExpr): source_collection = _get_sequence_source_collection(self._rhs) self._rhs = source_collection[[self._rhs]] if isinstance(self._lhs, CollectionExpr) and isinstance(self._rhs, CollectionExpr): if not self._validate_collection_child(): raise ExpressionError('Table schemas must be equal to form union') else: raise ExpressionError('Both inputs should be collections or sequences.') def _clean_schema(self): return TableSchema.from_lists(self._lhs.schema.names, self._lhs.schema.types) def accept(self, visitor): return visitor.visit_union(self) def iter_args(self): for it in zip(['collection(left)', 'collection(right)'], self.args): yield it class ConcatCollectionExpr(CollectionExpr): _args = '_lhs', '_rhs', node_name = 'Concat' def _init(self, *args, **kwargs): super(ConcatCollectionExpr, self)._init(*args, **kwargs) self._schema = self._clean_schema() def iter_args(self): for it in zip(['collection(left)', 'collection(right)'], self.args): yield it def _clean_schema(self): return TableSchema.from_lists( self._lhs.schema.names + self._rhs.schema.names, self._lhs.schema.types + self._rhs.schema.types, ) @staticmethod def _get_sequence_source_collection(expr): return next(it for it in expr.traverse(top_down=True, unique=True) if isinstance(it, CollectionExpr)) @classmethod def validate_input(cls, *inputs): new_inputs = [] for i in inputs: if isinstance(i, SequenceExpr): source_collection = cls._get_sequence_source_collection(i) new_inputs.append(source_collection[[i]]) else: new_inputs.append(i) if any(not isinstance(i, CollectionExpr) for i in new_inputs): raise ExpressionError('Inputs should be collections or sequences.') unioned = reduce(lambda a, b: a | b, (set(i.schema.names) for i in inputs)) total_fields = sum((len(i.schema.names) for i in inputs)) if total_fields != len(unioned): raise ExpressionError('Column names in inputs should not collides with each other.') def accept(self, visitor): return visitor.visit_concat(self) def union(left, right, distinct=False): """ Union two collections. :param left: left collection :param right: right collection :param distinct: :return: collection :Example: >>> df['name', 'id'].union(df2['id', 'name']) """ left, right = _make_different_sources(left, right) return UnionCollectionExpr(_lhs=left, _rhs=right, _distinct=distinct) def __horz_concat(left, rights): from ..utils import to_collection left = to_collection(left) for right in rights: right = to_collection(right) left, right = _make_different_sources(left, right) left = ConcatCollectionExpr(_lhs=left, _rhs=right) return left def concat(left, rights, distinct=False, axis=0): """ Concat collections. :param left: left collection :param rights: right collections, can be a DataFrame object or a list of DataFrames :param distinct: whether to remove duplicate entries. only available when axis == 0 :param axis: when axis == 0, the DataFrames are merged vertically, otherwise horizontally. :return: collection Note that axis==1 can only be used under Pandas DataFrames or XFlow. :Example: >>> df['name', 'id'].concat(df2['score'], axis=1) """ from ..utils import to_collection if isinstance(rights, Node): rights = [rights, ] if not rights: raise ValueError('At least one DataFrame should be provided.') if axis == 0: for right in rights: left = union(left, right, distinct=distinct) return left else: rights = [to_collection(r) for r in rights] ConcatCollectionExpr.validate_input(left, *rights) if hasattr(left, '_xflow_concat'): return left._xflow_concat(rights) else: return __horz_concat(left, rights) def _drop(expr, data, axis=0, columns=None): """ Drop data from a DataFrame. :param expr: collection to drop data from :param data: data to be removed :param axis: 0 for deleting rows, 1 for columns. :param columns: columns of data to select, only useful when axis == 0 :return: collection :Example: >>> import pandas as pd >>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6], 'c': [7, 8, 9]})) >>> df2 = DataFrame(pd.DataFrame({'a': [2, 3], 'b': [5, 7]})) >>> df1.drop(df2) a b c 0 1 4 7 1 3 6 9 >>> df1.drop(df2, columns='a') a b c 0 1 4 7 >>> df1.drop(['a'], axis=1) b c 0 4 7 1 5 8 2 6 9 >>> df1.drop(df2, axis=1) c 0 7 1 8 2 9 """ from ..utils import to_collection expr = to_collection(expr) if axis == 0: if not isinstance(data, (CollectionExpr, SequenceExpr)): raise ExpressionError('data should be a collection or sequence when axis == 1.') data = to_collection(data) if columns is None: columns = [n for n in data.schema.names] if isinstance(columns, six.string_types): columns = [columns, ] data = data.select(*columns).distinct() drop_predicates = [data[n].isnull() for n in data.schema.names] return expr.left_join(data, on=columns, suffixes=('', '_dp')).filter(*drop_predicates) \ .select(*expr.schema.names) else: if isinstance(data, (CollectionExpr, SequenceExpr)): data = to_collection(data).schema.names return expr.exclude(data) def setdiff(left, *rights, **kwargs): """ Exclude data from a collection, like `except` clause in SQL. All collections involved should have same schema. :param left: collection to drop data from :param rights: collection or list of collections :param distinct: whether to preserve duplicate entries :return: collection :Examples: >>> import pandas as pd >>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3, 3, 3], 'b': [1, 2, 3, 3, 3]})) >>> df2 = DataFrame(pd.DataFrame({'a': [1, 3], 'b': [1, 3]})) >>> df1.setdiff(df2) a b 0 2 2 1 3 3 2 3 3 >>> df1.setdiff(df2, distinct=True) a b 0 2 2 """ import time from ..utils import output distinct = kwargs.get('distinct', False) if isinstance(rights[0], list): rights = rights[0] cols = [n for n in left.schema.names] types = [n for n in left.schema.types] counter_col_name = 'exc_counter_%d' % int(time.time()) left = left[left, Scalar(1).rename(counter_col_name)] rights = [r[r, Scalar(-1).rename(counter_col_name)] for r in rights] unioned = left for r in rights: unioned = unioned.union(r) if distinct: aggregated = unioned.groupby(*cols).agg(**{counter_col_name: unioned[counter_col_name].min()}) return aggregated.filter(aggregated[counter_col_name] == 1).select(*cols) else: aggregated = unioned.groupby(*cols).agg(**{counter_col_name: unioned[counter_col_name].sum()}) @output(cols, types) def exploder(row): import sys irange = xrange if sys.version_info[0] < 3 else range for _ in irange(getattr(row, counter_col_name)): yield row[:-1] return aggregated.map_reduce(mapper=exploder).select(*cols) def intersect(left, *rights, **kwargs): """ Calc intersection among datasets, :param left: collection :param rights: collection or list of collections :param distinct: whether to preserve duolicate entries :return: collection :Examples: >>> import pandas as pd >>> df1 = DataFrame(pd.DataFrame({'a': [1, 2, 3, 3, 3], 'b': [1, 2, 3, 3, 3]})) >>> df2 = DataFrame(pd.DataFrame({'a': [1, 3, 3], 'b': [1, 3, 3]})) >>> df1.intersect(df2) a b 0 1 1 1 3 3 2 3 3 >>> df1.intersect(df2, distinct=True) a b 0 1 1 1 3 3 """ import time from ..utils import output distinct = kwargs.get('distinct', False) if isinstance(rights[0], list): rights = rights[0] cols = [n for n in left.schema.names] types = [n for n in left.schema.types] collections = (left, ) + rights idx_col_name = 'idx_%d' % int(time.time()) counter_col_name = 'exc_counter_%d' % int(time.time()) collections = [c[c, Scalar(idx).rename(idx_col_name)] for idx, c in enumerate(collections)] unioned = reduce(lambda a, b: a.union(b), collections) src_agg = unioned.groupby(*(cols + [idx_col_name])) \ .agg(**{counter_col_name: unioned.count()}) aggregators = { idx_col_name: src_agg[idx_col_name].nunique(), counter_col_name: src_agg[counter_col_name].min(), } final_agg = src_agg.groupby(*cols).agg(**aggregators) final_agg = final_agg.filter(final_agg[idx_col_name] == len(collections)) if distinct: return final_agg.filter(final_agg[counter_col_name] > 0).select(*cols) else: @output(cols, types) def exploder(row): import sys irange = xrange if sys.version_info[0] < 3 else range for _ in irange(getattr(row, counter_col_name)): yield row[:-2] return final_agg.map_reduce(mapper=exploder).select(*cols) CollectionExpr.union = union SequenceExpr.union = union CollectionExpr.__horz_concat = __horz_concat SequenceExpr.__horz_concat = __horz_concat CollectionExpr.concat = concat SequenceExpr.concat = concat CollectionExpr.drop = _drop SequenceExpr.drop = _drop CollectionExpr.setdiff = setdiff SequenceExpr.setdiff = setdiff CollectionExpr.except_ = setdiff SequenceExpr.except_ = setdiff CollectionExpr.intersect = intersect SequenceExpr.intersect = intersect