def numpy_to_pt()

in src/image_gen_aux/image_processor.py [0:0]


    def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
        """
        Convert a NumPy image to a PyTorch tensor.

        Args:
            images (np.ndarray): The input image(s) as a NumPy array.

        Returns:
            torch.Tensor: The converted image(s) as a PyTorch tensor.
        """
        if images.ndim == 3:
            images = images[..., None]
        images = torch.from_numpy(images.transpose(0, 3, 1, 2)).float()

        return images