def postprocess()

in optimum/amd/ryzenai/models/yolov3/image_processing_yolov3.py [0:0]


def postprocess(inputs, anchors, num_classes=80, stride=[8, 16, 32], shapes=[80, 40, 20]):
    nl = len(anchors)
    no = num_classes + 5

    outputs = []
    for i in range(nl):
        bs, _, ny, nx = inputs[i].shape
        grid, anchor_grid = make_grid(anchors[2 - i], nx, ny)

        inputs[i] = inputs[i].view(bs, nl, no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

        xy = (torch.sigmoid(inputs[i][..., 0:2]) + grid) * stride[2 - i]
        wh = (torch.exp(inputs[i][..., 2:4])) * anchor_grid

        conf = torch.sigmoid_(inputs[i][..., 4:])
        y = torch.cat((xy, wh, conf), -1)
        outputs.append(y.view(bs, -1, no))
    return torch.cat(outputs, 1)