in models/vision/detection/awsdet/core/anchor/anchor_target.py [0:0]
def _get_targets_single(self,
flat_anchors,
valid_flags,
gt_bboxes,
gt_labels,
img_shape,
unmap_outputs=True):
"""Compute regression and classification targets for anchors in
a single image.
Args:
flat_anchors: Multi-level anchors of the image, which are
concatenated into a single tensor of shape (num_anchors ,4)
valid_flags: Multi level valid flags of the image,
which are concatenated into a single tensor of
shape (num_anchors,).
gt_bboxes: Ground truth bboxes of the image,
shape (num_gts, 4).
gt_labels: Ground truth labels of each box, shape (num_gts,). If not None then assign
these labels to positive anchors
img_shape: shape of the image (unpadded)
unmap_outputs: Whether to map outputs back to the original
set of anchors.
Returns:
target_matches: (num_anchors,) 1 = positive anchor, -1 = negative anchor, 0 = neutral anchor
bboxes_targets: (num_anchors, 4)
bbox_inside_weights: (num_anchors, 4)
bbox_outside_weights: (num_anchors, 4)
"""
gt_bboxes, _ = trim_zeros(gt_bboxes)
# 1. Filter anchors to valid area
inside_flags = self._anchor_inside_flags(flat_anchors, valid_flags, img_shape)
# TODO: handle scenario where all flags are False
anchors = tf.boolean_mask(flat_anchors, inside_flags)
num_anchors = tf.shape(flat_anchors)[0]
# 2. Find IoUs
num_valid_anchors = tf.shape(anchors)[0]
target_matches = -tf.ones((num_valid_anchors,), tf.int32)
overlaps = geometry.compute_overlaps(anchors, gt_bboxes)
# a. best GT index for each anchor
argmax_overlaps = tf.argmax(overlaps, axis=1, output_type=tf.int32)
max_overlaps = tf.reduce_max(overlaps, axis=1)
# b. best anchor index for each GT (non deterministic in case of ties)
gt_argmax_overlaps = tf.argmax(overlaps, axis=0, output_type=tf.int32)
# 3. Assign labels
bg_cond = tf.math.less(max_overlaps, self.neg_iou_thr)
fg_cond = tf.math.greater_equal(max_overlaps, self.pos_iou_thr)
target_matches = tf.where(bg_cond, tf.zeros_like(target_matches), target_matches)
gt_indices = tf.expand_dims(gt_argmax_overlaps, axis=1)
if gt_labels is None: # RPN will have gt labels set to None
gt_labels = tf.ones(tf.shape(gt_indices)[0], dtype=tf.int32)
target_matches = tf.tensor_scatter_nd_update(target_matches, gt_indices, gt_labels)
target_matches = tf.where(fg_cond, tf.ones_like(target_matches), target_matches)
else:
gt_labels = gt_labels[:tf.shape(gt_indices)[0]] # get rid of padded labels (-1)
target_matches = tf.where(fg_cond, tf.gather(gt_labels, argmax_overlaps), target_matches)
if self.allow_low_quality_matches:
# we allow lesser overlap anchors, generally in earlier proposal stages
# a. find highest overlap value for each GT (note this max may be lower than pos iou thres)
gt_max_overlaps = tf.reduce_max(overlaps, axis=0)
# b. assign all anchors that are unassigned but have overlap equal to max overlap for that GT
low_quality_matches = tf.math.equal(tf.expand_dims(gt_max_overlaps, 0), overlaps)
unassigned_anchors = tf.equal(target_matches, -1)
unassigned_matches = tf.where(tf.math.logical_and(tf.expand_dims(unassigned_anchors, -1), low_quality_matches))
unassigned_indices = tf.expand_dims(tf.cast(unassigned_matches[:,0], tf.int64), -1)
unassigned_labels = tf.gather(gt_labels, unassigned_matches[:,1])
target_matches = tf.tensor_scatter_nd_update(target_matches,unassigned_indices, unassigned_labels)
# 4. Sample selected if we have greater number of candidates than needed by
# config (only if num_samples > 0, e.g. in two stage)
if self.num_samples > 0:
fg_inds = tf.where(tf.equal(target_matches, 1))[:, 0]
max_pos_samples = tf.cast(self.positive_fraction * self.num_samples, tf.int32)
if tf.greater(tf.size(fg_inds), max_pos_samples):
fg_inds = tf.random.shuffle(fg_inds)
disable_inds = fg_inds[max_pos_samples:]
fg_inds = fg_inds[:max_pos_samples]
disable_inds = tf.expand_dims(disable_inds, axis=1)
disable_labels = -tf.ones(tf.shape(disable_inds)[0], dtype=tf.int32)
target_matches = tf.tensor_scatter_nd_update(target_matches, disable_inds, disable_labels)
num_fg = tf.reduce_sum(tf.cast(tf.equal(target_matches, 1), tf.int32))
num_bg = self.num_samples - num_fg
bg_inds = tf.where(tf.equal(target_matches, 0))[:, 0]
if tf.greater(tf.size(bg_inds), num_bg):
bg_inds = tf.random.shuffle(bg_inds)
disable_inds = bg_inds[num_bg:]
bg_inds = bg_inds[:num_bg]
disable_inds = tf.expand_dims(disable_inds, axis=1)
disable_labels = -tf.ones(tf.shape(disable_inds)[0], dtype=tf.int32)
target_matches = tf.tensor_scatter_nd_update(target_matches, disable_inds, disable_labels)
# 5. Calculate deltas for chosen targets based on GT (encode)
bboxes_targets = transforms.bbox2delta(anchors, tf.gather(gt_bboxes, argmax_overlaps),
target_means=self.target_means,
target_stds=self.target_stds)
# Regression weights
bbox_inside_weights = tf.zeros((tf.shape(anchors)[0], 4), dtype=tf.float32)
# match_indices = tf.where(tf.equal(target_matches, 1))
match_indices = tf.where(tf.math.greater(target_matches, 0))
updates = tf.ones([tf.shape(match_indices)[0], 4], bbox_inside_weights.dtype)
bbox_inside_weights = tf.tensor_scatter_nd_update(bbox_inside_weights,
match_indices, updates)
bbox_outside_weights = tf.zeros((tf.shape(anchors)[0], 4), dtype=tf.float32)
if self.num_samples > 0:
num_examples = tf.reduce_sum(tf.cast(target_matches >= 0, bbox_outside_weights.dtype))
else:
num_examples = tf.reduce_sum(tf.cast(target_matches > 0, bbox_outside_weights.dtype))
num_fg = num_examples
num_bg = 0 # in RetinaNet we only care about positive anchors
out_indices = tf.where(target_matches >= 0)
updates = tf.ones([tf.shape(out_indices)[0], 4], bbox_outside_weights.dtype) * 1.0 / num_examples
bbox_outside_weights = tf.tensor_scatter_nd_update(bbox_outside_weights,
out_indices, updates)
# for everything that is not selected fill with `fill` value
selected_anchor_idx = tf.where(inside_flags)[:, 0]
return (tf.stop_gradient(_unmap(target_matches, num_anchors, selected_anchor_idx, -1)),
tf.stop_gradient(_unmap(bboxes_targets, num_anchors, selected_anchor_idx, 0)),
tf.stop_gradient(_unmap(bbox_inside_weights, num_anchors, selected_anchor_idx, 0)),
tf.stop_gradient(_unmap(bbox_outside_weights, num_anchors, selected_anchor_idx, 0)),
num_fg, num_bg)