pkg/query/logical/measure/measure_plan_aggregation.go (235 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package measure
import (
"context"
"fmt"
"github.com/pkg/errors"
"go.uber.org/multierr"
databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
measurev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/measure/v1"
modelv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/model/v1"
"github.com/apache/skywalking-banyandb/pkg/query/aggregation"
"github.com/apache/skywalking-banyandb/pkg/query/executor"
"github.com/apache/skywalking-banyandb/pkg/query/logical"
)
var (
_ logical.UnresolvedPlan = (*unresolvedAggregation)(nil)
errUnsupportedAggregationField = errors.New("unsupported aggregation operation on this field")
)
type unresolvedAggregation struct {
unresolvedInput logical.UnresolvedPlan
aggregationField *logical.Field
aggrFunc modelv1.AggregationFunction
isGroup bool
}
func newUnresolvedAggregation(input logical.UnresolvedPlan, aggrField *logical.Field, aggrFunc modelv1.AggregationFunction, isGroup bool) logical.UnresolvedPlan {
return &unresolvedAggregation{
unresolvedInput: input,
aggrFunc: aggrFunc,
aggregationField: aggrField,
isGroup: isGroup,
}
}
func (gba *unresolvedAggregation) Analyze(measureSchema logical.Schema) (logical.Plan, error) {
prevPlan, err := gba.unresolvedInput.Analyze(measureSchema)
if err != nil {
return nil, err
}
// check validity of aggregation fields
schema := prevPlan.Schema()
aggregationFieldRefs, err := schema.CreateFieldRef(gba.aggregationField)
if err != nil {
return nil, err
}
if len(aggregationFieldRefs) == 0 {
return nil, errors.Wrap(errFieldNotDefined, "aggregation schema")
}
fieldRef := aggregationFieldRefs[0]
switch fieldRef.Spec.Spec.FieldType {
case databasev1.FieldType_FIELD_TYPE_INT:
return newAggregationPlan[int64](gba, prevPlan, schema, fieldRef)
case databasev1.FieldType_FIELD_TYPE_FLOAT:
return newAggregationPlan[float64](gba, prevPlan, schema, fieldRef)
default:
return nil, errors.WithMessagef(errUnsupportedAggregationField, "field: %s", fieldRef.Spec.Spec)
}
}
type aggregationPlan[N aggregation.Number] struct {
*logical.Parent
schema logical.Schema
aggregationFieldRef *logical.FieldRef
aggrFunc aggregation.Func[N]
aggrType modelv1.AggregationFunction
isGroup bool
}
func newAggregationPlan[N aggregation.Number](gba *unresolvedAggregation, prevPlan logical.Plan,
measureSchema logical.Schema, fieldRef *logical.FieldRef,
) (*aggregationPlan[N], error) {
aggrFunc, err := aggregation.NewFunc[N](gba.aggrFunc)
if err != nil {
return nil, err
}
return &aggregationPlan[N]{
Parent: &logical.Parent{
UnresolvedInput: gba.unresolvedInput,
Input: prevPlan,
},
schema: measureSchema,
aggrFunc: aggrFunc,
aggregationFieldRef: fieldRef,
isGroup: gba.isGroup,
}, nil
}
func (g *aggregationPlan[N]) String() string {
return fmt.Sprintf("%s aggregation: aggregation{type=%d,field=%s}",
g.Input,
g.aggrType,
g.aggregationFieldRef.Field.Name)
}
func (g *aggregationPlan[N]) Children() []logical.Plan {
return []logical.Plan{g.Input}
}
func (g *aggregationPlan[N]) Schema() logical.Schema {
return g.schema.ProjFields(g.aggregationFieldRef)
}
func (g *aggregationPlan[N]) Execute(ec context.Context) (executor.MIterator, error) {
iter, err := g.Parent.Input.(executor.MeasureExecutable).Execute(ec)
if err != nil {
return nil, err
}
if g.isGroup {
return newAggGroupMIterator(iter, g.aggregationFieldRef, g.aggrFunc), nil
}
return newAggAllIterator(iter, g.aggregationFieldRef, g.aggrFunc), nil
}
type aggGroupIterator[N aggregation.Number] struct {
prev executor.MIterator
aggregationFieldRef *logical.FieldRef
aggrFunc aggregation.Func[N]
err error
}
func newAggGroupMIterator[N aggregation.Number](
prev executor.MIterator,
aggregationFieldRef *logical.FieldRef,
aggrFunc aggregation.Func[N],
) executor.MIterator {
return &aggGroupIterator[N]{
prev: prev,
aggregationFieldRef: aggregationFieldRef,
aggrFunc: aggrFunc,
}
}
func (ami *aggGroupIterator[N]) Next() bool {
if ami.err != nil {
return false
}
return ami.prev.Next()
}
func (ami *aggGroupIterator[N]) Current() []*measurev1.DataPoint {
if ami.err != nil {
return nil
}
ami.aggrFunc.Reset()
group := ami.prev.Current()
var resultDp *measurev1.DataPoint
for _, dp := range group {
value := dp.GetFields()[ami.aggregationFieldRef.Spec.FieldIdx].
GetValue()
v, err := aggregation.FromFieldValue[N](value)
if err != nil {
ami.err = err
return nil
}
ami.aggrFunc.In(v)
if resultDp != nil {
continue
}
resultDp = &measurev1.DataPoint{
TagFamilies: dp.TagFamilies,
}
}
if resultDp == nil {
return nil
}
val, err := aggregation.ToFieldValue(ami.aggrFunc.Val())
if err != nil {
ami.err = err
return nil
}
resultDp.Fields = []*measurev1.DataPoint_Field{
{
Name: ami.aggregationFieldRef.Field.Name,
Value: val,
},
}
return []*measurev1.DataPoint{resultDp}
}
func (ami *aggGroupIterator[N]) Close() error {
return multierr.Combine(ami.err, ami.prev.Close())
}
type aggAllIterator[N aggregation.Number] struct {
prev executor.MIterator
aggregationFieldRef *logical.FieldRef
aggrFunc aggregation.Func[N]
result *measurev1.DataPoint
err error
}
func newAggAllIterator[N aggregation.Number](
prev executor.MIterator,
aggregationFieldRef *logical.FieldRef,
aggrFunc aggregation.Func[N],
) executor.MIterator {
return &aggAllIterator[N]{
prev: prev,
aggregationFieldRef: aggregationFieldRef,
aggrFunc: aggrFunc,
}
}
func (ami *aggAllIterator[N]) Next() bool {
if ami.result != nil || ami.err != nil {
return false
}
var resultDp *measurev1.DataPoint
for ami.prev.Next() {
group := ami.prev.Current()
for _, dp := range group {
value := dp.GetFields()[ami.aggregationFieldRef.Spec.FieldIdx].
GetValue()
v, err := aggregation.FromFieldValue[N](value)
if err != nil {
ami.err = err
return false
}
ami.aggrFunc.In(v)
if resultDp != nil {
continue
}
resultDp = &measurev1.DataPoint{
TagFamilies: dp.TagFamilies,
}
}
}
if resultDp == nil {
return false
}
val, err := aggregation.ToFieldValue(ami.aggrFunc.Val())
if err != nil {
ami.err = err
return false
}
resultDp.Fields = []*measurev1.DataPoint_Field{
{
Name: ami.aggregationFieldRef.Field.Name,
Value: val,
},
}
ami.result = resultDp
return true
}
func (ami *aggAllIterator[N]) Current() []*measurev1.DataPoint {
if ami.result == nil {
return nil
}
return []*measurev1.DataPoint{ami.result}
}
func (ami *aggAllIterator[N]) Close() error {
return ami.prev.Close()
}