def get_worker_hash_range()

in src/datatrove/pipeline/dedup/minhash.py [0:0]


    def get_worker_hash_range(self, sig_files, rank, world_size):
        workers_per_bucket = world_size // self.config.num_buckets
        bucket, bucket_worker = divmod(rank, workers_per_bucket)
        hash_min, hash_max = (
            0,
            _mersenne_prime if self.config.hash_config.precision == 64 else self.config.hash_config.max,
        )
        if workers_per_bucket > 1 and len(sig_files):
            # take the first file and find bucket_worker boundaries. all workers in a bucket process the same set of
            # files, so this should be consistent across workers (and span the entire range of hashes)
            with self.input_folder.open(sig_files[0], mode="rb") as f:
                line_size = struct.calcsize(f"{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I")
                L, rem = divmod(f.size, line_size)
                assert rem == 0, "file size not divisible by line size"
                assert L >= workers_per_bucket, f"tried to use {workers_per_bucket=} but there are only {L} lines"
                if bucket_worker > 0:
                    # not first
                    f.seek(line_size * (L // workers_per_bucket) * bucket_worker, os.SEEK_SET)
                    hash_min = struct.unpack(
                        self.config.hash_config.struct_format,
                        f.read(struct.calcsize(self.config.hash_config.struct_format)),
                    )[0]
                if bucket_worker + 1 < workers_per_bucket:
                    # not last
                    f.seek(line_size * (L // workers_per_bucket) * (bucket_worker + 1), os.SEEK_SET)
                    hash_max = struct.unpack(
                        self.config.hash_config.struct_format,
                        f.read(struct.calcsize(self.config.hash_config.struct_format)),
                    )[0]
        return hash_min, hash_max