in grok/visualization.py [0:0]
def most_interesting(metric_data):
interesting_metric_data = {}
for arch in metric_data:
T = metric_data[arch]["T"]
max_acc_by_t = torch.max(
metric_data[arch]["val_accuracy"], dim=1, keepdim=True
).values.squeeze()
max_loss_by_t = torch.max(
metric_data[arch]["val_loss"], dim=1, keepdim=True
).values.squeeze()
acc_idx = torch.nonzero(max_acc_by_t >= 95).squeeze()
if acc_idx.shape == torch.Size([0]):
acc_idx = torch.nonzero(max_acc_by_t == max_acc_by_t.max()).squeeze()
if acc_idx.shape == torch.Size([]):
acc_idx = acc_idx.unsqueeze(0)
max_loss = torch.max(max_loss_by_t[acc_idx])
loss_idx = torch.nonzero(max_loss_by_t[acc_idx] == max_loss)
interesting_idx = acc_idx[loss_idx].squeeze()
interesting_metric_data[arch] = {}
for k in metric_data[arch]:
interesting_metric_data[arch][k] = metric_data[arch][k][
interesting_idx
].unsqueeze(0)
return interesting_metric_data