in XLM/src/data/loader.py [0:0]
def check_data_params(params):
"""
Check datasets parameters.
"""
# data path
assert os.path.isdir(params.data_path), params.data_path
# check languages
params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en']
assert len(params.langs) == len(set(params.langs)) >= 1
# assert sorted(params.langs) == params.langs
params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
params.lang2id = {k: v for v, k in params.id2lang.items()}
params.n_langs = len(params.langs)
# CLM steps
clm_steps = [s.split('-')
for s in params.clm_steps.split(',') if len(s) > 0]
params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s)
for s in clm_steps]
assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None)
for l1, l2 in params.clm_steps])
assert len(params.clm_steps) == len(set(params.clm_steps))
# MLM / TLM steps
mlm_steps = [s.split('-')
for s in params.mlm_steps.split(',') if len(s) > 0]
params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s)
for s in mlm_steps]
assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None)
for l1, l2 in params.mlm_steps])
assert len(params.mlm_steps) == len(set(params.mlm_steps))
# machine translation steps
params.mt_steps = [tuple(s.split('-'))
for s in params.mt_steps.split(',') if len(s) > 0]
assert all([len(x) == 2 for x in params.mt_steps])
assert all(
[l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps])
assert all([l1 != l2 for l1, l2 in params.mt_steps])
assert len(params.mt_steps) == len(set(params.mt_steps))
assert len(params.mt_steps) == 0 or not params.encoder_only
# denoising auto-encoder steps
params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0]
assert all([lang in params.langs for lang in params.ae_steps])
assert len(params.ae_steps) == len(set(params.ae_steps))
assert len(params.ae_steps) == 0 or not params.encoder_only
# back-translation steps
params.bt_steps = [tuple(s.split('-'))
for s in params.bt_steps.split(',') if len(s) > 0]
assert all([len(x) == 3 for x in params.bt_steps])
assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1,
l2, l3 in params.bt_steps])
assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
assert len(params.bt_steps) == len(set(params.bt_steps))
assert len(params.bt_steps) == 0 or not params.encoder_only
params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]
# check monolingual datasets
required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps)
if l2 is None] + params.ae_steps + params.bt_src_langs)
params.mono_dataset = {
lang: {
splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang))
for splt in ['train', 'valid', 'test']
} for lang in params.langs if lang in required_mono
}
for paths in params.mono_dataset.values():
for p in paths.values():
if not os.path.isfile(p):
logger.error(f"{p} not found")
if not params.eval_only:
assert all([all([os.path.isfile(p) or os.path.isfile(p.replace('pth', '0.pth'))
for p in paths.values()]) for paths in params.mono_dataset.values()])
# check parallel datasets
required_para_train = set(
params.clm_steps + params.mlm_steps + params.mt_steps)
required_para = required_para_train | set(
[(l2, l3) for _, l2, l3 in params.bt_steps])
params.para_dataset = {
(src, tgt): {
splt: (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)),
os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt)))
for splt in ['train', 'valid', 'test']
if splt != 'train' or (src, tgt) in required_para_train or (tgt, src) in required_para_train
} for src in params.langs for tgt in params.langs
if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para)
}
for paths in params.para_dataset.values():
for p1, p2 in paths.values():
if not os.path.isfile(p1):
logger.error(f"{p1} not found")
if not os.path.isfile(p2):
logger.error(f"{p2} not found")