def check_data_params()

in codegen_sources/model/src/data/loader.py [0:0]


def check_data_params(params):
    """
    Check datasets parameters.
    """
    # data path
    assert os.path.isdir(params.data_path), params.data_path

    # check languages
    params.langs = params.lgs.split("-") if params.lgs != "debug" else ["en"]
    assert len(params.langs) == len(set(params.langs)) >= 1, [
        l for l in params.langs if params.langs.count(l) >= 2
    ]

    # assert sorted(params.langs) == params.langs
    params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
    params.lang2id = {k: v for v, k in params.id2lang.items()}
    params.n_langs = len(params.langs)
    if params.lgs_id_mapping != "":
        mappings = params.lgs_id_mapping.split(",")
        for m in mappings:
            split = m.split(":")
            assert len(split) == 2, f"Cannot parse {m} in {params.lgs_id_mapping}"
            source, dest = split
            assert (
                source in params.langs
            ), f"unknown source {source} from {m}. Not part of the languages in {params.langs}"
            assert (
                dest in params.langs
            ), f"unknown destination language {dest} from {m}. Not part of the languages in {params.langs}"
            params.lang2id[source] = params.lang2id[dest]

    # CLM steps
    clm_steps = [s.split("-") for s in params.clm_steps.split(",") if len(s) > 0]
    params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps]
    assert all(
        [
            (l1 in params.langs) and (l2 in params.langs or l2 is None)
            for l1, l2 in params.clm_steps
        ]
    )
    assert len(params.clm_steps) == len(set(params.clm_steps))

    # MLM / TLM steps
    mlm_steps = [s.split("-") for s in params.mlm_steps.split(",") if len(s) > 0]
    params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps]
    assert all(
        [
            (l1 in params.langs) and (l2 in params.langs or l2 is None)
            for l1, l2 in params.mlm_steps
        ]
    )
    assert len(params.mlm_steps) == len(set(params.mlm_steps))

    # machine translation steps
    params.mt_steps = [
        tuple(s.split("-")) for s in params.mt_steps.split(",") if len(s) > 0
    ]

    assert all([len(x) == 2 for x in params.mt_steps])
    assert all(
        [l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps]
    )
    assert all([l1 != l2 for l1, l2 in params.mt_steps])
    assert len(params.mt_steps) == len(set(params.mt_steps))

    params.mt_spans_steps = [
        tuple(s.split("-")) for s in params.mt_spans_steps.split(",") if len(s) > 0
    ]
    assert all((len(split) == 3 for split in params.mt_spans_steps))
    assert all(
        [l1 != l2 and l1 != l3 and l2 != l3 for l1, l2, l3 in params.mt_spans_steps]
    )
    assert len(params.mt_spans_steps) == len(set(params.mt_spans_steps))
    assert (
        len(params.mt_steps) + len(params.mt_spans_steps) == 0
        or not params.encoder_only
    )

    assert (
        len(params.mt_spans_steps) > 0
    ) == params.spans_emb_encoder, f"mt_spans steps but spans are not used or trying to use spans without spans steps {len(params.mt_spans_steps)}, {params.spans_emb_encoder}"

    # do steps
    params.do_steps = [
        tuple(s.split("-")) for s in params.do_steps.split(",") if len(s) > 0
    ]
    assert all([len(x) == 2 for x in params.do_steps])
    assert all(
        [l1 in params.langs and l2 in params.langs for l1, l2 in params.do_steps]
    )
    assert all([l1 != l2 for l1, l2 in params.do_steps])
    assert len(params.do_steps) == len(set(params.do_steps))

    # classification steps
    params.classif_steps = [
        tuple(s.split("-")) for s in params.classif_steps.split(",") if len(s) > 0
    ]
    assert all([len(x) == 2 for x in params.classif_steps])
    assert all([l1 in params.langs for l1, l2 in params.classif_steps])
    assert all([l1 != l2 for l1, l2 in params.classif_steps])
    assert len(params.classif_steps) == len(set(params.classif_steps))
    assert (
        len(params.classif_steps) + len(params.mt_spans_steps) == 0
        or not params.n_classes_classif <= 0
    )
    params.use_classifier = True if len(params.classif_steps) > 0 else False

    # denoising auto-encoder steps
    params.ae_steps = [s for s in params.ae_steps.split(",") if len(s) > 0]
    assert all([lang in params.langs for lang in params.ae_steps])
    assert len(params.ae_steps) == len(set(params.ae_steps))
    assert len(params.ae_steps) == 0 or not params.encoder_only

    # back-translation steps
    params.bt_steps = [
        tuple(s.split("-")) for s in params.bt_steps.split(",") if len(s) > 0
    ]
    assert all([len(x) == 3 for x in params.bt_steps])
    assert all(
        [
            l1 in params.langs and l2 in params.langs and l3 in params.langs
            for l1, l2, l3 in params.bt_steps
        ]
    )
    assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
    assert len(params.bt_steps) == len(set(params.bt_steps))
    assert len(params.bt_steps) == 0 or not params.encoder_only
    params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]

    # self-training steps
    params.st_steps = [
        (s.split("-")[0], tuple(s.split("-")[1].split("|")))
        for s in params.st_steps.split(",")
        if len(s) > 0
    ]
    assert all([len(x) == 2 for x in params.st_steps])

    assert all(
        [
            l1 in params.langs and all([l2 in params.langs for l2 in langs2])
            for l1, langs2 in params.st_steps
        ]
    ), params.st_steps
    assert all([l1 != l2 for l1, langs2 in params.st_steps for l2 in langs2])
    assert len(params.st_steps) == len(set(params.st_steps))
    assert all([len(langs2) > 0 for l1, langs2 in params.st_steps]), params.st_steps
    params.st_src_langs = [l1 for l1, _ in params.st_steps]
    params.st_tgt_langs = list(
        set([l2 for _, langs2 in params.st_steps for l2 in langs2])
    )
    if len(params.st_src_langs) > 0:
        logger.info(f"st source langs: {params.st_src_langs}")
        logger.info(f"st target langs: {params.st_tgt_langs}")
        # unit tests path
        assert os.path.isfile(params.unit_tests_path), params.unit_tests_path

    # check monolingual datasets
    required_mono = set(
        [l1 for l1, l2 in (params.mlm_steps + params.clm_steps) if l2 is None]
        + params.ae_steps
        + params.bt_src_langs
    )
    params.mono_dataset = {
        lang: {
            splt: os.path.join(params.data_path, "%s.%s.pth" % (splt, lang))
            for splt in DATASET_SPLITS
        }
        for lang in params.langs
        if lang in required_mono
    }
    for lang in params.st_src_langs:
        if lang not in params.mono_dataset:
            params.mono_dataset[lang] = dict()
        params.mono_dataset[lang][SELF_TRAINED] = os.path.join(
            params.data_path, "%s.%s.pth" % (SELF_TRAINED, lang)
        )
    for paths in params.mono_dataset.values():
        for p in paths.values():
            if not os.path.isfile(p):
                logger.error(f"{p} not found")

    if not params.eval_only:
        assert all(
            [
                all(
                    [
                        os.path.isfile(p) or os.path.isfile(p.replace("pth", "0.pth"))
                        for p in paths.values()
                    ]
                )
                for paths in params.mono_dataset.values()
            ]
        ), [
            [
                p
                for p in paths.values()
                if not (os.path.isfile(p) or os.path.isfile(p.replace("pth", "0.pth")))
            ]
            for paths in params.mono_dataset.values()
        ]
    assert isinstance(
        params.n_sentences_eval, int
    ), f"n_sentences_eval was {params.n_sentences_eval}, it should be an int"
    # check parallel datasets
    required_para_train = set(
        params.clm_steps
        + params.mlm_steps
        + params.mt_steps
        + params.classif_steps
        + params.do_steps
    )
    required_para = (
        required_para_train
        | set([(l2, l3) for _, l2, l3 in params.bt_steps])
        | set([(l1, l2) for l1, langs2 in params.st_steps for l2 in langs2])
        | set([(l2, l1) for l1, langs2 in params.st_steps for l2 in langs2])
        | set(
            [
                (l2_1, l2_2)
                for l1, langs2 in params.st_steps
                for l2_1 in langs2
                for l2_2 in langs2
                if l2_1 != l2_2
            ]
        )
    )

    params.para_dataset = {
        (src, tgt): {
            splt: (
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, src, tgt, src)
                ),
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, src, tgt, tgt)
                ),
            )
            for splt in DATASET_SPLITS
            if splt != "train"
            or (src, tgt) in required_para_train
            or (tgt, src) in required_para_train
        }
        for src in params.langs
        for tgt in params.langs
        if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para)
    }
    for lang, label in params.classif_steps:
        params.para_dataset[(lang, label)] = {
            splt: (
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, lang, label, lang)
                ),
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, lang, label, label)
                ),
            )
            for splt in DATASET_SPLITS
        }

    for lang1, lang2, span in params.mt_spans_steps:
        params.para_dataset[(lang1, lang2, span)] = {
            splt: (
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, lang2, lang1)
                ),
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, lang2, lang2)
                ),
                os.path.join(
                    params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, span, span)
                ),
            )
            for splt in DATASET_SPLITS
        }

    for step_paths in params.para_dataset.values():
        for paths in step_paths.values():
            for p in paths:
                if not os.path.isfile(p):
                    logger.error(f"{p} not found")

    params.validation_metrics = params.validation_metrics.replace(
        "#obf_proba", str(params.obf_proba)
    )
    params.stopping_criterion = params.stopping_criterion.replace(
        "#obf_proba", str(params.obf_proba)
    )

    # parse which datasets should have sentence ids
    params.has_sentence_ids = (
        [s.split("|") for s in params.has_sentence_ids.split(",")]
        if params.has_sentence_ids != ""
        else []
    )

    assert all([len(x) == 2 for x in params.has_sentence_ids]), params.has_sentence_ids
    params.has_sentence_ids = [
        (split, tuple(langs.split("-"))) for split, langs in params.has_sentence_ids
    ]
    assert all(
        [len(langs) == 1 or len(langs) == 2 for split, langs in params.has_sentence_ids]
    ), params.has_sentence_ids
    for split, langs in params.has_sentence_ids:
        if langs == ("para",) or langs == ("all",):
            params.has_sentence_ids += [
                (split, langs) for langs in params.para_dataset.keys()
            ]
        if langs == ("mono",) or langs == ("all",):
            params.has_sentence_ids += [
                (split, (lang,)) for lang in params.mono_dataset.keys()
            ]
    assert all(
        [
            all([lang in params.langs + ["para", "mono", "all"] for lang in langs])
            for split, langs in params.has_sentence_ids
        ]
    ), params.has_sentence_ids
    assert len(set(params.has_sentence_ids)) == len(params.has_sentence_ids)

    assert (
        len(params.mono_dataset) > 0 or len(params.para_dataset) > 0
    ), "No dataset to be loaded, you probably forget to set a training step."