in train.py [0:0]
def main(args):
"""
Load data, train and evaluate model and save scores. Configuration in the args object.
Args:
args: Parsed command line arguments. Must include: pytorchlighting pre-defined args, task, node_h_dim_s, node_h_dim_v, edge_h_dim_s, edge_h_dim_v, pretrained_weights, ls, bs, early_stopping_patience, num_workers.
Returns:
None
"""
pl.seed_everything(42, workers=True)
# 1. Load data
train_dataset = data_loaders.get_dataset(
args.task, MODEL_TYPES[args.model_name], split="train"
)
valid_dataset = data_loaders.get_dataset(
args.task, MODEL_TYPES[args.model_name], split="valid"
)
print("Data loaded:", len(train_dataset), len(valid_dataset))
# 2. Prepare data loaders
if MODEL_TYPES[args.model_name] == "seq":
DataLoader = torch.utils.data.DataLoader
else:
DataLoader = torch_geometric.data.DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=args.bs,
shuffle=True,
num_workers=args.num_workers,
)
valid_loader = DataLoader(
valid_dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers,
)
# 3. Prepare model
datum = None
if MODEL_TYPES[args.model_name] != "seq":
# getting the dims from dataset
datum = train_dataset[0][0]
dict_args = vars(args)
model = init_model(
datum=datum,
num_outputs=train_dataset.num_outputs,
weights=train_dataset.pos_weights,
classify=IS_CLASSIFY[args.task],
**dict_args
)
if args.pretrained_weights:
# load pretrained weights
checkpoint = torch.load(
args.pretrained_weights, map_location=torch.device("cpu")
)
load_state_dict_to_model(model, checkpoint["state_dict"])
# 4. Training
# callbacks
early_stop_callback = EarlyStopping(
monitor="val_loss", patience=args.early_stopping_patience
)
# Init ModelCheckpoint callback, monitoring 'val_loss'
checkpoint_callback = ModelCheckpoint(monitor="val_loss")
# init pl.Trainer
trainer = pl.Trainer.from_argparse_args(
args,
deterministic=True,
callbacks=[early_stop_callback, checkpoint_callback],
)
# train
trainer.fit(model, train_loader, valid_loader)
print("Training finished")
print(
"checkpoint_callback.best_model_path:",
checkpoint_callback.best_model_path,
)
# 5. Evaluation
# load the best model
model = model.load_from_checkpoint(
checkpoint_path=checkpoint_callback.best_model_path,
weights=train_dataset.pos_weights,
)
print("Testing performance on test set")
# load test data
test_dataset = data_loaders.get_dataset(
args.task, MODEL_TYPES[args.model_name], split="test"
)
test_loader = DataLoader(
test_dataset,
batch_size=args.bs,
shuffle=False,
num_workers=args.num_workers,
)
scores = evaluate(model, test_loader, args.task)
# save scores to file
json.dump(
scores,
open(os.path.join(trainer.log_dir, "scores.json"), "w"),
)
return None