scripts/train_mhop.py [225:254]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                best_mrr = mrr

        logger.info("Training finished!")

    elif args.do_predict:
        acc = predict(args, model, eval_dataloader, device, logger)
        logger.info(f"test performance {acc}")

def predict(args, model, eval_dataloader, device, logger):
    model.eval()
    rrs_1, rrs_2 = [], [] # reciprocal rank
    for batch in tqdm(eval_dataloader):
        batch_to_feed = move_to_cuda(batch)
        with torch.no_grad():
            outputs = model(batch_to_feed)
            eval_results = mhop_eval(outputs, args)
            _rrs_1, _rrs_2 = eval_results["rrs_1"], eval_results["rrs_2"]
            rrs_1 += _rrs_1
            rrs_2 += _rrs_2
    mrr_1 = np.mean(rrs_1)
    mrr_2 = np.mean(rrs_2)
    logger.info(f"evaluated {len(rrs_1)} examples...")
    logger.info(f'MRR-1: {mrr_1}')
    logger.info(f'MRR-2: {mrr_2}')
    model.train()
    return {"mrr_1": mrr_1, "mrr_2": mrr_2, "mrr_avg": (mrr_1 + mrr_2) / 2}


if __name__ == "__main__":
    main()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



scripts/train_momentum.py [206:235]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                best_mrr = mrr

        logger.info("Training finished!")

    elif args.do_predict:
        acc = predict(args, model, eval_dataloader, device, logger)
        logger.info(f"test performance {acc}")

def predict(args, model, eval_dataloader, device, logger):
    model.eval()
    rrs_1, rrs_2 = [], [] # reciprocal rank
    for batch in tqdm(eval_dataloader):
        batch_to_feed = move_to_cuda(batch)
        with torch.no_grad():
            outputs = model(batch_to_feed)
            eval_results = mhop_eval(outputs, args)
            _rrs_1, _rrs_2 = eval_results["rrs_1"], eval_results["rrs_2"]
            rrs_1 += _rrs_1
            rrs_2 += _rrs_2
    mrr_1 = np.mean(rrs_1)
    mrr_2 = np.mean(rrs_2)
    logger.info(f"evaluated {len(rrs_1)} examples...")
    logger.info(f'MRR-1: {mrr_1}')
    logger.info(f'MRR-2: {mrr_2}')
    model.train()
    return {"mrr_1": mrr_1, "mrr_2": mrr_2, "mrr_avg": (mrr_1 + mrr_2) / 2}


if __name__ == "__main__":
    main()
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



