core/maxframe/dataframe/groupby/sample.py (93 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import Optional, Sequence, Union
import numpy as np
import pandas as pd
from ... import opcodes
from ...core import ENTITY_TYPE, OutputType, get_output_types
from ...serialization.serializables import (
BoolField,
DictField,
Float32Field,
Int32Field,
Int64Field,
KeyField,
NDArrayField,
StringField,
)
from ...tensor.random import RandomStateField
from ..initializer import Series as asseries
from ..operators import DataFrameOperator, DataFrameOperatorMixin
from ..utils import parse_index
class GroupBySample(DataFrameOperator, DataFrameOperatorMixin):
_op_type_ = opcodes.RAND_SAMPLE
_op_module_ = "dataframe.groupby"
groupby_params = DictField("groupby_params", default=None)
size = Int64Field("size", default=None)
frac = Float32Field("frac", default=None)
replace = BoolField("replace", default=None)
weights = KeyField("weights", default=None)
seed = Int32Field("seed", default=None)
_random_state = RandomStateField("random_state", default=None)
errors = StringField("errors", default=None)
# for chunks
# num of instances for chunks
input_nsplits = NDArrayField("input_nsplits", default=None)
def __init__(self, random_state=None, **kw):
super().__init__(_random_state=random_state, **kw)
@property
def random_state(self):
return self._random_state
def _set_inputs(self, inputs):
super()._set_inputs(inputs)
input_iter = iter(inputs)
next(input_iter)
if isinstance(self.weights, ENTITY_TYPE):
self.weights = next(input_iter)
def __call__(self, groupby):
df = groupby
while df.op.output_types[0] not in (OutputType.dataframe, OutputType.series):
df = df.inputs[0]
selection = groupby.op.groupby_params.pop("selection", None)
if df.ndim > 1 and selection:
if isinstance(selection, tuple) and selection not in df.dtypes:
selection = list(selection)
result_df = df[selection]
else:
result_df = df
params = result_df.params
params["shape"] = (
(np.nan,) if result_df.ndim == 1 else (np.nan, result_df.shape[-1])
)
params["index_value"] = parse_index(result_df.index_value.to_pandas()[:0])
input_dfs = [df]
if isinstance(self.weights, ENTITY_TYPE):
input_dfs.append(self.weights)
self._output_types = get_output_types(result_df)
return self.new_tileable(input_dfs, **params)
def groupby_sample(
groupby,
n: Optional[int] = None,
frac: Optional[float] = None,
replace: bool = False,
weights: Union[Sequence, pd.Series, None] = None,
random_state: Optional[np.random.RandomState] = None,
errors: str = "ignore",
):
"""
Return a random sample of items from each group.
You can use `random_state` for reproducibility.
Parameters
----------
n : int, optional
Number of items to return for each group. Cannot be used with
`frac` and must be no larger than the smallest group unless
`replace` is True. Default is one if `frac` is None.
frac : float, optional
Fraction of items to return. Cannot be used with `n`.
replace : bool, default False
Allow or disallow sampling of the same row more than once.
weights : list-like, optional
Default None results in equal probability weighting.
If passed a list-like then values must have the same length as
the underlying DataFrame or Series object and will be used as
sampling probabilities after normalization within each group.
Values must be non-negative with at least one positive element
within each group.
random_state : int, array-like, BitGenerator, np.random.RandomState, optional
If int, array-like, or BitGenerator (NumPy>=1.17), seed for
random number generator
If np.random.RandomState, use as numpy RandomState object.
errors : {'ignore', 'raise'}, default 'ignore'
If ignore, errors will not be raised when `replace` is False
and size of some group is less than `n`.
Returns
-------
Series or DataFrame
A new object of same type as caller containing items randomly
sampled within each group from the caller object.
See Also
--------
DataFrame.sample: Generate random samples from a DataFrame object.
numpy.random.choice: Generate a random sample from a given 1-D numpy
array.
Examples
--------
>>> import maxframe.dataframe as md
>>> df = md.DataFrame(
... {"a": ["red"] * 2 + ["blue"] * 2 + ["black"] * 2, "b": range(6)}
... )
>>> df.execute()
a b
0 red 0
1 red 1
2 blue 2
3 blue 3
4 black 4
5 black 5
Select one row at random for each distinct value in column a. The
`random_state` argument can be used to guarantee reproducibility:
>>> df.groupby("a").sample(n=1, random_state=1).execute()
a b
4 black 4
2 blue 2
1 red 1
Set `frac` to sample fixed proportions rather than counts:
>>> df.groupby("a")["b"].sample(frac=0.5, random_state=2).execute()
5 5
2 2
0 0
Name: b, dtype: int64
Control sample probabilities within groups by setting weights:
>>> df.groupby("a").sample(
... n=1,
... weights=[1, 1, 1, 0, 0, 1],
... random_state=1,
... ).execute()
a b
5 black 5
2 blue 2
0 red 0
"""
groupby_params = groupby.op.groupby_params.copy()
groupby_params.pop("as_index", None)
if weights is not None and not isinstance(weights, ENTITY_TYPE):
weights = asseries(weights)
n = 1 if n is None and frac is None else n
rs = copy.deepcopy(
random_state.to_numpy() if hasattr(random_state, "to_numpy") else random_state
)
if not isinstance(rs, np.random.RandomState): # pragma: no cover
rs = np.random.RandomState(rs)
op = GroupBySample(
size=n,
frac=frac,
replace=replace,
weights=weights,
random_state=rs,
groupby_params=groupby_params,
errors=errors,
)
return op(groupby)