cpc/merging_validation.go (265 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package cpc import ( "fmt" "io" "math" "strings" "github.com/apache/datasketches-go/common" ) type MergingValidation struct { hfmt, dfmt string hStrArr []string vIn uint64 // increments each update // Inputs lgMinK, lgMaxK int lgMulK int uPPO int incLgK int printStream io.Writer printWriter io.Writer } func NewMergingValidation( lgMinK, lgMaxK, lgMulK, uPPO, incLgK int, pS, pW io.Writer, ) *MergingValidation { if uPPO < 1 { uPPO = 1 } if incLgK < 1 { incLgK = 1 } mv := &MergingValidation{ lgMinK: lgMinK, lgMaxK: lgMaxK, lgMulK: lgMulK, uPPO: uPPO, incLgK: incLgK, printStream: pS, printWriter: pW, } mv.assembleFormats() return mv } // Start prints the header, then calls doRangeOfLgK. func (mv *MergingValidation) Start() error { mv.printf(mv.hfmt, mv.toInterfaceSlice(mv.hStrArr)...) return mv.doRangeOfLgK() } // doRangeOfLgK calls multiTestMerging for various (lgK, lgK±1) combinations. func (mv *MergingValidation) doRangeOfLgK() error { for lgK := mv.lgMinK; lgK <= mv.lgMaxK; lgK += mv.incLgK { if err := mv.multiTestMerging(lgK, lgK-1, lgK-1); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK-1, lgK); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK-1, lgK+1); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK, lgK-1); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK, lgK); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK, lgK+1); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK+1, lgK-1); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK+1, lgK); err != nil { return err } if err := mv.multiTestMerging(lgK, lgK+1, lgK+1); err != nil { return err } } return nil } // multiTestMerging loops over nA and nB up to 2^(lgKa+lgMulK) and 2^(lgKb+lgMulK), respectively. func (mv *MergingValidation) multiTestMerging(lgKm, lgKa, lgKb int) error { limA := int64(1 << uint(lgKa+mv.lgMulK)) limB := int64(1 << uint(lgKb+mv.lgMulK)) var nA int64 = 0 for nA <= limA { var nB int64 = 0 for nB <= limB { if err := mv.testMerging(lgKm, lgKa, lgKb, nA, nB); err != nil { return err } nB = int64(math.Round(common.PowerSeriesNextDouble(mv.uPPO, float64(nB), true, 2.0))) } nA = int64(math.Round(common.PowerSeriesNextDouble(mv.uPPO, float64(nA), true, 2.0))) } return nil } // testMerging does the actual test for one combination of (lgKm, lgKa, lgKb, nA, nB). // It merges two sketches A and B into a union ugM, compares it with a direct combined sketch D, // and returns an error if any discrepancy is found. func (mv *MergingValidation) testMerging(lgKm, lgKa, lgKb int, nA, nB int64) error { // Create the union with the minimum lgK among lgKm, lgKa, and lgKb. minLg := lgKm if lgKa < minLg { minLg = lgKa } if lgKb < minLg { minLg = lgKb } ugM, err := NewCpcUnionSketchWithDefault(lgKm) if err != nil { return fmt.Errorf("failed to create CpcUnion: %v", err) } // Determine the direct sketch's lgK: the minimum among non-empty sketches. lgKd := lgKm if lgKa < lgKd && nA != 0 { lgKd = lgKa } if lgKb < lgKd && nB != 0 { lgKd = lgKb } skD, err := NewCpcSketchWithDefault(lgKd) if err != nil { return fmt.Errorf("failed to create CpcSketch: %v", err) } skA, err := NewCpcSketchWithDefault(lgKa) if err != nil { return fmt.Errorf("failed to create CpcSketch: %v", err) } skB, err := NewCpcSketchWithDefault(lgKb) if err != nil { return fmt.Errorf("failed to create CpcSketch: %v", err) } for i := int64(0); i < nA; i++ { mv.vIn += common.InverseGoldenU64 in := mv.vIn if err = skA.UpdateUint64(in); err != nil { return fmt.Errorf("skA.UpdateUint64 error: %v", err) } if err = skD.UpdateUint64(in); err != nil { return fmt.Errorf("skD.UpdateUint64 error: %v", err) } } for i := int64(0); i < nB; i++ { mv.vIn += common.InverseGoldenU64 in := mv.vIn if err = skB.UpdateUint64(in); err != nil { return fmt.Errorf("skB.UpdateUint64 error: %v", err) } if err = skD.UpdateUint64(in); err != nil { return fmt.Errorf("skD.UpdateUint64 error: %v", err) } } if err := ugM.Update(skA); err != nil { return fmt.Errorf("union update skA error: %v", err) } if err := ugM.Update(skB); err != nil { return fmt.Errorf("union update skB error: %v", err) } finalLgKm := ugM.lgK matrixM, err := ugM.GetBitMatrix() if err != nil { return fmt.Errorf("ugM.GetBitMatrix error: %v", err) } cM := ugM.getNumCoupons() cD := skD.numCoupons flavorD := skD.getFlavor() flavorA := skA.getFlavor() flavorB := skB.getFlavor() dOff := skD.windowOffset aOff := skA.windowOffset bOff := skB.windowOffset flavorDoff := fmt.Sprintf("%s%2d", flavorD.String(), dOff) flavorAoff := fmt.Sprintf("%s%2d", flavorA.String(), aOff) flavorBoff := fmt.Sprintf("%s%2d", flavorB.String(), bOff) iconEstD := iconEstimate(lgKd, cD) if finalLgKm > lgKm { return fmt.Errorf("finalLgKm > lgKm") } if cM > (skA.numCoupons + skB.numCoupons) { return fmt.Errorf("union coupon count too large") } if cM != cD { return fmt.Errorf("mismatch coupon counts union=%d direct=%d", cM, cD) } if finalLgKm != lgKd { return fmt.Errorf("union lgK mismatch: got %d, expected %d", finalLgKm, lgKd) } // Compare union bit matrix with direct sketch bit matrix. matrixD, err := skD.bitMatrixOfSketch() if err != nil { return fmt.Errorf("bitMatrixOfSketch error: %v", err) } if len(matrixM) != len(matrixD) { return fmt.Errorf("matrix length mismatch union vs direct") } for i := range matrixM { if matrixM[i] != matrixD[i] { return fmt.Errorf("matrix bits mismatch union vs direct") } } // Compare union's result with direct. skR, err := ugM.GetResult() if err != nil { return err } iconEstR := iconEstimate(skR.lgK, skR.numCoupons) if math.Abs(iconEstD-iconEstR) > 1e-9 { return fmt.Errorf("ICON mismatch direct=%.9g union=%.9g", iconEstD, iconEstR) } if !specialEquals(skD, skR, false, true) { return fmt.Errorf("skD != skR") } // Print final line mv.printf(mv.dfmt, lgKm, lgKa, lgKb, lgKd, nA, nB, nA+nB, flavorAoff, flavorBoff, flavorDoff, skA.numCoupons, skB.numCoupons, cD, iconEstR, ) return nil } // assembleFormats sets up columns for printing the final results. func (mv *MergingValidation) assembleFormats() { assy := [][]string{ {"lgKm", "%4s", "%4d"}, {"lgKa", "%4s", "%4d"}, {"lgKb", "%4s", "%4d"}, {"lgKfd", "%6s", "%6d"}, {"nA", "%12s", "%12d"}, {"nB", "%12s", "%12d"}, {"nA+nB", "%12s", "%12d"}, {"Flavor_a", "%11s", "%11s"}, {"Flavor_b", "%11s", "%11s"}, {"Flavor_fd", "%11s", "%11s"}, {"Coupons_a", "%9s", "%9d"}, {"Coupons_b", "%9s", "%9d"}, {"Coupons_fd", "%9s", "%9d"}, {"IconEst_dr", "%12s", "%,12.0f"}, } cols := len(assy) mv.hStrArr = make([]string, cols) var headerFmt strings.Builder var dataFmt strings.Builder headerFmt.WriteString("\nMerging Validation\n") for i := 0; i < cols; i++ { mv.hStrArr[i] = assy[i][0] headerFmt.WriteString(assy[i][1]) if i < cols-1 { headerFmt.WriteString("\t") } else { headerFmt.WriteString("\n") } dataFmt.WriteString(assy[i][2]) if i < cols-1 { dataFmt.WriteString("\t") } else { dataFmt.WriteString("\n") } } mv.hfmt = headerFmt.String() mv.dfmt = dataFmt.String() } // printf writes to both printStream and printWriter if they are not nil. func (mv *MergingValidation) printf(format string, args ...interface{}) { if mv.printStream != nil { fmt.Fprintf(mv.printStream, format, args...) } if mv.printWriter != nil { fmt.Fprintf(mv.printWriter, format, args...) } } // toInterfaceSlice helps pass a slice of strings to fmt.Fprintf for the header. func (mv *MergingValidation) toInterfaceSlice(ss []string) []interface{} { out := make([]interface{}, len(ss)) for i := range ss { out[i] = ss[i] } return out }