def child()

in scripts/create_partial_metrics.py [0:0]


def child(hparams):
    expt_dir = hparams.expt_dir
    epochs = [int(e) for e in hparams.epochs.split(",")]
    # print("epochs = ", epochs)
    device = torch.device(f"cuda:{hparams.gpu}")
    ckpt_dir = expt_dir + "/" + "checkpoints"
    # ckpt_files = [ckpt_dir + f"/epoch={epoch}.ckpt" for epoch in epochs]
    hparams.logdir = expt_dir

    results = {
        "val_loss": None,
        "val_accuracy": None,
    }

    processed_epochs = []
    # with tqdm(epochs, unit="epochs", initial=epochs[0], total=epochs[-1]) as pbar:
    #    last_epoch = epochs[0]
    for idx, epoch in tqdm(list(enumerate(epochs))):
        # pbar.update(epoch - last_epoch)
        # last_epoch = epoch
        ckpt_files = glob.glob(ckpt_dir + f"/epoch={epoch}-step=*.ckpt")
        ckpt_files += glob.glob(ckpt_dir + f"/epoch={epoch}.ckpt")
        try:
            ckpt_file = ckpt_files[-1]
            ckpt = torch.load(
                ckpt_file,
                map_location=f"cuda:{0}",  # FIXME
            )
            processed_epochs.append(epoch)
        except FileNotFoundError:
            continue

        for k, v in ckpt["hyper_parameters"].items():
            setattr(hparams, k, v)

        new_state_dict = {}
        for k, v in ckpt["state_dict"].items():
            if k.startswith("transformer."):
                new_state_dict[k] = v
            else:
                new_state_dict["transformer." + k] = v
        ckpt["state_dict"] = new_state_dict

        model = trainer.TrainableTransformer(hparams).float()
        model.load_state_dict(ckpt["state_dict"])
        model = model.to(device).eval()
        dl = model.test_dataloader()
        dl.reset_iteration(shuffle=False)

        outputs = [model.test_step(batch, idx) for (idx, batch) in enumerate(dl)]
        r = model.test_epoch_end(outputs)["log"]
        if results["val_loss"] is None:
            results["val_loss"] = r["test_loss"].squeeze().unsqueeze(0)
            results["val_accuracy"] = r["test_accuracy"].squeeze().unsqueeze(0)
        else:
            results["val_loss"] = torch.cat(
                [results["val_loss"], r["test_loss"].squeeze().unsqueeze(0)], dim=0
            )
            results["val_accuracy"] = torch.cat(
                [
                    results["val_accuracy"],
                    r["test_accuracy"].squeeze().unsqueeze(0),
                ],
                dim=0,
            )

    for k, v in results.items():
        results[k] = v.to("cpu")
    results["epochs"] = torch.LongTensor(processed_epochs, device="cpu")
    results["dl"] = dl

    os.makedirs(expt_dir + "/activations", exist_ok=True)
    ptfile = (
        expt_dir + f"/activations/activations_{epochs[0]:010d}_{epochs[-1]:010d}.pt"
    )
    torch.save(results, ptfile)