def __init__()

in compert/model.py [0:0]


    def __init__(
            self,
            num_genes,
            num_drugs,
            num_cell_types,
            device="cpu",
            seed=0,
            patience=5,
            loss_ae='gauss',
            doser_type='logsigm',
            decoder_activation='linear',
            hparams=""):
        super(ComPert, self).__init__()
        # set generic attributes
        self.num_genes = num_genes
        self.num_drugs = num_drugs
        self.num_cell_types = num_cell_types
        self.device = device
        self.seed = seed
        self.loss_ae = loss_ae
        # early-stopping
        self.patience = patience
        self.best_score = -1e3
        self.patience_trials = 0

        # set hyperparameters
        self.set_hparams_(seed, hparams)

        # set models
        self.encoder = MLP(
            [num_genes] +
            [self.hparams["autoencoder_width"]] *
            self.hparams["autoencoder_depth"] +
            [self.hparams["dim"]])

        self.decoder = MLP(
            [self.hparams["dim"]] +
            [self.hparams["autoencoder_width"]] *
            self.hparams["autoencoder_depth"] +
            [num_genes * 2], last_layer_act=decoder_activation)

        self.adversary_drugs = MLP(
            [self.hparams["dim"]] +
            [self.hparams["adversary_width"]] *
            self.hparams["adversary_depth"] +
            [num_drugs])

        self.adversary_cell_types = MLP(
            [self.hparams["dim"]] +
            [self.hparams["adversary_width"]] *
            self.hparams["adversary_depth"] +
            [num_cell_types])

        # set dosers
        self.doser_type = doser_type
        if doser_type == 'mlp':
            self.dosers = torch.nn.ModuleList()
            for _ in range(num_drugs):
                self.dosers.append(
                    MLP([1] +
                        [self.hparams["dosers_width"]] *
                        self.hparams["dosers_depth"] +
                        [1],
                        batch_norm=False))
        else:
            self.dosers = GeneralizedSigmoid(num_drugs, self.device,
            nonlin=doser_type)

        self.drug_embeddings = torch.nn.Embedding(
            num_drugs, self.hparams["dim"])
        self.cell_type_embeddings = torch.nn.Embedding(
            num_cell_types, self.hparams["dim"])

        # losses
        if self.loss_ae == 'nb':
            self.loss_autoencoder = NBLoss()
        else:
            self.loss_autoencoder = GaussianLoss()
        self.loss_adversary_drugs = torch.nn.BCEWithLogitsLoss()
        self.loss_adversary_cell_types = torch.nn.CrossEntropyLoss()
        self.iteration = 0

        self.to(self.device)

        # optimizers
        self.optimizer_autoencoder = torch.optim.Adam(
            list(self.encoder.parameters()) +
            list(self.decoder.parameters()) +
            list(self.drug_embeddings.parameters()) +
            list(self.cell_type_embeddings.parameters()),
            lr=self.hparams["autoencoder_lr"],
            weight_decay=self.hparams["autoencoder_wd"])

        self.optimizer_adversaries = torch.optim.Adam(
            list(self.adversary_drugs.parameters()) +
            list(self.adversary_cell_types.parameters()),
            lr=self.hparams["adversary_lr"],
            weight_decay=self.hparams["adversary_wd"])

        self.optimizer_dosers = torch.optim.Adam(
            self.dosers.parameters(),
            lr=self.hparams["dosers_lr"],
            weight_decay=self.hparams["dosers_wd"])

        # learning rate schedulers
        self.scheduler_autoencoder = torch.optim.lr_scheduler.StepLR(
            self.optimizer_autoencoder, step_size=self.hparams["step_size_lr"])

        self.scheduler_adversary = torch.optim.lr_scheduler.StepLR(
            self.optimizer_adversaries, step_size=self.hparams["step_size_lr"])

        self.scheduler_dosers = torch.optim.lr_scheduler.StepLR(
            self.optimizer_dosers, step_size=self.hparams["step_size_lr"])

        self.history = {'epoch': [], 'stats_epoch': []}