def main()

in run_inference.py [0:0]


def main(config_file_path):
    config = setup(config_file_path)
    md_transformer = MetaDataTransformer(text_index=config.get("TEST", "text_index"),
                                         md_indices=config.get("TEST", "md_indices",
                                                                           fallback=""),
                                         md_transformations=config.get("TEST",
                                                                       "md_transformations",
                                                                       fallback=""))

    tokenizer = get_tokenizer(tokenizer_type=config.get("TOKENIZER", "tokenizer_type"),
                              data_path=config.get("DATA", "train_data_directory_full"),
                              md_transformer=md_transformer,
                              vocab_limit=int(config.get("TOKENIZER", "vocab_limit")),
                              force_new_creation=False)
    tokenizer.add_special_tokens(md_transformer.get_md_tokens())

    # Getting dataloaders
    ppl_dataloader_full = get_dataloader(config, tokenizer, md_transformer, "ppl", "full", config_section="TEST")
    ppl_dataloader_head = get_dataloader(config, tokenizer, md_transformer, "ppl", "head", config_section="TEST")
    ppl_dataloader_tail = get_dataloader(config, tokenizer, md_transformer, "ppl", "tail", config_section="TEST")

    # Setting up model
    model = torch.load(config.get("TEST", "model_path"), map_location=device)
    _, dev_loss_fn = get_loss_fn(config.get("MODEL", "model_type"))
    no_reduction_loss_fn = get_no_reduction_loss_fn(config.get("MODEL", "model_type"))

    #### Evaluation Cycle ###
    model.eval()
    model.to(device)

    loss_full, ppl_full = eval_model(model, ppl_dataloader_full,
                                            dev_loss_fn)
    loss_head, ppl_head = eval_model(model, ppl_dataloader_head,
                                            dev_loss_fn)
    loss_tail, ppl_tail = eval_model(model, ppl_dataloader_tail,
                                            dev_loss_fn)

    logging.info("Full evaluation: ")
    logging.info(f"\t loss: {loss_full} ppl: {ppl_full}")
    logging.info("Head evaluation: ")
    logging.info(f"\t loss: {loss_head} ppl: {ppl_head}")
    logging.info("Tail evaluation: ")
    logging.info(f"\t loss: {loss_tail} ppl: {ppl_tail}")