core/maxframe/dataframe/groupby/sample.py (93 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Optional, Sequence, Union import numpy as np import pandas as pd from ... import opcodes from ...core import ENTITY_TYPE, OutputType, get_output_types from ...serialization.serializables import ( BoolField, DictField, Float32Field, Int32Field, Int64Field, KeyField, NDArrayField, StringField, ) from ...tensor.random import RandomStateField from ..initializer import Series as asseries from ..operators import DataFrameOperator, DataFrameOperatorMixin from ..utils import parse_index class GroupBySample(DataFrameOperator, DataFrameOperatorMixin): _op_type_ = opcodes.RAND_SAMPLE _op_module_ = "dataframe.groupby" groupby_params = DictField("groupby_params", default=None) size = Int64Field("size", default=None) frac = Float32Field("frac", default=None) replace = BoolField("replace", default=None) weights = KeyField("weights", default=None) seed = Int32Field("seed", default=None) _random_state = RandomStateField("random_state", default=None) errors = StringField("errors", default=None) # for chunks # num of instances for chunks input_nsplits = NDArrayField("input_nsplits", default=None) def __init__(self, random_state=None, **kw): super().__init__(_random_state=random_state, **kw) @property def random_state(self): return self._random_state def _set_inputs(self, inputs): super()._set_inputs(inputs) input_iter = iter(inputs) next(input_iter) if isinstance(self.weights, ENTITY_TYPE): self.weights = next(input_iter) def __call__(self, groupby): df = groupby while df.op.output_types[0] not in (OutputType.dataframe, OutputType.series): df = df.inputs[0] selection = groupby.op.groupby_params.pop("selection", None) if df.ndim > 1 and selection: if isinstance(selection, tuple) and selection not in df.dtypes: selection = list(selection) result_df = df[selection] else: result_df = df params = result_df.params params["shape"] = ( (np.nan,) if result_df.ndim == 1 else (np.nan, result_df.shape[-1]) ) params["index_value"] = parse_index(result_df.index_value.to_pandas()[:0]) input_dfs = [df] if isinstance(self.weights, ENTITY_TYPE): input_dfs.append(self.weights) self._output_types = get_output_types(result_df) return self.new_tileable(input_dfs, **params) def groupby_sample( groupby, n: Optional[int] = None, frac: Optional[float] = None, replace: bool = False, weights: Union[Sequence, pd.Series, None] = None, random_state: Optional[np.random.RandomState] = None, errors: str = "ignore", ): """ Return a random sample of items from each group. You can use `random_state` for reproducibility. Parameters ---------- n : int, optional Number of items to return for each group. Cannot be used with `frac` and must be no larger than the smallest group unless `replace` is True. Default is one if `frac` is None. frac : float, optional Fraction of items to return. Cannot be used with `n`. replace : bool, default False Allow or disallow sampling of the same row more than once. weights : list-like, optional Default None results in equal probability weighting. If passed a list-like then values must have the same length as the underlying DataFrame or Series object and will be used as sampling probabilities after normalization within each group. Values must be non-negative with at least one positive element within each group. random_state : int, array-like, BitGenerator, np.random.RandomState, optional If int, array-like, or BitGenerator (NumPy>=1.17), seed for random number generator If np.random.RandomState, use as numpy RandomState object. errors : {'ignore', 'raise'}, default 'ignore' If ignore, errors will not be raised when `replace` is False and size of some group is less than `n`. Returns ------- Series or DataFrame A new object of same type as caller containing items randomly sampled within each group from the caller object. See Also -------- DataFrame.sample: Generate random samples from a DataFrame object. numpy.random.choice: Generate a random sample from a given 1-D numpy array. Examples -------- >>> import maxframe.dataframe as md >>> df = md.DataFrame( ... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)} ... ) >>> df.execute() a b 0 red 0 1 red 1 2 blue 2 3 blue 3 4 black 4 5 black 5 Select one row at random for each distinct value in column a. The `random_state` argument can be used to guarantee reproducibility: >>> df.groupby("a").sample(n=1, random_state=1).execute() a b 4 black 4 2 blue 2 1 red 1 Set `frac` to sample fixed proportions rather than counts: >>> df.groupby("a")["b"].sample(frac=0.5, random_state=2).execute() 5 5 2 2 0 0 Name: b, dtype: int64 Control sample probabilities within groups by setting weights: >>> df.groupby("a").sample( ... n=1, ... weights=[1, 1, 1, 0, 0, 1], ... random_state=1, ... ).execute() a b 5 black 5 2 blue 2 0 red 0 """ groupby_params = groupby.op.groupby_params.copy() groupby_params.pop("as_index", None) if weights is not None and not isinstance(weights, ENTITY_TYPE): weights = asseries(weights) n = 1 if n is None and frac is None else n rs = copy.deepcopy( random_state.to_numpy() if hasattr(random_state, "to_numpy") else random_state ) if not isinstance(rs, np.random.RandomState): # pragma: no cover rs = np.random.RandomState(rs) op = GroupBySample( size=n, frac=frac, replace=replace, weights=weights, random_state=rs, groupby_params=groupby_params, errors=errors, ) return op(groupby)