in optimum/amd/ryzenai/models/yolov8/image_processing_yolov8.py [0:0]
def postprocess(inputs, reg_max=16, num_classes=80, stride=[8, 16, 32]):
nl = len(stride)
no = num_classes + reg_max * 4
box, cls = torch.cat([xi.view(inputs[0].shape[0], no, -1) for xi in inputs], 2).split(
(reg_max * 4, num_classes), 1
)
distance = dfl(box).chunk(2, 1)
anchors, strides = [], []
for i in range(nl):
_, _, ny, nx = inputs[i].shape
anchor = make_anchor(inputs[i], ny, nx)
ustride = torch.full((ny * nx, 1), stride[i], dtype=inputs[i].dtype, device=inputs[i].device)
anchors.append(anchor)
strides.append(ustride)
anchors = torch.cat(anchors).transpose(0, 1).unsqueeze(0)
strides = torch.cat(strides).transpose(0, 1)
distance = dfl(box).chunk(2, 1)
x1_y1 = anchors - distance[0]
x2_y2 = anchors + distance[1]
dbox = torch.cat(((x2_y2 + x1_y1) / 2, x2_y2 - x1_y1), dim=1) * strides
y = torch.cat((dbox, cls.sigmoid()), 1)
return y