def scale_image()

in src/image_gen_aux/image_processor.py [0:0]


    def scale_image(image: torch.Tensor, scale: float, mutiple_factor: int = 8) -> torch.Tensor:
        """
        Scales an image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.

        Args:
            image (`torch.Tensor`): The input image tensor of shape (batch, channels, height, width).
            scale (`float`): The scaling factor applied to the image dimensions.
            multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.

        Returns:
            `torch.Tensor`: The scaled image tensor.
        """

        if scale == 1.0:
            return image, scale

        _batch, _channels, height, width = image.shape

        # Calculate new dimensions while maintaining aspect ratio
        new_height = int(height * scale)
        new_width = int(width * scale)

        # Ensure new dimensions are multiples of mutiple_factor
        new_height = (new_height // mutiple_factor) * mutiple_factor
        new_width = (new_width // mutiple_factor) * mutiple_factor

        # if the final height and widht changed because of the multiple_factor, we need to set the scale too
        scale = new_height / height

        # Resize the image using the calculated dimensions
        resized_image = torch.nn.functional.interpolate(
            image, size=(new_height, new_width), mode="bilinear", align_corners=False
        )

        return resized_image, scale