def forward()

in shaDow/layers.py [0:0]


    def forward(self, feats_in_l, idx_targets, sizes_subg):
        if self.prediction_task == 'link':
            assert idx_targets.shape[0] == 2 * sizes_subg.shape[0]
        else:
            assert idx_targets.shape[0] == sizes_subg.shape[0]
        if self.type_pool == 'center':
            if self.type_res == 'none':
                feat_in = feats_in_l[-1][idx_targets]
                if self.prediction_task == 'node':
                    return feat_in
            else:       # regular JK
                feats_root_l = [f[idx_targets] for f in feats_in_l]
                feat_in = self.f_residue(feats_root_l)
            feat_in = self.aggr_target_emb(feat_in)
        elif self.type_pool in ['max', 'mean', 'sum']:
            # first pool subgraph at each layer, then residue
            offsets = torch.cumsum(sizes_subg, dim=0)
            offsets = torch.roll(offsets, 1)
            offsets[0] = 0
            idx = torch.arange(feats_in_l[-1].shape[0]).to(feats_in_l[-1].device)
            if self.type_res == 'none':
                feat_pool = F.embedding_bag(idx, feats_in_l[-1], offsets, mode=self.type_pool)
                feat_root = feats_in_l[-1][idx_targets]
            else:
                feat_pool_l = []
                for feat in feats_in_l:
                    feat_pool = F.embedding_bag(idx, feat, offsets, mode=self.type_pool)
                    feat_pool_l.append(feat_pool)
                feat_pool = self.f_residue(feat_pool_l)
                feat_root = self.f_residue([f[idx_targets] for f in feats_in_l])
            feat_in = torch.cat([self.aggr_target_emb(feat_root), feat_pool], dim=1)
        elif self.type_pool == 'sort':
            if self.type_res == 'none':
                feat_pool_in = feats_in_l[-1]
                feat_root = feats_in_l[-1][idx_targets]
            else:
                feat_pool_in = self.f_residue(feats_in_l)
                feat_root = self.f_residue([f[idx_targets] for f in feats_in_l])
            arange = torch.arange(sizes_subg.size(0)).to(sizes_subg.device)
            idx_batch = torch.repeat_interleave(arange, sizes_subg)
            feat_pool_k = global_sort_pool(feat_pool_in, idx_batch, self.k)     # #subg x (k * F)
            feat_pool = self.nn_pool(feat_pool_k)
            feat_in = torch.cat([self.aggr_target_emb(feat_root), feat_pool], dim=1)
        else:
            raise NotImplementedError
        return self.f_norm(self.nn(feat_in))