def assign_interaction_pairs_parts()

in phosa/global_opt.py [0:0]


    def assign_interaction_pairs_parts(self, verts_person, verts_object):
        """
        Assigns pairs of person parts and objects pairs that are interacting.

        This is computed separately from the loss function because there are potential
        speed improvements to re-using stale interaction pairs across multiple
        iterations (although not currently being done).

        A part of a person and a part of an object are interacting if the 3D bounding
        boxes overlap:
            * Check if X-Y bounding boxes overlap by projecting to image plane (with
              some expansion defined by BBOX_EXPANSION_PARTS), and
            * Check if Z overlaps by thresholding distance.

        Args:
            verts_person (N_p x V_p x 3).
            verts_object (N_o x V_o x 3).

        Returns:
            interaction_pairs_parts:
                List[Tuple(person_index, person_part, object_index, object_part)]
        """
        with torch.no_grad():
            bboxes_person = [
                project_bbox(v, self.renderer, self.labels_person, self.expansion_parts)
                for v in verts_person
            ]
            bboxes_object = [
                project_bbox(v, self.renderer, self.labels_object, self.expansion_parts)
                for v in verts_object
            ]
            self.interaction_pairs_parts = []
            for i_p, i_o in itertools.product(
                range(len(verts_person)), range(len(verts_object))
            ):
                for part_object in self.interaction_map_parts.keys():
                    for part_person in self.interaction_map_parts[part_object]:
                        bbox_object = bboxes_object[i_o][part_object]
                        bbox_person = bboxes_person[i_p][part_person]
                        is_overlapping = check_overlap(bbox_object, bbox_person)
                        z_dist = compute_dist_z(
                            verts_object[i_o][self.labels_object[part_object]],
                            verts_person[i_p][self.labels_person[part_person]],
                        )
                        if is_overlapping and z_dist < self.thresh:
                            self.interaction_pairs_parts.append(
                                (i_p, part_person, i_o, part_object)
                            )
            return self.interaction_pairs_parts