arrow/decimal/decimal.go (359 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package decimal import ( "errors" "fmt" "math" "math/big" "math/bits" "unsafe" "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/internal/debug" ) // DecimalTypes is a generic constraint representing the implemented decimal types // in this package, and a single point of update for future additions. Everything // else is constrained by this. type DecimalTypes interface { Decimal32 | Decimal64 | Decimal128 | Decimal256 } // Num is an interface that is useful for building generic types for all decimal // type implementations. It presents all the methods and interfaces necessary to // operate on the decimal objects without having to care about the bit width. type Num[T DecimalTypes] interface { Negate() T Add(T) T Sub(T) T Mul(T) T Div(T) (res, rem T) Pow(T) T FitsInPrecision(int32) bool Abs() T Sign() int Rescale(int32, int32) (T, error) Cmp(T) int IncreaseScaleBy(int32) T ReduceScaleBy(int32, bool) T ToFloat32(int32) float32 ToFloat64(int32) float64 ToBigFloat(int32) *big.Float ToString(int32) string } type ( Decimal32 int32 Decimal64 int64 Decimal128 = decimal128.Num Decimal256 = decimal256.Num ) func MaxPrecision[T DecimalTypes]() int { // max precision is computed by Floor(log10(2^(nbytes * 8 - 1) - 1)) var z T return int(math.Floor(math.Log10(math.Pow(2, float64(unsafe.Sizeof(z))*8-1) - 1))) } func (d Decimal32) Negate() Decimal32 { return -d } func (d Decimal64) Negate() Decimal64 { return -d } func (d Decimal32) Add(rhs Decimal32) Decimal32 { return d + rhs } func (d Decimal64) Add(rhs Decimal64) Decimal64 { return d + rhs } func (d Decimal32) Sub(rhs Decimal32) Decimal32 { return d - rhs } func (d Decimal64) Sub(rhs Decimal64) Decimal64 { return d - rhs } func (d Decimal32) Mul(rhs Decimal32) Decimal32 { return d * rhs } func (d Decimal64) Mul(rhs Decimal64) Decimal64 { return d * rhs } func (d Decimal32) Div(rhs Decimal32) (res, rem Decimal32) { return d / rhs, d % rhs } func (d Decimal64) Div(rhs Decimal64) (res, rem Decimal64) { return d / rhs, d % rhs } // about 4-5x faster than using math.Pow which requires converting to float64 // and back to integers func intPow[T int32 | int64](base, exp T) T { result := T(1) for { if exp&1 == 1 { result *= base } exp >>= 1 if exp == 0 { break } base *= base } return result } func (d Decimal32) Pow(rhs Decimal32) Decimal32 { return Decimal32(intPow(int32(d), int32(rhs))) } func (d Decimal64) Pow(rhs Decimal64) Decimal64 { return Decimal64(intPow(int64(d), int64(rhs))) } func (d Decimal32) Sign() int { if d == 0 { return 0 } return int(1 | (d >> 31)) } func (d Decimal64) Sign() int { if d == 0 { return 0 } return int(1 | (d >> 63)) } func (n Decimal32) Abs() Decimal32 { if n < 0 { return -n } return n } func (n Decimal64) Abs() Decimal64 { if n < 0 { return -n } return n } func (n Decimal32) FitsInPrecision(prec int32) bool { debug.Assert(prec > 0, "precision must be > 0") debug.Assert(prec <= 9, "precision must be <= 9") return n.Abs() < Decimal32(math.Pow10(int(prec))) } func (n Decimal64) FitsInPrecision(prec int32) bool { debug.Assert(prec > 0, "precision must be > 0") debug.Assert(prec <= 18, "precision must be <= 18") return n.Abs() < Decimal64(math.Pow10(int(prec))) } func (n Decimal32) ToString(scale int32) string { return n.ToBigFloat(scale).Text('f', int(scale)) } func (n Decimal64) ToString(scale int32) string { return n.ToBigFloat(scale).Text('f', int(scale)) } var pt5 = big.NewFloat(0.5) func decimalFromString[T interface { Decimal32 | Decimal64 FitsInPrecision(int32) bool }](v string, prec, scale int32) (n T, err error) { var nbits = uint(unsafe.Sizeof(T(0))) * 8 var out *big.Float out, _, err = big.ParseFloat(v, 10, nbits, big.ToNearestEven) if scale < 0 { var tmp big.Int val, _ := out.Int(&tmp) if val.BitLen() > int(nbits) { return n, fmt.Errorf("bitlen too large for decimal%d", nbits) } n = T(val.Int64() / int64(math.Pow10(int(-scale)))) } else { var precInBits = uint(math.Round(float64(prec+scale+1)/math.Log10(2))) + 1 p := (&big.Float{}).SetInt(big.NewInt(int64(math.Pow10(int(scale))))) out.SetPrec(precInBits).Mul(out, p) if out.Signbit() { out.Sub(out, pt5) } else { out.Add(out, pt5) } var tmp big.Int val, _ := out.Int(&tmp) if val.BitLen() > int(nbits) { return n, fmt.Errorf("bitlen too large for decimal%d", nbits) } n = T(val.Int64()) } if !n.FitsInPrecision(prec) { err = fmt.Errorf("val %v doesn't fit in precision %d", n, prec) } return } func Decimal32FromString(v string, prec, scale int32) (n Decimal32, err error) { return decimalFromString[Decimal32](v, prec, scale) } func Decimal64FromString(v string, prec, scale int32) (n Decimal64, err error) { return decimalFromString[Decimal64](v, prec, scale) } func Decimal128FromString(v string, prec, scale int32) (n Decimal128, err error) { return decimal128.FromString(v, prec, scale) } func Decimal256FromString(v string, prec, scale int32) (n Decimal256, err error) { return decimal256.FromString(v, prec, scale) } func scalePositiveFloat64(v float64, prec, scale int32) (float64, error) { v *= math.Pow10(int(scale)) v = math.RoundToEven(v) maxabs := math.Pow10(int(prec)) if v >= maxabs { return 0, fmt.Errorf("cannot convert %f to decimal(precision=%d, scale=%d)", v, prec, scale) } return v, nil } func fromPositiveFloat[T Decimal32 | Decimal64, F float32 | float64](v F, prec, scale int32) (T, error) { if prec > int32(MaxPrecision[T]()) { return T(0), fmt.Errorf("invalid precision %d for converting float to Decimal", prec) } val, err := scalePositiveFloat64(float64(v), prec, scale) if err != nil { return T(0), err } return T(F(val)), nil } func Decimal32FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal32, error) { if v < 0 { dec, err := fromPositiveFloat[Decimal32](-v, prec, scale) if err != nil { return dec, err } return -dec, nil } return fromPositiveFloat[Decimal32](v, prec, scale) } func Decimal64FromFloat[F float32 | float64](v F, prec, scale int32) (Decimal64, error) { if v < 0 { dec, err := fromPositiveFloat[Decimal64](-v, prec, scale) if err != nil { return dec, err } return -dec, nil } return fromPositiveFloat[Decimal64](v, prec, scale) } func Decimal128FromFloat(v float64, prec, scale int32) (Decimal128, error) { return decimal128.FromFloat64(v, prec, scale) } func Decimal256FromFloat(v float64, prec, scale int32) (Decimal256, error) { return decimal256.FromFloat64(v, prec, scale) } func (n Decimal32) ToFloat32(scale int32) float32 { return float32(n.ToFloat64(scale)) } func (n Decimal64) ToFloat32(scale int32) float32 { return float32(n.ToFloat64(scale)) } func (n Decimal32) ToFloat64(scale int32) float64 { return float64(n) * math.Pow10(-int(scale)) } func (n Decimal64) ToFloat64(scale int32) float64 { return float64(n) * math.Pow10(-int(scale)) } func (n Decimal32) ToBigFloat(scale int32) *big.Float { f := (&big.Float{}).SetInt64(int64(n)) if scale < 0 { f.SetPrec(32).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) } else { f.SetPrec(32).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) } return f } func (n Decimal64) ToBigFloat(scale int32) *big.Float { f := (&big.Float{}).SetInt64(int64(n)) if scale < 0 { f.SetPrec(64).Mul(f, (&big.Float{}).SetInt64(intPow(10, -int64(scale)))) } else { f.SetPrec(64).Quo(f, (&big.Float{}).SetInt64(intPow(10, int64(scale)))) } return f } func cmpDec[T Decimal32 | Decimal64](lhs, rhs T) int { switch { case lhs > rhs: return 1 case lhs < rhs: return -1 } return 0 } func (n Decimal32) Cmp(other Decimal32) int { return cmpDec(n, other) } func (n Decimal64) Cmp(other Decimal64) int { return cmpDec(n, other) } func (n Decimal32) IncreaseScaleBy(increase int32) Decimal32 { debug.Assert(increase >= 0, "invalid increase scale for decimal32") debug.Assert(increase <= 9, "invalid increase scale for decimal32") return n * Decimal32(intPow(10, increase)) } func (n Decimal64) IncreaseScaleBy(increase int32) Decimal64 { debug.Assert(increase >= 0, "invalid increase scale for decimal64") debug.Assert(increase <= 18, "invalid increase scale for decimal64") return n * Decimal64(intPow(10, int64(increase))) } func reduceScale[T interface { Decimal32 | Decimal64 Abs() T }](n T, reduce int32, round bool) T { if reduce == 0 { return n } divisor := T(intPow(10, reduce)) if !round { return n / divisor } quo, remainder := n/divisor, n%divisor divisorHalf := divisor / 2 if remainder.Abs() >= divisorHalf { if n > 0 { quo++ } else { quo-- } } return quo } func (n Decimal32) ReduceScaleBy(reduce int32, round bool) Decimal32 { debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") debug.Assert(reduce <= 9, "invalid reduce scale for decimal32") return reduceScale(n, reduce, round) } func (n Decimal64) ReduceScaleBy(reduce int32, round bool) Decimal64 { debug.Assert(reduce >= 0, "invalid reduce scale for decimal32") debug.Assert(reduce <= 18, "invalid reduce scale for decimal32") return reduceScale(n, reduce, round) } //lint:ignore U1000 function is being used, staticcheck seems to not follow generics func (n Decimal32) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal32) (out Decimal32, loss bool) { if deltaScale < 0 { debug.Assert(multiplier != 0, "multiplier must not be zero") quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) return Decimal32(quo), remainder != 0 } overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) if overflow != 0 { return Decimal32(result), true } out = Decimal32(result) return out, out < n } //lint:ignore U1000 function is being used, staticcheck seems to not follow generics func (n Decimal64) rescaleWouldCauseDataLoss(deltaScale int32, multiplier Decimal64) (out Decimal64, loss bool) { if deltaScale < 0 { debug.Assert(multiplier != 0, "multiplier must not be zero") quo, remainder := bits.Div32(0, uint32(n), uint32(multiplier)) return Decimal64(quo), remainder != 0 } overflow, result := bits.Mul32(uint32(n), uint32(multiplier)) if overflow != 0 { return Decimal64(result), true } out = Decimal64(result) return out, out < n } func rescale[T interface { Decimal32 | Decimal64 rescaleWouldCauseDataLoss(int32, T) (T, bool) Sign() int }](n T, originalScale, newScale int32) (out T, err error) { if originalScale == newScale { return n, nil } deltaScale := newScale - originalScale absDeltaScale := int32(math.Abs(float64(deltaScale))) sign := n.Sign() if n < 0 { n = -n } multiplier := T(intPow(10, absDeltaScale)) var wouldHaveLoss bool out, wouldHaveLoss = n.rescaleWouldCauseDataLoss(deltaScale, multiplier) if wouldHaveLoss { err = errors.New("rescale data loss") } out *= T(sign) return } func (n Decimal32) Rescale(originalScale, newScale int32) (out Decimal32, err error) { return rescale(n, originalScale, newScale) } func (n Decimal64) Rescale(originalScale, newScale int32) (out Decimal64, err error) { return rescale(n, originalScale, newScale) } var ( _ Num[Decimal32] = Decimal32(0) _ Num[Decimal64] = Decimal64(0) _ Num[Decimal128] = Decimal128{} _ Num[Decimal256] = Decimal256{} )