def gen_configs()

in curiosity/cli.py [0:0]


def gen_configs(gpu, metrics_dir):
    """
    Create the configuration files for the different models.
    This is separate from hyper parameter tuning and directly
    corresponds to models in the paper table

    The gpu flag can be taken multiple times and indicates to write
    jobs configured to use
    those gpus.
    """
    gpu_list = []
    if isinstance(gpu, (tuple, list)):
        if len(gpu) == 0:
            gpu_list.append(-1)
        else:
            for e in gpu:
                gpu_list.append(int(e))
    elif isinstance(gpu, int):
        gpu_list.append(gpu)
    elif isinstance(gpu, str):
        gpu_list.append(int(gpu))
    else:
        raise ValueError("wrong input type")
    gpu_list = cycle(gpu_list)
    # Note: key should be valid in a unix filename
    # Values must be strings that represent jsonnet "code"
    # These names are shared to the figure plotting code since the filename
    # is based on it
    all_configs = {
        # This is the full model, so default params
        "glove_bilstm": {},
        # Ablations leaving on out
        "glove_bilstm-known": {"disable_known_entities": "true"},
        # Completely ablate out anything related to dialog acts
        "glove_bilstm-da": {"disable_dialog_acts": "true"},
        # Completely ablate out anything related to likes
        "glove_bilstm-like": {"disable_likes": "true"},
        # Completley ablate out anything related to facts
        "glove_bilstm-facts": {"disable_facts": "true"},
        # This is the full model, so default params
        "bert": {"use_glove": "false", "use_bert": "true"},
        # Ablations leaving on out
        "bert-known": {
            "disable_known_entities": "true",
            "use_glove": "false",
            "use_bert": "true",
        },
        # Completely ablate out anything related to dialog acts
        "bert-da": {
            "disable_dialog_acts": "true",
            "use_glove": "false",
            "use_bert": "true",
        },
        # Completely ablate out anything related to likes
        "bert-like": {
            "disable_likes": "true",
            "use_glove": "false",
            "use_bert": "true",
        },
        # Completley ablate out anything related to facts
        "bert-facts": {
            "disable_facts": "true",
            "use_glove": "false",
            "use_bert": "true",
        },
    }
    with open("run_allennlp.sh", "w") as exp_f:
        exp_f.write("#!/usr/bin/env bash\n")
        for name, conf in all_configs.items():
            model_conf: str = _jsonnet.evaluate_file(
                "configs/model.jsonnet", tla_codes=conf
            )
            config_path = os.path.join("configs/generated", name + ".json")
            model_path = os.path.join("models", name)
            with open(config_path, "w") as f:
                f.write(model_conf)

            job_gpu = next(gpu_list)
            if job_gpu != -1:
                gpu_option = (
                    " -o '" + json.dumps({"trainer": {"cuda_device": job_gpu}}) + "'"
                )
            else:
                gpu_option = ""
            exp_f.write(f"# Experiments for: {name}\n")
            exp_f.write(
                f"allennlp train --include-package curiosity -s {model_path} -f {config_path}{gpu_option}\n"
            )

            val_out = os.path.join(metrics_dir, f"{name}_val_metrics.json")
            test_out = os.path.join(metrics_dir, f"{name}_test_metrics.json")
            zero_out = os.path.join(metrics_dir, f"{name}_zero_metrics.json")
            exp_f.write(
                f"allennlp evaluate --include-package curiosity{gpu_option} --output-file {val_out} {model_path} {VAL_DIALOGS}\n"
            )
            exp_f.write(
                f"allennlp evaluate --include-package curiosity{gpu_option} --output-file {test_out} {model_path} {TEST_DIALOGS}\n"
            )
            exp_f.write(
                f"allennlp evaluate --include-package curiosity{gpu_option} --output-file {zero_out} {model_path} {ZERO_DIALOGS}\n"
            )
            exp_f.write("\n")