in src/image_gen_aux/image_processor.py [0:0]
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
"""
Convert a NumPy image to a PyTorch tensor.
Args:
images (np.ndarray): The input image(s) as a NumPy array.
Returns:
torch.Tensor: The converted image(s) as a PyTorch tensor.
"""
if images.ndim == 3:
images = images[..., None]
images = torch.from_numpy(images.transpose(0, 3, 1, 2)).float()
return images