def main()

in src/main.py [0:0]


def main(args: dict) -> None:
    args = setup(args, seed=config.seed)
    uri = "mongodb://{}:{}@{}:{}/?tls=true&tlsCAFile=rds-combined-ca-bundle.pem&replicaSet=rs0&readPreference=secondaryPreferred&retryWrites=false".format(
        args["db_username"],
        args["db_password"],
        args["db_host"],
        args["db_port"],
    )

    datasets = [
        ProteinDataset(
            [
                {"$match": match_by_split(split)},
                {"$project": config.projection},
            ],
            db_uri=uri,
            db_name="proteins",
            collection_name="proteins",
            k=args["knn"],
        )
        for split in ("train", "valid", "test")
    ]
    train_loader = data.DataLoader(
        BufferedShuffleDataset(datasets[0], buffer_size=config.buffer_size),
        batch_size=args["batch_size"],
        collate_fn=collate_protein_graphs,
        num_workers=config.num_workers,
    )

    valid_loader = data.DataLoader(
        datasets[1],
        batch_size=args["batch_size"],
        collate_fn=collate_protein_graphs,
    )
    test_loader = data.DataLoader(
        datasets[2],
        batch_size=args["batch_size"],
        collate_fn=collate_protein_graphs,
    )

    # Create the model with given dimensions
    dim_nfeats = len(d1_to_index)
    model = GCN(dim_nfeats, config.h_feats, config.n_classes).to(
        args["device"]
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=args["lr"])
    stopper = EarlyStopper(args["patience"])

    for epoch in range(args["n_epochs"]):
        # Train
        run_a_train_epoch(args, epoch, model, train_loader, optimizer)

        # Validation and early stop
        val_score = run_an_eval_epoch(args, model, valid_loader)
        early_stop = stopper.step(val_score, model)
        print(
            "epoch {:d}/{:d}, validation roc-auc {:.4f}, ".format(
                epoch + 1, args["n_epochs"], val_score
            )
        )
        print("best validation roc-auc {:.4f}".format(stopper.best_score))
        if early_stop:
            break

    stopper.load_checkpoint(model)
    test_score = run_an_eval_epoch(args, model, test_loader)
    print("Best validation score {:.4f}".format(stopper.best_score))
    print("Test score {:.4f}".format(test_score))