in compert/model.py [0:0]
def update(self, genes, drugs, cell_types):
"""
Update ComPert's parameters given a minibatch of genes, drugs, and
cell types.
"""
genes, drugs, cell_types = self.move_inputs_(genes, drugs, cell_types)
gene_reconstructions, latent_basal = self.predict(
genes, drugs, cell_types, return_latent_basal=True)
reconstruction_loss = self.loss_autoencoder(
gene_reconstructions, genes)
adversary_drugs_predictions = self.adversary_drugs(
latent_basal)
adversary_drugs_loss = self.loss_adversary_drugs(
adversary_drugs_predictions, drugs.gt(0).float())
adversary_cell_types_predictions = self.adversary_cell_types(
latent_basal)
adversary_cell_types_loss = self.loss_adversary_cell_types(
adversary_cell_types_predictions, cell_types.argmax(1))
# two place-holders for when adversary is not executed
adversary_drugs_penalty = torch.Tensor([0])
adversary_cell_types_penalty = torch.Tensor([0])
if self.iteration % self.hparams["adversary_steps"]:
adversary_drugs_penalty = torch.autograd.grad(
adversary_drugs_predictions.sum(),
latent_basal,
create_graph=True)[0].pow(2).mean()
adversary_cell_types_penalty = torch.autograd.grad(
adversary_cell_types_predictions.sum(),
latent_basal,
create_graph=True)[0].pow(2).mean()
self.optimizer_adversaries.zero_grad()
(adversary_drugs_loss +
self.hparams["penalty_adversary"] *
adversary_drugs_penalty +
adversary_cell_types_loss +
self.hparams["penalty_adversary"] *
adversary_cell_types_penalty).backward()
self.optimizer_adversaries.step()
else:
self.optimizer_autoencoder.zero_grad()
self.optimizer_dosers.zero_grad()
(reconstruction_loss -
self.hparams["reg_adversary"] *
adversary_drugs_loss -
self.hparams["reg_adversary"] *
adversary_cell_types_loss).backward()
self.optimizer_autoencoder.step()
self.optimizer_dosers.step()
self.iteration += 1
return {
"loss_reconstruction": reconstruction_loss.item(),
"loss_adv_drugs": adversary_drugs_loss.item(),
"loss_adv_cell_types": adversary_cell_types_loss.item(),
"penalty_adv_drugs": adversary_drugs_penalty.item(),
"penalty_adv_cell_types": adversary_cell_types_penalty.item()
}