spark/sql/group.go (148 lines of code) (raw):
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"context"
"github.com/apache/spark-connect-go/v35/spark/sql/types"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
"github.com/apache/spark-connect-go/v35/spark/sql/column"
"github.com/apache/spark-connect-go/v35/spark/sql/functions"
)
type GroupedData struct {
df *dataFrameImpl
groupType string
groupingCols []column.Convertible
pivotValues []types.LiteralType
pivotCol column.Convertible
}
// Agg compute aggregates and returns the result as a DataFrame. The aggegrate expressions
// are passed as column.Column arguments.
func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error) {
if len(exprs) == 0 {
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be empty")
}
agg := &proto.Aggregate{
Input: gd.df.relation,
}
// Add all grouping and aggregate expressions.
agg.GroupingExpressions = make([]*proto.Expression, len(gd.groupingCols))
for i, col := range gd.groupingCols {
exp, err := col.ToProto(ctx)
if err != nil {
return nil, err
}
agg.GroupingExpressions[i] = exp
}
agg.AggregateExpressions = make([]*proto.Expression, len(exprs))
for i, expr := range exprs {
exp, err := expr.ToProto(ctx)
if err != nil {
return nil, err
}
agg.AggregateExpressions[i] = exp
}
// Apply the groupType
switch gd.groupType {
case "pivot":
agg.GroupType = proto.Aggregate_GROUP_TYPE_PIVOT
// Apply all pivot behavior and convert columns into literals.
if len(gd.pivotValues) == 0 {
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "pivotValues should not be empty")
}
protoCol, err := gd.pivotCol.ToProto(ctx)
if err != nil {
return nil, err
}
agg.Pivot = &proto.Aggregate_Pivot{
Values: make([]*proto.Expression_Literal, len(gd.pivotValues)),
Col: protoCol,
}
for i, v := range gd.pivotValues {
exp, err := column.NewLiteral(v).ToProto(ctx)
if err != nil {
return nil, err
}
agg.Pivot.Values[i] = exp.GetLiteral()
}
case "groupby":
agg.GroupType = proto.Aggregate_GROUP_TYPE_GROUPBY
case "rollup":
agg.GroupType = proto.Aggregate_GROUP_TYPE_ROLLUP
case "cube":
agg.GroupType = proto.Aggregate_GROUP_TYPE_CUBE
}
rel := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
},
RelType: &proto.Relation_Aggregate{
Aggregate: agg,
},
}
return NewDataFrame(gd.df.session, rel), nil
}
func (gd *GroupedData) numericAgg(ctx context.Context, name string, cols ...string) (DataFrame, error) {
schema, err := gd.df.Schema(ctx)
if err != nil {
return nil, err
}
// Find all numeric cols in the schema:
numericCols := make([]string, 0)
for _, field := range schema.Fields {
if field.DataType.IsNumeric() {
numericCols = append(numericCols, field.Name)
}
}
aggCols := cols
if len(cols) > 0 {
invalidCols := make([]string, 0)
for _, col := range cols {
found := false
for _, nc := range numericCols {
if col == nc {
found = true
}
}
if !found {
invalidCols = append(invalidCols, col)
}
}
if len(invalidCols) > 0 {
return nil, sparkerrors.WithStringf(sparkerrors.InvalidInputError,
"columns %v are not numeric", invalidCols)
}
} else {
aggCols = numericCols
}
finalColumns := make([]column.Convertible, len(aggCols))
for i, col := range aggCols {
finalColumns[i] = column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, functions.Col(col)))
}
return gd.Agg(ctx, finalColumns...)
}
// Min Computes the min value for each numeric column for each group.
func (gd *GroupedData) Min(ctx context.Context, cols ...string) (DataFrame, error) {
return gd.numericAgg(ctx, "min", cols...)
}
// Max Computes the max value for each numeric column for each group.
func (gd *GroupedData) Max(ctx context.Context, cols ...string) (DataFrame, error) {
return gd.numericAgg(ctx, "max", cols...)
}
// Avg Computes the avg value for each numeric column for each group.
func (gd *GroupedData) Avg(ctx context.Context, cols ...string) (DataFrame, error) {
return gd.numericAgg(ctx, "avg", cols...)
}
// Sum Computes the sum value for each numeric column for each group.
func (gd *GroupedData) Sum(ctx context.Context, cols ...string) (DataFrame, error) {
return gd.numericAgg(ctx, "sum", cols...)
}
// Count Computes the count value for each group.
func (gd *GroupedData) Count(ctx context.Context) (DataFrame, error) {
return gd.Agg(ctx, functions.Count(functions.Lit(types.Int64(1))).Alias("count"))
}
// Mean Computes the average value for each numeric column for each group.
func (gd *GroupedData) Mean(ctx context.Context, cols ...string) (DataFrame, error) {
return gd.Avg(ctx, cols...)
}
func (gd *GroupedData) Pivot(ctx context.Context, pivotCol string, pivotValues []types.LiteralType) (*GroupedData, error) {
if gd.groupType != "groupby" {
if gd.groupType == "pivot" {
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot cannot be applied on pivot")
}
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "pivot can only be applied on groupby")
}
return &GroupedData{
df: gd.df,
groupType: "pivot",
groupingCols: gd.groupingCols,
pivotValues: pivotValues,
pivotCol: column.NewColumnReferenceWithPlanId(pivotCol, gd.df.PlanId()),
}, nil
}