def setup()

in src/datatuner/lm/evaluate.py [0:0]


def setup(args):
    """Setup the models and tokenizer and return them"""
    out_folder = None

    assert args.model_checkpoint
    model_directory, is_local = get_model_directory(args.model_checkpoint)
    if not args.input:

        if not args.out_folder:

            out_folder = Path(f"eval_results/{get_curr_time()}")

        else:
            out_folder = Path(args.out_folder)

        out_folder.mkdir(parents=True, exist_ok=True)

        EVAL_ARGS_FILE = out_folder / "eval_args.json"
        json.dump(vars(args), open(EVAL_ARGS_FILE, "w"), indent=2)

    if not is_local:
        mlflow.start_run(args.model_checkpoint, nested=True)

    random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    model, tokenizer = load_pretrained(model_directory, model_type=args.model_type, smoothing=args.smoothing)

    model.to(args.device)
    model.eval()

    task_config = load_task_config(args.task_config or (model_directory / "task_config.json"))
    learned_fields = [x["id"] for x in task_config["data_shape"] if x["learn"] is True and x["type"] != "special"]
    input_text_fields = [x["id"] for x in task_config["data_shape"] if x["learn"] is False and x["type"] == "text"]

    if args.reranker is not None:
        reranker_model_directory, is_local = get_model_directory(args.reranker)
        reranker = Reranker(reranker_model_directory, args.device, is_local=is_local)
    else:
        reranker = None

    if args.cons_classifier:
        dataset = args.cons_classifier
        cons_classifier = ConsistencyClassifier(
            {
                "model_name_or_path": f"{PACKAGE_LOCATION}/{dataset}_consistency_roberta-large_lower",
                "model_type": "roberta",
                "model_name": "roberta-large",
                "task_name": "mnli",
                "data_dir": f"{PACKAGE_LOCATION}/{dataset}_consistency_roberta-large_lower",
                "output_dir": "tmp",
                "no_cuda": True,
                "overwrite_cache": True,
                "do_lower_case": True,
            }
        )
    else:
        cons_classifier = None

    return (
        model,
        tokenizer,
        task_config,
        learned_fields,
        input_text_fields,
        reranker,
        out_folder,
        is_local,
        cons_classifier,
    )