in optimum/amd/ryzenai/models/yolox/image_processing_yolox.py [0:0]
def postprocess(outputs, img_size, strides):
grids = []
expanded_strides = []
device = strides.device
dtype = strides.dtype
outputs = [out.reshape(*out.shape[:2], -1).transpose(2, 1) for out in outputs]
outputs = torch.cat(outputs, axis=1)
outputs[..., 4:] = outputs[..., 4:].sigmoid()
hsizes = [img_size[0] // stride for stride in strides]
wsizes = [img_size[1] // stride for stride in strides]
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = torch.meshgrid(
torch.arange(wsize, device=device, dtype=dtype),
torch.arange(hsize, device=device, dtype=dtype),
indexing="xy",
)
grid = torch.stack((xv, yv), 2).reshape(1, -1, 2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(torch.full((*shape, 1), stride, dtype=dtype, device=device))
grids = torch.cat(grids, 1)
expanded_strides = torch.cat(expanded_strides, 1)
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * expanded_strides
return outputs