cpc/cpc_union.go (296 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package cpc import ( "fmt" "github.com/apache/datasketches-go/internal" "math/bits" ) type CpcUnion struct { seed uint64 lgK int // Note: at most one of bitMatrix and accumulator will be non-null at any given moment. // accumulator is a sketch object that is employed until it graduates out of Sparse mode. // At that point, it is converted into a full-sized bitMatrix, which is mathematically a sketch, // but doesn't maintain any of the "extra" fields of our sketch objects, so some additional work // is required when getResult is called at the end. bitMatrix []uint64 accumulator *CpcSketch } func NewCpcUnionSketch(lgK int, seed uint64) (CpcUnion, error) { acc, err := NewCpcSketch(lgK, seed) if err != nil { return CpcUnion{}, err } return CpcUnion{ seed: seed, lgK: lgK, // We begin with the accumulator holding an EMPTY_MERGED sketch object. // As an optimization the accumulator could start as NULL, but that would require changes elsewhere. accumulator: acc, }, nil } func NewCpcUnionSketchWithDefault(lgK int) (CpcUnion, error) { return NewCpcUnionSketch(lgK, internal.DEFAULT_UPDATE_SEED) } func (u *CpcUnion) GetFamilyId() int { return internal.FamilyEnum.CPC.Id } func (u *CpcUnion) Update(source *CpcSketch) error { if err := checkSeeds(u.seed, source.seed); err != nil { return err } sourceFlavorOrd := source.getFlavor() if sourceFlavorOrd == CpcFlavorEmpty { return nil } if err := u.checkUnionState(); err != nil { return err } // Downsample union if the source sketch has a smaller lgK. if source.lgK < u.lgK { if err := u.reduceUnionK(source.lgK); err != nil { return err } } // If the source is past SPARSE mode, ensure union is in bitMatrix mode. if sourceFlavorOrd > CpcFlavorSparse && u.accumulator != nil { bitMatrix, err := u.accumulator.bitMatrixOfSketch() if err != nil { return err } u.bitMatrix = bitMatrix u.accumulator = nil } state := (sourceFlavorOrd - 1) << 1 if u.bitMatrix != nil { state |= 1 } switch state { case 0: // Case A: source is SPARSE, union.accumulator valid, bitMatrix == nil. if u.accumulator.lgK == 0 { return fmt.Errorf("union accumulator cannot be nil") } // If the union is EMPTY and lgK matches, copy the source. if u.accumulator.getFlavor() == CpcFlavorEmpty && u.lgK == source.lgK { cp, err := source.Copy() if err != nil { return err } u.accumulator = cp break } if err := walkTableUpdatingSketch(u.accumulator, source.pairTable); err != nil { return err } // If accumulator has graduated beyond SPARSE, switch to bitMatrix. if u.accumulator.getFlavor() > CpcFlavorSparse { bitMatrix, err := u.accumulator.bitMatrixOfSketch() if err != nil { return err } u.bitMatrix = bitMatrix u.accumulator = nil } case 1: // Case B: source is SPARSE, union already in bitMatrix mode. u.orTableIntoMatrix(source.pairTable) case 3, 5: // Case C: source is HYBRID or PINNED, union in bitMatrix mode. if err := u.orWindowIntoMatrix(source.slidingWindow, 0, source.lgK); err != nil { return err } u.orTableIntoMatrix(source.pairTable) case 7: // Case D: source is SLIDING, union in bitMatrix mode. sourceMatrix, err := source.bitMatrixOfSketch() if err != nil { return err } if err := u.orMatrixIntoMatrix(sourceMatrix, source.lgK); err != nil { return err } default: return fmt.Errorf("illegal Union state: %d", state) } return nil } func (u *CpcUnion) GetResult() (*CpcSketch, error) { if err := u.checkUnionState(); err != nil { return nil, err } if u.accumulator != nil { if u.accumulator.numCoupons == 0 { result, err := NewCpcSketch(u.lgK, u.accumulator.seed) if err != nil { return nil, err } result.mergeFlag = true return result, nil } if u.accumulator.getFlavor() != CpcFlavorSparse { return nil, fmt.Errorf("accumulator must be SPARSE") } // Return a copy of the accumulator. result, err := u.accumulator.Copy() if err != nil { return nil, err } result.mergeFlag = true return result, nil } // Case: union contains a bitMatrix. matrix := u.bitMatrix lgK := u.lgK result, err := NewCpcSketch(u.lgK, u.seed) if err != nil { return nil, err } numCoupons := countBitsSetInMatrix(matrix) result.numCoupons = numCoupons flavor := determineFlavor(lgK, numCoupons) if flavor <= CpcFlavorSparse { return nil, fmt.Errorf("flavor must be greater than SPARSE") } offset := determineCorrectOffset(lgK, numCoupons) result.windowOffset = offset k := 1 << lgK window := make([]byte, k) result.slidingWindow = window newTableLgSize := max(lgK-4, 2) table, err := NewPairTable(newTableLgSize, 6+lgK) if err != nil { return nil, err } result.pairTable = table maskForClearingWindow := (0xFF << offset) ^ -1 maskForFlippingEarlyZone := (1 << offset) - 1 allSurprisesORed := uint64(0) for i := 0; i < k; i++ { pattern := matrix[i] window[i] = byte((pattern >> offset) & 0xFF) pattern &= uint64(maskForClearingWindow) pattern ^= uint64(maskForFlippingEarlyZone) allSurprisesORed |= pattern for pattern != 0 { col := bits.TrailingZeros64(pattern) pattern ^= 1 << col rowCol := (i << 6) | col isNovel, err := table.maybeInsert(rowCol) if err != nil { return nil, err } if !isNovel { return nil, fmt.Errorf("isNovel must be true") } } } result.fiCol = bits.TrailingZeros64(allSurprisesORed) if result.fiCol > offset { result.fiCol = offset } result.mergeFlag = true return result, nil } func (u *CpcUnion) checkUnionState() error { if u == nil { return fmt.Errorf("union cannot be nil") } accumulator := u.accumulator if (accumulator != nil) == (u.bitMatrix != nil) { return fmt.Errorf("accumulator and bitMatrix cannot be both valid or both nil") } if accumulator != nil { if accumulator.numCoupons > 0 { if accumulator.slidingWindow != nil || accumulator.pairTable == nil { return fmt.Errorf("non-empty union accumulator must be SPARSE") } } if u.lgK != accumulator.lgK { return fmt.Errorf("union LgK must equal accumulator LgK") } } return nil } func (u *CpcUnion) reduceUnionK(newLgK int) error { if newLgK < u.lgK { if u.bitMatrix != nil { newK := 1 << newLgK newMatrix := make([]uint64, newK) orMatrixIntoMatrix(newMatrix, newLgK, u.bitMatrix, u.lgK) u.bitMatrix = newMatrix u.lgK = newLgK } else { oldSketch := u.accumulator if oldSketch.numCoupons == 0 { acc, err := NewCpcSketch(newLgK, oldSketch.seed) if err != nil { return err } u.accumulator = acc u.lgK = newLgK return nil } newSketch, err := NewCpcSketch(newLgK, oldSketch.seed) if err != nil { return err } if err := walkTableUpdatingSketch(newSketch, oldSketch.pairTable); err != nil { return err } finalNewFlavor := newSketch.getFlavor() if finalNewFlavor == CpcFlavorSparse { u.accumulator = newSketch u.lgK = newLgK return nil } // The new sketch has graduated beyond sparse, so convert to bitMatrix. bitMatrix, err := newSketch.bitMatrixOfSketch() if err != nil { return err } u.bitMatrix = bitMatrix u.lgK = newLgK // Ensure that the accumulator is cleared. u.accumulator = nil } } return nil } func (u *CpcUnion) orWindowIntoMatrix(srcWindow []byte, srcOffset int, srcLgK int) error { //assert(destLgK <= srcLgK) if u.lgK > srcLgK { return fmt.Errorf("destLgK <= srcLgK") } destMask := (1 << u.lgK) - 1 // downsamples when destLgK < srcLgK srcK := 1 << srcLgK for srcRow := 0; srcRow < srcK; srcRow++ { u.bitMatrix[srcRow&destMask] |= uint64(srcWindow[srcRow]) << srcOffset } return nil } func (u *CpcUnion) orTableIntoMatrix(srcTable *pairTable) { slots := srcTable.slotsArr numSlots := 1 << srcTable.lgSizeInts destMask := (1 << u.lgK) - 1 // downsamples when destLgK < srcLgK for i := 0; i < numSlots; i++ { rowCol := slots[i] if rowCol != -1 { col := rowCol & 63 row := rowCol >> 6 u.bitMatrix[row&destMask] |= 1 << col // Set the bit. } } } func (u *CpcUnion) orMatrixIntoMatrix(srcMatrix []uint64, srcLgK int) error { if u.lgK > srcLgK { return fmt.Errorf("destLgK <= srcLgK") } destMask := (1 << u.lgK) - 1 // downsamples when destLgK < srcLgK srcK := 1 << srcLgK for srcRow := 0; srcRow < srcK; srcRow++ { u.bitMatrix[srcRow&destMask] |= srcMatrix[srcRow] } return nil } func (u *CpcUnion) getNumCoupons() uint64 { if u.bitMatrix != nil { return countBitsSetInMatrix(u.bitMatrix) } return u.accumulator.numCoupons } func (u *CpcUnion) GetBitMatrix() ([]uint64, error) { if err := u.checkUnionState(); err != nil { return nil, err } if u.bitMatrix != nil { return u.bitMatrix, nil } if u.accumulator == nil { return nil, fmt.Errorf("both bitMatrix and accumulator are nil, invalid union state") } bm, err := u.accumulator.bitMatrixOfSketch() if err != nil { return nil, fmt.Errorf("accumulator.bitMatrixOfSketch failed: %v", err) } return bm, nil }