def _collate_fn()

in graphlearn_torch/python/loader/link_loader.py [0:0]


  def _collate_fn(self, sampler_out: Union[SamplerOutput, HeteroSamplerOutput]):
    r"""format sampler output to Data/HeteroData
      For the out-edge sampling scheme (i.e. the direction of edges in
      the output is inverse to the original graph), we put the reversed
      edge_label_index into the (dst, rev_to, src) subgraph for
      HeteroSamplerOutput and (dst, to, src) for SamplerOutput.
      However, for the in-edge sampling scheme (i.e. the direction of edges 
      in the output is the same as the original graph), we do not need to
      reverse the edge type of the sampler_out.
    """
    if isinstance(sampler_out, SamplerOutput):
      x = self.data.node_features[sampler_out.node]
      if self.data.edge_features is not None and sampler_out.edge is not None:
        edge_attr = self.data.edge_features[sampler_out.edge]
      else:
        edge_attr = None
      res_data = to_data(sampler_out,
                         node_feats=x,
                         edge_feats=edge_attr,
                        )
    else: # hetero
      x_dict = {}
      x_dict = {ntype : self.data.get_node_feature(ntype)[ids.to(torch.int64)] for ntype, ids in sampler_out.node.items()}
      edge_attr_dict = {}
      if sampler_out.edge is not None:
        for etype, eids in sampler_out.edge.items():
          if self.edge_dir == 'out':
            efeat = self.data.get_edge_feature(reverse_edge_type(etype))
          elif self.edge_dir == 'in':
            efeat = self.data.get_edge_feature(etype)
          if efeat is not None:
            edge_attr_dict[etype] = efeat[eids.to(torch.int64)]

      res_data = to_hetero_data(sampler_out,
                                node_feat_dict=x_dict,
                                edge_feat_dict=edge_attr_dict,
                                edge_dir=self.edge_dir,
                               )
    return res_data