func()

in spark/sql/group.go [111:152]


func (gd *GroupedData) numericAgg(ctx context.Context, name string, cols ...string) (DataFrame, error) {
	schema, err := gd.df.Schema(ctx)
	if err != nil {
		return nil, err
	}

	// Find all numeric cols in the schema:
	numericCols := make([]string, 0)
	for _, field := range schema.Fields {
		if field.DataType.IsNumeric() {
			numericCols = append(numericCols, field.Name)
		}
	}

	aggCols := cols
	if len(cols) > 0 {
		invalidCols := make([]string, 0)
		for _, col := range cols {
			found := false
			for _, nc := range numericCols {
				if col == nc {
					found = true
				}
			}
			if !found {
				invalidCols = append(invalidCols, col)
			}
		}
		if len(invalidCols) > 0 {
			return nil, sparkerrors.WithStringf(sparkerrors.InvalidInputError,
				"columns %v are not numeric", invalidCols)
		}
	} else {
		aggCols = numericCols
	}

	finalColumns := make([]column.Convertible, len(aggCols))
	for i, col := range aggCols {
		finalColumns[i] = column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, functions.Col(col)))
	}
	return gd.Agg(ctx, finalColumns...)
}