in codegen_sources/model/src/data/loader.py [0:0]
def check_data_params(params):
"""
Check datasets parameters.
"""
# data path
assert os.path.isdir(params.data_path), params.data_path
# check languages
params.langs = params.lgs.split("-") if params.lgs != "debug" else ["en"]
assert len(params.langs) == len(set(params.langs)) >= 1, [
l for l in params.langs if params.langs.count(l) >= 2
]
# assert sorted(params.langs) == params.langs
params.id2lang = {k: v for k, v in enumerate(sorted(params.langs))}
params.lang2id = {k: v for v, k in params.id2lang.items()}
params.n_langs = len(params.langs)
if params.lgs_id_mapping != "":
mappings = params.lgs_id_mapping.split(",")
for m in mappings:
split = m.split(":")
assert len(split) == 2, f"Cannot parse {m} in {params.lgs_id_mapping}"
source, dest = split
assert (
source in params.langs
), f"unknown source {source} from {m}. Not part of the languages in {params.langs}"
assert (
dest in params.langs
), f"unknown destination language {dest} from {m}. Not part of the languages in {params.langs}"
params.lang2id[source] = params.lang2id[dest]
# CLM steps
clm_steps = [s.split("-") for s in params.clm_steps.split(",") if len(s) > 0]
params.clm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in clm_steps]
assert all(
[
(l1 in params.langs) and (l2 in params.langs or l2 is None)
for l1, l2 in params.clm_steps
]
)
assert len(params.clm_steps) == len(set(params.clm_steps))
# MLM / TLM steps
mlm_steps = [s.split("-") for s in params.mlm_steps.split(",") if len(s) > 0]
params.mlm_steps = [(s[0], None) if len(s) == 1 else tuple(s) for s in mlm_steps]
assert all(
[
(l1 in params.langs) and (l2 in params.langs or l2 is None)
for l1, l2 in params.mlm_steps
]
)
assert len(params.mlm_steps) == len(set(params.mlm_steps))
# machine translation steps
params.mt_steps = [
tuple(s.split("-")) for s in params.mt_steps.split(",") if len(s) > 0
]
assert all([len(x) == 2 for x in params.mt_steps])
assert all(
[l1 in params.langs and l2 in params.langs for l1, l2 in params.mt_steps]
)
assert all([l1 != l2 for l1, l2 in params.mt_steps])
assert len(params.mt_steps) == len(set(params.mt_steps))
params.mt_spans_steps = [
tuple(s.split("-")) for s in params.mt_spans_steps.split(",") if len(s) > 0
]
assert all((len(split) == 3 for split in params.mt_spans_steps))
assert all(
[l1 != l2 and l1 != l3 and l2 != l3 for l1, l2, l3 in params.mt_spans_steps]
)
assert len(params.mt_spans_steps) == len(set(params.mt_spans_steps))
assert (
len(params.mt_steps) + len(params.mt_spans_steps) == 0
or not params.encoder_only
)
assert (
len(params.mt_spans_steps) > 0
) == params.spans_emb_encoder, f"mt_spans steps but spans are not used or trying to use spans without spans steps {len(params.mt_spans_steps)}, {params.spans_emb_encoder}"
# do steps
params.do_steps = [
tuple(s.split("-")) for s in params.do_steps.split(",") if len(s) > 0
]
assert all([len(x) == 2 for x in params.do_steps])
assert all(
[l1 in params.langs and l2 in params.langs for l1, l2 in params.do_steps]
)
assert all([l1 != l2 for l1, l2 in params.do_steps])
assert len(params.do_steps) == len(set(params.do_steps))
# classification steps
params.classif_steps = [
tuple(s.split("-")) for s in params.classif_steps.split(",") if len(s) > 0
]
assert all([len(x) == 2 for x in params.classif_steps])
assert all([l1 in params.langs for l1, l2 in params.classif_steps])
assert all([l1 != l2 for l1, l2 in params.classif_steps])
assert len(params.classif_steps) == len(set(params.classif_steps))
assert (
len(params.classif_steps) + len(params.mt_spans_steps) == 0
or not params.n_classes_classif <= 0
)
params.use_classifier = True if len(params.classif_steps) > 0 else False
# denoising auto-encoder steps
params.ae_steps = [s for s in params.ae_steps.split(",") if len(s) > 0]
assert all([lang in params.langs for lang in params.ae_steps])
assert len(params.ae_steps) == len(set(params.ae_steps))
assert len(params.ae_steps) == 0 or not params.encoder_only
# back-translation steps
params.bt_steps = [
tuple(s.split("-")) for s in params.bt_steps.split(",") if len(s) > 0
]
assert all([len(x) == 3 for x in params.bt_steps])
assert all(
[
l1 in params.langs and l2 in params.langs and l3 in params.langs
for l1, l2, l3 in params.bt_steps
]
)
assert all([l1 == l3 and l1 != l2 for l1, l2, l3 in params.bt_steps])
assert len(params.bt_steps) == len(set(params.bt_steps))
assert len(params.bt_steps) == 0 or not params.encoder_only
params.bt_src_langs = [l1 for l1, _, _ in params.bt_steps]
# self-training steps
params.st_steps = [
(s.split("-")[0], tuple(s.split("-")[1].split("|")))
for s in params.st_steps.split(",")
if len(s) > 0
]
assert all([len(x) == 2 for x in params.st_steps])
assert all(
[
l1 in params.langs and all([l2 in params.langs for l2 in langs2])
for l1, langs2 in params.st_steps
]
), params.st_steps
assert all([l1 != l2 for l1, langs2 in params.st_steps for l2 in langs2])
assert len(params.st_steps) == len(set(params.st_steps))
assert all([len(langs2) > 0 for l1, langs2 in params.st_steps]), params.st_steps
params.st_src_langs = [l1 for l1, _ in params.st_steps]
params.st_tgt_langs = list(
set([l2 for _, langs2 in params.st_steps for l2 in langs2])
)
if len(params.st_src_langs) > 0:
logger.info(f"st source langs: {params.st_src_langs}")
logger.info(f"st target langs: {params.st_tgt_langs}")
# unit tests path
assert os.path.isfile(params.unit_tests_path), params.unit_tests_path
# check monolingual datasets
required_mono = set(
[l1 for l1, l2 in (params.mlm_steps + params.clm_steps) if l2 is None]
+ params.ae_steps
+ params.bt_src_langs
)
params.mono_dataset = {
lang: {
splt: os.path.join(params.data_path, "%s.%s.pth" % (splt, lang))
for splt in DATASET_SPLITS
}
for lang in params.langs
if lang in required_mono
}
for lang in params.st_src_langs:
if lang not in params.mono_dataset:
params.mono_dataset[lang] = dict()
params.mono_dataset[lang][SELF_TRAINED] = os.path.join(
params.data_path, "%s.%s.pth" % (SELF_TRAINED, lang)
)
for paths in params.mono_dataset.values():
for p in paths.values():
if not os.path.isfile(p):
logger.error(f"{p} not found")
if not params.eval_only:
assert all(
[
all(
[
os.path.isfile(p) or os.path.isfile(p.replace("pth", "0.pth"))
for p in paths.values()
]
)
for paths in params.mono_dataset.values()
]
), [
[
p
for p in paths.values()
if not (os.path.isfile(p) or os.path.isfile(p.replace("pth", "0.pth")))
]
for paths in params.mono_dataset.values()
]
assert isinstance(
params.n_sentences_eval, int
), f"n_sentences_eval was {params.n_sentences_eval}, it should be an int"
# check parallel datasets
required_para_train = set(
params.clm_steps
+ params.mlm_steps
+ params.mt_steps
+ params.classif_steps
+ params.do_steps
)
required_para = (
required_para_train
| set([(l2, l3) for _, l2, l3 in params.bt_steps])
| set([(l1, l2) for l1, langs2 in params.st_steps for l2 in langs2])
| set([(l2, l1) for l1, langs2 in params.st_steps for l2 in langs2])
| set(
[
(l2_1, l2_2)
for l1, langs2 in params.st_steps
for l2_1 in langs2
for l2_2 in langs2
if l2_1 != l2_2
]
)
)
params.para_dataset = {
(src, tgt): {
splt: (
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, src, tgt, src)
),
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, src, tgt, tgt)
),
)
for splt in DATASET_SPLITS
if splt != "train"
or (src, tgt) in required_para_train
or (tgt, src) in required_para_train
}
for src in params.langs
for tgt in params.langs
if src < tgt and ((src, tgt) in required_para or (tgt, src) in required_para)
}
for lang, label in params.classif_steps:
params.para_dataset[(lang, label)] = {
splt: (
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, lang, label, lang)
),
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, lang, label, label)
),
)
for splt in DATASET_SPLITS
}
for lang1, lang2, span in params.mt_spans_steps:
params.para_dataset[(lang1, lang2, span)] = {
splt: (
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, lang2, lang1)
),
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, lang2, lang2)
),
os.path.join(
params.data_path, "%s.%s-%s.%s.pth" % (splt, lang1, span, span)
),
)
for splt in DATASET_SPLITS
}
for step_paths in params.para_dataset.values():
for paths in step_paths.values():
for p in paths:
if not os.path.isfile(p):
logger.error(f"{p} not found")
params.validation_metrics = params.validation_metrics.replace(
"#obf_proba", str(params.obf_proba)
)
params.stopping_criterion = params.stopping_criterion.replace(
"#obf_proba", str(params.obf_proba)
)
# parse which datasets should have sentence ids
params.has_sentence_ids = (
[s.split("|") for s in params.has_sentence_ids.split(",")]
if params.has_sentence_ids != ""
else []
)
assert all([len(x) == 2 for x in params.has_sentence_ids]), params.has_sentence_ids
params.has_sentence_ids = [
(split, tuple(langs.split("-"))) for split, langs in params.has_sentence_ids
]
assert all(
[len(langs) == 1 or len(langs) == 2 for split, langs in params.has_sentence_ids]
), params.has_sentence_ids
for split, langs in params.has_sentence_ids:
if langs == ("para",) or langs == ("all",):
params.has_sentence_ids += [
(split, langs) for langs in params.para_dataset.keys()
]
if langs == ("mono",) or langs == ("all",):
params.has_sentence_ids += [
(split, (lang,)) for lang in params.mono_dataset.keys()
]
assert all(
[
all([lang in params.langs + ["para", "mono", "all"] for lang in langs])
for split, langs in params.has_sentence_ids
]
), params.has_sentence_ids
assert len(set(params.has_sentence_ids)) == len(params.has_sentence_ids)
assert (
len(params.mono_dataset) > 0 or len(params.para_dataset) > 0
), "No dataset to be loaded, you probably forget to set a training step."