def generate_pos_hn_example()

in siammot/modelling/track_head/EMM/target_sampler.py [0:0]


    def generate_pos_hn_example(self, proposals, gts):
        """
        Generate positive and hard negative training examples
        """
        src_gts = copy.deepcopy(gts)
        tar_gts = self.track_utils.swap_pairs(copy.deepcopy(gts))

        track_source = []
        track_target = []
        track_pair = []
        for src_gt, tar_gt, proposal in zip(src_gts, tar_gts, proposals):
            pos_src_boxlist, pos_pair_boxlist, pos_tar_boxlist = ([] for _ in range(3))
            hn_src_boxlist, hn_pair_boxlist, hn_tar_boxlist = ([] for _ in range(3))

            proposal_h = proposal.bbox[:, 3] - proposal.bbox[:, 1]
            src_h = src_gt.bbox[:, 3] - src_gt.bbox[:, 1]
            src_ids = src_gt.get_field('ids')
            tar_ids = tar_gt.get_field('ids')

            for i, src_id in enumerate(src_ids):
                _src_box = src_gt[src_ids == src_id]
                _tar_box = self.get_target_box(tar_gt, tar_ids == src_id)

                pos_src_boxes = self.generate_pos(_src_box, proposal)
                pos_pair_boxes = copy.deepcopy(pos_src_boxes)
                pos_tar_boxes = self.duplicate_boxlist(_tar_box, len(pos_src_boxes))

                hn_pair_boxes = self.generate_hn_pair(_src_box, proposal, src_h[i], proposal_h)
                hn_src_boxes = self.duplicate_boxlist(_src_box, len(hn_pair_boxes))
                hn_tar_boxes = self.duplicate_boxlist(_tar_box, len(hn_pair_boxes))

                pos_src_boxlist.append(pos_src_boxes)
                pos_pair_boxlist.append(pos_pair_boxes)
                pos_tar_boxlist.append(pos_tar_boxes)

                hn_src_boxlist.append(hn_src_boxes)
                hn_pair_boxlist.append(hn_pair_boxes)
                hn_tar_boxlist.append(hn_tar_boxes)

            num_pos = int(self.proposals_per_image * self.pos_ratio)
            num_hn = int(self.proposals_per_image * self.hn_ratio)
            sampled_pos = self.sample_examples(pos_src_boxlist, pos_pair_boxlist,
                                               pos_tar_boxlist, num_pos)
            sampled_hn = self.sample_examples(hn_src_boxlist, hn_pair_boxlist,
                                              hn_tar_boxlist, num_hn)
            track_source.append(cat_boxlist([sampled_pos[0], sampled_hn[0]]))
            track_pair.append(cat_boxlist([sampled_pos[1], sampled_hn[1]]))
            track_target.append(cat_boxlist([sampled_pos[2], sampled_hn[2]]))

        return track_source, track_pair, track_target