in scripts/create_partial_metrics.py [0:0]
def child(hparams):
expt_dir = hparams.expt_dir
epochs = [int(e) for e in hparams.epochs.split(",")]
# print("epochs = ", epochs)
device = torch.device(f"cuda:{hparams.gpu}")
ckpt_dir = expt_dir + "/" + "checkpoints"
# ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs]
hparams.logdir = expt_dir
results = {
"val_loss": None,
"val_accuracy": None,
}
processed_epochs = []
# with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar:
# last_epoch = epochs[0]
for idx, epoch in tqdm(list(enumerate(epochs))):
# pbar.update(epoch - last_epoch)
# last_epoch = epoch
ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt")
ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt")
try:
ckpt_file = ckpt_files[-1]
ckpt = torch.load(
ckpt_file,
map_location=f"cuda:{0}", # FIXME
)
processed_epochs.append(epoch)
except FileNotFoundError:
continue
for k, v in ckpt["hyper_parameters"].items():
setattr(hparams, k, v)
new_state_dict = {}
for k, v in ckpt["state_dict"].items():
if k.startswith("transformer."):
new_state_dict[k] = v
else:
new_state_dict["transformer." + k] = v
ckpt["state_dict"] = new_state_dict
model = trainer.TrainableTransformer(hparams).float()
model.load_state_dict(ckpt["state_dict"])
model = model.to(device).eval()
dl = model.test_dataloader()
dl.reset_iteration(shuffle=False)
outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)]
r = model.test_epoch_end(outputs)["log"]
if results["val_loss"] is None:
results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0)
results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0)
else:
results["val_loss"] = torch.cat(
[results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0
)
results["val_accuracy"] = torch.cat(
[
results["val_accuracy"],
r["test_accuracy"].squeeze().unsqueeze(0),
],
dim=0,
)
for k, v in results.items():
results[k] = v.to("cpu")
results["epochs"] = torch.LongTensor(processed_epochs, device="cpu")
results["dl"] = dl
os.makedirs(expt_dir + "/activations", exist_ok=True)
ptfile = (
expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt"
)
torch.save(results, ptfile)