def main()

in train.py [0:0]


def main(args):
    """
    Load data, train and evaluate model and save scores. Configuration in the args object.

    Args:
        args: Parsed command line arguments. Must include: pytorchlighting pre-defined args, task, node_h_dim_s, node_h_dim_v, edge_h_dim_s, edge_h_dim_v, pretrained_weights, ls, bs, early_stopping_patience, num_workers.

    Returns:
        None
    """
    pl.seed_everything(42, workers=True)
    # 1. Load data
    train_dataset = data_loaders.get_dataset(
        args.task, MODEL_TYPES[args.model_name], split="train"
    )
    valid_dataset = data_loaders.get_dataset(
        args.task, MODEL_TYPES[args.model_name], split="valid"
    )
    print("Data loaded:", len(train_dataset), len(valid_dataset))
    # 2. Prepare data loaders
    if MODEL_TYPES[args.model_name] == "seq":
        DataLoader = torch.utils.data.DataLoader
    else:
        DataLoader = torch_geometric.data.DataLoader

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.bs,
        shuffle=True,
        num_workers=args.num_workers,
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=args.bs,
        shuffle=False,
        num_workers=args.num_workers,
    )
    # 3. Prepare model
    datum = None
    if MODEL_TYPES[args.model_name] != "seq":
        # getting the dims from dataset
        datum = train_dataset[0][0]
    dict_args = vars(args)
    model = init_model(
        datum=datum,
        num_outputs=train_dataset.num_outputs,
        weights=train_dataset.pos_weights,
        classify=IS_CLASSIFY[args.task],
        **dict_args
    )
    if args.pretrained_weights:
        # load pretrained weights
        checkpoint = torch.load(
            args.pretrained_weights, map_location=torch.device("cpu")
        )
        load_state_dict_to_model(model, checkpoint["state_dict"])
    # 4. Training
    # callbacks
    early_stop_callback = EarlyStopping(
        monitor="val_loss", patience=args.early_stopping_patience
    )
    # Init ModelCheckpoint callback, monitoring 'val_loss'
    checkpoint_callback = ModelCheckpoint(monitor="val_loss")
    # init pl.Trainer
    trainer = pl.Trainer.from_argparse_args(
        args,
        deterministic=True,
        callbacks=[early_stop_callback, checkpoint_callback],
    )
    # train
    trainer.fit(model, train_loader, valid_loader)
    print("Training finished")
    print(
        "checkpoint_callback.best_model_path:",
        checkpoint_callback.best_model_path,
    )
    # 5. Evaluation
    # load the best model
    model = model.load_from_checkpoint(
        checkpoint_path=checkpoint_callback.best_model_path,
        weights=train_dataset.pos_weights,
    )
    print("Testing performance on test set")
    # load test data
    test_dataset = data_loaders.get_dataset(
        args.task, MODEL_TYPES[args.model_name], split="test"
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.bs,
        shuffle=False,
        num_workers=args.num_workers,
    )
    scores = evaluate(model, test_loader, args.task)
    # save scores to file
    json.dump(
        scores,
        open(os.path.join(trainer.log_dir, "scores.json"), "w"),
    )
    return None