in graphlearn_torch/python/utils/common.py [0:0]
def load_and_concatenate_tensors(filename, device):
# Load file and read tensors
with open(filename, 'rb') as f:
tensor_list = []
while True:
try:
tensor = pickle.load(f)
tensor_list.append(tensor)
except EOFError:
break
# Pre-allocate memory for combined tensor
combined_tensor = torch.empty((sum(t.shape[0] for t in tensor_list),
*tensor_list[0].shape[1:]), dtype=tensor_list[0].dtype, device=device)
# Concatenate tensors in list into combined tensor
start_idx = 0
for tensor in tensor_list:
end_idx = start_idx + tensor.shape[0]
combined_tensor[start_idx:end_idx] = tensor.to(device)
start_idx = end_idx
return combined_tensor