in NMT/src/data/loader.py [0:0]
def check_all_data_params(params):
"""
Check datasets parameters.
"""
# check languages
params.langs = params.langs.split(',')
assert len(params.langs) == len(set(params.langs)) >= 2
assert sorted(params.langs) == params.langs
params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
params.lang2id = {k: v for v, k in params.id2lang.items()}
params.n_langs = len(params.langs)
# check monolingual datasets
params.mono_dataset = {k: v for k, v in [x.split(':') for x in params.mono_dataset.split(';') if len(x) > 0]}
assert not (len(params.mono_dataset) == 0) ^ (params.n_mono == 0)
if len(params.mono_dataset) > 0:
assert type(params.mono_dataset) is dict
assert all(lang in params.langs for lang in params.mono_dataset.keys())
assert all(len(v.split(',')) == 3 for v in params.mono_dataset.values())
params.mono_dataset = {k: tuple(v.split(',')) for k, v in params.mono_dataset.items()}
assert all(all(((i > 0 and path == '') or os.path.isfile(path)) for i, path in enumerate(paths))
for paths in params.mono_dataset.values())
# check parallel datasets
params.para_dataset = {k: v for k, v in [x.split(':') for x in params.para_dataset.split(';') if len(x) > 0]}
assert type(params.para_dataset) is dict
assert all(len(k.split('-')) == 2 for k in params.para_dataset.keys())
assert all(len(v.split(',')) == 3 for v in params.para_dataset.values())
params.para_dataset = {tuple(k.split('-')): tuple(v.split(',')) for k, v in params.para_dataset.items()}
assert not (params.n_para == 0) ^ (all(v[0] == '' for v in params.para_dataset.values()))
for (lang1, lang2), (train_path, valid_path, test_path) in params.para_dataset.items():
assert lang1 < lang2 and lang1 in params.langs and lang2 in params.langs
assert train_path == '' or os.path.isfile(train_path.replace('XX', lang1))
assert train_path == '' or os.path.isfile(train_path.replace('XX', lang2))
assert os.path.isfile(valid_path.replace('XX', lang1))
assert os.path.isfile(valid_path.replace('XX', lang2))
assert os.path.isfile(test_path.replace('XX', lang1))
assert os.path.isfile(test_path.replace('XX', lang2))
# check back-parallel datasets
params.back_dataset = {k: v for k, v in [x.split(':') for x in params.back_dataset.split(';') if len(x) > 0]}
assert type(params.back_dataset) is dict
assert not (len(params.back_dataset) == 0) ^ (params.n_back == 0)
assert all(len(k.split('-')) == 2 for k in params.back_dataset.keys())
assert all(len(v.split(',')) == 2 for v in params.back_dataset.values())
params.back_dataset = {
tuple(k.split('-')): tuple(v.split(','))
for k, v in params.back_dataset.items()
}
for (lang1, lang2), (src_path, tgt_path) in params.back_dataset.items():
assert lang1 in params.langs and lang2 in params.langs
assert os.path.isfile(src_path)
assert os.path.isfile(tgt_path)
# check parallel directions
params.para_directions = [x.split('-') for x in params.para_directions.split(',') if len(x) > 0]
if len(params.para_directions) > 0:
assert params.n_para != 0
assert type(params.para_directions) is list
assert all(len(x) == 2 for x in params.para_directions)
params.para_directions = [tuple(x) for x in params.para_directions]
assert len(params.para_directions) == len(set(params.para_directions))
# check that every direction has an associated train set
for lang1, lang2 in params.para_directions:
assert lang1 in params.langs and lang2 in params.langs
k = (lang1, lang2) if lang1 < lang2 else (lang2, lang1)
assert k in params.para_dataset
assert params.para_dataset[k][0] != ''
# check mono directions
params.mono_directions = [x for x in params.mono_directions.split(',') if len(x) > 0]
if len(params.mono_directions) > 0:
assert params.n_mono != 0
assert type(params.mono_directions) is list
assert all(lang in params.langs for lang in params.mono_directions)
assert all(lang in params.mono_dataset for lang in params.mono_directions)
# check directions with pivot
params.pivo_directions = [x.split('-') for x in params.pivo_directions.split(',') if len(x) > 0]
if len(params.pivo_directions) > 0:
assert type(params.pivo_directions) is list
assert all(len(x) == 3 for x in params.pivo_directions)
params.pivo_directions = [tuple(x) for x in params.pivo_directions]
assert len(params.pivo_directions) == len(set(params.pivo_directions))
# check that every direction has an associated train set
for lang1, lang2, lang3 in params.pivo_directions:
assert lang1 in params.langs
assert lang2 in params.langs
assert lang3 in params.langs
# 2-lang back-translation - autoencoding
if lang1 != lang2 == lang3:
k = (lang1, lang2) if lang1 < lang2 else (lang2, lang1)
assert k in params.para_dataset
assert params.para_dataset[k][0] != ''
# 2-lang back-translation - parallel data
elif lang1 == lang3 != lang2:
assert lang1 in params.mono_dataset
# 3-lang back-translation - parallel data
else:
assert lang1 != lang2 and lang2 != lang3 and lang1 != lang3
k = (lang1, lang3) if lang1 < lang3 else (lang3, lang1)
assert k in params.para_dataset
assert params.para_dataset[k][0] != ''
assert params.otf_backprop_temperature == -1 or params.otf_backprop_temperature > 0
assert params.otf_update_enc or params.otf_update_dec
else:
assert params.otf_backprop_temperature == -1
# check back-parallel directions
params.back_directions = [x.split('-') for x in params.back_directions.split(',') if len(x) > 0]
if len(params.back_directions) > 0:
assert type(params.back_directions) is list
assert all(len(x) == 2 for x in params.back_directions)
params.back_directions = [tuple(x) for x in params.back_directions]
assert len(params.back_directions) == len(set(params.back_directions))
# check that every direction has an associated train set
for lang1, lang2 in params.back_directions:
assert lang1 in params.langs
assert lang2 in params.langs
assert lang1 != lang2 # might not be necessary (could be a denoising autoencoder)
assert (lang1, lang2) in params.back_dataset
# check all monolingual datasets are used
for lang, _ in params.mono_dataset.items():
assert lang in params.mono_directions or any(lang1 == lang3 == lang for (lang1, _, lang3) in params.pivo_directions)
# check all parallel datasets are used
for (lang1, lang2), (train_path, _, _) in params.para_dataset.items():
assert (train_path == '' or
(lang1, lang2) in params.para_directions or
(lang2, lang1) in params.para_directions or
any((lang1 == _lang1 and lang2 == _lang2) or (lang1 == _lang2 and lang2 == _lang1) or
(lang1 == _lang1 and lang2 == _lang3) or (lang1 == _lang3 and lang2 == _lang1)
for _lang1, _lang2, _lang3 in params.pivo_directions))
# check all back-parallel datasets are used
for (lang1, lang2), _ in params.back_dataset.items():
assert (lang1, lang2) in params.back_directions
# check there is at least one direction / some data
assert len(params.mono_directions) + len(params.para_directions) + len(params.pivo_directions) > 0
assert not params.n_mono == params.n_para == 0
# check vocabulary parameters
params.vocab = {k: v for k, v in [x.split(':') for x in params.vocab.split(';') if len(x) > 0]}
if len(params.vocab) > 0:
assert type(params.vocab) is dict
assert set(params.vocab.keys()) == set(params.langs)
assert all(os.path.isfile(path) for path in params.vocab.values())
assert params.vocab_min_count == 0 or params.vocab_min_count >= 0 and len(params.vocab) > 0
# check coefficients
assert not (params.lambda_dis == "0") ^ (params.n_dis == 0)
assert not (params.lambda_xe_mono == "0") ^ (len(params.mono_directions) == 0)
assert not (params.lambda_xe_para == "0") ^ (len(params.para_directions) == 0)
assert not (params.lambda_xe_back == "0") ^ (len(params.back_directions) == 0)
assert not (params.lambda_xe_otfd == "0") ^ (len([True for _, lang2, lang3 in params.pivo_directions if lang2 != lang3]) == 0)
assert not (params.lambda_xe_otfa == "0") ^ (len([True for _, lang2, lang3 in params.pivo_directions if lang2 == lang3]) == 0)
# max length / max vocab / sentence noise
assert params.max_len > 0
assert params.max_vocab == -1 or params.max_vocab > 0
if len(params.mono_directions) == 0:
assert params.word_shuffle == 0
assert params.word_dropout == 0
assert params.word_blank == 0
else:
assert params.word_shuffle == 0 or params.word_shuffle > 1
assert 0 <= params.word_dropout < 1
assert 0 <= params.word_blank < 1