in src/main.py [0:0]
def main(args: dict) -> None:
args = setup(args, seed=config.seed)
uri = "mongodb://{}:{}@{}:{}/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false".format(
args["db_username"],
args["db_password"],
args["db_host"],
args["db_port"],
)
datasets = [
ProteinDataset(
[
{"$match": match_by_split(split)},
{"$project": config.projection},
],
db_uri=uri,
db_name="proteins",
collection_name="proteins",
k=args["knn"],
)
for split in ("train", "valid", "test")
]
train_loader = data.DataLoader(
BufferedShuffleDataset(datasets[0], buffer_size=config.buffer_size),
batch_size=args["batch_size"],
collate_fn=collate_protein_graphs,
num_workers=config.num_workers,
)
valid_loader = data.DataLoader(
datasets[1],
batch_size=args["batch_size"],
collate_fn=collate_protein_graphs,
)
test_loader = data.DataLoader(
datasets[2],
batch_size=args["batch_size"],
collate_fn=collate_protein_graphs,
)
# Create the model with given dimensions
dim_nfeats = len(d1_to_index)
model = GCN(dim_nfeats, config.h_feats, config.n_classes).to(
args["device"]
)
optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])
stopper = EarlyStopper(args["patience"])
for epoch in range(args["n_epochs"]):
# Train
run_a_train_epoch(args, epoch, model, train_loader, optimizer)
# Validation and early stop
val_score = run_an_eval_epoch(args, model, valid_loader)
early_stop = stopper.step(val_score, model)
print(
"epoch {:d}/{:d}, validation roc-auc {:.4f}, ".format(
epoch + 1, args["n_epochs"], val_score
)
)
print("best validation roc-auc {:.4f}".format(stopper.best_score))
if early_stop:
break
stopper.load_checkpoint(model)
test_score = run_an_eval_epoch(args, model, test_loader)
print("Best validation score {:.4f}".format(stopper.best_score))
print("Test score {:.4f}".format(test_score))