in shaDow/layers.py [0:0]
def forward(self, feats_in_l, idx_targets, sizes_subg):
if self.prediction_task == 'link':
assert idx_targets.shape[0] == 2 * sizes_subg.shape[0]
else:
assert idx_targets.shape[0] == sizes_subg.shape[0]
if self.type_pool == 'center':
if self.type_res == 'none':
feat_in = feats_in_l[-1][idx_targets]
if self.prediction_task == 'node':
return feat_in
else: # regular JK
feats_root_l = [f[idx_targets] for f in feats_in_l]
feat_in = self.f_residue(feats_root_l)
feat_in = self.aggr_target_emb(feat_in)
elif self.type_pool in ['max', 'mean', 'sum']:
# first pool subgraph at each layer, then residue
offsets = torch.cumsum(sizes_subg, dim=0)
offsets = torch.roll(offsets, 1)
offsets[0] = 0
idx = torch.arange(feats_in_l[-1].shape[0]).to(feats_in_l[-1].device)
if self.type_res == 'none':
feat_pool = F.embedding_bag(idx, feats_in_l[-1], offsets, mode=self.type_pool)
feat_root = feats_in_l[-1][idx_targets]
else:
feat_pool_l = []
for feat in feats_in_l:
feat_pool = F.embedding_bag(idx, feat, offsets, mode=self.type_pool)
feat_pool_l.append(feat_pool)
feat_pool = self.f_residue(feat_pool_l)
feat_root = self.f_residue([f[idx_targets] for f in feats_in_l])
feat_in = torch.cat([self.aggr_target_emb(feat_root), feat_pool], dim=1)
elif self.type_pool == 'sort':
if self.type_res == 'none':
feat_pool_in = feats_in_l[-1]
feat_root = feats_in_l[-1][idx_targets]
else:
feat_pool_in = self.f_residue(feats_in_l)
feat_root = self.f_residue([f[idx_targets] for f in feats_in_l])
arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device)
idx_batch = torch.repeat_interleave(arange, sizes_subg)
feat_pool_k = global_sort_pool(feat_pool_in, idx_batch, self.k) # #subg x (k * F)
feat_pool = self.nn_pool(feat_pool_k)
feat_in = torch.cat([self.aggr_target_emb(feat_root), feat_pool], dim=1)
else:
raise NotImplementedError
return self.f_norm(self.nn(feat_in))