in src/datatrove/pipeline/stats/merger.py [0:0]
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
"""
Args:
data: DocumentsPipeline: (Default value = None)
rank: int: (Default value = 0)
world_size: int: (Default value = 1)
Each node will read a folder with stats files and merge them into a single file
"""
folders_shard = self.get_leaf_non_empty_folders()[rank::world_size]
logger.info(f"Merging {len(folders_shard)} stat folders")
with self.track_time():
for folder in tqdm(folders_shard):
input_files = self.input_folder.glob(f"{folder}/[0-9][0-9][0-9][0-9][0-9].json")
logger.info(f"Processing folder {folder} with {len(input_files)} files")
stat = MetricStatsDict()
for file in tqdm(input_files):
# Use inplace add to avoid creating a new dict
with self.input_folder.open(file, "rt") as f:
for key, item in json.load(f).items():
stat[key] += MetricStats.from_dict(item)
with self.output_folder.open(f"{folder}/{STATS_MERGED_NAME}", "wt") as f:
group_name = Path(folder).parent.name
if group_name in self.top_k_config.top_k_groups:
top_k_keys = heapq.nlargest(self.top_k_config.top_k, stat, key=lambda x: stat.get(x).n)
stat = MetricStatsDict(init={s: stat.get(s) for s in top_k_keys})
json.dump(stat.to_dict(), f)
if self.remove_input:
for file in input_files:
self.input_folder.rm(file)
if data:
yield from data