in decontamination/decontaminate.py [0:0]
def main(args):
# Load the evaluation data to build n-grams index
eval_ngrams, eval_datasets, eval_texts = {}, [], []
eval_data = load_dataset(args.eval_dataset, split="train", num_proc=args.num_proc)
for example in tqdm(eval_data):
tokens = tokenize(example["text"])
ngrams = get_ngrams(tokens, args.ngram_length)
if ngrams:
idx = len(eval_texts)
eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
eval_datasets.append(example.get("task_name", "unknown"))
eval_texts.append(example["text"])
train_dataset_path = Path(args.train_dataset)
if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
if train_dataset_path.suffix == ".json":
train_data = Dataset.from_json(args.train_dataset)
elif train_dataset_path.suffix == ".csv":
train_data = Dataset.from_csv(args.train_dataset)
else:
train_data = load_dataset(args.train_dataset, split="train", num_proc=args.num_proc)
contamination_report = train_data.map(
lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
)
contamination_report = contamination_report.map(
lambda example: add_match_stats(example), num_proc=args.num_proc
)
contamination_report.push_to_hub(args.report_dataset_name, private=args.private)
contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)
if args.save_decontaminated:
contaminated_completions = set(contamination_report["completion"])
filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)