in train.py [0:0]
def test(test_dataloader, model, FLAGS, logdir, test=False):
# Load data from a dataloader and then evaluate both the reconstruction MSE and visualize proteins
cum_corrupt = []
cum_decorrupt = []
accs = []
losses = []
rotamer_recovery = []
best_rotamer_recovery = []
angle_width = np.pi / 180 * 1
if FLAGS.train:
neg_sample = 150
else:
neg_sample = FLAGS.neg_sample
if test:
itr = 1
else:
itr = 20
counter = 0
for _ in range(itr):
for node_pos, node_neg, gt_chis, neg_chis, res in tqdm(test_dataloader):
counter += 1
if FLAGS.cuda:
node_pos = node_pos.cuda()
node_neg = node_neg.cuda()
with torch.no_grad():
energy_pos = model.forward(node_pos)
energy_neg = model.forward(node_neg)
energy_neg = energy_neg.view(energy_pos.size(0), neg_sample)
partition_function = -torch.cat([energy_pos, energy_neg], dim=1)
idx = torch.argmax(partition_function, dim=1)
sort_idx = torch.argsort(partition_function, dim=1, descending=True)
log_prob = (-energy_pos) - torch.logsumexp(partition_function, dim=1, keepdim=True)
loss = (-log_prob).mean()
# If the minimum idx is 0 then the ground truth configuration has the lowest energy
acc = idx.eq(0).float().mean()
accs.append(acc.item())
losses.append(loss.item())
node_pos, node_neg = node_pos.cpu().numpy(), node_neg.cpu().numpy()
idx = idx.cpu().detach().numpy()
sort_idx = sort_idx.cpu().detach().numpy()
node_neg = np.reshape(node_neg, (-1, neg_sample, *node_neg.shape[1:]))
for i in range(node_pos.shape[0]):
gt_chi, chi_valid = gt_chis[i][0]
if sort_idx[i, 0] == 0:
neg_chi, chi_valid = neg_chis[i][sort_idx[i, 1] - 1]
else:
neg_chi, chi_valid = neg_chis[i][sort_idx[i, 0] - 1]
neg_chi[neg_chi > 180] = neg_chi[neg_chi > 180] - 360
score, max_dist = compute_rotamer_score_planar(gt_chi, neg_chi, chi_valid, res[i])
rotamer_recovery.append(score.all())
score = 0
min_distance = float("inf")
print_dist = None
if not test:
for j, (neg_chi, chi_valid) in enumerate(neg_chis[i]):
temp_score, max_dist = compute_rotamer_score_planar(
gt_chi, neg_chi, chi_valid, res[i]
)
score = max(score, temp_score)
if max(max_dist) < min_distance:
min_distance = max(max_dist)
print_dist = max_dist
else:
score = 1
best_rotamer_recovery.append(score)
if counter > 20 and (not test):
# Return preliminary scores of rotamer recovery
print(
"Mean cumulative accuracy of: ",
np.mean(accs),
np.std(accs) / np.sqrt(len(accs)),
)
print("Mean losses of: ", np.mean(losses), np.std(losses) / np.sqrt(len(losses)))
print(
"Rotamer recovery ",
np.mean(rotamer_recovery),
np.std(rotamer_recovery) / np.sqrt(len(rotamer_recovery)),
)
print(
"Best Rotamer recovery ",
np.mean(best_rotamer_recovery),
np.std(best_rotamer_recovery) / np.sqrt(len(best_rotamer_recovery)),
)
break
return np.mean(accs), np.mean(losses), np.mean(rotamer_recovery)