in grok/visualization.py [0:0]
def load_metric_data(data_dir, epochs=100000, load_partial_data=True):
# layers x heads x d_model x train_pct
data = {}
expts = os.listdir(data_dir)
archs = factor_expts(expts)
logger.debug(archs)
for arch in archs:
T = sorted(archs[arch].keys())
data[arch] = {
"T": torch.LongTensor(T),
"metrics": torch.zeros((max(T), 5, epochs)),
}
# print(f"metrics_shape = {data[arch]['metrics'].shape}")
for i, t in tqdm(list(enumerate(T))):
expt = archs[arch][t]
logger.debug(expt)
log_dir = data_dir + "/" + expt
# print("log_dir", log_dir)
try:
with open(log_dir + "/default/version_0/metrics.csv", "r") as fh:
logger.debug(f"loading {log_dir}")
reader = list(csv.DictReader(fh))
val_t = torch.FloatTensor(
[
[
float(r["val_loss"]),
float(r["val_accuracy"]),
]
for r in reader
if r["val_loss"]
]
).T
train_t = torch.FloatTensor(
[
[
float(r["learning_rate"]),
float(r["train_loss"]),
float(r["train_accuracy"]),
]
for r in reader
if r["train_loss"]
]
).T
# logger.debug(val_t.shape)
# logger.debug(train_t[0, -3:])
if load_partial_data:
raise Exception("Not implemented")
elif (val_t.shape[-1] >= epochs) and (train_t.shape[-1] >= epochs):
data[arch]["metrics"][i] = torch.cat(
[train_t[..., :epochs], val_t[..., :epochs]], dim=0
)
else:
data[arch]["T"][i] = 0
# except FileNotFoundError:
except:
data[arch]["T"][i] = 0
indices = torch.nonzero(data[arch]["T"]).squeeze()
if len(indices.shape) == 0:
indices = indices.unsqueeze(0)
# print(f"indices.shape = {indices.shape}")
data[arch]["T"] = data[arch]["T"][indices]
# print(f"data[arch]['T'].shape = {data[arch]['T'].shape}")
data[arch]["metrics"] = data[arch]["metrics"][indices]
# print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
data[arch]["metrics"] = torch.transpose(data[arch]["metrics"], 0, 1)
# print(f"data[arch]['metrics'].shape = {data[arch]['metrics'].shape}")
return data