optimum/amd/ryzenai/models/semanticfpn/image_processing_semantic_fpn.py (121 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the MIT License.
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import (
get_resize_output_image_size,
rescale,
resize,
)
from transformers.image_utils import (
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
)
from transformers.utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, TensorType
class SemanticFPNImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
size: Dict[str, int] = None,
rescale_factor: Union[int, float] = 1 / 255.0,
image_mean: Union[float, List[float]] = None,
image_std: Union[float, List[float]] = None,
**kwargs,
):
size = size if size is not None else {"height": 256, "width": 512}
super().__init__(**kwargs)
self.size = size
self.data_format = ChannelDimension.LAST
self.rescale_factor = rescale_factor
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BILINEAR,
data_format=None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
):
size = get_size_dict(size, default_to_square=False)
if "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError(
"Size must contain 'height' and 'width' keys or 'shortest_edge' and 'longest_edge' keys. Got"
f" {size.keys()}."
)
size = get_resize_output_image_size(
input_image=image,
size=size,
default_to_square=False,
input_data_format=input_data_format,
)
image = resize(
image, size=size, resample=resample, data_format=data_format, input_data_format=input_data_format, **kwargs
)
return image
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
data_format = data_format if data_format is not None else self.data_format
self.data_format = data_format
images = make_list_of_images(images)
# All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images]
# We assume that all images have the same channel dimension format.
preprocessed_images = []
target_sizes = []
for image in images:
input_data_format = infer_channel_dimension_format(images[0])
if input_data_format == ChannelDimension.FIRST:
image = image.transpose((2, 0, 1))
input_data_format = ChannelDimension.LAST
target_sizes.append(tuple(image.shape[:2]))
image = self.resize(image, size=self.size, input_data_format=input_data_format)
image = rescale(
image=image, scale=self.rescale_factor, data_format=data_format, input_data_format=input_data_format
)
image = self.normalize(
image,
mean=self.image_mean,
std=self.image_std,
data_format=data_format,
input_data_format=input_data_format,
)
image = np.ascontiguousarray(image, dtype=np.float32)
preprocessed_images.append(image)
data = {"pixel_values": preprocessed_images, "target_sizes": target_sizes}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_semantic_segmentation(
self,
outputs,
target_sizes: Union[TensorType, List[Tuple]] = None,
):
outputs = list(outputs.values())
if not isinstance(outputs[0], torch.Tensor):
outputs = [torch.tensor(out) for out in outputs]
if isinstance(outputs, (tuple, list)):
outputs = outputs[0]
if self.data_format == ChannelDimension.LAST:
outputs = torch.permute(outputs, (0, 3, 1, 2))
if target_sizes is not None:
semantic_segmentation = []
for idx in range(len(target_sizes)):
resized_logits = torch.nn.functional.interpolate(
outputs[idx].unsqueeze(dim=0), size=tuple(target_sizes[idx]), mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = outputs.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]
return semantic_segmentation