in compert/model.py [0:0]
def __init__(
self,
num_genes,
num_drugs,
num_cell_types,
device="cpu",
seed=0,
patience=5,
loss_ae='gauss',
doser_type='logsigm',
decoder_activation='linear',
hparams=""):
super(ComPert, self).__init__()
# set generic attributes
self.num_genes = num_genes
self.num_drugs = num_drugs
self.num_cell_types = num_cell_types
self.device = device
self.seed = seed
self.loss_ae = loss_ae
# early-stopping
self.patience = patience
self.best_score = -1e3
self.patience_trials = 0
# set hyperparameters
self.set_hparams_(seed, hparams)
# set models
self.encoder = MLP(
[num_genes] +
[self.hparams["autoencoder_width"]] *
self.hparams["autoencoder_depth"] +
[self.hparams["dim"]])
self.decoder = MLP(
[self.hparams["dim"]] +
[self.hparams["autoencoder_width"]] *
self.hparams["autoencoder_depth"] +
[num_genes * 2], last_layer_act=decoder_activation)
self.adversary_drugs = MLP(
[self.hparams["dim"]] +
[self.hparams["adversary_width"]] *
self.hparams["adversary_depth"] +
[num_drugs])
self.adversary_cell_types = MLP(
[self.hparams["dim"]] +
[self.hparams["adversary_width"]] *
self.hparams["adversary_depth"] +
[num_cell_types])
# set dosers
self.doser_type = doser_type
if doser_type == 'mlp':
self.dosers = torch.nn.ModuleList()
for _ in range(num_drugs):
self.dosers.append(
MLP([1] +
[self.hparams["dosers_width"]] *
self.hparams["dosers_depth"] +
[1],
batch_norm=False))
else:
self.dosers = GeneralizedSigmoid(num_drugs, self.device,
nonlin=doser_type)
self.drug_embeddings = torch.nn.Embedding(
num_drugs, self.hparams["dim"])
self.cell_type_embeddings = torch.nn.Embedding(
num_cell_types, self.hparams["dim"])
# losses
if self.loss_ae == 'nb':
self.loss_autoencoder = NBLoss()
else:
self.loss_autoencoder = GaussianLoss()
self.loss_adversary_drugs = torch.nn.BCEWithLogitsLoss()
self.loss_adversary_cell_types = torch.nn.CrossEntropyLoss()
self.iteration = 0
self.to(self.device)
# optimizers
self.optimizer_autoencoder = torch.optim.Adam(
list(self.encoder.parameters()) +
list(self.decoder.parameters()) +
list(self.drug_embeddings.parameters()) +
list(self.cell_type_embeddings.parameters()),
lr=self.hparams["autoencoder_lr"],
weight_decay=self.hparams["autoencoder_wd"])
self.optimizer_adversaries = torch.optim.Adam(
list(self.adversary_drugs.parameters()) +
list(self.adversary_cell_types.parameters()),
lr=self.hparams["adversary_lr"],
weight_decay=self.hparams["adversary_wd"])
self.optimizer_dosers = torch.optim.Adam(
self.dosers.parameters(),
lr=self.hparams["dosers_lr"],
weight_decay=self.hparams["dosers_wd"])
# learning rate schedulers
self.scheduler_autoencoder = torch.optim.lr_scheduler.StepLR(
self.optimizer_autoencoder, step_size=self.hparams["step_size_lr"])
self.scheduler_adversary = torch.optim.lr_scheduler.StepLR(
self.optimizer_adversaries, step_size=self.hparams["step_size_lr"])
self.scheduler_dosers = torch.optim.lr_scheduler.StepLR(
self.optimizer_dosers, step_size=self.hparams["step_size_lr"])
self.history = {'epoch': [], 'stats_epoch': []}