in graphlearn_torch/python/utils/common.py [0:0]
def merge_hetero_sampler_output(
in_sample: Any, out_sample: Any, device,
edge_dir: Literal['in', 'out']='out'):
def subid2gid(sample):
for k, v in sample.row.items():
sample.row[k] = sample.node[k[0]][v]
for k, v in sample.col.items():
sample.col[k] = sample.node[k[-1]][v]
def merge_tensor_dict(in_dict, out_dict, unique=False):
for k, v in in_dict.items():
vals = out_dict.get(k, torch.tensor([], device=device))
out_dict[k] = torch.cat((vals, v)).unique() if unique \
else torch.cat((vals, v))
subid2gid(in_sample)
subid2gid(out_sample)
merge_tensor_dict(in_sample.node, out_sample.node, unique=True)
merge_tensor_dict(in_sample.row, out_sample.row)
merge_tensor_dict(in_sample.col, out_sample.col)
for k, v in out_sample.row.items():
out_sample.row[k] = id2idx(out_sample.node[k[0]])[v.to(torch.int64)]
for k, v in out_sample.col.items():
out_sample.col[k] = id2idx(out_sample.node[k[-1]])[v.to(torch.int64)]
# if in_sample.batch is not None and out_sample.batch is not None:
# merge_tensor_dict(in_sample.batch, out_sample.batch)
if in_sample.edge is not None and out_sample.edge is not None:
merge_tensor_dict(in_sample.edge, out_sample.edge, unique=False)
if out_sample.edge_types is not None and in_sample.edge_types is not None:
out_sample.edge_types = list(set(out_sample.edge_types) | set(in_sample.edge_types))
if edge_dir == 'out':
out_sample.edge_types = [
reverse_edge_type(etype) if etype[0] != etype[-1] else etype
for etype in out_sample.edge_types
]
return out_sample