def load_metric_data()

in grok/visualization.py [0:0]


def load_metric_data(data_dir, epochs=100000, load_partial_data=True):
    # layers x heads x d_model x train_pct
    data = {}
    expts = os.listdir(data_dir)
    archs = factor_expts(expts)
    logger.debug(archs)
    for arch in archs:
        T = sorted(archs[arch].keys())
        data[arch] = {
            "T": torch.LongTensor(T),
            "metrics": torch.zeros((max(T), 5, epochs)),
        }
        # print(f"metrics_shape = {data[arch]['metrics'].shape}")
        for i, t in tqdm(list(enumerate(T))):
            expt = archs[arch][t]
            logger.debug(expt)
            log_dir = data_dir + "/" + expt

            # print("log_dir", log_dir)
            try:
                with open(log_dir + "/default/version_0/metrics.csv", "r") as fh:
                    logger.debug(f"loading {log_dir}")
                    reader = list(csv.DictReader(fh))
                    val_t = torch.FloatTensor(
                        [
                            [
                                float(r["val_loss"]),
                                float(r["val_accuracy"]),
                            ]
                            for r in reader
                            if r["val_loss"]
                        ]
                    ).T
                    train_t = torch.FloatTensor(
                        [
                            [
                                float(r["learning_rate"]),
                                float(r["train_loss"]),
                                float(r["train_accuracy"]),
                            ]
                            for r in reader
                            if r["train_loss"]
                        ]
                    ).T
                    # logger.debug(val_t.shape)
                    # logger.debug(train_t[0, -3:])
                    if load_partial_data:
                        raise Exception("Not implemented")
                    elif (val_t.shape[-1] >= epochs) and (train_t.shape[-1] >= epochs):
                        data[arch]["metrics"][i] = torch.cat(
                            [train_t[..., :epochs], val_t[..., :epochs]], dim=0
                        )
                    else:
                        data[arch]["T"][i] = 0
            # except FileNotFoundError:
            except:
                data[arch]["T"][i] = 0
        indices = torch.nonzero(data[arch]["T"]).squeeze()
        if len(indices.shape) == 0:
            indices = indices.unsqueeze(0)
        # print(f"indices.shape = {indices.shape}")
        data[arch]["T"] = data[arch]["T"][indices]
        # print(f"data[arch]['T'].shape = {data[arch]['T'].shape}")
        data[arch]["metrics"] = data[arch]["metrics"][indices]
        # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
        data[arch]["metrics"] = torch.transpose(data[arch]["metrics"], 0, 1)
        # print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
    return data