def assign_human_masks()

in phosa/global_opt.py [0:0]


    def assign_human_masks(self, masks_human=None, min_overlap=0.5):
        """
        Uses a greedy matching algorithm to assign masks to human instances. The
        assigned human masks are used to compute the ordinal depth loss.

        If the human predictor uses the same instances as the segmentation algorithm,
        then this greedy assignment is unnecessary as the human instances will already
        have corresponding masks.

        1. Compute IOU between all human silhouettes and human masks
        2. Sort IOUs
        3. Assign people to masks in order, skipping people and masks that
            have already been assigned.

        Args:
            masks_human: Human bitmask tensor from instance segmentation algorithm.
            min_overlap (float): Minimum IOU threshold to assign the human mask to a
                human instance.

        Returns:
            N_h x
        """
        f = self.faces_person
        verts_person = self.get_verts_person()
        if masks_human is None:
            return torch.zeros(verts_person.shape[0], IMAGE_SIZE, IMAGE_SIZE).cuda()
        person_silhouettes = torch.cat(
            [self.renderer(v.unsqueeze(0), f, mode="silhouettes") for v in verts_person]
        ).bool()

        intersection = masks_human.unsqueeze(0) & person_silhouettes.unsqueeze(1)
        union = masks_human.unsqueeze(0) | person_silhouettes.unsqueeze(1)

        iou = intersection.sum(dim=(2, 3)).float() / union.sum(dim=(2, 3)).float()
        iou = iou.cpu().numpy()
        # https://stackoverflow.com/questions/30577375
        best_indices = np.dstack(np.unravel_index(np.argsort(-iou.ravel()), iou.shape))[
            0
        ]
        human_indices_used = set()
        mask_indices_used = set()
        # If no match found, mask will just be empty, incurring 0 loss for depth.
        human_masks = torch.zeros(verts_person.shape[0], IMAGE_SIZE, IMAGE_SIZE).bool()
        for human_index, mask_index in best_indices:
            if human_index in human_indices_used:
                continue
            if mask_index in mask_indices_used:
                continue
            if iou[human_index, mask_index] < min_overlap:
                break
            human_masks[human_index] = masks_human[mask_index]
            human_indices_used.add(human_index)
            mask_indices_used.add(mask_index)
        return human_masks.cuda()