in src/datatrove/pipeline/dedup/minhash.py [0:0]
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1):
with self.track_time():
# check if we can skip the sig writing step
if not self.check_can_skip_sig_writing(rank):
buckets = [
self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb")
for bi in range(self.config.num_buckets)
]
for doc_idx, doc in enumerate(data):
self.stat_update(StatHints.total)
shingles = self.get_shingles(doc.text)
if shingles.size != 0:
sig = self.get_signature(shingles)
for bi, (bucket, bucket_sig) in enumerate(zip(buckets, sig)):
# print(f"{self.hashes_per_bucket=} {bucket_sig=}")
bucket.write(
struct.pack(
f"<{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I",
*bucket_sig,
doc_idx,
)
)
for file in buckets:
file.close()
logger.info("Sorting buckets...")
for bi in range(self.config.num_buckets):
# read all records, sort and write back
dtype = np.dtype(
[
(f"field{i + 1}", f"<{self.config.hash_config.struct_format}")
for i in range(self.config.hashes_per_bucket)
]
+ [(f"field{self.config.hashes_per_bucket + 1}", "<I")]
)
with self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="rb") as fi:
records = np.frombuffer(fi.read(), dtype=dtype)
indices = np.argsort(records, order=dtype.names)
with self.output_folder.open(f"bucket_{bi:03d}/{rank:05d}.minhash.sig", mode="wb") as fo:
for idx in indices:
fo.write(records[idx].tobytes())