in hugegraph-ml/src/hugegraph_ml/models/care_gnn.py [0:0]
def RLModule(self, graph, epoch, idx):
for layer in self.layers:
for etype in self.edges:
if not layer.cvg[etype]:
# formula 5
eid = graph.in_edges(idx, form="eid", etype=etype)
avg_dist = th.mean(layer.dist[etype][eid])
# formula 6
if layer.last_avg_dist[etype] < avg_dist:
if layer.p[etype] - self.step_size > 0:
layer.p[etype] -= self.step_size
layer.f[etype].append(-1)
else:
if layer.p[etype] + self.step_size <= 1:
layer.p[etype] += self.step_size
layer.f[etype].append(+1)
layer.last_avg_dist[etype] = avg_dist
# formula 7
if epoch >= 9 and abs(sum(layer.f[etype][-10:])) <= 2:
layer.cvg[etype] = True