in src/image_gen_aux/image_processor.py [0:0]
def scale_image(image: torch.Tensor, scale: float, mutiple_factor: int = 8) -> torch.Tensor:
"""
Scales an image while maintaining aspect ratio and ensuring dimensions are multiples of `multiple_factor`.
Args:
image (`torch.Tensor`): The input image tensor of shape (batch, channels, height, width).
scale (`float`): The scaling factor applied to the image dimensions.
multiple_factor (`int`, *optional*, defaults to 8): The factor by which the new dimensions should be divisible.
Returns:
`torch.Tensor`: The scaled image tensor.
"""
if scale == 1.0:
return image, scale
_batch, _channels, height, width = image.shape
# Calculate new dimensions while maintaining aspect ratio
new_height = int(height * scale)
new_width = int(width * scale)
# Ensure new dimensions are multiples of mutiple_factor
new_height = (new_height // mutiple_factor) * mutiple_factor
new_width = (new_width // mutiple_factor) * mutiple_factor
# if the final height and widht changed because of the multiple_factor, we need to set the scale too
scale = new_height / height
# Resize the image using the calculated dimensions
resized_image = torch.nn.functional.interpolate(
image, size=(new_height, new_width), mode="bilinear", align_corners=False
)
return resized_image, scale