def get_metric_data()

in grok/visualization.py [0:0]


def get_metric_data(data, limits={}):
    my_limits = deepcopy(default_metric_limits)
    my_limits.update(limits)
    limits = my_limits

    for k in limits.keys():
        metric = k.replace("min_", "").replace("max_", "")
        assert (
            limits["max_" + metric] >= limits["min_" + metric]
        ), f"invalid {metric} limits"

    d = {}
    for arch in filter_archs(data, limits):
        logger.debug(arch)
        indices = torch.nonzero(
            torch.logical_and(
                data[arch]["T"] >= limits["min_T"], data[arch]["T"] <= limits["max_T"]
            )
        ).squeeze(dim=-1)
        logger.debug(f"indices={indices}")
        learning_rate, train_loss, train_accuracy, val_loss, val_accuracy = data[arch][
            "metrics"
        ][:, indices, :]
        d[arch] = {
            "T": data[arch]["T"][indices],
            "learning_rate": data[arch]["metrics"][0, indices, :],
            "train_loss": data[arch]["metrics"][1, indices, :],
            "train_accuracy": data[arch]["metrics"][2, indices, :],
            "val_loss": data[arch]["metrics"][3, indices, :],
            "val_accuracy": data[arch]["metrics"][4, indices, :],
        }
    return d