in src/datatrove/pipeline/stats/base.py [0:0]
def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
groups_dicts: dict[GROUP, dict[str, MetricStatsDict]] = {
group: defaultdict(MetricStatsDict) for group in self.groups
}
for doc in data:
with self.track_time():
try:
doc_stats = self.extract_stats(doc)
except Exception as e:
logger.error(f"Error while extracting stats from document {doc.id}", exc_info=e)
raise e
for group, counters in groups_dicts.items():
for stat, value in doc_stats.items():
key, value = self.get_kv(doc, value, group)
if not isinstance(value, dict):
counters[stat][key] += value
else:
# each key in this dictionary is a suffix for the main stat
for suffix, val in value.items():
stat_name = stat if not suffix else f"{stat}__{suffix}"
counters[stat_name][key] += val
doc.metadata.update(doc_stats)
yield doc
# save to disk
for group, stats_dict in groups_dicts.items():
group_top_k_keys = None
for stat_name, stat_values in stats_dict.items():
if group in self.top_k_cfg.top_k_groups:
# We don't have to compute this for every stat in group, as stat.n will be constant
if group_top_k_keys is None:
group_top_k_keys = heapq.nlargest(
self.top_k_cfg.top_k, stat_values, key=lambda x: stat_values[x].n
)
stat_values = MetricStatsDict(init={s: stat_values[s] for s in group_top_k_keys})
with self.output_folder.open(f"{group}/{stat_name}/{rank:05d}.json", "wt") as f:
json.dump(stat_values.to_dict(), f)
# delete the group_dicts to save mem
del groups_dicts