def _get_targets_single()

in models/vision/detection/awsdet/core/anchor/anchor_target.py [0:0]


    def _get_targets_single(self,
                            flat_anchors,
                            valid_flags,
                            gt_bboxes,
                            gt_labels,
                            img_shape,
                            unmap_outputs=True):
        """Compute regression and classification targets for anchors in
            a single image.
        Args:
            flat_anchors: Multi-level anchors of the image, which are
                concatenated into a single tensor of shape (num_anchors ,4)
            valid_flags: Multi level valid flags of the image,
                which are concatenated into a single tensor of
                    shape (num_anchors,).
            gt_bboxes: Ground truth bboxes of the image,
                shape (num_gts, 4).
            gt_labels: Ground truth labels of each box, shape (num_gts,). If not None then assign
                these labels to positive anchors
            img_shape: shape of the image (unpadded)
            unmap_outputs: Whether to map outputs back to the original
                set of anchors.
        Returns:
            target_matches: (num_anchors,) 1 = positive anchor, -1 = negative anchor, 0 = neutral anchor 
            bboxes_targets: (num_anchors, 4)
            bbox_inside_weights: (num_anchors, 4)
            bbox_outside_weights: (num_anchors, 4)
        """
        gt_bboxes, _ = trim_zeros(gt_bboxes)
        # 1. Filter anchors to valid area
        inside_flags = self._anchor_inside_flags(flat_anchors, valid_flags, img_shape)
        # TODO: handle scenario where all flags are False
        anchors = tf.boolean_mask(flat_anchors, inside_flags)
        num_anchors = tf.shape(flat_anchors)[0]

        # 2. Find IoUs
        num_valid_anchors = tf.shape(anchors)[0]
        target_matches = -tf.ones((num_valid_anchors,), tf.int32)
        overlaps = geometry.compute_overlaps(anchors, gt_bboxes)
        # a. best GT index for each anchor
        argmax_overlaps = tf.argmax(overlaps, axis=1, output_type=tf.int32)
        max_overlaps = tf.reduce_max(overlaps, axis=1)
        # b. best anchor index for each GT (non deterministic in case of ties)
        gt_argmax_overlaps = tf.argmax(overlaps, axis=0, output_type=tf.int32)

        # 3. Assign labels
        bg_cond = tf.math.less(max_overlaps, self.neg_iou_thr)
        fg_cond = tf.math.greater_equal(max_overlaps, self.pos_iou_thr)
        target_matches = tf.where(bg_cond, tf.zeros_like(target_matches), target_matches)
        gt_indices = tf.expand_dims(gt_argmax_overlaps, axis=1)
        if gt_labels is None: # RPN will have gt labels set to None
            gt_labels = tf.ones(tf.shape(gt_indices)[0], dtype=tf.int32)
            target_matches = tf.tensor_scatter_nd_update(target_matches, gt_indices, gt_labels) 
            target_matches = tf.where(fg_cond, tf.ones_like(target_matches), target_matches)
        else:
            gt_labels = gt_labels[:tf.shape(gt_indices)[0]] # get rid of padded labels (-1)
            target_matches = tf.where(fg_cond, tf.gather(gt_labels, argmax_overlaps), target_matches)

        if self.allow_low_quality_matches:
           # we allow lesser overlap anchors, generally in earlier proposal stages
           # a. find highest overlap value for each GT (note this max may be lower than pos iou thres)
           gt_max_overlaps = tf.reduce_max(overlaps, axis=0)
           # b. assign all anchors that are unassigned but have overlap equal to max overlap for that GT
           low_quality_matches = tf.math.equal(tf.expand_dims(gt_max_overlaps, 0), overlaps)
           unassigned_anchors = tf.equal(target_matches, -1)
           unassigned_matches = tf.where(tf.math.logical_and(tf.expand_dims(unassigned_anchors, -1), low_quality_matches))
           unassigned_indices = tf.expand_dims(tf.cast(unassigned_matches[:,0], tf.int64), -1)
           unassigned_labels = tf.gather(gt_labels, unassigned_matches[:,1]) 
           target_matches = tf.tensor_scatter_nd_update(target_matches,unassigned_indices, unassigned_labels)


        # 4. Sample selected if we have greater number of candidates than needed by 
        #    config (only if num_samples > 0, e.g. in two stage)
        if self.num_samples > 0:
            fg_inds = tf.where(tf.equal(target_matches, 1))[:, 0]
            max_pos_samples = tf.cast(self.positive_fraction * self.num_samples, tf.int32)
            if tf.greater(tf.size(fg_inds), max_pos_samples):
                fg_inds = tf.random.shuffle(fg_inds)
                disable_inds = fg_inds[max_pos_samples:]
                fg_inds = fg_inds[:max_pos_samples]
                disable_inds = tf.expand_dims(disable_inds, axis=1)
                disable_labels = -tf.ones(tf.shape(disable_inds)[0], dtype=tf.int32)
                target_matches = tf.tensor_scatter_nd_update(target_matches, disable_inds, disable_labels)
            num_fg = tf.reduce_sum(tf.cast(tf.equal(target_matches, 1), tf.int32))
            num_bg = self.num_samples - num_fg 
            bg_inds = tf.where(tf.equal(target_matches, 0))[:, 0]
            if tf.greater(tf.size(bg_inds), num_bg):
                bg_inds = tf.random.shuffle(bg_inds)
                disable_inds = bg_inds[num_bg:]
                bg_inds = bg_inds[:num_bg]
                disable_inds = tf.expand_dims(disable_inds, axis=1)
                disable_labels = -tf.ones(tf.shape(disable_inds)[0], dtype=tf.int32)
                target_matches = tf.tensor_scatter_nd_update(target_matches, disable_inds, disable_labels)

        # 5. Calculate deltas for chosen targets based on GT (encode)
        bboxes_targets = transforms.bbox2delta(anchors, tf.gather(gt_bboxes, argmax_overlaps),
                                                       target_means=self.target_means,
                                                       target_stds=self.target_stds)

        # Regression weights
        bbox_inside_weights = tf.zeros((tf.shape(anchors)[0], 4), dtype=tf.float32)
        # match_indices = tf.where(tf.equal(target_matches, 1))
        match_indices = tf.where(tf.math.greater(target_matches, 0))

        updates = tf.ones([tf.shape(match_indices)[0], 4], bbox_inside_weights.dtype)
        bbox_inside_weights = tf.tensor_scatter_nd_update(bbox_inside_weights,
                                                match_indices, updates)

        bbox_outside_weights = tf.zeros((tf.shape(anchors)[0], 4), dtype=tf.float32)
        if self.num_samples > 0:
            num_examples = tf.reduce_sum(tf.cast(target_matches >= 0, bbox_outside_weights.dtype))
        else:
            num_examples = tf.reduce_sum(tf.cast(target_matches > 0, bbox_outside_weights.dtype))
            num_fg = num_examples
            num_bg = 0 # in RetinaNet we only care about positive anchors
        out_indices = tf.where(target_matches >= 0)
        updates = tf.ones([tf.shape(out_indices)[0], 4], bbox_outside_weights.dtype) * 1.0 / num_examples
        bbox_outside_weights = tf.tensor_scatter_nd_update(bbox_outside_weights,
                                                out_indices, updates)
        # for everything that is not selected fill with `fill` value
        selected_anchor_idx = tf.where(inside_flags)[:, 0]
        return (tf.stop_gradient(_unmap(target_matches, num_anchors, selected_anchor_idx, -1)),
               tf.stop_gradient(_unmap(bboxes_targets, num_anchors, selected_anchor_idx, 0)),
               tf.stop_gradient(_unmap(bbox_inside_weights, num_anchors, selected_anchor_idx, 0)),
               tf.stop_gradient(_unmap(bbox_outside_weights, num_anchors, selected_anchor_idx, 0)),
               num_fg, num_bg)