in odps/df/expr/collections.py [0:0]
def pivot_table(expr, values=None, rows=None, columns=None, aggfunc='mean',
fill_value=None):
"""
Create a spreadsheet-style pivot table as a DataFrame.
:param expr: collection
:param values (optional): column to aggregate
:param rows: rows to group
:param columns: keys to group by on the pivot table column
:param aggfunc: aggregate function or functions
:param fill_value (optional): value to replace missing value with, default None
:return: collection
:Example:
>>> df
A B C D
0 foo one small 1
1 foo one large 2
2 foo one large 2
3 foo two small 3
4 foo two small 3
5 bar one large 4
6 bar one small 5
7 bar two small 6
8 bar two large 7
>>> table = df.pivot_table(values='D', rows=['A', 'B'], columns='C', aggfunc='sum')
>>> table
A B large_D_sum small_D_sum
0 bar one 4.0 5.0
1 bar two 7.0 6.0
2 foo one 4.0 1.0
3 foo two NaN 6.0
"""
def get_names(iters):
return [r if isinstance(r, six.string_types) else r.name
for r in iters]
def get_aggfunc_name(f):
if isinstance(f, six.string_types):
if '(' in f:
f = re.sub(r' *\( *', '_', f)
f = re.sub(r' *[+\-\*/,] *', '_', f)
f = re.sub(r' *\) *', '', f)
f = f.replace('.', '_')
return f
if isinstance(f, FunctionWrapper):
return f.output_names[0]
return 'aggregation'
if not rows:
raise ValueError('No group keys passed')
rows = utils.to_list(rows)
rows_names = get_names(rows)
rows = [expr._get_field(r) for r in rows]
if isinstance(aggfunc, dict):
agg_func_names = lkeys(aggfunc)
aggfunc = lvalues(aggfunc)
else:
aggfunc = utils.to_list(aggfunc)
agg_func_names = [get_aggfunc_name(af) for af in aggfunc]
if not columns:
if values is None:
values = [n for n in expr.schema.names if n not in rows_names]
else:
values = utils.to_list(values)
values = [expr._get_field(v) for v in values]
names = rows_names
types = [r.dtype for r in rows]
for func, func_name in zip(aggfunc, agg_func_names):
for value in values:
if isinstance(func, six.string_types):
seq = value.eval(func, rewrite=False)
if isinstance(seq, ReprWrapper):
seq = seq()
else:
seq = value.agg(func)
seq = seq.rename('{0}_{1}'.format(value.name, func_name))
names.append(seq.name)
types.append(seq.dtype)
schema = TableSchema.from_lists(names, types)
return PivotTableCollectionExpr(_input=expr, _group=rows, _values=values,
_fill_value=fill_value, _schema=schema,
_agg_func=aggfunc, _agg_func_names=agg_func_names)
else:
columns = [expr._get_field(c) for c in utils.to_list(columns)]
if values:
values = utils.to_list(values)
else:
names = set(c.name for c in rows + columns)
values = [n for n in expr.schema.names if n not in names]
if not values:
raise ValueError('No values found for pivot_table')
values = [expr._get_field(v) for v in values]
if len(columns) > 1:
raise ValueError('More than one `columns` are not supported yet')
schema = DynamicSchema.from_lists(rows_names, [r.dtype for r in rows])
base_tp = PivotTableCollectionExpr
tp = type(base_tp.__name__, (DynamicCollectionExpr, base_tp), dict())
return tp(_input=expr, _group=rows, _values=values,
_columns=columns, _agg_func=aggfunc,
_fill_value=fill_value, _schema=schema,
_agg_func_names=agg_func_names)