in graphlearn_torch/python/loader/node_loader.py [0:0]
def _collate_fn(self, sampler_out: Union[SamplerOutput, HeteroSamplerOutput]):
r"""format sampler output to Data/HeteroData"""
if isinstance(sampler_out, SamplerOutput):
x = self.data.node_features[sampler_out.node]
y = self.input_t_label[sampler_out.node] \
if self.input_t_label is not None else None
if self.data.edge_features is not None and sampler_out.edge is not None:
edge_attr = self.data.edge_features[sampler_out.edge]
else:
edge_attr = None
res_data = to_data(sampler_out, batch_labels=y,
node_feats=x, edge_feats=edge_attr)
else: # hetero
x_dict = {}
x_dict = {ntype : self.data.get_node_feature(ntype)[ids] for ntype, ids in sampler_out.node.items()}
input_t_ids = sampler_out.node[self._input_type]
y_dict = {self._input_type: self.input_t_label[input_t_ids]} \
if self.input_t_label is not None else None
edge_attr_dict = {}
if sampler_out.edge is not None:
for etype, eids in sampler_out.edge.items():
efeat = self.data.get_edge_feature(etype)
if efeat is not None:
edge_attr_dict[etype] = efeat[eids]
res_data = to_hetero_data(sampler_out, batch_label_dict=y_dict,
node_feat_dict=x_dict,
edge_feat_dict=edge_attr_dict,
edge_dir=self.data.edge_dir)
return res_data