def test_tbsm()

in tbsm_pytorch.py [0:0]


def test_tbsm(args, use_gpu):
    # prepare data
    test_ld, N_test = tp.make_tbsm_data_and_loader(args, "test")

    # setup initial values
    z_test = np.zeros((N_test, ), dtype=np.float)
    t_test = np.zeros((N_test, ), dtype=np.float)

    # check saved model exists
    if not path.exists(args.save_model):
        sys.exit("Can't find saved model. Exiting...")

    # create or load TBSM
    tbsm, device = get_tbsm(args, use_gpu)
    print(args.save_model)

    # main eval loop
    # NOTE: call to tbsm.eval() not needed here, see
    # https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615
    offset = 0
    for _, (X, lS_o, lS_i, T) in enumerate(test_ld):

        batchSize = X[0].shape[0]

        Z = tbsm(*data_wrap(X,
            lS_o,
            lS_i,
            use_gpu,
            device
        ))

        z_test[offset: offset + batchSize] = np.squeeze(Z.detach().cpu().numpy(),
        axis=1)
        t_test[offset: offset + batchSize] = np.squeeze(T.detach().cpu().numpy(),
        axis=1)
        offset += batchSize

    if args.quality_metric == "auc":
        # compute AUC metric
        auc_score = 100.0 * roc_auc_score(t_test.astype(int), z_test)
        print("auc score: ", auc_score)
    else:
        sys.exit("Metric not supported.")