in compert/train.py [0:0]
def train_compert(args, return_model=False):
"""
Trains a ComPert autoencoder
"""
autoencoder, datasets = prepare_compert(args)
datasets.update({
"loader_tr": torch.utils.data.DataLoader(
datasets["training"],
batch_size=autoencoder.hparams["batch_size"],
shuffle=True)
})
pjson({"training_args": args})
pjson({"autoencoder_params": autoencoder.hparams})
start_time = time.time()
for epoch in range(args["max_epochs"]):
epoch_training_stats = defaultdict(float)
for genes, drugs, cell_types in datasets["loader_tr"]:
minibatch_training_stats = autoencoder.update(
genes, drugs, cell_types)
for key, val in minibatch_training_stats.items():
epoch_training_stats[key] += val
for key, val in epoch_training_stats.items():
epoch_training_stats[key] = val / len(datasets["loader_tr"])
if not (key in autoencoder.history.keys()):
autoencoder.history[key] = []
autoencoder.history[key].append(val)
autoencoder.history['epoch'].append(epoch)
ellapsed_minutes = (time.time() - start_time) / 60
autoencoder.history['elapsed_time_min'] = ellapsed_minutes
# decay learning rate if necessary
# also check stopping condition: patience ran out OR
# time ran out OR max epochs achieved
stop = ellapsed_minutes > args["max_minutes"] or \
(epoch == args["max_epochs"] - 1)
if (epoch % args["checkpoint_freq"]) == 0 or stop:
evaluation_stats = evaluate(autoencoder, datasets)
for key, val in evaluation_stats.items():
if not (key in autoencoder.history.keys()):
autoencoder.history[key] = []
autoencoder.history[key].append(val)
autoencoder.history['stats_epoch'].append(epoch)
pjson({
"epoch": epoch,
"training_stats": epoch_training_stats,
"evaluation_stats": evaluation_stats,
"ellapsed_minutes": ellapsed_minutes
})
torch.save(
(autoencoder.state_dict(), args, autoencoder.history),
os.path.join(
args["save_dir"],
"model_seed={}_epoch={}.pt".format(args["seed"], epoch)))
pjson({"model_saved": "model_seed={}_epoch={}.pt\n".format(
args["seed"], epoch)})
stop = stop or autoencoder.early_stopping(
np.mean(evaluation_stats["test"]))
if stop:
pjson({"early_stop": epoch})
break
if return_model:
return autoencoder, datasets