def postprocess()

in optimum/amd/ryzenai/models/yolox/image_processing_yolox.py [0:0]


def postprocess(outputs, img_size, strides):
    grids = []
    expanded_strides = []
    device = strides.device
    dtype = strides.dtype

    outputs = [out.reshape(*out.shape[:2], -1).transpose(2, 1) for out in outputs]
    outputs = torch.cat(outputs, axis=1)
    outputs[..., 4:] = outputs[..., 4:].sigmoid()

    hsizes = [img_size[0] // stride for stride in strides]
    wsizes = [img_size[1] // stride for stride in strides]
    for hsize, wsize, stride in zip(hsizes, wsizes, strides):
        xv, yv = torch.meshgrid(
            torch.arange(wsize, device=device, dtype=dtype),
            torch.arange(hsize, device=device, dtype=dtype),
            indexing="xy",
        )
        grid = torch.stack((xv, yv), 2).reshape(1, -1, 2)
        grids.append(grid)
        shape = grid.shape[:2]
        expanded_strides.append(torch.full((*shape, 1), stride, dtype=dtype, device=device))

    grids = torch.cat(grids, 1)
    expanded_strides = torch.cat(expanded_strides, 1)
    outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
    outputs[..., 2:4] = torch.exp(outputs[..., 2:4]) * expanded_strides

    return outputs