def main()

in decontamination/decontaminate.py [0:0]


def main(args):
    # Load the evaluation data to build n-grams index
    eval_ngrams, eval_datasets, eval_texts = {}, [], []
    eval_data = load_dataset(args.eval_dataset, split="train", num_proc=args.num_proc)
    for example in tqdm(eval_data):
        tokens = tokenize(example["text"])
        ngrams = get_ngrams(tokens, args.ngram_length)
        if ngrams:
            idx = len(eval_texts)
            eval_ngrams.update(zip(ngrams, [idx] * len(ngrams)))
            eval_datasets.append(example.get("task_name", "unknown"))
            eval_texts.append(example["text"])

    train_dataset_path = Path(args.train_dataset)
    if train_dataset_path.exists() and train_dataset_path.suffix in [".json", ".csv"]:
        if train_dataset_path.suffix == ".json":
            train_data = Dataset.from_json(args.train_dataset)
        elif train_dataset_path.suffix == ".csv":
            train_data = Dataset.from_csv(args.train_dataset)
    else:
        train_data = load_dataset(args.train_dataset, split="train", num_proc=args.num_proc)

    contamination_report = train_data.map(
        lambda batch: retrieve_ngrams_batch(batch, eval_ngrams, eval_datasets, eval_texts, args.ngram_length),
        batched=True, batch_size=1000, num_proc=args.num_proc, remove_columns=train_data.column_names
    )

    contamination_report = contamination_report.map(
        lambda example: add_match_stats(example), num_proc=args.num_proc
    )

    contamination_report.push_to_hub(args.report_dataset_name, private=args.private)

    contamination_report = contamination_report.filter(lambda x: x["diff_ratio"] > args.diff_threshold)

    if args.save_decontaminated:
        contaminated_completions = set(contamination_report["completion"])
        filtered_data = train_data.filter(lambda x: x["completion"] not in contaminated_completions)
        filtered_data.push_to_hub(args.decontaminated_dataset_name, private=args.private)