def forward()

in hugegraph-ml/src/hugegraph_ml/models/bgrl.py [0:0]


    def forward(self, graph, feat):
        transform_1 = get_graph_drop_transform(
        drop_edge_p=0.3, feat_mask_p=0.3
        )
        transform_2 = get_graph_drop_transform(
        drop_edge_p=0.2, feat_mask_p=0.4
        )
        online_x = transform_1(graph)
        target_x = transform_2(graph)
        online_x, target_x = dgl.add_self_loop(online_x), dgl.add_self_loop(target_x)
        online_feats, target_feats = online_x.ndata["feat"], target_x.ndata["feat"]
        # forward online network
        online_y1 = self.online_encoder(online_x, online_feats)
        # prediction
        online_q1 = self.predictor(online_y1)
        # forward target network
        with torch.no_grad():
            target_y1 = self.target_encoder(target_x, target_feats).detach()
        # forward online network 2
        online_y2 = self.online_encoder(target_x, target_feats)
        # prediction
        online_q2 = self.predictor(online_y2)
        # forward target network
        with torch.no_grad():
            target_y2 = self.target_encoder(online_x, online_feats).detach()
        loss = (
            2
            - cosine_similarity(online_q1, target_y1.detach(), dim=-1).mean() # pylint: disable=E1102
            - cosine_similarity(online_q2, target_y2.detach(), dim=-1).mean() # pylint: disable=E1102
        )
        return loss