in src/datatrove/pipeline/decont/n_grams.py [0:0]
def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1):
if world_size != 1:
raise ValueError("Decontamination index building requires a single worker.")
hashes = defaultdict(set)
# use whatever date is parsed in with the following format:
# doc.text -> label
# doc.metadata["input"] -> input
if data:
for doc in data:
if not self.config.find_query_ngrams and "query" not in doc.metadata:
raise ValueError(
"only_label_ngrams is False but could not find 'query' field in documents metadata"
)
hashes[doc.metadata.get("task", "input")].update(
self.compute_hashes(doc.text, doc.metadata.get("query", None))
)
# parse data from lighteval defined tasks
from lighteval.tasks.lighteval_task import LightevalTask
from lighteval.tasks.registry import Registry
task_dict = Registry(cache_dir=os.getenv("HF_HOME"), custom_tasks=self.custom_lighteval_tasks).get_task_dict(
self.lighteval_tasks
)
LightevalTask.load_datasets(task_dict.values())
for task_name, task in task_dict.items():
for eval_doc in task.eval_docs():
try:
golds = eval_doc.get_golds()
query = eval_doc.query
except Exception as e:
logger.warning(f"Error while fetching doc data: {e}")
continue
for gold in golds:
hashes[task_name].update(self.compute_hashes(gold, query))
for task_name, task_hashes in hashes.items():
hashes_array = np.array(list(task_hashes), dtype=self.config.hash_config.np_descr)
logger.info(f"Saving {len(task_hashes)} hashes for {task_name}")
with self.output_folder.open(f"{task_name.replace(' ', '_')}.index.hashes", mode="wb") as f:
if self.output_folder.is_local():
hashes_array.tofile(f)
else:
f.write(hashes_array.tobytes())