def format_hetero_sampler_output()

in graphlearn_torch/python/utils/common.py [0:0]


def format_hetero_sampler_output(in_sample: Any, edge_dir=Literal['in', 'out']):
  for k in in_sample.node.keys():
    in_sample.node[k] = in_sample.node[k].unique()
  if in_sample.edge_types is not None:
    if edge_dir == 'out':
      in_sample.edge_types = [
        reverse_edge_type(etype) if etype[0] != etype[-1] else etype
        for etype in in_sample.edge_types
      ]
  return in_sample