def postprocess()

in optimum/amd/ryzenai/models/yolov8/image_processing_yolov8.py [0:0]


def postprocess(inputs, reg_max=16, num_classes=80, stride=[8, 16, 32]):
    nl = len(stride)
    no = num_classes + reg_max * 4

    box, cls = torch.cat([xi.view(inputs[0].shape[0], no, -1) for xi in inputs], 2).split(
        (reg_max * 4, num_classes), 1
    )
    distance = dfl(box).chunk(2, 1)

    anchors, strides = [], []
    for i in range(nl):
        _, _, ny, nx = inputs[i].shape
        anchor = make_anchor(inputs[i], ny, nx)
        ustride = torch.full((ny * nx, 1), stride[i], dtype=inputs[i].dtype, device=inputs[i].device)

        anchors.append(anchor)
        strides.append(ustride)

    anchors = torch.cat(anchors).transpose(0, 1).unsqueeze(0)
    strides = torch.cat(strides).transpose(0, 1)
    distance = dfl(box).chunk(2, 1)

    x1_y1 = anchors - distance[0]
    x2_y2 = anchors + distance[1]

    dbox = torch.cat(((x2_y2 + x1_y1) / 2, x2_y2 - x1_y1), dim=1) * strides

    y = torch.cat((dbox, cls.sigmoid()), 1)

    return y