def load_and_concatenate_tensors()

in graphlearn_torch/python/utils/common.py [0:0]


def load_and_concatenate_tensors(filename, device):
    # Load file and read tensors
    with open(filename, 'rb') as f:
        tensor_list = []
        while True:
            try:
                tensor = pickle.load(f)
                tensor_list.append(tensor)
            except EOFError:
                break
    # Pre-allocate memory for combined tensor
    combined_tensor = torch.empty((sum(t.shape[0] for t in tensor_list), 
      *tensor_list[0].shape[1:]), dtype=tensor_list[0].dtype, device=device)
    # Concatenate tensors in list into combined tensor
    start_idx = 0
    for tensor in tensor_list:
        end_idx = start_idx + tensor.shape[0]
        combined_tensor[start_idx:end_idx] = tensor.to(device)
        start_idx = end_idx
    return combined_tensor