in spark/sql/group.go [40:109]
func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error) {
if len(exprs) == 0 {
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be empty")
}
agg := &proto.Aggregate{
Input: gd.df.relation,
}
// Add all grouping and aggregate expressions.
agg.GroupingExpressions = make([]*proto.Expression, len(gd.groupingCols))
for i, col := range gd.groupingCols {
exp, err := col.ToProto(ctx)
if err != nil {
return nil, err
}
agg.GroupingExpressions[i] = exp
}
agg.AggregateExpressions = make([]*proto.Expression, len(exprs))
for i, expr := range exprs {
exp, err := expr.ToProto(ctx)
if err != nil {
return nil, err
}
agg.AggregateExpressions[i] = exp
}
// Apply the groupType
switch gd.groupType {
case "pivot":
agg.GroupType = proto.Aggregate_GROUP_TYPE_PIVOT
// Apply all pivot behavior and convert columns into literals.
if len(gd.pivotValues) == 0 {
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "pivotValues should not be empty")
}
protoCol, err := gd.pivotCol.ToProto(ctx)
if err != nil {
return nil, err
}
agg.Pivot = &proto.Aggregate_Pivot{
Values: make([]*proto.Expression_Literal, len(gd.pivotValues)),
Col: protoCol,
}
for i, v := range gd.pivotValues {
exp, err := column.NewLiteral(v).ToProto(ctx)
if err != nil {
return nil, err
}
agg.Pivot.Values[i] = exp.GetLiteral()
}
case "groupby":
agg.GroupType = proto.Aggregate_GROUP_TYPE_GROUPBY
case "rollup":
agg.GroupType = proto.Aggregate_GROUP_TYPE_ROLLUP
case "cube":
agg.GroupType = proto.Aggregate_GROUP_TYPE_CUBE
}
rel := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
},
RelType: &proto.Relation_Aggregate{
Aggregate: agg,
},
}
return NewDataFrame(gd.df.session, rel), nil
}