def merge_hetero_sampler_output()

in graphlearn_torch/python/utils/common.py [0:0]


def merge_hetero_sampler_output(
    in_sample: Any, out_sample: Any, device,
    edge_dir: Literal['in', 'out']='out'):
  def subid2gid(sample):
    for k, v in sample.row.items():
      sample.row[k] = sample.node[k[0]][v]
    for k, v in sample.col.items():
      sample.col[k] = sample.node[k[-1]][v]

  def merge_tensor_dict(in_dict, out_dict, unique=False):
    for k, v in in_dict.items():
      vals = out_dict.get(k, torch.tensor([], device=device))
      out_dict[k] = torch.cat((vals, v)).unique() if unique \
        else torch.cat((vals, v))

  subid2gid(in_sample)
  subid2gid(out_sample)
  merge_tensor_dict(in_sample.node, out_sample.node, unique=True)
  merge_tensor_dict(in_sample.row, out_sample.row)
  merge_tensor_dict(in_sample.col, out_sample.col)

  for k, v in out_sample.row.items():
    out_sample.row[k] = id2idx(out_sample.node[k[0]])[v.to(torch.int64)]
  for k, v in out_sample.col.items():
    out_sample.col[k] = id2idx(out_sample.node[k[-1]])[v.to(torch.int64)]

  # if in_sample.batch is not None and out_sample.batch is not None:
  #   merge_tensor_dict(in_sample.batch, out_sample.batch)
  if in_sample.edge is not None and out_sample.edge is not None:
    merge_tensor_dict(in_sample.edge, out_sample.edge, unique=False)
  if out_sample.edge_types is not None and in_sample.edge_types is not None:
    out_sample.edge_types = list(set(out_sample.edge_types) | set(in_sample.edge_types))
    if edge_dir == 'out':
      out_sample.edge_types = [
        reverse_edge_type(etype) if etype[0] != etype[-1] else etype
        for etype in out_sample.edge_types
      ]

  return out_sample