in run_inference.py [0:0]
def main(config_file_path):
config = setup(config_file_path)
md_transformer = MetaDataTransformer(text_index=config.get("TEST", "text_index"),
md_indices=config.get("TEST", "md_indices",
fallback=""),
md_transformations=config.get("TEST",
"md_transformations",
fallback=""))
tokenizer = get_tokenizer(tokenizer_type=config.get("TOKENIZER", "tokenizer_type"),
data_path=config.get("DATA", "train_data_directory_full"),
md_transformer=md_transformer,
vocab_limit=int(config.get("TOKENIZER", "vocab_limit")),
force_new_creation=False)
tokenizer.add_special_tokens(md_transformer.get_md_tokens())
# Getting dataloaders
ppl_dataloader_full = get_dataloader(config, tokenizer, md_transformer, "ppl", "full", config_section="TEST")
ppl_dataloader_head = get_dataloader(config, tokenizer, md_transformer, "ppl", "head", config_section="TEST")
ppl_dataloader_tail = get_dataloader(config, tokenizer, md_transformer, "ppl", "tail", config_section="TEST")
# Setting up model
model = torch.load(config.get("TEST", "model_path"), map_location=device)
_, dev_loss_fn = get_loss_fn(config.get("MODEL", "model_type"))
no_reduction_loss_fn = get_no_reduction_loss_fn(config.get("MODEL", "model_type"))
#### Evaluation Cycle ###
model.eval()
model.to(device)
loss_full, ppl_full = eval_model(model, ppl_dataloader_full,
dev_loss_fn)
loss_head, ppl_head = eval_model(model, ppl_dataloader_head,
dev_loss_fn)
loss_tail, ppl_tail = eval_model(model, ppl_dataloader_tail,
dev_loss_fn)
logging.info("Full evaluation: ")
logging.info(f"\t loss: {loss_full} ppl: {ppl_full}")
logging.info("Head evaluation: ")
logging.info(f"\t loss: {loss_head} ppl: {ppl_head}")
logging.info("Tail evaluation: ")
logging.info(f"\t loss: {loss_tail} ppl: {ppl_tail}")