def test()

in train.py [0:0]


def test(test_dataloader, model, FLAGS, logdir, test=False):
    # Load data from a dataloader and then evaluate both the reconstruction MSE and visualize proteins
    cum_corrupt = []
    cum_decorrupt = []
    accs = []
    losses = []
    rotamer_recovery = []
    best_rotamer_recovery = []

    angle_width = np.pi / 180 * 1

    if FLAGS.train:
        neg_sample = 150
    else:
        neg_sample = FLAGS.neg_sample

    if test:
        itr = 1
    else:
        itr = 20

    counter = 0
    for _ in range(itr):
        for node_pos, node_neg, gt_chis, neg_chis, res in tqdm(test_dataloader):
            counter += 1

            if FLAGS.cuda:
                node_pos = node_pos.cuda()
                node_neg = node_neg.cuda()

            with torch.no_grad():
                energy_pos = model.forward(node_pos)
                energy_neg = model.forward(node_neg)
                energy_neg = energy_neg.view(energy_pos.size(0), neg_sample)
                partition_function = -torch.cat([energy_pos, energy_neg], dim=1)
                idx = torch.argmax(partition_function, dim=1)
                sort_idx = torch.argsort(partition_function, dim=1, descending=True)

                log_prob = (-energy_pos) - torch.logsumexp(partition_function, dim=1, keepdim=True)
                loss = (-log_prob).mean()

            # If the minimum idx is 0 then the ground truth configuration has the lowest energy
            acc = idx.eq(0).float().mean()
            accs.append(acc.item())
            losses.append(loss.item())

            node_pos, node_neg = node_pos.cpu().numpy(), node_neg.cpu().numpy()

            idx = idx.cpu().detach().numpy()
            sort_idx = sort_idx.cpu().detach().numpy()
            node_neg = np.reshape(node_neg, (-1, neg_sample, *node_neg.shape[1:]))

            for i in range(node_pos.shape[0]):
                gt_chi, chi_valid = gt_chis[i][0]

                if sort_idx[i, 0] == 0:
                    neg_chi, chi_valid = neg_chis[i][sort_idx[i, 1] - 1]

                else:
                    neg_chi, chi_valid = neg_chis[i][sort_idx[i, 0] - 1]

                neg_chi[neg_chi > 180] = neg_chi[neg_chi > 180] - 360
                score, max_dist = compute_rotamer_score_planar(gt_chi, neg_chi, chi_valid, res[i])
                rotamer_recovery.append(score.all())

                score = 0
                min_distance = float("inf")
                print_dist = None
                if not test:
                    for j, (neg_chi, chi_valid) in enumerate(neg_chis[i]):
                        temp_score, max_dist = compute_rotamer_score_planar(
                            gt_chi, neg_chi, chi_valid, res[i]
                        )
                        score = max(score, temp_score)

                        if max(max_dist) < min_distance:
                            min_distance = max(max_dist)
                            print_dist = max_dist
                else:
                    score = 1

                best_rotamer_recovery.append(score)


            if counter > 20 and (not test):
                # Return preliminary scores of rotamer recovery
                print(
                    "Mean cumulative accuracy of: ",
                    np.mean(accs),
                    np.std(accs) / np.sqrt(len(accs)),
                )
                print("Mean losses of: ", np.mean(losses), np.std(losses) / np.sqrt(len(losses)))
                print(
                    "Rotamer recovery ",
                    np.mean(rotamer_recovery),
                    np.std(rotamer_recovery) / np.sqrt(len(rotamer_recovery)),
                )
                print(
                    "Best Rotamer recovery ",
                    np.mean(best_rotamer_recovery),
                    np.std(best_rotamer_recovery) / np.sqrt(len(best_rotamer_recovery)),
                )
                break

    return np.mean(accs), np.mean(losses), np.mean(rotamer_recovery)