in grok/visualization.py [0:0]
def get_metric_data(data, limits={}):
my_limits = deepcopy(default_metric_limits)
my_limits.update(limits)
limits = my_limits
for k in limits.keys():
metric = k.replace("min_", "").replace("max_", "")
assert (
limits["max_" + metric] >= limits["min_" + metric]
), f"invalid {metric} limits"
d = {}
for arch in filter_archs(data, limits):
logger.debug(arch)
indices = torch.nonzero(
torch.logical_and(
data[arch]["T"] >= limits["min_T"], data[arch]["T"] <= limits["max_T"]
)
).squeeze(dim=-1)
logger.debug(f"indices={indices}")
learning_rate, train_loss, train_accuracy, val_loss, val_accuracy = data[arch][
"metrics"
][:, indices, :]
d[arch] = {
"T": data[arch]["T"][indices],
"learning_rate": data[arch]["metrics"][0, indices, :],
"train_loss": data[arch]["metrics"][1, indices, :],
"train_accuracy": data[arch]["metrics"][2, indices, :],
"val_loss": data[arch]["metrics"][3, indices, :],
"val_accuracy": data[arch]["metrics"][4, indices, :],
}
return d