def collate_fn()

in hype_kg/codes/dataloader.py [0:0]


    def collate_fn(data):
        '''Stacks positive_samples, negative_samples and their weights into one tensor.'''
        positive_sample = torch.stack([_[0] for _ in data], dim=0)
        negative_sample = torch.stack([_[1] for _ in data], dim=0)
        subsample_weight = torch.cat([_[2] for _ in data], dim=0)
        mode = data[0][3]
        return positive_sample, negative_sample, subsample_weight, mode