in src/datatrove/pipeline/dedup/minhash.py [0:0]
def get_worker_hash_range(self, sig_files, rank, world_size):
workers_per_bucket = world_size // self.config.num_buckets
bucket, bucket_worker = divmod(rank, workers_per_bucket)
hash_min, hash_max = (
0,
_mersenne_prime if self.config.hash_config.precision == 64 else self.config.hash_config.max,
)
if workers_per_bucket > 1 and len(sig_files):
# take the first file and find bucket_worker boundaries. all workers in a bucket process the same set of
# files, so this should be consistent across workers (and span the entire range of hashes)
with self.input_folder.open(sig_files[0], mode="rb") as f:
line_size = struct.calcsize(f"{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I")
L, rem = divmod(f.size, line_size)
assert rem == 0, "file size not divisible by line size"
assert L >= workers_per_bucket, f"tried to use {workers_per_bucket=} but there are only {L} lines"
if bucket_worker > 0:
# not first
f.seek(line_size * (L // workers_per_bucket) * bucket_worker, os.SEEK_SET)
hash_min = struct.unpack(
self.config.hash_config.struct_format,
f.read(struct.calcsize(self.config.hash_config.struct_format)),
)[0]
if bucket_worker + 1 < workers_per_bucket:
# not last
f.seek(line_size * (L // workers_per_bucket) * (bucket_worker + 1), os.SEEK_SET)
hash_max = struct.unpack(
self.config.hash_config.struct_format,
f.read(struct.calcsize(self.config.hash_config.struct_format)),
)[0]
return hash_min, hash_max