boolean resizeHashTable()

in flink-runtime/src/main/java/org/apache/flink/runtime/operators/hash/CompactingHashTable.java [883:1117]


    boolean resizeHashTable() throws IOException {
        final int newNumBuckets = 2 * this.numBuckets;
        final int bucketsPerSegment = this.bucketsPerSegmentMask + 1;
        final int newNumSegments = (newNumBuckets + (bucketsPerSegment - 1)) / bucketsPerSegment;
        final int additionalSegments = newNumSegments - this.buckets.length;
        final int numPartitions = this.partitions.size();

        if (this.availableMemory.size() < additionalSegments) {
            for (int i = 0; i < numPartitions; i++) {
                compactPartition(i);
                if (this.availableMemory.size() >= additionalSegments) {
                    break;
                }
            }
        }

        if (this.availableMemory.size() < additionalSegments || this.closed) {
            return false;
        } else {
            this.isResizing = true;
            // allocate new buckets
            final int startOffset = (this.numBuckets * HASH_BUCKET_SIZE) % this.segmentSize;
            final int oldNumBuckets = this.numBuckets;
            final int oldNumSegments = this.buckets.length;
            MemorySegment[] mergedBuckets = new MemorySegment[newNumSegments];
            System.arraycopy(this.buckets, 0, mergedBuckets, 0, this.buckets.length);
            this.buckets = mergedBuckets;
            this.numBuckets = newNumBuckets;
            // initialize all new buckets
            boolean oldSegment = (startOffset != 0);
            final int startSegment = oldSegment ? (oldNumSegments - 1) : oldNumSegments;
            for (int i = startSegment, bucket = oldNumBuckets;
                    i < newNumSegments && bucket < this.numBuckets;
                    i++) {
                MemorySegment seg;
                int bucketOffset;
                if (oldSegment) { // the first couple of new buckets may be located on an old
                    // segment
                    seg = this.buckets[i];
                    for (int k = (oldNumBuckets % bucketsPerSegment);
                            k < bucketsPerSegment && bucket < this.numBuckets;
                            k++, bucket++) {
                        bucketOffset = k * HASH_BUCKET_SIZE;
                        // initialize the header fields
                        seg.put(
                                bucketOffset + HEADER_PARTITION_OFFSET,
                                assignPartition(bucket, (byte) numPartitions));
                        seg.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                        seg.putLong(
                                bucketOffset + HEADER_FORWARD_OFFSET,
                                BUCKET_FORWARD_POINTER_NOT_SET);
                    }
                } else {
                    seg = getNextBuffer();
                    // go over all buckets in the segment
                    for (int k = 0;
                            k < bucketsPerSegment && bucket < this.numBuckets;
                            k++, bucket++) {
                        bucketOffset = k * HASH_BUCKET_SIZE;
                        // initialize the header fields
                        seg.put(
                                bucketOffset + HEADER_PARTITION_OFFSET,
                                assignPartition(bucket, (byte) numPartitions));
                        seg.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                        seg.putLong(
                                bucketOffset + HEADER_FORWARD_OFFSET,
                                BUCKET_FORWARD_POINTER_NOT_SET);
                    }
                }
                this.buckets[i] = seg;
                oldSegment = false; // we write on at most one old segment
            }
            int hashOffset;
            int hash;
            int pointerOffset;
            long pointer;
            IntArrayList hashList = new IntArrayList(NUM_ENTRIES_PER_BUCKET);
            LongArrayList pointerList = new LongArrayList(NUM_ENTRIES_PER_BUCKET);
            IntArrayList overflowHashes = new IntArrayList(64);
            LongArrayList overflowPointers = new LongArrayList(64);

            // go over all buckets and split them between old and new buckets
            for (int i = 0; i < numPartitions; i++) {
                InMemoryPartition<T> partition = this.partitions.get(i);
                final MemorySegment[] overflowSegments = partition.overflowSegments;

                int posHashCode;
                for (int j = 0, bucket = i;
                        j < this.buckets.length && bucket < oldNumBuckets;
                        j++) {
                    MemorySegment segment = this.buckets[j];
                    // go over all buckets in the segment belonging to the partition
                    for (int k = bucket % bucketsPerSegment;
                            k < bucketsPerSegment && bucket < oldNumBuckets;
                            k += numPartitions, bucket += numPartitions) {
                        int bucketOffset = k * HASH_BUCKET_SIZE;
                        if ((int) segment.get(bucketOffset + HEADER_PARTITION_OFFSET) != i) {
                            throw new IOException(
                                    "Accessed wrong bucket! wanted: "
                                            + i
                                            + " got: "
                                            + segment.get(bucketOffset + HEADER_PARTITION_OFFSET));
                        }
                        // loop over all segments that are involved in the bucket (original bucket
                        // plus overflow buckets)
                        int countInSegment = segment.getInt(bucketOffset + HEADER_COUNT_OFFSET);
                        int numInSegment = 0;
                        pointerOffset = bucketOffset + BUCKET_POINTER_START_OFFSET;
                        hashOffset = bucketOffset + BUCKET_HEADER_LENGTH;
                        while (true) {
                            while (numInSegment < countInSegment) {
                                hash = segment.getInt(hashOffset);
                                if ((hash % this.numBuckets) != bucket
                                        && (hash % this.numBuckets) != (bucket + oldNumBuckets)) {
                                    throw new IOException(
                                            "wanted: "
                                                    + bucket
                                                    + " or "
                                                    + (bucket + oldNumBuckets)
                                                    + " got: "
                                                    + hash % this.numBuckets);
                                }
                                pointer = segment.getLong(pointerOffset);
                                hashList.add(hash);
                                pointerList.add(pointer);
                                pointerOffset += POINTER_LEN;
                                hashOffset += HASH_CODE_LEN;
                                numInSegment++;
                            }
                            // this segment is done. check if there is another chained bucket
                            final long forwardPointer =
                                    segment.getLong(bucketOffset + HEADER_FORWARD_OFFSET);
                            if (forwardPointer == BUCKET_FORWARD_POINTER_NOT_SET) {
                                break;
                            }
                            final int overflowSegNum = (int) (forwardPointer >>> 32);
                            segment = overflowSegments[overflowSegNum];
                            bucketOffset = (int) forwardPointer;
                            countInSegment = segment.getInt(bucketOffset + HEADER_COUNT_OFFSET);
                            pointerOffset = bucketOffset + BUCKET_POINTER_START_OFFSET;
                            hashOffset = bucketOffset + BUCKET_HEADER_LENGTH;
                            numInSegment = 0;
                        }
                        segment = this.buckets[j];
                        bucketOffset = k * HASH_BUCKET_SIZE;
                        // reset bucket for re-insertion
                        segment.putInt(bucketOffset + HEADER_COUNT_OFFSET, 0);
                        segment.putLong(
                                bucketOffset + HEADER_FORWARD_OFFSET,
                                BUCKET_FORWARD_POINTER_NOT_SET);
                        // refill table
                        if (hashList.size() != pointerList.size()) {
                            throw new IOException(
                                    "Pointer and hash counts do not match. hashes: "
                                            + hashList.size()
                                            + " pointer: "
                                            + pointerList.size());
                        }
                        int newSegmentIndex = (bucket + oldNumBuckets) / bucketsPerSegment;
                        MemorySegment newSegment = this.buckets[newSegmentIndex];

                        // we need to avoid overflows in the first run
                        int oldBucketCount = 0;
                        int newBucketCount = 0;
                        while (!hashList.isEmpty()) {
                            hash = hashList.removeLast();
                            pointer = pointerList.removeLong(pointerList.size() - 1);
                            posHashCode = hash % this.numBuckets;
                            if (posHashCode == bucket && oldBucketCount < NUM_ENTRIES_PER_BUCKET) {
                                bucketOffset = (bucket % bucketsPerSegment) * HASH_BUCKET_SIZE;
                                insertBucketEntryFromStart(
                                        segment,
                                        bucketOffset,
                                        hash,
                                        pointer,
                                        partition.getPartitionNumber());
                                oldBucketCount++;
                            } else if (posHashCode == (bucket + oldNumBuckets)
                                    && newBucketCount < NUM_ENTRIES_PER_BUCKET) {
                                bucketOffset =
                                        ((bucket + oldNumBuckets) % bucketsPerSegment)
                                                * HASH_BUCKET_SIZE;
                                insertBucketEntryFromStart(
                                        newSegment,
                                        bucketOffset,
                                        hash,
                                        pointer,
                                        partition.getPartitionNumber());
                                newBucketCount++;
                            } else if (posHashCode == (bucket + oldNumBuckets)
                                    || posHashCode == bucket) {
                                overflowHashes.add(hash);
                                overflowPointers.add(pointer);
                            } else {
                                throw new IOException(
                                        "Accessed wrong bucket. Target: "
                                                + bucket
                                                + " or "
                                                + (bucket + oldNumBuckets)
                                                + " Hit: "
                                                + posHashCode);
                            }
                        }
                        hashList.clear();
                        pointerList.clear();
                    }
                }
                // reset partition's overflow buckets and reclaim their memory
                this.availableMemory.addAll(partition.resetOverflowBuckets());
                // clear overflow lists
                int bucketArrayPos;
                int bucketInSegmentPos;
                MemorySegment bucket;
                while (!overflowHashes.isEmpty()) {
                    hash = overflowHashes.removeLast();
                    pointer = overflowPointers.removeLong(overflowPointers.size() - 1);
                    posHashCode = hash % this.numBuckets;
                    bucketArrayPos = posHashCode >>> this.bucketsPerSegmentBits;
                    bucketInSegmentPos =
                            (posHashCode & this.bucketsPerSegmentMask) << NUM_INTRA_BUCKET_BITS;
                    bucket = this.buckets[bucketArrayPos];
                    insertBucketEntryFromStart(
                            bucket,
                            bucketInSegmentPos,
                            hash,
                            pointer,
                            partition.getPartitionNumber());
                }
                overflowHashes.clear();
                overflowPointers.clear();
            }
            this.isResizing = false;
            return true;
        }
    }