arrow/compute/expression.go (681 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //go:build go1.18 package compute import ( "bytes" "encoding/hex" "errors" "fmt" "hash/maphash" "reflect" "strconv" "strings" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/compute/exec" "github.com/apache/arrow-go/v18/arrow/compute/internal/kernels" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apache/arrow-go/v18/arrow/memory" "github.com/apache/arrow-go/v18/arrow/scalar" ) var hashSeed = maphash.MakeSeed() // Expression is an interface for mapping one datum to another. An expression // is one of: // // A literal Datum // A reference to a single (potentially nested) field of an input Datum // A call to a compute function, with arguments specified by other Expressions // // Deprecated: use substrait-go expressions instead. type Expression interface { fmt.Stringer // IsBound returns true if this expression has been bound to a particular // Datum and/or Schema. IsBound() bool // IsScalarExpr returns true if this expression is composed only of scalar // literals, field references and calls to scalar functions. IsScalarExpr() bool // IsNullLiteral returns true if this expression is a literal and entirely // null. IsNullLiteral() bool // IsSatisfiable returns true if this expression could evaluate to true IsSatisfiable() bool // FieldRef returns a pointer to the underlying field reference, or nil if // this expression is not a field reference. FieldRef() *FieldRef // Type returns the datatype this expression will evaluate to. Type() arrow.DataType Hash() uint64 Equals(Expression) bool // Release releases the underlying bound C++ memory that is allocated when // a Bind is performed. Any bound expression should get released to ensure // no memory leaks. Release() } func printDatum(datum Datum) string { switch datum := datum.(type) { case *ScalarDatum: if !datum.Value.IsValid() { return "null" } switch datum.Type().ID() { case arrow.STRING, arrow.LARGE_STRING: return strconv.Quote(datum.Value.(scalar.BinaryScalar).String()) case arrow.BINARY, arrow.FIXED_SIZE_BINARY, arrow.LARGE_BINARY: return `"` + strings.ToUpper(hex.EncodeToString(datum.Value.(scalar.BinaryScalar).Data())) + `"` } return datum.Value.String() default: return datum.String() } } // Literal is an expression denoting a literal Datum which could be any value // as a scalar, an array, or so on. // // Deprecated: use substrait-go expressions Literal instead. type Literal struct { Literal Datum } func (Literal) FieldRef() *FieldRef { return nil } func (l *Literal) String() string { return printDatum(l.Literal) } func (l *Literal) Type() arrow.DataType { return l.Literal.(ArrayLikeDatum).Type() } func (l *Literal) IsBound() bool { return l.Type() != nil } func (l *Literal) IsScalarExpr() bool { return l.Literal.Kind() == KindScalar } func (l *Literal) Equals(other Expression) bool { if rhs, ok := other.(*Literal); ok { return l.Literal.Equals(rhs.Literal) } return false } func (l *Literal) IsNullLiteral() bool { if ad, ok := l.Literal.(ArrayLikeDatum); ok { return ad.NullN() == ad.Len() } return true } func (l *Literal) IsSatisfiable() bool { if l.IsNullLiteral() { return false } if sc, ok := l.Literal.(*ScalarDatum); ok && sc.Type().ID() == arrow.BOOL { return sc.Value.(*scalar.Boolean).Value } return true } func (l *Literal) Hash() uint64 { if l.IsScalarExpr() { return scalar.Hash(hashSeed, l.Literal.(*ScalarDatum).Value) } return 0 } func (l *Literal) Release() { l.Literal.Release() } // Parameter represents a field reference and needs to be bound in order to determine // its type and shape. // // Deprecated: use substrait-go field references instead. type Parameter struct { ref *FieldRef // post bind props dt arrow.DataType index int } func (Parameter) IsNullLiteral() bool { return false } func (p *Parameter) Type() arrow.DataType { return p.dt } func (p *Parameter) IsBound() bool { return p.Type() != nil } func (p *Parameter) IsScalarExpr() bool { return p.ref != nil } func (p *Parameter) IsSatisfiable() bool { return p.Type() == nil || p.Type().ID() != arrow.NULL } func (p *Parameter) FieldRef() *FieldRef { return p.ref } func (p *Parameter) Hash() uint64 { return p.ref.Hash(hashSeed) } func (p *Parameter) String() string { switch { case p.ref.IsName(): return p.ref.Name() case p.ref.IsFieldPath(): return p.ref.FieldPath().String() default: return p.ref.String() } } func (p *Parameter) Equals(other Expression) bool { if rhs, ok := other.(*Parameter); ok { return p.ref.Equals(*rhs.ref) } return false } func (p *Parameter) Release() {} type comparisonType int8 const ( compNA comparisonType = 0 compEQ comparisonType = 1 compLT comparisonType = 2 compGT comparisonType = 4 compNE comparisonType = compLT | compGT compLE comparisonType = compLT | compEQ compGE comparisonType = compGT | compEQ ) //lint:ignore U1000 ignore that this is unused for now func (c comparisonType) name() string { switch c { case compEQ: return "equal" case compLT: return "less" case compGT: return "greater" case compNE: return "not_equal" case compLE: return "less_equal" case compGE: return "greater_equal" } return "na" } func (c comparisonType) getOp() string { switch c { case compEQ: return "==" case compLT: return "<" case compGT: return ">" case compNE: return "!=" case compLE: return "<=" case compGE: return ">=" } debug.Assert(false, "invalid getop") return "" } var compmap = map[string]comparisonType{ "equal": compEQ, "less": compLT, "greater": compGT, "not_equal": compNE, "less_equal": compLE, "greater_equal": compGE, } func optionsToString(fn FunctionOptions) string { if s, ok := fn.(fmt.Stringer); ok { return s.String() } var b strings.Builder v := reflect.Indirect(reflect.ValueOf(fn)) b.WriteByte('{') for i := 0; i < v.Type().NumField(); i++ { fld := v.Type().Field(i) tag := fld.Tag.Get("compute") if tag == "-" { continue } fldVal := v.Field(i) fmt.Fprintf(&b, "%s=%v, ", tag, fldVal.Interface()) } ret := b.String() return ret[:len(ret)-2] + "}" } // Call is a function call with specific arguments which are themselves other // expressions. A call can also have options that are specific to the function // in question. It must be bound to determine the shape and type. // // Deprecated: use substrait-go expression functions instead. type Call struct { funcName string args []Expression dt arrow.DataType options FunctionOptions cachedHash uint64 } func (c *Call) IsNullLiteral() bool { return false } func (c *Call) FieldRef() *FieldRef { return nil } func (c *Call) Type() arrow.DataType { return c.dt } func (c *Call) IsSatisfiable() bool { return c.Type() == nil || c.Type().ID() != arrow.NULL } func (c *Call) String() string { binary := func(op string) string { return "(" + c.args[0].String() + " " + op + " " + c.args[1].String() + ")" } if cmp, ok := compmap[c.funcName]; ok { return binary(cmp.getOp()) } const kleene = "_kleene" if strings.HasSuffix(c.funcName, kleene) { return binary(strings.TrimSuffix(c.funcName, kleene)) } if c.funcName == "make_struct" && c.options != nil { opts := c.options.(*MakeStructOptions) out := "{" for i, a := range c.args { out += opts.FieldNames[i] + "=" + a.String() + ", " } return out[:len(out)-2] + "}" } var b strings.Builder b.WriteString(c.funcName + "(") for _, a := range c.args { b.WriteString(a.String() + ", ") } if c.options != nil { b.WriteString(optionsToString(c.options)) b.WriteString(" ") } ret := b.String() return ret[:len(ret)-2] + ")" } func (c *Call) Hash() uint64 { if c.cachedHash != 0 { return c.cachedHash } var h maphash.Hash h.SetSeed(hashSeed) h.WriteString(c.funcName) c.cachedHash = h.Sum64() for _, arg := range c.args { c.cachedHash = exec.HashCombine(c.cachedHash, arg.Hash()) } return c.cachedHash } func (c *Call) IsScalarExpr() bool { for _, arg := range c.args { if !arg.IsScalarExpr() { return false } } return false // return isFuncScalar(c.funcName) } func (c *Call) IsBound() bool { return c.Type() != nil } func (c *Call) Equals(other Expression) bool { rhs, ok := other.(*Call) if !ok { return false } if c.funcName != rhs.funcName || len(c.args) != len(rhs.args) { return false } for i := range c.args { if !c.args[i].Equals(rhs.args[i]) { return false } } if opt, ok := c.options.(FunctionOptionsEqual); ok { return opt.Equals(rhs.options) } return reflect.DeepEqual(c.options, rhs.options) } func (c *Call) Release() { for _, a := range c.args { a.Release() } if r, ok := c.options.(releasable); ok { r.Release() } } // FunctionOptions can be any type which has a TypeName function. The fields // of the type will be used (via reflection) to determine the information to // propagate when serializing to pass to the C++ for execution. type FunctionOptions interface { TypeName() string } type FunctionOptionsEqual interface { Equals(FunctionOptions) bool } type FunctionOptionsCloneable interface { Clone() FunctionOptions } type MakeStructOptions struct { FieldNames []string `compute:"field_names"` FieldNullability []bool `compute:"field_nullability"` FieldMetadata []*arrow.Metadata `compute:"field_metadata"` } func (MakeStructOptions) TypeName() string { return "MakeStructOptions" } type NullOptions struct { NanIsNull bool `compute:"nan_is_null"` } func (NullOptions) TypeName() string { return "NullOptions" } type StrptimeOptions struct { Format string `compute:"format"` Unit arrow.TimeUnit `compute:"unit"` } func (StrptimeOptions) TypeName() string { return "StrptimeOptions" } type NullSelectionBehavior = kernels.NullSelectionBehavior const ( SelectionEmitNulls = kernels.EmitNulls SelectionDropNulls = kernels.DropNulls ) type ArithmeticOptions struct { NoCheckOverflow bool `compute:"check_overflow"` } func (ArithmeticOptions) TypeName() string { return "ArithmeticOptions" } type ( CastOptions = kernels.CastOptions FilterOptions = kernels.FilterOptions TakeOptions = kernels.TakeOptions ) func DefaultFilterOptions() *FilterOptions { return &FilterOptions{} } func DefaultTakeOptions() *TakeOptions { return &TakeOptions{BoundsCheck: true} } func DefaultCastOptions(safe bool) *CastOptions { if safe { return &CastOptions{} } return &CastOptions{ AllowIntOverflow: true, AllowTimeTruncate: true, AllowTimeOverflow: true, AllowDecimalTruncate: true, AllowFloatTruncate: true, AllowInvalidUtf8: true, } } func UnsafeCastOptions(dt arrow.DataType) *CastOptions { return NewCastOptions(dt, false) } func SafeCastOptions(dt arrow.DataType) *CastOptions { return NewCastOptions(dt, true) } func NewCastOptions(dt arrow.DataType, safe bool) *CastOptions { opts := DefaultCastOptions(safe) if dt != nil { opts.ToType = dt } else { opts.ToType = arrow.Null } return opts } func Cast(ex Expression, dt arrow.DataType) Expression { opts := &CastOptions{} if dt == nil { opts.ToType = arrow.Null } else { opts.ToType = dt } return NewCall("cast", []Expression{ex}, opts) } // Deprecated: Use SetOptions instead type SetLookupOptions struct { ValueSet Datum `compute:"value_set"` SkipNulls bool `compute:"skip_nulls"` } func (SetLookupOptions) TypeName() string { return "SetLookupOptions" } func (s *SetLookupOptions) Release() { s.ValueSet.Release() } func (s *SetLookupOptions) Equals(other FunctionOptions) bool { rhs, ok := other.(*SetLookupOptions) if !ok { return false } return s.SkipNulls == rhs.SkipNulls && s.ValueSet.Equals(rhs.ValueSet) } func (s *SetLookupOptions) FromStructScalar(sc *scalar.Struct) error { if v, err := sc.Field("skip_nulls"); err == nil { s.SkipNulls = v.(*scalar.Boolean).Value } value, err := sc.Field("value_set") if err != nil { return err } if v, ok := value.(scalar.ListScalar); ok { s.ValueSet = NewDatum(v.GetList()) return nil } return errors.New("set lookup options valueset should be a list") } var ( funcOptionsMap map[string]reflect.Type funcOptsTypes = []FunctionOptions{ SetLookupOptions{}, ArithmeticOptions{}, CastOptions{}, FilterOptions{}, NullOptions{}, StrptimeOptions{}, MakeStructOptions{}, } ) func init() { funcOptionsMap = make(map[string]reflect.Type) for _, ft := range funcOptsTypes { funcOptionsMap[ft.TypeName()] = reflect.TypeOf(ft) } } // NewLiteral constructs a new literal expression from any value. It is passed // to NewDatum which will construct the appropriate Datum and/or scalar // value for the type provided. func NewLiteral(arg interface{}) Expression { return &Literal{Literal: NewDatum(arg)} } func NullLiteral(dt arrow.DataType) Expression { return &Literal{Literal: NewDatum(scalar.MakeNullScalar(dt))} } // NewRef constructs a parameter expression which refers to a specific field func NewRef(ref FieldRef) Expression { return &Parameter{ref: &ref, index: -1} } // NewFieldRef is shorthand for NewRef(FieldRefName(field)) func NewFieldRef(field string) Expression { return NewRef(FieldRefName(field)) } // NewCall constructs an expression that represents a specific function call with // the given arguments and options. func NewCall(name string, args []Expression, opts FunctionOptions) Expression { return &Call{funcName: name, args: args, options: opts} } // Project is shorthand for `make_struct` to produce a record batch output // from a group of expressions. func Project(values []Expression, names []string) Expression { nulls := make([]bool, len(names)) for i := range nulls { nulls[i] = true } meta := make([]*arrow.Metadata, len(names)) return NewCall("make_struct", values, &MakeStructOptions{FieldNames: names, FieldNullability: nulls, FieldMetadata: meta}) } // Equal is a convenience function for the equal function func Equal(lhs, rhs Expression) Expression { return NewCall("equal", []Expression{lhs, rhs}, nil) } // NotEqual creates a call to not_equal func NotEqual(lhs, rhs Expression) Expression { return NewCall("not_equal", []Expression{lhs, rhs}, nil) } // Less is shorthand for NewCall("less",....) func Less(lhs, rhs Expression) Expression { return NewCall("less", []Expression{lhs, rhs}, nil) } // LessEqual is shorthand for NewCall("less_equal",....) func LessEqual(lhs, rhs Expression) Expression { return NewCall("less_equal", []Expression{lhs, rhs}, nil) } // Greater is shorthand for NewCall("greater",....) func Greater(lhs, rhs Expression) Expression { return NewCall("greater", []Expression{lhs, rhs}, nil) } // GreaterEqual is shorthand for NewCall("greater_equal",....) func GreaterEqual(lhs, rhs Expression) Expression { return NewCall("greater_equal", []Expression{lhs, rhs}, nil) } // IsNull creates an expression that returns true if the passed in expression is // null. Optionally treating NaN as null if desired. func IsNull(lhs Expression, nanIsNull bool) Expression { return NewCall("less", []Expression{lhs}, &NullOptions{nanIsNull}) } // IsValid is the inverse of IsNull func IsValid(lhs Expression) Expression { return NewCall("is_valid", []Expression{lhs}, nil) } type binop func(lhs, rhs Expression) Expression func foldLeft(op binop, args ...Expression) Expression { switch len(args) { case 0: return nil case 1: return args[0] } folded := args[0] for _, a := range args[1:] { folded = op(folded, a) } return folded } func and(lhs, rhs Expression) Expression { return NewCall("and_kleene", []Expression{lhs, rhs}, nil) } // And constructs a tree of calls to and_kleene for boolean And logic taking // an arbitrary number of values. func And(lhs, rhs Expression, ops ...Expression) Expression { folded := foldLeft(and, append([]Expression{lhs, rhs}, ops...)...) if folded != nil { return folded } return NewLiteral(true) } func or(lhs, rhs Expression) Expression { return NewCall("or_kleene", []Expression{lhs, rhs}, nil) } // Or constructs a tree of calls to or_kleene for boolean Or logic taking // an arbitrary number of values. func Or(lhs, rhs Expression, ops ...Expression) Expression { folded := foldLeft(or, append([]Expression{lhs, rhs}, ops...)...) if folded != nil { return folded } return NewLiteral(false) } // Not creates a call to "invert" for the value specified. func Not(expr Expression) Expression { return NewCall("invert", []Expression{expr}, nil) } func SerializeOptions(opts FunctionOptions, mem memory.Allocator) (*memory.Buffer, error) { sc, err := scalar.ToScalar(opts, mem) if err != nil { return nil, err } if sc, ok := sc.(releasable); ok { defer sc.Release() } arr, err := scalar.MakeArrayFromScalar(sc, 1, mem) if err != nil { return nil, err } defer arr.Release() batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Type: arr.DataType(), Nullable: true}}, nil), []arrow.Array{arr}, 1) defer batch.Release() buf := &bufferWriteSeeker{mem: mem} wr, err := ipc.NewFileWriter(buf, ipc.WithSchema(batch.Schema()), ipc.WithAllocator(mem)) if err != nil { return nil, err } wr.Write(batch) wr.Close() return buf.buf, nil } // SerializeExpr serializes expressions by converting them to Metadata and // storing this in the schema of a Record. Embedded arrays and scalars are // stored in its columns. Finally the record is written as an IPC file func SerializeExpr(expr Expression, mem memory.Allocator) (*memory.Buffer, error) { var ( cols []arrow.Array metaKey []string metaValue []string visit func(Expression) error ) addScalar := func(s scalar.Scalar) (string, error) { ret := len(cols) arr, err := scalar.MakeArrayFromScalar(s, 1, mem) if err != nil { return "", err } cols = append(cols, arr) return strconv.Itoa(ret), nil } visit = func(e Expression) error { switch e := e.(type) { case *Literal: if !e.IsScalarExpr() { return errors.New("not implemented: serialization of non-scalar literals") } metaKey = append(metaKey, "literal") s, err := addScalar(e.Literal.(*ScalarDatum).Value) if err != nil { return err } metaValue = append(metaValue, s) case *Parameter: if e.ref.Name() == "" { return errors.New("not implemented: serialization of non-name field_ref") } metaKey = append(metaKey, "field_ref") metaValue = append(metaValue, e.ref.Name()) case *Call: metaKey = append(metaKey, "call") metaValue = append(metaValue, e.funcName) for _, arg := range e.args { visit(arg) } if e.options != nil { st, err := scalar.ToScalar(e.options, mem) if err != nil { return err } metaKey = append(metaKey, "options") s, err := addScalar(st) if err != nil { return err } metaValue = append(metaValue, s) for _, f := range st.(*scalar.Struct).Value { switch s := f.(type) { case releasable: defer s.Release() } } } metaKey = append(metaKey, "end") metaValue = append(metaValue, e.funcName) } return nil } if err := visit(expr); err != nil { return nil, err } fields := make([]arrow.Field, len(cols)) for i, c := range cols { fields[i].Type = c.DataType() defer c.Release() } metadata := arrow.NewMetadata(metaKey, metaValue) rec := array.NewRecord(arrow.NewSchema(fields, &metadata), cols, 1) defer rec.Release() buf := &bufferWriteSeeker{mem: mem} wr, err := ipc.NewFileWriter(buf, ipc.WithSchema(rec.Schema()), ipc.WithAllocator(mem)) if err != nil { return nil, err } wr.Write(rec) wr.Close() return buf.buf, nil } func DeserializeExpr(mem memory.Allocator, buf *memory.Buffer) (Expression, error) { rdr, err := ipc.NewFileReader(bytes.NewReader(buf.Bytes()), ipc.WithAllocator(mem)) if err != nil { return nil, err } defer rdr.Close() batch, err := rdr.Read() if err != nil { return nil, err } if !batch.Schema().HasMetadata() { return nil, errors.New("serialized Expression's batch repr had no metadata") } if batch.NumRows() != 1 { return nil, fmt.Errorf("serialized Expression's batch repr was not a single row - had %d", batch.NumRows()) } var ( getone func() (Expression, error) index int = 0 metadata = batch.Schema().Metadata() ) getscalar := func(i string) (scalar.Scalar, error) { colIndex, err := strconv.ParseInt(i, 10, 32) if err != nil { return nil, err } if colIndex >= batch.NumCols() { return nil, errors.New("column index out of bounds") } return scalar.GetScalar(batch.Column(int(colIndex)), 0) } getone = func() (Expression, error) { if index >= metadata.Len() { return nil, errors.New("unterminated serialized Expression") } key, val := metadata.Keys()[index], metadata.Values()[index] index++ switch key { case "literal": scalar, err := getscalar(val) if err != nil { return nil, err } if r, ok := scalar.(releasable); ok { defer r.Release() } return NewLiteral(scalar), err case "field_ref": return NewFieldRef(val), nil case "call": args := make([]Expression, 0) for metadata.Keys()[index] != "end" { if metadata.Keys()[index] == "options" { optsScalar, err := getscalar(metadata.Values()[index]) if err != nil { return nil, err } if r, ok := optsScalar.(releasable); ok { defer r.Release() } var opts FunctionOptions if optsScalar != nil { typname, err := optsScalar.(*scalar.Struct).Field("_type_name") if err != nil { return nil, err } if typname.DataType().ID() != arrow.BINARY { return nil, errors.New("options scalar typename must be binary") } optionsVal := reflect.New(funcOptionsMap[string(typname.(*scalar.Binary).Data())]).Interface() if err := scalar.FromScalar(optsScalar.(*scalar.Struct), optionsVal); err != nil { return nil, err } opts = optionsVal.(FunctionOptions) } index += 2 return NewCall(val, args, opts), nil } arg, err := getone() if err != nil { return nil, err } args = append(args, arg) } index++ return NewCall(val, args, nil), nil default: return nil, fmt.Errorf("unrecognized serialized Expression key %s", key) } } return getone() }