in spark/sql/group.go [111:152]
func (gd *GroupedData) numericAgg(ctx context.Context, name string, cols ...string) (DataFrame, error) {
schema, err := gd.df.Schema(ctx)
if err != nil {
return nil, err
}
// Find all numeric cols in the schema:
numericCols := make([]string, 0)
for _, field := range schema.Fields {
if field.DataType.IsNumeric() {
numericCols = append(numericCols, field.Name)
}
}
aggCols := cols
if len(cols) > 0 {
invalidCols := make([]string, 0)
for _, col := range cols {
found := false
for _, nc := range numericCols {
if col == nc {
found = true
}
}
if !found {
invalidCols = append(invalidCols, col)
}
}
if len(invalidCols) > 0 {
return nil, sparkerrors.WithStringf(sparkerrors.InvalidInputError,
"columns %v are not numeric", invalidCols)
}
} else {
aggCols = numericCols
}
finalColumns := make([]column.Convertible, len(aggCols))
for i, col := range aggCols {
finalColumns[i] = column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, functions.Col(col)))
}
return gd.Agg(ctx, finalColumns...)
}