optimum/amd/ryzenai/models/yolox/image_processing_yolox.py (160 lines of code) (raw):
# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the MIT License.
from typing import Iterable, List, Optional, Tuple, Union
import cv2
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import (
PaddingMode,
flip_channel_order,
pad,
)
from transformers.image_utils import (
ChannelDimension,
ImageInput,
get_image_size,
infer_channel_dimension_format,
make_list_of_images,
to_numpy_array,
)
from transformers.utils import TensorType
from ..detection_utils import non_max_suppression
def postprocess(outputs, img_size, strides):
grids = []
expanded_strides = []
device = strides.device
dtype = strides.dtype
outputs = [out.reshape(*out.shape[:2], -1).transpose(2, 1) for out in outputs]
outputs = torch.cat(outputs, axis=1)
outputs[..., 4:] = outputs[..., 4:].sigmoid()
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = torch.meshgrid(
torch.arange(wsize, device=device, dtype=dtype),
torch.arange(hsize, device=device, dtype=dtype),
indexing="xy",
)
grid = torch.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(torch.full((*shape, 1), stride, dtype=dtype, device=device))
grids = torch.cat(grids, 1)
expanded_strides = torch.cat(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * expanded_strides
return outputs
class YoloXImageProcessor(BaseImageProcessor):
model_input_names = ["pixel_values"]
def __init__(
self,
size=None,
stride: List[int] = [8, 16, 32],
**kwargs,
):
size = size if size is not None else {"height": 640, "width": 640}
super().__init__(**kwargs)
self.size = size
self.resample = cv2.INTER_LINEAR
self.data_format = ChannelDimension.LAST
self.stride = stride
def resize(
self,
image: np.ndarray,
size: Tuple[int, int],
resample=cv2.INTER_LINEAR,
) -> np.ndarray:
image = cv2.resize(
image,
(size[1], size[0]),
interpolation=resample,
).astype(np.uint8)
return image
# Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image
def pad(
self,
image: np.ndarray,
output_size: Tuple[int, int],
constant_values: Union[float, Iterable[float]] = 114,
data_format: Optional[ChannelDimension] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
output_height, output_width = output_size
pad_bottom = output_height - input_height
pad_right = output_width - input_width
padding = ((0, pad_bottom), (0, pad_right))
padded_image = pad(
image,
padding,
mode=PaddingMode.CONSTANT,
constant_values=constant_values,
data_format=data_format,
input_data_format=input_data_format,
)
return padded_image
def preprocess(
self,
images: ImageInput,
return_tensors: Optional[Union[TensorType, str]] = None,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> BatchFeature:
data_format = data_format if data_format is not None else self.data_format
self.data_format = data_format
images = make_list_of_images(images)
# All transformations expect numpy arrays
images = [to_numpy_array(image) for image in images]
# We assume that all images have the same channel dimension format.
target_sizes = []
padded_images = []
for image in images:
input_data_format = infer_channel_dimension_format(images[0])
image = flip_channel_order(image, input_data_format=input_data_format)
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
target_sizes.append(image.shape)
ratio = min(self.size["height"] / input_height, self.size["width"] / input_width)
size = (int(ratio * input_height), int(ratio * input_width))
resized_image = self.resize(image, size=size, resample=self.resample)
padded_img = self.pad(
resized_image,
output_size=(self.size["height"], self.size["width"]),
data_format=data_format,
input_data_format=input_data_format,
)
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
padded_images.append(padded_img)
data = {"pixel_values": padded_images, "target_sizes": target_sizes}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs
def post_process_object_detection(
self,
outputs,
threshold: float = 0.1,
nms_threshold: float = 0.45,
target_sizes: Union[TensorType, List[Tuple]] = None,
agnostic_nms=True,
merge_nms=False,
max_detections=1000,
data_format: Union[str, ChannelDimension] = None,
):
data_format = data_format if data_format is not None else self.data_format
if merge_nms:
raise ValueError("Merge NMS is not yet supported!")
outputs = list(outputs.values())
if not isinstance(outputs[0], torch.Tensor):
outputs = [torch.Tensor(out) for out in outputs]
if data_format == ChannelDimension.LAST:
outputs = [torch.permute(out, (0, 3, 1, 2)) for out in outputs]
predictions = postprocess(outputs, (self.size["height"], self.size["width"]), torch.Tensor(self.stride))
dets = non_max_suppression(
predictions,
threshold,
nms_threshold,
agnostic=agnostic_nms,
max_detections=max_detections,
)
results = []
for i, det in enumerate(dets):
if target_sizes is not None:
input_height, input_width, _ = target_sizes[i]
ratio = min(self.size["height"] / input_height, self.size["width"] / input_width)
det[:, :4] /= ratio
results.append({"scores": det[:, 4], "labels": det[:, 5], "boxes": det[:, :4]})
return results