graphlearn_torch/python/sampler/neighbor_sampler.py (561 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
from typing import Dict, Optional, Union, Literal
import torch
import threading
from .. import py_graphlearn_torch as pywrap
from ..data import Graph
from ..typing import NodeType, EdgeType, NumNeighbors, reverse_edge_type
from ..utils import (
merge_dict, merge_hetero_sampler_output, format_hetero_sampler_output,
id2idx, count_dict
)
from .base import (
BaseSampler, EdgeIndex,
NodeSamplerInput, EdgeSamplerInput,
SamplerOutput, HeteroSamplerOutput, NeighborOutput,
)
from .negative_sampler import RandomNegativeSampler
class NeighborSampler(BaseSampler):
r""" Neighbor Sampler.
"""
def __init__(self,
graph: Union[Graph, Dict[EdgeType, Graph]],
num_neighbors: Optional[NumNeighbors] = None,
device: torch.device=torch.device('cuda:0'),
with_edge: bool=False,
with_neg: bool=False,
with_weight: bool=False,
strategy: str = 'random',
edge_dir: Literal['in', 'out'] = 'out',
seed: int = None):
self.graph = graph
self.num_neighbors = num_neighbors
self.device = device
self.with_edge = with_edge
self.with_neg = with_neg
self.with_weight = with_weight
self.strategy = strategy
self.edge_dir = edge_dir
self._subgraph_op = None
self._sampler = None
self._neg_sampler = None
self._inducer = None
self._sampler_lock = threading.Lock()
self.is_sampler_initialized = False
self.is_neg_sampler_initialized = False
if seed is not None:
pywrap.RandomSeedManager.getInstance().setSeed(seed)
if isinstance(self.graph, Graph): #homo
self._g_cls = 'homo'
if self.graph.mode == 'CPU':
self.device = torch.device('cpu')
else: # hetero
self._g_cls = 'hetero'
self.edge_types = []
self.node_types = set()
for etype, graph in self.graph.items():
self.edge_types.append(etype)
self.node_types.add(etype[0])
self.node_types.add(etype[2])
if self.graph[self.edge_types[0]].mode == 'CPU':
self.device = torch.device('cpu')
self._set_num_neighbors_and_num_hops(self.num_neighbors)
@property
def subgraph_op(self):
self.lazy_init_subgraph_op()
return self._subgraph_op
def lazy_init_sampler(self):
if not self.is_sampler_initialized:
with self._sampler_lock:
if self._sampler is None:
if self._g_cls == 'homo':
if self.device.type == 'cuda':
self._sampler = pywrap.CUDARandomSampler(self.graph.graph_handler)
elif self.with_weight == False:
self._sampler = pywrap.CPURandomSampler(self.graph.graph_handler)
else:
self._sampler = pywrap.CPUWeightedSampler(self.graph.graph_handler)
self.is_sampler_initialized = True
else: # hetero
self._sampler = {}
for etype, g in self.graph.items():
if self.device != torch.device('cpu'):
self._sampler[etype] = pywrap.CUDARandomSampler(g.graph_handler)
elif self.with_weight == False:
self._sampler[etype] = pywrap.CPURandomSampler(g.graph_handler)
else:
self._sampler[etype] = pywrap.CPUWeightedSampler(g.graph_handler)
self.is_sampler_initialized = True
def lazy_init_neg_sampler(self):
if not self.is_neg_sampler_initialized and self.with_neg:
with self._sampler_lock:
if self._neg_sampler is None:
if self._g_cls == 'homo':
self._neg_sampler = RandomNegativeSampler(
graph=self.graph,
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_neg_sampler_initialized = True
else: # hetero
self._neg_sampler = {}
for etype, g in self.graph.items():
self._neg_sampler[etype] = RandomNegativeSampler(
graph=g,
mode=self.device.type.upper(),
edge_dir=self.edge_dir
)
self.is_neg_sampler_initialized = True
def lazy_init_subgraph_op(self):
if self._subgraph_op is None:
with self._sampler_lock:
if self._subgraph_op is None:
if self.device.type == 'cuda':
self._subgraph_op = pywrap.CUDASubGraphOp(self.graph.graph_handler)
else:
self._subgraph_op = pywrap.CPUSubGraphOp(self.graph.graph_handler)
def sample_one_hop(
self,
input_seeds: torch.Tensor,
req_num: int,
etype: EdgeType = None
) -> NeighborOutput:
self.lazy_init_sampler()
sampler = self._sampler[etype] if etype is not None else self._sampler
input_seeds = input_seeds.to(self.device)
edge_ids = None
if not self.with_edge:
nbrs, nbrs_num = sampler.sample(input_seeds, req_num)
else:
nbrs, nbrs_num, edge_ids = sampler.sample_with_edge(input_seeds, req_num)
if nbrs.numel() == 0:
nbrs = torch.tensor([], dtype=torch.int64 ,device=self.device)
nbrs_num = torch.zeros_like(input_seeds, dtype=torch.int64, device=self.device)
edge_ids = torch.tensor([], device=self.device, dtype=torch.int64) \
if self.with_edge else None
return NeighborOutput(nbrs, nbrs_num, edge_ids)
def sample_from_nodes(
self,
inputs: NodeSamplerInput,
**kwargs
) -> Union[HeteroSamplerOutput, SamplerOutput]:
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if self._g_cls == 'hetero':
assert input_type is not None
output = self._hetero_sample_from_nodes({input_type: input_seeds})
else:
output = self._sample_from_nodes(input_seeds)
return output
def _sample_from_nodes(
self,
input_seeds: torch.Tensor
) -> SamplerOutput:
r""" Sample on homogenous graphs and induce COO format subgraph.
Note that messages in PyG are passed from src to dst. In 'out' direction,
we sample src's out neighbors and induce [src_index, dst_index] subgraphs.
The direction of sampling is opposite to the direction of message passing.
To be consistent with the semantics of PyG, the final edge index is
transpose to [dst_index, src_index]. In 'in' direction, we don't need to
reverse it.
"""
out_nodes, out_rows, out_cols, out_edges = [], [], [], []
num_sampled_nodes, num_sampled_edges = [], []
inducer = self.get_inducer(input_seeds.numel())
srcs = inducer.init_node(input_seeds)
batch = srcs
num_sampled_nodes.append(input_seeds.numel())
out_nodes.append(srcs)
for req_num in self.num_neighbors:
out_nbrs = self.sample_one_hop(srcs, req_num)
if out_nbrs.nbr.numel() == 0:
break
nodes, rows, cols = inducer.induce_next(
srcs, out_nbrs.nbr, out_nbrs.nbr_num)
out_nodes.append(nodes)
out_rows.append(rows)
out_cols.append(cols)
if out_nbrs.edge is not None:
out_edges.append(out_nbrs.edge)
num_sampled_nodes.append(nodes.size(0))
num_sampled_edges.append(cols.size(0))
srcs = nodes
return SamplerOutput(
node=torch.cat(out_nodes),
row=torch.cat(out_cols) if len(out_cols) > 0 else torch.tensor(out_cols),
col=torch.cat(out_rows) if len(out_rows) > 0 else torch.tensor(out_rows),
edge=(torch.cat(out_edges) if out_edges else None),
batch=batch,
num_sampled_nodes=num_sampled_nodes,
num_sampled_edges=num_sampled_edges,
device=self.device
)
def _hetero_sample_from_nodes(
self,
input_seeds_dict: Dict[NodeType, torch.Tensor],
) -> HeteroSamplerOutput:
r""" Sample on heterogenous graphs and induce COO format subgraph dict.
Note that messages in PyG are passed from src to dst. In 'out' direction,
we sample src's out neighbors and induce [src_index, dst_index] subgraphs.
The direction of sampling is opposite to the direction of message passing.
To be consistent with the semantics of PyG, the final edge index is transpose to
[dst_index, src_index] and edge_type is reversed as well. For example,
given the edge_type (u, u2i, i), we sample by meta-path u->i, but return
edge_index_dict {(i, rev_u2i, u) : [i, u]}. In 'in' direction, we don't need to
reverse it.
"""
# sample neighbors hop by hop.
max_input_batch_size = max([t.numel() for t in input_seeds_dict.values()])
inducer = self.get_inducer(max_input_batch_size)
src_dict = inducer.init_node(input_seeds_dict)
batch = src_dict
out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {}
num_sampled_nodes, num_sampled_edges = {}, {}
merge_dict(src_dict, out_nodes)
count_dict(src_dict, num_sampled_nodes, 1)
for i in range(self.num_hops):
nbr_dict, edge_dict = {}, {}
for etype in self.edge_types:
req_num = self.num_neighbors[etype][i]
# out sampling needs dst_type==seed_type, in sampling needs src_type==seed_type
if self.edge_dir == 'in':
src = src_dict.get(etype[-1], None)
if src is not None and src.numel() > 0:
output = self.sample_one_hop(src, req_num, etype)
if output.nbr.numel() == 0:
continue
nbr_dict[reverse_edge_type(etype)] = [src, output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[reverse_edge_type(etype)] = output.edge
elif self.edge_dir == 'out':
src = src_dict.get(etype[0], None)
if src is not None and src.numel() > 0:
output = self.sample_one_hop(src, req_num, etype)
if output.nbr.numel() == 0:
continue
nbr_dict[etype] = [src, output.nbr, output.nbr_num]
if output.edge is not None:
edge_dict[etype] = output.edge
if len(nbr_dict) == 0:
continue
nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict)
merge_dict(nodes_dict, out_nodes)
merge_dict(rows_dict, out_rows)
merge_dict(cols_dict, out_cols)
merge_dict(edge_dict, out_edges)
count_dict(nodes_dict, num_sampled_nodes, i + 2)
count_dict(cols_dict, num_sampled_edges, i + 1)
src_dict = nodes_dict
for etype, rows in out_rows.items():
out_rows[etype] = torch.cat(rows)
out_cols[etype] = torch.cat(out_cols[etype])
if self.with_edge:
out_edges[etype] = torch.cat(out_edges[etype])
res_rows, res_cols, res_edges = {}, {}, {}
for etype, rows in out_rows.items():
rev_etype = reverse_edge_type(etype)
res_rows[rev_etype] = out_cols[etype]
res_cols[rev_etype] = rows
if self.with_edge:
res_edges[rev_etype] = out_edges[etype]
return HeteroSamplerOutput(
node={k : torch.cat(v) for k, v in out_nodes.items()},
row=res_rows,
col=res_cols,
edge=(res_edges if len(res_edges) else None),
batch=batch,
num_sampled_nodes={k : torch.tensor(v, device=self.device)
for k, v in num_sampled_nodes.items()},
num_sampled_edges={
reverse_edge_type(k) : torch.tensor(v, device=self.device)
for k, v in num_sampled_edges.items()},
edge_types=self.edge_types,
device=self.device
)
def sample_from_edges(
self,
inputs: EdgeSamplerInput,
**kwargs,
) -> Union[HeteroSamplerOutput, SamplerOutput]:
r"""Performs sampling from an edge sampler input, leveraging a sampling
function of the same signature as `node_sample`.
Note that in out-edge sampling, we reverse the direction of src and dst
for the output so that features of the sampled nodes during training can
be aggregated from k-hop to (k-1)-hop nodes.
"""
src = inputs.row.to(self.device)
dst = inputs.col.to(self.device)
edge_label = None if inputs.label is None else inputs.label.to(self.device)
input_type = inputs.input_type
neg_sampling = inputs.neg_sampling
num_pos = src.numel()
num_neg = 0
# Negative Sampling
self.lazy_init_neg_sampler()
if neg_sampling is not None:
# When we are doing negative sampling, we append negative information
# of nodes/edges to `src`, `dst`.
# Later on, we can easily reconstruct what belongs to positive and
# negative examples by slicing via `num_pos`.
num_neg = math.ceil(num_pos * neg_sampling.amount)
if neg_sampling.is_binary():
# In the "binary" case, we randomly sample negative pairs of nodes.
if input_type is not None:
neg_pair = self._neg_sampler[input_type].sample(num_neg)
else:
neg_pair = self._neg_sampler.sample(num_neg)
src_neg, dst_neg = neg_pair[0], neg_pair[1]
src = torch.cat([src, src_neg], dim=0)
dst = torch.cat([dst, dst_neg], dim=0)
if edge_label is None:
edge_label = torch.ones(num_pos, device=self.device)
size = (src_neg.size()[0], ) + edge_label.size()[1:]
edge_neg_label = edge_label.new_zeros(size)
edge_label = torch.cat([edge_label, edge_neg_label])
elif neg_sampling.is_triplet():
# TODO: make triplet negative sampling strict.
# In the "triplet" case, we randomly sample negative destinations
# in a "non-strict" manner.
assert num_neg % num_pos == 0
if input_type is not None:
neg_pair = self._neg_sampler[input_type].sample(num_neg, padding=True)
else:
neg_pair = self._neg_sampler.sample(num_neg, padding=True)
dst_neg = neg_pair[1]
dst = torch.cat([dst, dst_neg], dim=0)
assert edge_label is None
# Neighbor Sampling
if input_type is not None: # hetero
if input_type[0] != input_type[-1]: # Two distinct node types:
src_seed, dst_seed = src, dst
src, inverse_src = src.unique(return_inverse=True)
dst, inverse_dst = dst.unique(return_inverse=True)
seed_dict = {input_type[0]: src, input_type[-1]: dst}
else: # Only a single node type: Merge both source and destination.
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
seed_dict = {input_type[0]: seed}
temp_out = []
for it, node in seed_dict.items():
seeds = NodeSamplerInput(node=node, input_type=it)
temp_out.append(self.sample_from_nodes(seeds))
if len(temp_out) == 2:
out = merge_hetero_sampler_output(temp_out[0],
temp_out[1],
device=self.device,
edge_dir=self.edge_dir)
else:
out = format_hetero_sampler_output(temp_out[0], edge_dir=self.edge_dir)
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
if input_type[0] != input_type[-1]:
inverse_src = id2idx(out.node[input_type[0]])[src_seed]
inverse_dst = id2idx(out.node[input_type[-1]])[dst_seed]
edge_label_index = torch.stack([
inverse_src,
inverse_dst,
], dim=0)
else:
edge_label_index = inverse_seed.view(2, -1)
out.metadata = {'edge_label_index': edge_label_index,
'edge_label': edge_label}
out.input_type = input_type
elif neg_sampling.is_triplet():
if input_type[0] != input_type[-1]:
inverse_src = id2idx(out.node[input_type[0]])[src_seed]
inverse_dst = id2idx(out.node[input_type[-1]])[dst_seed]
src_index = inverse_src
dst_pos_index = inverse_dst[:num_pos]
dst_neg_index = inverse_dst[num_pos:]
else:
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = {'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index}
out.input_type = input_type
else: #homo
seed = torch.cat([src, dst], dim=0)
seed, inverse_seed = seed.unique(return_inverse=True)
out = self.sample_from_nodes(seed)
# edge_label
if neg_sampling is None or neg_sampling.is_binary():
edge_label_index = inverse_seed.view(2, -1)
out.metadata = {'edge_label_index': edge_label_index,
'edge_label': edge_label}
elif neg_sampling.is_triplet():
src_index = inverse_seed[:num_pos]
dst_pos_index = inverse_seed[num_pos:2 * num_pos]
dst_neg_index = inverse_seed[2 * num_pos:]
dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)
out.metadata = {'src_index': src_index,
'dst_pos_index': dst_pos_index,
'dst_neg_index': dst_neg_index}
return out
def sample_pyg_v1(self, ids: torch.Tensor):
r""" Sample multi-hop neighbors and organize results to PyG's `EdgeIndex`.
Args:
ids: input ids, 1D tensor.
The sampled results that is the same as PyG's `NeighborSampler`(PyG v1)
"""
ids = ids.to(self.device)
adjs = []
srcs = ids
out_ids = ids
batch_size = 0
inducer = self.get_inducer(srcs.numel())
for i, req_num in enumerate(self.num_neighbors):
srcs = inducer.init_node(srcs)
batch_size = srcs.numel() if i == 0 else batch_size
out_nbrs = self.sample_one_hop(srcs, req_num)
nodes, rows, cols = \
inducer.induce_next(srcs, out_nbrs.nbr, out_nbrs.nbr_num)
edge_index = torch.stack([cols, rows]) # we use csr instead of csc in PyG.
out_ids = torch.cat([srcs, nodes])
adj_size = torch.LongTensor([out_ids.size(0), srcs.size(0)])
adjs.append(EdgeIndex(edge_index, out_nbrs.edge, adj_size))
srcs = out_ids
return batch_size, out_ids, adjs[::-1]
def subgraph(
self,
inputs: NodeSamplerInput,
) -> SamplerOutput:
self.lazy_init_subgraph_op()
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
if self.num_neighbors is not None:
nodes = [input_seeds]
for num in self.num_neighbors:
nbr = self.sample_one_hop(nodes[-1], num).nbr
nodes.append(torch.unique(nbr))
nodes, mapping = torch.cat(nodes).unique(return_inverse=True)
else:
nodes, mapping = torch.unique(input_seeds, return_inverse=True)
subgraph = self._subgraph_op.node_subgraph(nodes, self.with_edge)
return SamplerOutput(
node=subgraph.nodes,
# The edge index should be reversed.
row=subgraph.cols,
col=subgraph.rows,
edge=subgraph.eids if self.with_edge else None,
device=self.device,
metadata=mapping[:input_seeds.numel()])
def sample_prob(
self,
inputs: NodeSamplerInput,
node_cnt: Union[int, Dict[NodeType, int]]
) -> Union[torch.Tensor, Dict[NodeType, torch.Tensor]]:
r""" Get the probability of each node being sampled.
"""
self.lazy_init_sampler()
inputs = NodeSamplerInput.cast(inputs)
input_seeds = inputs.node.to(self.device)
input_type = inputs.input_type
if self._g_cls == 'hetero':
assert input_type is not None
output = self._hetero_sample_prob({input_type : input_seeds}, node_cnt)
else:
output = self._sample_prob(input_seeds, node_cnt)
return output
def _sample_prob(
self,
input_seeds: torch.Tensor,
node_cnt: int
) -> torch.Tensor:
last_prob = \
torch.ones(node_cnt, device=self.device, dtype=torch.float32) * 0.01
last_prob[input_seeds] = 1
for req in self.num_neighbors:
cur_prob = torch.zeros(node_cnt, device=self.device, dtype=torch.float32)
self._sampler.cal_nbr_prob(
req, last_prob, last_prob, self.graph.graph_handler, cur_prob
)
last_prob = cur_prob
return last_prob
def _hetero_sample_prob(
self,
input_seeds_dict: Dict[NodeType, torch.Tensor],
node_dict: Dict[NodeType, int]
) -> Dict[NodeType, torch.Tensor]:
probs = {}
for ntype in node_dict.keys():
probs[ntype] = []
# calculate probs for each subgraph
for i in range(self.num_hops):
for etype in self.edge_types:
req = self.num_neighbors[etype][i]
# homogenous subgraph case
if etype[0] == etype[2]:
if len(probs[etype[0]]) == 0:
last_prob = torch.ones(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
last_prob[input_seeds_dict[etype[0]]] = 1
else:
last_prob = self.aggregate_prob(probs[etype[0]],
node_dict[etype[0]].size(0),
device=self.device)
cur_prob = torch.zeros(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32)
self._sampler[etype].cal_nbr_prob(
req, last_prob, last_prob,
self._graph_dict[etype].graph_handler, cur_prob
)
last_prob = cur_prob
probs[etype[0]].append(last_prob)
# hetero bipartite graph case
else:
if len(probs[etype[0]]) == 0:
last_prob = torch.ones(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
last_prob[input_seeds_dict[etype[0]]] = 1
else:
last_prob = self.aggregate_prob(probs[etype[0]],
node_dict[etype[0]].size(0),
device=self.device)
etypes = [nbr_etype
for nbr_etype in self.edge_types
if nbr_etype[0] == etype[2]]
temp_probs = []
# prepare nbr_prob
if len(probs[etype[2]]) == 0:
nbr_prob = torch.ones(node_dict[etype[2]].size(0),
device=self.device,
dtype=torch.float32) * 0.005
if etype[2] in input_seeds_dict:
nbr_prob[input_seeds_dict[etype[2]]] = 1
else:
nbr_prob = self.aggregate_prob(probs[etype[2]],
node_dict[etype[2]].size(0),
device=self.device)
for nbr_etype in etypes:
cur_prob = torch.zeros(node_dict[etype[0]].size(0),
device=self.device,
dtype=torch.float32)
self._sampler[etype].cal_nbr_prob(
req, last_prob, nbr_prob,
self._graph_dict[nbr_etype].graph_handler, cur_prob
)
last_prob = cur_prob
temp_probs.append(last_prob)
# aggregate prob for the bipartite graph
# with #{subgraphs where the neighbours are}
sub_temp_prob = self.aggregate_prob(temp_probs,
node_dict[etype[0]].size(0),
device=self.device)
probs[etype[0]].append(sub_temp_prob)
# aggregate probs from each subgraph
# with #{subgraphs}
for ntype, prob in probs.items():
res = self.aggregate_prob(
prob, node_dict[ntype].size(0), device=self.device)
if i == self.num_hops - 1:
probs[ntype] = res
else:
probs[ntype] = [res]
return probs
def get_inducer(self, input_batch_size: int):
if self._inducer is None:
self._inducer = self.create_inducer(input_batch_size)
return self._inducer
def create_inducer(self, input_batch_size: int):
max_num_nodes = self._max_sampled_nodes(input_batch_size)
if self.device.type == 'cuda':
if self._g_cls == 'homo':
inducer = pywrap.CUDAInducer(max_num_nodes)
else:
inducer = pywrap.CUDAHeteroInducer(max_num_nodes)
else:
if self._g_cls == 'homo':
inducer = pywrap.CPUInducer(max_num_nodes)
else:
inducer = pywrap.CPUHeteroInducer(max_num_nodes)
return inducer
def _set_num_neighbors_and_num_hops(self, num_neighbors):
if isinstance(num_neighbors, (list, tuple)):
num_neighbors = {key: num_neighbors for key in self.edge_types}
assert isinstance(num_neighbors, dict)
self.num_neighbors = num_neighbors
# Add at least one element to the list to ensure `max` is well-defined
self.num_hops = max([0] + [len(v) for v in num_neighbors.values()])
for key, value in self.num_neighbors.items():
if len(value) != self.num_hops:
raise ValueError(f"Expected the edge type {key} to have "
f"{self.num_hops} entries (got {len(value)})")
def _max_sampled_nodes(
self,
input_batch_size: int,
) -> Union[int, Dict[str, int]]:
if self._g_cls == 'homo':
res = [input_batch_size]
for num in self.num_neighbors:
res.append(res[-1] * num)
return sum(res)
res = {k : [] for k in self.node_types}
for etype, num_list in self.num_neighbors.items():
tmp_res = [input_batch_size]
for num in num_list:
tmp_res.append(tmp_res[-1] * num)
res[etype[0]].extend(tmp_res)
res[etype[2]].extend(tmp_res)
return {k : sum(v) for k, v in res.items()}
def _aggregate_prob(self, probs, node_num, device):
"""
Aggregate probs from each subgraph
p = 1 - ((1-p_0)(1-p_1)...(1-p_k))**(1/k)
where k := #{subgraphs}
"""
res = torch.ones(node_num, device=device, dtype=torch.float32)
for temp_prob in probs:
# to avoid the case that p_i=1 causes p=1 s.t the whole importance won't
# be decided by one term.
res *= (1 + .002 - temp_prob)
res = 1 - res ** (1/len(probs))
return res.clamp(min=0.0)