in cpc/merging_validation.go [133:269]
func (mv *MergingValidation) testMerging(lgKm, lgKa, lgKb int, nA, nB int64) error {
// Create the union with the minimum lgK among lgKm, lgKa, and lgKb.
minLg := lgKm
if lgKa < minLg {
minLg = lgKa
}
if lgKb < minLg {
minLg = lgKb
}
ugM, err := NewCpcUnionSketchWithDefault(lgKm)
if err != nil {
return fmt.Errorf("failed to create CpcUnion: %v", err)
}
// Determine the direct sketch's lgK: the minimum among non-empty sketches.
lgKd := lgKm
if lgKa < lgKd && nA != 0 {
lgKd = lgKa
}
if lgKb < lgKd && nB != 0 {
lgKd = lgKb
}
skD, err := NewCpcSketchWithDefault(lgKd)
if err != nil {
return fmt.Errorf("failed to create CpcSketch: %v", err)
}
skA, err := NewCpcSketchWithDefault(lgKa)
if err != nil {
return fmt.Errorf("failed to create CpcSketch: %v", err)
}
skB, err := NewCpcSketchWithDefault(lgKb)
if err != nil {
return fmt.Errorf("failed to create CpcSketch: %v", err)
}
for i := int64(0); i < nA; i++ {
mv.vIn += common.InverseGoldenU64
in := mv.vIn
if err = skA.UpdateUint64(in); err != nil {
return fmt.Errorf("skA.UpdateUint64 error: %v", err)
}
if err = skD.UpdateUint64(in); err != nil {
return fmt.Errorf("skD.UpdateUint64 error: %v", err)
}
}
for i := int64(0); i < nB; i++ {
mv.vIn += common.InverseGoldenU64
in := mv.vIn
if err = skB.UpdateUint64(in); err != nil {
return fmt.Errorf("skB.UpdateUint64 error: %v", err)
}
if err = skD.UpdateUint64(in); err != nil {
return fmt.Errorf("skD.UpdateUint64 error: %v", err)
}
}
if err := ugM.Update(skA); err != nil {
return fmt.Errorf("union update skA error: %v", err)
}
if err := ugM.Update(skB); err != nil {
return fmt.Errorf("union update skB error: %v", err)
}
finalLgKm := ugM.lgK
matrixM, err := ugM.GetBitMatrix()
if err != nil {
return fmt.Errorf("ugM.GetBitMatrix error: %v", err)
}
cM := ugM.getNumCoupons()
cD := skD.numCoupons
flavorD := skD.getFlavor()
flavorA := skA.getFlavor()
flavorB := skB.getFlavor()
dOff := skD.windowOffset
aOff := skA.windowOffset
bOff := skB.windowOffset
flavorDoff := fmt.Sprintf("%s%2d", flavorD.String(), dOff)
flavorAoff := fmt.Sprintf("%s%2d", flavorA.String(), aOff)
flavorBoff := fmt.Sprintf("%s%2d", flavorB.String(), bOff)
iconEstD := iconEstimate(lgKd, cD)
if finalLgKm > lgKm {
return fmt.Errorf("finalLgKm > lgKm")
}
if cM > (skA.numCoupons + skB.numCoupons) {
return fmt.Errorf("union coupon count too large")
}
if cM != cD {
return fmt.Errorf("mismatch coupon counts union=%d direct=%d", cM, cD)
}
if finalLgKm != lgKd {
return fmt.Errorf("union lgK mismatch: got %d, expected %d", finalLgKm, lgKd)
}
// Compare union bit matrix with direct sketch bit matrix.
matrixD, err := skD.bitMatrixOfSketch()
if err != nil {
return fmt.Errorf("bitMatrixOfSketch error: %v", err)
}
if len(matrixM) != len(matrixD) {
return fmt.Errorf("matrix length mismatch union vs direct")
}
for i := range matrixM {
if matrixM[i] != matrixD[i] {
return fmt.Errorf("matrix bits mismatch union vs direct")
}
}
// Compare union's result with direct.
skR, err := ugM.GetResult()
if err != nil {
return err
}
iconEstR := iconEstimate(skR.lgK, skR.numCoupons)
if math.Abs(iconEstD-iconEstR) > 1e-9 {
return fmt.Errorf("ICON mismatch direct=%.9g union=%.9g", iconEstD, iconEstR)
}
if !specialEquals(skD, skR, false, true) {
return fmt.Errorf("skD != skR")
}
// Print final line
mv.printf(mv.dfmt,
lgKm, lgKa, lgKb, lgKd,
nA, nB, nA+nB,
flavorAoff, flavorBoff, flavorDoff,
skA.numCoupons, skB.numCoupons, cD, iconEstR,
)
return nil
}