def check_data_params()

in XLM/src/data/loader.py [0:0]


def check_data_params(params):
    """
    Check datasets parameters.
    """
    # data path
    assert os.path.isdir(params.data_path), params.data_path

    # check languages
    params.langs = params.lgs.split('-') if params.lgs != 'debug' else ['en']
    assert len(params.langs) == len(set(params.langs)) >= 1
    # assert sorted(params.langs) == params.langs
    params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
    params.lang2id = {k: v for v, k in params.id2lang.items()}
    params.n_langs = len(params.langs)

    # CLM steps
    clm_steps = [s.split('-')
                 for s in params.clm_steps.split(',') if len(s) > 0]
    params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s)
                        for s in clm_steps]
    assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None)
                for l1, l2 in params.clm_steps])
    assert len(params.clm_steps) == len(set(params.clm_steps))

    # MLM / TLM steps
    mlm_steps = [s.split('-')
                 for s in params.mlm_steps.split(',') if len(s) > 0]
    params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s)
                        for s in mlm_steps]
    assert all([(l1 in params.langs) and (l2 in params.langs or l2 is None)
                for l1, l2 in params.mlm_steps])
    assert len(params.mlm_steps) == len(set(params.mlm_steps))

    # machine translation steps
    params.mt_steps = [tuple(s.split('-'))
                       for s in params.mt_steps.split(',') if len(s) > 0]
    assert all([len(x) == 2 for x in params.mt_steps])
    assert all(
        [l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps])
    assert all([l1 != l2 for l1, l2 in params.mt_steps])
    assert len(params.mt_steps) == len(set(params.mt_steps))
    assert len(params.mt_steps) == 0 or not params.encoder_only

    # denoising auto-encoder steps
    params.ae_steps = [s for s in params.ae_steps.split(',') if len(s) > 0]
    assert all([lang in params.langs for lang in params.ae_steps])
    assert len(params.ae_steps) == len(set(params.ae_steps))
    assert len(params.ae_steps) == 0 or not params.encoder_only

    # back-translation steps
    params.bt_steps = [tuple(s.split('-'))
                       for s in params.bt_steps.split(',') if len(s) > 0]
    assert all([len(x) == 3 for x in params.bt_steps])
    assert all([l1 in params.langs and l2 in params.langs and l3 in params.langs for l1,
                l2, l3 in params.bt_steps])
    assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
    assert len(params.bt_steps) == len(set(params.bt_steps))
    assert len(params.bt_steps) == 0 or not params.encoder_only
    params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]

    # check monolingual datasets
    required_mono = set([l1 for l1, l2 in (params.mlm_steps + params.clm_steps)
                         if l2 is None] + params.ae_steps + params.bt_src_langs)
    params.mono_dataset = {
        lang: {
            splt: os.path.join(params.data_path, '%s.%s.pth' % (splt, lang))
            for splt in ['train', 'valid', 'test']
        } for lang in params.langs if lang in required_mono
    }
    for paths in params.mono_dataset.values():
        for p in paths.values():
            if not os.path.isfile(p):
                logger.error(f"{p} not found")
    if not params.eval_only:
        assert all([all([os.path.isfile(p) or os.path.isfile(p.replace('pth', '0.pth'))
                    for p in paths.values()]) for paths in params.mono_dataset.values()])

    # check parallel datasets
    required_para_train = set(
        params.clm_steps + params.mlm_steps + params.mt_steps)
    required_para = required_para_train | set(
        [(l2, l3) for _, l2, l3 in params.bt_steps])
    params.para_dataset = {
        (src, tgt): {
            splt: (os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, src)),
                   os.path.join(params.data_path, '%s.%s-%s.%s.pth' % (splt, src, tgt, tgt)))
            for splt in ['train', 'valid', 'test']
            if splt != 'train' or (src, tgt) in required_para_train or (tgt, src) in required_para_train
        } for src in params.langs for tgt in params.langs
        if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para)
    }
    for paths in params.para_dataset.values():
        for p1, p2 in paths.values():
            if not os.path.isfile(p1):
                logger.error(f"{p1} not found")
            if not os.path.isfile(p2):
                logger.error(f"{p2} not found")