in src/datatrove/pipeline/dedup/minhash.py [0:0]
def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1):
assert data is None, "You should not use an input block before MinhashDedupBuckets"
assert (world_size % self.config.num_buckets) == 0, "Number of tasks must be divisible by num_buckets"
workers_per_bucket = world_size // self.config.num_buckets
bucket, bucket_worker = divmod(rank, workers_per_bucket)
with self.track_time():
sig_files = self.input_folder.list_files(subdirectory=f"bucket_{bucket:03d}")
hash_min, hash_max = self.get_worker_hash_range(sig_files, rank, world_size)
logger.info(
f"Running worker {bucket_worker + 1}/{workers_per_bucket} on bucket {bucket:03d}. "
f"Hash range: {[hash_min, hash_max]}"
)
sig_readers = [
read_sigs(
file,
file_i,
self.config,
min_hash=hash_min,
max_hash=hash_max,
lines_to_buffer=self.lines_to_buffer,
)
for file_i, file in enumerate(self.input_folder.open_files(sig_files, mode="rb"))
]
own_index_regex = re.compile(rf"bucket_{bucket:03d}/{self.create_index_name}_\d{{2}}.minhash.index")
index_files = (
[
filename
for filename in self.index_folder.list_files(subdirectory=f"bucket_{bucket:03d}")
# exclude "itself" if the index was partially uploaded/ended midway + other workers
if not self.create_index_name or not own_index_regex.fullmatch(filename)
]
if self.index_folder
else None
)
if index_files:
logger.info(f"Found {len(index_files)} index file(s): {', '.join(index_files)}")
sig_readers.extend(
[
read_sigs(
file,
len(sig_readers) + file_i,
self.config,
index_file=True,
min_hash=hash_min,
max_hash=hash_max,
lines_to_buffer=self.lines_to_buffer,
)
for file_i, file in enumerate(self.index_folder.open_files(index_files, mode="rb"))
]
)
pq = [x for x in [next(sig_reader, None) for sig_reader in sig_readers] if x is not None]
heapq.heapify(pq)
logger.info("Finished initializing signatures priority queue.")
# out index file
out_index = None
if self.index_folder and self.create_index_name:
out_index = self.index_folder.open(
f"bucket_{bucket:03d}/{self.create_index_name}_{bucket_worker:02d}.minhash.index", mode="wb"
)
with self.output_folder.open(f"{bucket:05d}_{bucket_worker:02d}.dups", mode="wb") as out_f:
last: HashSig | None = None
while pq:
v: HashSig = heapq.heappop(pq)
assert last is None or v >= last, f"Sig queue sort error. {v=} < {last=}"
if not v.is_from_index():
if last is not None and last.sig == v.sig:
# write (file_id1, doc_id1, file_id2, doc_id2)
if last.is_from_index():
# we can't actually write -1, so we use SENTINEL instead
out_f.write(struct.pack("<4I", SENTINEL, SENTINEL, int(v.file_stem), v.doc_id))
self.stat_update("index_match", "total_matches")
# if there isn't an index, or we are not only deduping in relation to the index
elif not index_files or not self.only_dedup_in_index:
out_f.write(
struct.pack("<4I", int(last.file_stem), last.doc_id, int(v.file_stem), v.doc_id)
)
self.stat_update("total_matches")
elif out_index:
# new sig that isn't part of any index, save to our new index
out_index.write(
struct.pack(
f"<%d{self.config.hash_config.struct_format}" % self.config.hashes_per_bucket,
*v.sig,
)
)
last = v
next_sig = next(sig_readers[v.reader_id], None)
if next_sig:
assert next_sig >= v, f"Next sig sort error. {next_sig=} < {v=}"
heapq.heappush(pq, next_sig)
if out_index:
out_index.close()