def update()

in compert/model.py [0:0]


    def update(self, genes, drugs, cell_types):
        """
        Update ComPert's parameters given a minibatch of genes, drugs, and
        cell types.
        """

        genes, drugs, cell_types = self.move_inputs_(genes, drugs, cell_types)

        gene_reconstructions, latent_basal = self.predict(
            genes, drugs, cell_types, return_latent_basal=True)

        reconstruction_loss = self.loss_autoencoder(
            gene_reconstructions, genes)

        adversary_drugs_predictions = self.adversary_drugs(
            latent_basal)
        adversary_drugs_loss = self.loss_adversary_drugs(
            adversary_drugs_predictions, drugs.gt(0).float())

        adversary_cell_types_predictions = self.adversary_cell_types(
            latent_basal)
        adversary_cell_types_loss = self.loss_adversary_cell_types(
            adversary_cell_types_predictions, cell_types.argmax(1))

        # two place-holders for when adversary is not executed
        adversary_drugs_penalty = torch.Tensor([0])
        adversary_cell_types_penalty = torch.Tensor([0])

        if self.iteration % self.hparams["adversary_steps"]:
            adversary_drugs_penalty = torch.autograd.grad(
                adversary_drugs_predictions.sum(),
                latent_basal,
                create_graph=True)[0].pow(2).mean()

            adversary_cell_types_penalty = torch.autograd.grad(
                adversary_cell_types_predictions.sum(),
                latent_basal,
                create_graph=True)[0].pow(2).mean()

            self.optimizer_adversaries.zero_grad()
            (adversary_drugs_loss +
             self.hparams["penalty_adversary"] *
             adversary_drugs_penalty +
             adversary_cell_types_loss +
             self.hparams["penalty_adversary"] *
             adversary_cell_types_penalty).backward()
            self.optimizer_adversaries.step()
        else:
            self.optimizer_autoencoder.zero_grad()
            self.optimizer_dosers.zero_grad()
            (reconstruction_loss -
             self.hparams["reg_adversary"] *
             adversary_drugs_loss -
             self.hparams["reg_adversary"] *
             adversary_cell_types_loss).backward()
            self.optimizer_autoencoder.step()
            self.optimizer_dosers.step()

        self.iteration += 1

        return {
            "loss_reconstruction": reconstruction_loss.item(),
            "loss_adv_drugs": adversary_drugs_loss.item(),
            "loss_adv_cell_types": adversary_cell_types_loss.item(),
            "penalty_adv_drugs": adversary_drugs_penalty.item(),
            "penalty_adv_cell_types": adversary_cell_types_penalty.item()
        }