in hugegraph-ml/src/hugegraph_ml/models/bgrl.py [0:0]
def forward(self, graph, feat):
transform_1 = get_graph_drop_transform(
drop_edge_p=0.3, feat_mask_p=0.3
)
transform_2 = get_graph_drop_transform(
drop_edge_p=0.2, feat_mask_p=0.4
)
online_x = transform_1(graph)
target_x = transform_2(graph)
online_x, target_x = dgl.add_self_loop(online_x), dgl.add_self_loop(target_x)
online_feats, target_feats = online_x.ndata["feat"], target_x.ndata["feat"]
# forward online network
online_y1 = self.online_encoder(online_x, online_feats)
# prediction
online_q1 = self.predictor(online_y1)
# forward target network
with torch.no_grad():
target_y1 = self.target_encoder(target_x, target_feats).detach()
# forward online network 2
online_y2 = self.online_encoder(target_x, target_feats)
# prediction
online_q2 = self.predictor(online_y2)
# forward target network
with torch.no_grad():
target_y2 = self.target_encoder(online_x, online_feats).detach()
loss = (
2
- cosine_similarity(online_q1, target_y1.detach(), dim=-1).mean() # pylint: disable=E1102
- cosine_similarity(online_q2, target_y2.detach(), dim=-1).mean() # pylint: disable=E1102
)
return loss