def CNNDMSummarizationDataset()

in utils_nlp/dataset/cnndm.py [0:0]


def CNNDMSummarizationDataset(*args, **kwargs):
    """Load the CNN/Daily Mail dataset preprocessed by harvardnlp group."""

    URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"]

    def _setup_datasets(
        url, top_n=-1, local_cache_path=".data", prepare_extractive=True
    ):
        FILE_NAME = "cnndm.tar.gz"
        maybe_download(url, FILE_NAME, local_cache_path)
        dataset_tar = os.path.join(local_cache_path, FILE_NAME)
        extracted_files = extract_archive(dataset_tar)
        for fname in extracted_files:
            if fname.endswith("train.txt.src"):
                train_source_file = fname
            if fname.endswith("train.txt.tgt.tagged"):
                train_target_file = fname
            if fname.endswith("test.txt.src"):
                test_source_file = fname
            if fname.endswith("test.txt.tgt.tagged"):
                test_target_file = fname

        if prepare_extractive:

            return (
                SummarizationDataset(
                    train_source_file,
                    target_file=train_target_file,
                    source_preprocessing=[_clean, tokenize.sent_tokenize],
                    target_preprocessing=[
                        _clean,
                        _remove_ttags,
                        _target_sentence_tokenization,
                    ],
                    word_tokenize=nltk.word_tokenize,
                    top_n=top_n,
                ),
                SummarizationDataset(
                    test_source_file,
                    target_file=test_target_file,
                    source_preprocessing=[_clean, tokenize.sent_tokenize],
                    target_preprocessing=[
                        _clean,
                        _remove_ttags,
                        _target_sentence_tokenization,
                    ],
                    word_tokenize=nltk.word_tokenize,
                    top_n=top_n,
                ),
            )
        else:
            return (
                SummarizationDataset(
                    train_source_file,
                    target_file=train_target_file,
                    source_preprocessing=[_clean, tokenize.sent_tokenize],
                    target_preprocessing=[
                        _clean,
                        _remove_ttags,
                        _target_sentence_tokenization,
                    ],
                    top_n=top_n,
                ),
                SummarizationDataset(
                    test_source_file,
                    target_file=test_target_file,
                    source_preprocessing=[_clean, tokenize.sent_tokenize],
                    target_preprocessing=[
                        _clean,
                        _remove_ttags,
                        _target_sentence_tokenization,
                    ],
                    top_n=top_n,
                ),
            )

    return _setup_datasets(*((URLS[0],) + args), **kwargs)