optimum/amd/ryzenai/models/yolov8/image_processing_yolov8.py (138 lines of code) (raw):

# Copyright 2023 The HuggingFace Team. All rights reserved. # Licensed under the MIT License. from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch from transformers.image_processing_utils import BaseImageProcessor, BatchFeature from transformers.image_transforms import ( rescale, ) from transformers.image_utils import ( ChannelDimension, ImageInput, infer_channel_dimension_format, make_list_of_images, to_numpy_array, ) from transformers.utils import TensorType from ..detection_utils import non_max_suppression, scale_coords from ..image_transforms import letterbox_image def make_anchor(input, ny, nx, grid_cell_offset=0.5): t, d = input.dtype, input.device y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) y, x = torch.meshgrid(y + grid_cell_offset, x + grid_cell_offset, indexing="ij") return torch.stack((x, y), -1).view(-1, 2) def dfl(x, c1=16): b, c, a = x.shape weights = torch.arange(c1, dtype=torch.float).view(1, c1, 1, 1) inter = x.view(b, 4, c1, a).transpose(2, 1).softmax(1) return (inter * weights).sum(dim=1, keepdim=True).view(b, 4, a) def postprocess(inputs, reg_max=16, num_classes=80, stride=[8, 16, 32]): nl = len(stride) no = num_classes + reg_max * 4 box, cls = torch.cat([xi.view(inputs[0].shape[0], no, -1) for xi in inputs], 2).split( (reg_max * 4, num_classes), 1 ) distance = dfl(box).chunk(2, 1) anchors, strides = [], [] for i in range(nl): _, _, ny, nx = inputs[i].shape anchor = make_anchor(inputs[i], ny, nx) ustride = torch.full((ny * nx, 1), stride[i], dtype=inputs[i].dtype, device=inputs[i].device) anchors.append(anchor) strides.append(ustride) anchors = torch.cat(anchors).transpose(0, 1).unsqueeze(0) strides = torch.cat(strides).transpose(0, 1) distance = dfl(box).chunk(2, 1) x1_y1 = anchors - distance[0] x2_y2 = anchors + distance[1] dbox = torch.cat(((x2_y2 + x1_y1) / 2, x2_y2 - x1_y1), dim=1) * strides y = torch.cat((dbox, cls.sigmoid()), 1) return y class YoloV8ImageProcessor(BaseImageProcessor): model_input_names = ["pixel_values"] def __init__( self, size: Dict[str, int] = None, rescale_factor: Union[int, float] = 1 / 255.0, num_classes: int = 80, stride: List[int] = [8, 16, 32], reg_max: int = 16, **kwargs, ): size = size if size is not None else {"height": 640, "width": 640} super().__init__(**kwargs) self.size = size self.data_format = ChannelDimension.LAST self.rescale_factor = rescale_factor self.num_classes = num_classes self.stride = stride self.reg_max = reg_max def preprocess( self, images: ImageInput, return_tensors: Optional[Union[TensorType, str]] = None, data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> BatchFeature: data_format = data_format if data_format is not None else self.data_format self.data_format = data_format images = make_list_of_images(images) # All transformations expect numpy arrays images = [to_numpy_array(image) for image in images] # We assume that all images have the same channel dimension format. input_data_format = infer_channel_dimension_format(images[0]) preprocessed_images = [] target_sizes = [] for image in images: if input_data_format == ChannelDimension.FIRST: image = image.transpose((2, 0, 1)) input_data_format = ChannelDimension.LAST target_sizes.append(image.shape) image = letterbox_image( image, [self.size["height"], self.size["width"]], input_data_format=input_data_format, ) image = np.ascontiguousarray(image, dtype=np.float32) image = rescale( image=image, scale=self.rescale_factor, data_format=data_format, input_data_format=input_data_format ) preprocessed_images.append(image) data = {"pixel_values": preprocessed_images, "target_sizes": target_sizes} encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs def post_process_object_detection( self, outputs, threshold: float = 0.25, nms_threshold: float = 0.7, target_sizes: Union[TensorType, List[Tuple]] = None, agnostic_nms=False, merge_nms=False, max_detections=1000, data_format: Union[str, ChannelDimension] = None, ): data_format = data_format if data_format is not None else self.data_format if merge_nms: raise ValueError("Merge NMS is not yet supported!") outputs = list(outputs.values()) if not isinstance(outputs[0], torch.Tensor): outputs = [torch.tensor(out) for out in outputs] if data_format == ChannelDimension.LAST: outputs = [torch.permute(out, (0, 3, 1, 2)) for out in outputs] predictions = postprocess(outputs, num_classes=self.num_classes, reg_max=self.reg_max, stride=self.stride) dets = non_max_suppression( predictions.transpose(2, 1), threshold, nms_threshold, agnostic=agnostic_nms, class_conf_start_index=4, max_detections=max_detections, ) results = [] for i, det in enumerate(dets): if target_sizes is not None: det[:, :4] = scale_coords( (self.size["height"], self.size["width"]), target_sizes[i], det[:, :4], ).round() results.append({"scores": det[:, 4], "labels": det[:, 5], "boxes": det[:, :4]}) return results