def train_compert()

in compert/train.py [0:0]


def train_compert(args, return_model=False):
    """
    Trains a ComPert autoencoder
    """

    autoencoder, datasets = prepare_compert(args)

    datasets.update({
        "loader_tr": torch.utils.data.DataLoader(
                        datasets["training"],
                        batch_size=autoencoder.hparams["batch_size"],
                        shuffle=True)
    })

    pjson({"training_args": args})
    pjson({"autoencoder_params": autoencoder.hparams})

    start_time = time.time()
    for epoch in range(args["max_epochs"]):
        epoch_training_stats = defaultdict(float)

        for genes, drugs, cell_types in datasets["loader_tr"]:
            minibatch_training_stats = autoencoder.update(
                genes, drugs, cell_types)

            for key, val in minibatch_training_stats.items():
                epoch_training_stats[key] += val

        for key, val in epoch_training_stats.items():
            epoch_training_stats[key] = val / len(datasets["loader_tr"])
            if not (key in autoencoder.history.keys()):
                autoencoder.history[key] = []
            autoencoder.history[key].append(val)
        autoencoder.history['epoch'].append(epoch)

        ellapsed_minutes = (time.time() - start_time) / 60
        autoencoder.history['elapsed_time_min'] = ellapsed_minutes

        # decay learning rate if necessary
        # also check stopping condition: patience ran out OR
        # time ran out OR max epochs achieved
        stop = ellapsed_minutes > args["max_minutes"] or \
            (epoch == args["max_epochs"] - 1)

        if (epoch % args["checkpoint_freq"]) == 0 or stop:
            evaluation_stats = evaluate(autoencoder, datasets)
            for key, val in evaluation_stats.items():
                if not (key in autoencoder.history.keys()):
                    autoencoder.history[key] = []
                autoencoder.history[key].append(val)
            autoencoder.history['stats_epoch'].append(epoch)

            pjson({
                "epoch": epoch,
                "training_stats": epoch_training_stats,
                "evaluation_stats": evaluation_stats,
                "ellapsed_minutes": ellapsed_minutes
            })

            torch.save(
                (autoencoder.state_dict(), args, autoencoder.history),
                os.path.join(
                    args["save_dir"],
                    "model_seed={}_epoch={}.pt".format(args["seed"], epoch)))

            pjson({"model_saved": "model_seed={}_epoch={}.pt\n".format(
                args["seed"], epoch)})
            stop = stop or autoencoder.early_stopping(
                np.mean(evaluation_stats["test"]))
            if stop:
                pjson({"early_stop": epoch})
                break

    if return_model:
        return autoencoder, datasets