def most_interesting()

in grok/visualization.py [0:0]


def most_interesting(metric_data):
    interesting_metric_data = {}
    for arch in metric_data:
        T = metric_data[arch]["T"]
        max_acc_by_t = torch.max(
            metric_data[arch]["val_accuracy"], dim=1, keepdim=True
        ).values.squeeze()
        max_loss_by_t = torch.max(
            metric_data[arch]["val_loss"], dim=1, keepdim=True
        ).values.squeeze()
        acc_idx = torch.nonzero(max_acc_by_t >= 95).squeeze()
        if acc_idx.shape == torch.Size([0]):
            acc_idx = torch.nonzero(max_acc_by_t == max_acc_by_t.max()).squeeze()
        if acc_idx.shape == torch.Size([]):
            acc_idx = acc_idx.unsqueeze(0)
        max_loss = torch.max(max_loss_by_t[acc_idx])
        loss_idx = torch.nonzero(max_loss_by_t[acc_idx] == max_loss)
        interesting_idx = acc_idx[loss_idx].squeeze()

        interesting_metric_data[arch] = {}
        for k in metric_data[arch]:
            interesting_metric_data[arch][k] = metric_data[arch][k][
                interesting_idx
            ].unsqueeze(0)

        return interesting_metric_data