core/maxframe/dataframe/indexing/where.py (155 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import pandas as pd from ... import opcodes from ...core import ENTITY_TYPE from ...serialization.serializables import AnyField, BoolField, Int32Field, StringField from ...tensor.utils import filter_inputs from ..core import DATAFRAME_TYPE, SERIES_TYPE from ..operators import DataFrameOperator, DataFrameOperatorMixin from ..utils import build_df, build_series, validate_axis class DataFrameWhere(DataFrameOperator, DataFrameOperatorMixin): _op_type_ = opcodes.WHERE _input = AnyField("input") cond = AnyField("cond", default=None) other = AnyField("other", default=None) axis = Int32Field("axis", default=None) level = AnyField("level", default=None) errors = StringField("errors", default=None) try_cast = BoolField("try_cast", default=None) replace_true = BoolField("replace_true", default=None) def __init__(self, input=None, **kw): super().__init__(_input=input, **kw) @property def input(self): return self._input def __call__(self, df_or_series): def _check_input_index(obj, axis=None): axis = axis if axis is not None else self.axis if isinstance(obj, DATAFRAME_TYPE) and ( df_or_series.columns_value.key != obj.columns_value.key or df_or_series.index_value.key != obj.index_value.key ): raise NotImplementedError("Aligning different indices not supported") elif ( isinstance(obj, SERIES_TYPE) and df_or_series.axes[axis].index_value.key != obj.index_value.key ): raise NotImplementedError("Aligning different indices not supported") _check_input_index(self.cond, axis=0) _check_input_index(self.other) if isinstance(df_or_series, DATAFRAME_TYPE): mock_obj = build_df(df_or_series) else: mock_obj = build_series(df_or_series) if isinstance(self.other, (pd.DataFrame, DATAFRAME_TYPE)): mock_other = build_df(self.other) elif isinstance(self.other, (pd.Series, SERIES_TYPE)): mock_other = build_series(self.other) else: mock_other = self.other result_df = mock_obj.where( np.zeros(mock_obj.shape).astype(bool), other=mock_other, axis=self.axis, level=self.level, errors=self.errors, try_cast=self.try_cast, ) inputs = filter_inputs([df_or_series, self.cond, self.other]) if isinstance(df_or_series, DATAFRAME_TYPE): return self.new_dataframe( inputs, shape=df_or_series.shape, dtypes=result_df.dtypes, index_value=df_or_series.index_value, columns_value=df_or_series.columns_value, ) else: return self.new_series( inputs, shape=df_or_series.shape, name=df_or_series.name, dtype=result_df.dtype, index_value=df_or_series.index_value, ) def _set_inputs(self, inputs): super()._set_inputs(inputs) inputs_iter = iter(self._inputs) self._input = next(inputs_iter) if isinstance(self.cond, ENTITY_TYPE): self.cond = next(inputs_iter) if isinstance(self.other, ENTITY_TYPE): self.other = next(inputs_iter) _doc_template = """ Replace values where the condition is {replace_true}. Parameters ---------- cond : bool Series/DataFrame, array-like, or callable Where `cond` is False, keep the original value. Where True, replace with corresponding value from `other`. If `cond` is callable, it is computed on the Series/DataFrame and should return boolean Series/DataFrame or array. The callable must not change input Series/DataFrame (though pandas doesn't check it). other : scalar, Series/DataFrame, or callable Entries where `cond` is True are replaced with corresponding value from `other`. If other is callable, it is computed on the Series/DataFrame and should return scalar or Series/DataFrame. The callable must not change input Series/DataFrame (though pandas doesn't check it). inplace : bool, default False Whether to perform the operation in place on the data. axis : int, default None Alignment axis if needed. level : int, default None Alignment level if needed. errors : str, {{'raise', 'ignore'}}, default 'raise' Note that currently this parameter won't affect the results and will always coerce to a suitable dtype. - 'raise' : allow exceptions to be raised. - 'ignore' : suppress exceptions. On error return original object. try_cast : bool, default False Try to cast the result back to the input type (if possible). Returns ------- Same type as caller See Also -------- :func:`DataFrame.{opposite}` : Return an object of same shape as self. Notes ----- The mask method is an application of the if-then idiom. For each element in the calling DataFrame, if ``cond`` is ``False`` the element is used; otherwise the corresponding element from the DataFrame ``other`` is used. The signature for :func:`DataFrame.where` differs from :func:`numpy.where`. Roughly ``df1.where(m, df2)`` is equivalent to ``np.where(m, df1, df2)``. For further details and examples see the ``mask`` documentation in :ref:`indexing <indexing.where_mask>`. Examples -------- >>> import maxframe.tensor as mt >>> import maxframe.dataframe as md >>> s = md.Series(range(5)) >>> s.where(s > 0).execute() 0 NaN 1 1.0 2 2.0 3 3.0 4 4.0 dtype: float64 >>> s.mask(s > 0).execute() 0 0.0 1 NaN 2 NaN 3 NaN 4 NaN dtype: float64 >>> s.where(s > 1, 10).execute() 0 10 1 10 2 2 3 3 4 4 dtype: int64 >>> df = md.DataFrame(mt.arange(10).reshape(-1, 2), columns=['A', 'B']) >>> df.execute() A B 0 0 1 1 2 3 2 4 5 3 6 7 4 8 9 >>> m = df % 3 == 0 >>> df.where(m, -df).execute() A B 0 0 -1 1 -2 3 2 -4 -5 3 6 -7 4 -8 9 >>> df.where(m, -df) == mt.where(m, df, -df).execute() A B 0 True True 1 True True 2 True True 3 True True 4 True True >>> df.where(m, -df) == df.mask(~m, -df).execute() A B 0 True True 1 True True 2 True True 3 True True 4 True True """ def _where( df_or_series, cond, other=np.nan, inplace=False, axis=None, level=None, errors="raise", try_cast=False, replace_true=False, ): if df_or_series.ndim == 2 and getattr(other, "ndim", 2) == 1 and axis is None: raise ValueError("Must specify axis=0 or 1") axis = validate_axis(axis or 0, df_or_series) op = DataFrameWhere( cond=cond, other=other, axis=axis, level=level, errors=errors, try_cast=try_cast, replace_true=replace_true, ) result = op(df_or_series) if inplace: df_or_series.data = result.data else: return result def where( df_or_series, cond, other=np.nan, inplace=False, axis=None, level=None, errors="raise", try_cast=False, ): return _where( df_or_series, cond, other=other, inplace=inplace, axis=axis, level=level, errors=errors, try_cast=try_cast, replace_true=False, ) def mask( df_or_series, cond, other=np.nan, inplace=False, axis=None, level=None, errors="raise", try_cast=False, ): return _where( df_or_series, cond, other=other, inplace=inplace, axis=axis, level=level, errors=errors, try_cast=try_cast, replace_true=True, ) mask.__doc__ = _doc_template.format(replace_true=True, opposite="where") where.__doc__ = _doc_template.format(replace_true=False, opposite="mask")