def run()

in src/datatuner/lm/evaluate.py [0:0]


def run():
    parser = ArgumentParser()
    parser.add_argument("--model_type", type=str, default=None, help="gpt or gpt2")
    parser.add_argument("--input", dest="input", action="store_true")
    parser.add_argument("--cache_pointer", dest="cache_pointer", action="store_true")
    parser.add_argument("--cache_theta", type=float, default=0.0001, help="factor used in the cache pointer mechanism")
    parser.add_argument(
        "--cache_lambda", type=float, default=0, help="weight of the cache probs when cache_pointer is used"
    )
    parser.add_argument(
        "--boost_factor", type=float, default=1, help="weight of the cache probs when cache_pointer is used"
    )
    parser.add_argument(
        "--ignore_existing",
        dest="ignore_existing",
        action="store_true",
        help="ignore previous runs, overwrite the test output files, and start from scratch",
    )

    parser.add_argument("--model_checkpoint", type=str, default="", help="Path, url or short name of the model")
    parser.add_argument("--reranking_mode", type=str, default="average", help="Reranking mode")
    parser.add_argument("--out_folder", type=str, default="", help="subfolder name of the eval results folder")
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)"
    )
    parser.add_argument("--task_config", type=str, help="Path to the tokenization config file", default=None)

    parser.add_argument("--filename", type=str, default="data/instances_dev.pkl", help="File to use for decoding")
    parser.add_argument("--reranker", type=str, default=None, help="model used for reranking (in question answering)")
    parser.add_argument("--no_sample", action="store_true", help="Set to use greedy decoding instead of sampling")
    parser.add_argument(
        "--no_postranking", action="store_true", help="Set to disable post reranking in the presence of reranker"
    )
    parser.add_argument("--max_length", type=int, default=100, help="Maximum length of the output utterances")
    parser.add_argument("--min_length", type=int, default=1, help="Minimum length of the output utterances")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    parser.add_argument("--nbest", type=int, default=5, help="Number of times to run the output generation")
    parser.add_argument("--beam_width", type=int, default=5, help="Beam search width")
    parser.add_argument(
        "--per_step_predictions", type=int, default=2, help="Number of predictions per step of beam search"
    )
    parser.add_argument(
        "--min_complete_in_beam",
        type=int,
        default=10,
        help="Minimum number of complete beam search elements to terminate beam",
    )
    parser.add_argument(
        "--aux_weight", type=float, default=0.5, help="auxiliary model weight (used if a reranker is provided)"
    )
    parser.add_argument(
        "--min_prob",
        type=float,
        default=0.00,
        help="minimum cumulative probability of available tokens to be used on decoding in beam search",
    )
    parser.add_argument(
        "--min_token_prob",
        type=float,
        default=0.00,
        help="minimum probability of token to be used on decoding in beam search",
    )
    parser.add_argument("--prob_window", type=int, default=0, help="Probability window")
    parser.add_argument("--temperature", type=float, default=1, help="Sampling softmax temperature")
    parser.add_argument(
        "--log_every", type=int, default=50, help="frequency of logging the output and computing the metrics"
    )
    parser.add_argument("--dec_dropout", type=float, default=0.0, help="Decoding dropout")
    parser.add_argument("--averaging", type=str, default="default", help="averaging method")
    parser.add_argument("--ewm_alpha", type=int, default=0.5, help="value of com for the EWM average")
    parser.add_argument(
        "--beam_alpha", type=float, default=0.75, help="value of alpha for length penalty in beam search"
    )
    parser.add_argument("--top_k", type=int, default=0, help="Filter top-k tokens before sampling (<=0: no filtering)")
    parser.add_argument(
        "--top_p", type=float, default=0.9, help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)"
    )
    parser.add_argument("--frac", type=float, default=1.0, help="fraction of test data to consider")
    parser.add_argument("--max_data", type=int, default=0, help="Number of data items (0 includes everything)")
    parser.add_argument("--smoothing", action="store_true", help="If true use label smoothing")
    parser.add_argument("--add_coverage_penalty", action="store_true", help="Add coverage penalty while decoding")
    parser.add_argument("--no_mlflow_logging", action="store_true", help="If true disable logging to mlflow")
    parser.add_argument(
        "--cons_classifier", type=str, default=None, help="consistency classifier checkpoint, to use during decoding)"
    )

    parser.add_argument(
        "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
    )
    args = parser.parse_args()
    args.min_complete_in_beam = min(args.min_complete_in_beam, args.beam_width)

    if st._is_running_with_streamlit:
        examples = load_test_data()

        st.header("DataTuner Demo")

        args.nbest = 3
        args.beam_width = 3
        args.per_step_predictions = 2
        args.input = True

        client = mlflow.tracking.MlflowClient()

        args.model_checkpoint = st.sidebar.selectbox(
            "Model", get_finished_models([5]), 0, lambda x: " ".join(list(get_run_info(client, x).values()))
        )

        if False:
            args.reranker = st.sidebar.selectbox(
                "Reranker",
                [f"{DATA_DIR}/distilgpt2/",
                 f"{DATA_DIR}/gpt2/",
                 f"{DATA_DIR}/gpt2-medium/"],
                -1,
                lambda x: " ".join(list(get_run_info(client, x).values())),
            )

        st.write(f"**Main Model**: {get_run_info(client, args.model_checkpoint)}")
        st.write(f"**Auxiliary Model**: {get_run_info(client, args.reranker)}")

    else:
        pass

    model, tokenizer, task_config, learned_fields, input_text_fields, reranker, out_folder, is_local, cons_classifier = setup(
        args
    )
    if args.input:

        def process_input(inst):
            if st._is_running_with_streamlit:
                write = st.write
            else:
                write = print

            new_item, _ = process_one_item(
                inst,
                tokenizer,
                model,
                task_config,
                args,
                input_cache=input_cache,
                avoided_cache=avoided_cache,
                reranker=reranker,
                cons_classifier=cons_classifier,
            )

            for key in learned_fields:
                write("**Answers:**")
                if type(new_item[key]) == list:

                    for x in new_item[key]:
                        write(x)
                    if reranker is not None and not args.no_postranking:
                        write("\n**Answers Reranked from Pretrained Model:**")
                        for x in new_item["reranked_" + key]:
                            write(x)

                else:
                    text_to_print = f'{key}: {new_item[key]}'
                    write(text_to_print)

        input_cache = {}
        avoided_cache = defaultdict(lambda: 0)

        inst = {}
        empty = False
        if st._is_running_with_streamlit:
            args.cons_classifier = "ldc"
            mr_key = dataset_fields[args.cons_classifier]["data"] if args.cons_classifier else "linearized_amr"
            option = st.selectbox("Select an example", examples, 0, lambda x: x[mr_key])

            args.repetition_penalty = st.sidebar.slider("Repetition Penalty", 1.0, 10.0, float(args.repetition_penalty))
            args.cache_lambda = st.sidebar.slider("Cache Lambda", 0.0, 1.0, float(args.cache_lambda))
            args.boost_factor = st.sidebar.slider("boost_factor", 0.0, 3.0, float(args.boost_factor))
            args.cache_theta = st.sidebar.slider("Cache Theta", 0.0, 1.0, float(args.cache_theta))

            args.reranking_mode = st.sidebar.selectbox("Auxiliary Mode", ["average", "max"], 0)
            args.averaging = st.sidebar.selectbox("Averaging Method", ["arithmetic", "geometric", "ewm", "min"], 0)

            args.aux_weight = st.sidebar.slider("Weight of Auxiliary Model", 0.0, 1.0, 0.0)
            args.min_prob = st.sidebar.slider("Min Probability", 0.0, 1.0, float(args.min_prob))
            args.min_token_prob = st.sidebar.slider("min_token_prob", 0.0, 1.0, float(args.min_token_prob))
            args.ewm_alpha = st.sidebar.slider("ewm_alpha", 0.0, 1.0, float(args.ewm_alpha))
            args.prob_window = st.sidebar.slider("prob_window", 0, 100, args.prob_window)

            args.top_k = st.sidebar.slider("top_k", 0, 100, int(args.top_k))
            args.nbest = st.sidebar.slider("nbest", 1, 10, int(args.nbest))
            args.beam_width = st.sidebar.slider("beam_width", 1, 10, int(args.beam_width))

            args.top_p = st.sidebar.slider("top_p", 0.0, 1.0, float(args.top_p))
            args.dec_dropout = st.sidebar.slider("Decoding Dropout", 0.0, 1.0, float(args.dec_dropout))
            args.temperature = st.sidebar.slider("temperature", 0.0, 3.0, float(args.temperature))
            args.per_step_predictions = st.sidebar.slider("per_step_predictions", 1, 5, int(args.per_step_predictions))
            args.no_sample = bool(st.sidebar.slider("no_sample", 0, 1, 1))

            for key in input_text_fields:
                text_input = st.text_area(key, option[key] if type(option[key]) == str else "; ".join(option[key]))
                if not text_input:
                    empty = True
                inst[key] = tokenizer.encode(text_input)

            for key in learned_fields:
                st.write(f"{key}: {option[key] if type(option[key]) == str else option[key][-1]}")

            if not empty:
                process_input(inst)

        else:
            while True:
                inst = {}
                for key in input_text_fields:
                    text_input = input(f"{key}>> ")
                    inst[key] = tokenizer.encode(text_input)

                process_input(inst)

    else:

        infile = Path(args.filename)

        data = get_dataset_from_file(tokenizer, infile, task_config, args.max_data)
        outfilename = f"generated.json"
        metrics_results = defaultdict(list)

        out_filepath = out_folder / outfilename
        metrics_fields = task_config["metrics_fields"] if "metrics_fields" in task_config else []
        output_to_metrics = {}
        for out_entity in task_config["data_shape"]:
            if "metrics" in out_entity:
                output_to_metrics[out_entity["id"]] = out_entity["metrics"]

        def write_output(final=False):

            stats = aggregate_metrics(all_outputs, learned_fields, metrics_fields, output_to_metrics, final=final)
            for key in stats:
                if "total" in stats[key]:
                    logger.info(f"{key}: {stats[key]['total']}")

            (
                    out_folder
                    / f"stats._{infile.stem}_{args.max_data}_{'reranked' if args.reranker else ''}_generated.json"
            ).write_text(json.dumps(stats, indent=2))

            out_filepath.write_text(json.dumps(all_outputs, indent=2))
            logger.info(f"written to {out_filepath}")
            key = learned_fields[-1]
            # Check if first item in beam is equal to original
            not_matching_items = [
                item for item in all_outputs if item["original_" + key] != item[key + (" " * len("original_"))][0]
            ]
            (out_folder / f"non_matching_{outfilename}").write_text(json.dumps(not_matching_items, indent=2))
            return stats

        if not args.ignore_existing and out_filepath.exists():
            all_outputs = json.load(open(out_filepath, "r"))
            skip = len(all_outputs)
            for s in range(skip):
                if "extra_fields" in task_config:
                    for field in task_config["extra_fields"]:
                        all_outputs[s][field] = data[s][field]
        else:
            all_outputs = []
            skip = 0

        logger.info(f"skipping {skip} items that were already analyzed")
        for i, inst in enumerate(tqdm(data)):
            original_inst = custom_deep_copy(inst)
            try:
                if random.random() > args.frac:
                    continue

                if i < skip:
                    continue

                new_item, matching = process_one_item(
                    inst,
                    tokenizer,
                    model,
                    task_config,
                    args,
                    metrics_results=metrics_results,
                    metrics_fields=task_config["metrics_fields"] if "metrics_fields" in task_config else [],
                    avoided_cache=defaultdict(lambda: 0),
                    reranker=reranker,
                    cons_classifier=cons_classifier,
                )

                for key in learned_fields:
                    new_key = key + " " * len("original_")
                    new_item[new_key] = new_item[key]
                    orig_key = "original_" + key
                    orig = new_item[orig_key]
                    del new_item[key]
                    del new_item[orig_key]
                    new_item[orig_key] = orig

                if "extra_fields" in task_config:
                    for field in task_config["extra_fields"]:
                        new_item[field] = inst[field]

                if not matching:
                    logger.info(json.dumps(new_item, indent=2))

                all_outputs.append(new_item)

                if len(all_outputs) % args.log_every == 0:
                    write_output()

            except Exception as e:
                new_item = {}

                for key in learned_fields:

                    orig_key = "original_" + key

                    new_item[orig_key] = original_inst[key]
                    if new_item[orig_key] and type(new_item[orig_key][0]) == list:
                        new_item[orig_key] = new_item[orig_key][-1]
                for key in input_text_fields:
                    new_item[key] = original_inst[key]
                for key in new_item:
                    new_item[key] = tokenizer.decode(new_item[key])

                for key in learned_fields:
                    new_key = key + " " * len("original_")
                    new_item[new_key] = [""]

                all_outputs.append(new_item)
                logger.error(e)
                raise
                import ipdb

                ipdb.set_trace()

        stats = write_output(final=True)

        if not is_local and not args.no_mlflow_logging:
            mlflow.log_artifact(out_folder, "evaluation")
            flattened_stats = flatten(stats)
            flattened_stats = {k: flattened_stats[k] for k in flattened_stats if k.count("-") <= 3}
            mlflow.log_metrics(flattened_stats)