def WMT14()

in torchtext/experimental/datasets/raw/wmt14.py [0:0]


def WMT14(root, split,
          language_pair=('de', 'en'),
          train_set='train.tok.clean.bpe.32000',
          valid_set='newstest2013.tok.bpe.32000',
          test_set='newstest2014.tok.bpe.32000'):
    """WMT14 Dataset

    The available datasets include following:

    **Language pairs**:

    +-----+-----+-----+
    |     |'en' |'de' |
    +-----+-----+-----+
    |'en' |     |   x |
    +-----+-----+-----+
    |'de' |  x  |     |
    +-----+-----+-----+


    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’)
        language_pair: tuple or list containing src and tgt language
        train_set: A string to identify train set.
        valid_set: A string to identify validation set.
        test_set: A string to identify test set.

    Examples:
        >>> from torchtext.datasets import WMT14
        >>> train_iter, valid_iter, test_iter = WMT14()
        >>> src_sentence, tgt_sentence = next(train_iter)
    """

    supported_language = ['en', 'de']
    supported_train_set = [s for s in NUM_LINES if 'train' in s]
    supported_valid_set = [s for s in NUM_LINES if 'test' in s]
    supported_test_set = [s for s in NUM_LINES if 'test' in s]

    assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively'

    if language_pair[0] not in supported_language:
        raise ValueError("Source language '{}' is not supported. Valid options are {}".
                         format(language_pair[0], supported_language))

    if language_pair[1] not in supported_language:
        raise ValueError("Target language '{}' is not supported. Valid options are {}".
                         format(language_pair[1], supported_language))

    if train_set not in supported_train_set:
        raise ValueError("'{}' is not a valid train set identifier. valid options are {}".
                         format(train_set, supported_train_set))

    if valid_set not in supported_valid_set:
        raise ValueError("'{}' is not a valid valid set identifier. valid options are {}".
                         format(valid_set, supported_valid_set))

    if test_set not in supported_test_set:
        raise ValueError("'{}' is not a valid valid set identifier. valid options are {}".
                         format(test_set, supported_test_set))

    train_filenames = '{}.{}'.format(train_set, language_pair[0]), '{}.{}'.format(train_set, language_pair[1])
    valid_filenames = '{}.{}'.format(valid_set, language_pair[0]), '{}.{}'.format(valid_set, language_pair[1])
    test_filenames = '{}.{}'.format(test_set, language_pair[0]), '{}.{}'.format(test_set, language_pair[1])

    if split == 'train':
        src_file, tgt_file = train_filenames
    elif split == 'valid':
        src_file, tgt_file = valid_filenames
    else:
        src_file, tgt_file = test_filenames

    dataset_tar = download_from_url(URL, root=root, hash_value=MD5, path=os.path.join(root, _PATH), hash_type='md5')
    extracted_files = extract_archive(dataset_tar)

    data_filenames = {
        split: _construct_filepaths(extracted_files, src_file, tgt_file),
    }

    for key in data_filenames:
        if len(data_filenames[key]) == 0 or data_filenames[key] is None:
            raise FileNotFoundError(
                "Files are not found for data type {}".format(key))

    assert data_filenames[split][0] is not None, "Internal Error: File not found for reading"
    assert data_filenames[split][1] is not None, "Internal Error: File not found for reading"
    src_data_iter = _read_text_iterator(data_filenames[split][0])
    tgt_data_iter = _read_text_iterator(data_filenames[split][1])

    def _iter(src_data_iter, tgt_data_iter):
        for item in zip(src_data_iter, tgt_data_iter):
            yield item

    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[os.path.splitext(src_file)[0]], _iter(src_data_iter, tgt_data_iter))