in seal_link_pred.py [0:0]
def train():
model.train()
total_loss = 0
pbar = tqdm(train_loader, ncols=70)
for data in pbar:
data = data.to(device)
optimizer.zero_grad()
x = data.x if args.use_feature else None
edge_weight = data.edge_weight if args.use_edge_weight else None
node_id = data.node_id if emb else None
logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
loss = BCEWithLogitsLoss()(logits.view(-1), data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
return total_loss / len(train_dataset)