optimum/amd/ryzenai/models/yolox/image_processing_yolox.py (160 lines of code) (raw):

# Copyright 2023 The HuggingFace Team. All rights reserved. # Licensed under the MIT License. from typing import Iterable, List, Optional, Tuple, Union import cv2 import numpy as np import torch from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( PaddingMode, flip_channel_order, pad, ) from transformers.image_utils import ( ChannelDimension, ImageInput, get_image_size, infer_channel_dimension_format, make_list_of_images, to_numpy_array, ) from transformers.utils import TensorType from ..detection_utils import non_max_suppression def postprocess(outputs, img_size, strides): grids = [] expanded_strides = [] device = strides.device dtype = strides.dtype outputs = [out.reshape(*out.shape[:2], -1).transpose(2, 1) for out in outputs] outputs = torch.cat(outputs, axis=1) outputs[..., 4:] = outputs[..., 4:].sigmoid() hsizes = [img_size[0] // stride for stride in strides] wsizes = [img_size[1] // stride for stride in strides] for hsize, wsize, stride in zip(hsizes, wsizes, strides): xv, yv = torch.meshgrid( torch.arange(wsize, device=device, dtype=dtype), torch.arange(hsize, device=device, dtype=dtype), indexing="xy", ) grid = torch.stack((xv, yv), 2).reshape(1, -1, 2) grids.append(grid) shape = grid.shape[:2] expanded_strides.append(torch.full((*shape, 1), stride, dtype=dtype, device=device)) grids = torch.cat(grids, 1) expanded_strides = torch.cat(expanded_strides, 1) outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * expanded_strides return outputs class YoloXImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, size=None, stride: List[int] = [8, 16, 32], **kwargs, ): size = size if size is not None else {"height": 640, "width": 640} super().__init__(**kwargs) self.size = size self.resample = cv2.INTER_LINEAR self.data_format = ChannelDimension.LAST self.stride = stride def resize( self, image: np.ndarray, size: Tuple[int, int], resample=cv2.INTER_LINEAR, ) -> np.ndarray: image = cv2.resize( image, (size[1], size[0]), interpolation=resample, ).astype(np.uint8) return image # Copied from transformers.models.detr.image_processing_detr.DetrImageProcessor._pad_image def pad( self, image: np.ndarray, output_size: Tuple[int, int], constant_values: Union[float, Iterable[float]] = 114, data_format: Optional[ChannelDimension] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: input_height, input_width = get_image_size(image, channel_dim=input_data_format) output_height, output_width = output_size pad_bottom = output_height - input_height pad_right = output_width - input_width padding = ((0, pad_bottom), (0, pad_right)) padded_image = pad( image, padding, mode=PaddingMode.CONSTANT, constant_values=constant_values, data_format=data_format, input_data_format=input_data_format, ) return padded_image def preprocess( self, images: ImageInput, return_tensors: Optional[Union[TensorType, str]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> BatchFeature: data_format = data_format if data_format is not None else self.data_format self.data_format = data_format images = make_list_of_images(images) # All transformations expect numpy arrays images = [to_numpy_array(image) for image in images] # We assume that all images have the same channel dimension format. target_sizes = [] padded_images = [] for image in images: input_data_format = infer_channel_dimension_format(images[0]) image = flip_channel_order(image, input_data_format=input_data_format) input_height, input_width = get_image_size(image, channel_dim=input_data_format) target_sizes.append(image.shape) ratio = min(self.size["height"] / input_height, self.size["width"] / input_width) size = (int(ratio * input_height), int(ratio * input_width)) resized_image = self.resize(image, size=size, resample=self.resample) padded_img = self.pad( resized_image, output_size=(self.size["height"], self.size["width"]), data_format=data_format, input_data_format=input_data_format, ) padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) padded_images.append(padded_img) data = {"pixel_values": padded_images, "target_sizes": target_sizes} encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs def post_process_object_detection( self, outputs, threshold: float = 0.1, nms_threshold: float = 0.45, target_sizes: Union[TensorType, List[Tuple]] = None, agnostic_nms=True, merge_nms=False, max_detections=1000, data_format: Union[str, ChannelDimension] = None, ): data_format = data_format if data_format is not None else self.data_format if merge_nms: raise ValueError("Merge NMS is not yet supported!") outputs = list(outputs.values()) if not isinstance(outputs[0], torch.Tensor): outputs = [torch.Tensor(out) for out in outputs] if data_format == ChannelDimension.LAST: outputs = [torch.permute(out, (0, 3, 1, 2)) for out in outputs] predictions = postprocess(outputs, (self.size["height"], self.size["width"]), torch.Tensor(self.stride)) dets = non_max_suppression( predictions, threshold, nms_threshold, agnostic=agnostic_nms, max_detections=max_detections, ) results = [] for i, det in enumerate(dets): if target_sizes is not None: input_height, input_width, _ = target_sizes[i] ratio = min(self.size["height"] / input_height, self.size["width"] / input_width) det[:, :4] /= ratio results.append({"scores": det[:, 4], "labels": det[:, 5], "boxes": det[:, :4]}) return results