in torchtext/experimental/datasets/raw/wmt14.py [0:0]
def WMT14(root, split,
language_pair=('de', 'en'),
train_set='train.tok.clean.bpe.32000',
valid_set='newstest2013.tok.bpe.32000',
test_set='newstest2014.tok.bpe.32000'):
"""WMT14 Dataset
The available datasets include following:
**Language pairs**:
+-----+-----+-----+
| |'en' |'de' |
+-----+-----+-----+
|'en' | | x |
+-----+-----+-----+
|'de' | x | |
+-----+-----+-----+
Args:
root: Directory where the datasets are saved. Default: ".data"
split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’)
language_pair: tuple or list containing src and tgt language
train_set: A string to identify train set.
valid_set: A string to identify validation set.
test_set: A string to identify test set.
Examples:
>>> from torchtext.datasets import WMT14
>>> train_iter, valid_iter, test_iter = WMT14()
>>> src_sentence, tgt_sentence = next(train_iter)
"""
supported_language = ['en', 'de']
supported_train_set = [s for s in NUM_LINES if 'train' in s]
supported_valid_set = [s for s in NUM_LINES if 'test' in s]
supported_test_set = [s for s in NUM_LINES if 'test' in s]
assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively'
if language_pair[0] not in supported_language:
raise ValueError("Source language '{}' is not supported. Valid options are {}".
format(language_pair[0], supported_language))
if language_pair[1] not in supported_language:
raise ValueError("Target language '{}' is not supported. Valid options are {}".
format(language_pair[1], supported_language))
if train_set not in supported_train_set:
raise ValueError("'{}' is not a valid train set identifier. valid options are {}".
format(train_set, supported_train_set))
if valid_set not in supported_valid_set:
raise ValueError("'{}' is not a valid valid set identifier. valid options are {}".
format(valid_set, supported_valid_set))
if test_set not in supported_test_set:
raise ValueError("'{}' is not a valid valid set identifier. valid options are {}".
format(test_set, supported_test_set))
train_filenames = '{}.{}'.format(train_set, language_pair[0]), '{}.{}'.format(train_set, language_pair[1])
valid_filenames = '{}.{}'.format(valid_set, language_pair[0]), '{}.{}'.format(valid_set, language_pair[1])
test_filenames = '{}.{}'.format(test_set, language_pair[0]), '{}.{}'.format(test_set, language_pair[1])
if split == 'train':
src_file, tgt_file = train_filenames
elif split == 'valid':
src_file, tgt_file = valid_filenames
else:
src_file, tgt_file = test_filenames
dataset_tar = download_from_url(URL, root=root, hash_value=MD5, path=os.path.join(root, _PATH), hash_type='md5')
extracted_files = extract_archive(dataset_tar)
data_filenames = {
split: _construct_filepaths(extracted_files, src_file, tgt_file),
}
for key in data_filenames:
if len(data_filenames[key]) == 0 or data_filenames[key] is None:
raise FileNotFoundError(
"Files are not found for data type {}".format(key))
assert data_filenames[split][0] is not None, "Internal Error: File not found for reading"
assert data_filenames[split][1] is not None, "Internal Error: File not found for reading"
src_data_iter = _read_text_iterator(data_filenames[split][0])
tgt_data_iter = _read_text_iterator(data_filenames[split][1])
def _iter(src_data_iter, tgt_data_iter):
for item in zip(src_data_iter, tgt_data_iter):
yield item
return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[os.path.splitext(src_file)[0]], _iter(src_data_iter, tgt_data_iter))