in lib/roi_data/fast_rcnn_rel.py [0:0]
def _sample_rois_pos_neg_for_one_branch(
all_rois, gt_boxes, gt_labels, gt_vecs, low_shot_helper, label):
rois_per_image = int(cfg.TRAIN.BATCH_SIZE_PER_IM)
fg_rois_per_image = int(
np.round(cfg.TRAIN.FG_FRACTION * rois_per_image))
overlaps = box_utils.bbox_overlaps(
all_rois[:, 1:5].astype(dtype=np.float32, copy=False),
gt_boxes[:, :4].astype(dtype=np.float32, copy=False))
max_overlaps = overlaps.max(axis=1)
gt_assignment = overlaps.argmax(axis=1)
gt_inds = np.where((max_overlaps >= 0.99999))[0]
pos_inds = np.where((max_overlaps >= cfg.TRAIN.FG_THRESH) &
(max_overlaps < 0.99999))[0]
fg_rois_per_this_image = min(int(fg_rois_per_image),
gt_inds.size + pos_inds.size)
if pos_inds.size > 0 and \
pos_inds.size > fg_rois_per_image - gt_inds.size:
pos_inds = npr.choice(pos_inds,
size=(fg_rois_per_this_image - gt_inds.size),
replace=False)
fg_inds = np.append(pos_inds, gt_inds)
# duplicate low-shot predicates to increase their chances to be chosen
if cfg.TRAIN.OVERSAMPLE_SO2:
pos_labels = gt_labels[gt_assignment[fg_inds]] - 1
if label == 'sbj':
low_shot_inds = \
np.array([fg_inds[i] for i, s in enumerate(pos_labels) if
low_shot_helper.check_low_shot_s([s, -1, -1])], dtype=np.int32)
elif label == 'obj':
low_shot_inds = \
np.array([fg_inds[i] for i, o in enumerate(pos_labels) if
low_shot_helper.check_low_shot_o([-1, -1, o])], dtype=np.int32)
else:
raise NotImplementedError
fg_inds = np.append(low_shot_inds, fg_inds)
if fg_inds.size > fg_rois_per_image:
fg_inds = npr.choice(fg_inds, size=fg_rois_per_image, replace=False)
bg_inds = np.where((max_overlaps < cfg.TRAIN.BG_THRESH_HI) &
(max_overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
bg_rois_per_this_image = min(rois_per_image - fg_inds.size,
rois_per_image - fg_rois_per_image,
bg_inds.size)
if bg_inds.size > 0:
bg_inds = npr.choice(bg_inds, size=bg_rois_per_this_image, replace=False)
if cfg.TRAIN.OVERSAMPLE_SO:
pos_labels = gt_labels[gt_assignment[fg_inds]] - 1
# low_shot_inds contains one dummy ROI at the very beginning
# This is to make sure that low_shot ROIs are never empty
if label == 'sbj':
low_shot_inds = \
np.array([fg_inds[i] for i, s in enumerate(pos_labels) if
low_shot_helper.check_low_shot_s([s, -1, -1])], dtype=np.int32)
elif label == 'obj':
low_shot_inds = \
np.array([fg_inds[i] for i, o in enumerate(pos_labels) if
low_shot_helper.check_low_shot_o([-1, -1, o])], dtype=np.int32)
else:
raise NotImplementedError
fg_inds = np.append(low_shot_inds, fg_inds)
low_shot_ends = np.array([low_shot_inds.size, -1], dtype=np.int32)
regular_starts = np.array([low_shot_inds.size, 0], dtype=np.int32)
keep_inds = np.append(fg_inds, bg_inds)
rois = all_rois[keep_inds]
pos_vecs = gt_vecs[gt_assignment[fg_inds]]
all_labels = np.zeros(len(keep_inds), dtype=np.float32)
all_labels[:fg_inds.size] = gt_labels[gt_assignment[fg_inds]]
all_labels_horizontal_tile = np.tile(
all_labels, (fg_inds.size, 1))
all_labels_vertical_tile = np.tile(
all_labels[:fg_inds.size], (keep_inds.size, 1)).transpose()
neg_affinity_mask = \
np.array(all_labels_horizontal_tile !=
all_labels_vertical_tile).astype(np.float32)
pos_labels_horizontal_tile = np.tile(
all_labels[:fg_inds.size], (fg_inds.size, 1))
pos_labels_vertical_tile = np.tile(
all_labels[:fg_inds.size], (fg_inds.size, 1)).transpose()
pos_affinity_mask = \
np.array(pos_labels_horizontal_tile ==
pos_labels_vertical_tile).astype(np.float32)
if cfg.TRAIN.OVERSAMPLE_SO:
return rois, pos_vecs, all_labels, neg_affinity_mask, pos_affinity_mask, \
low_shot_ends, regular_starts
else:
return rois, pos_vecs, all_labels, neg_affinity_mask, pos_affinity_mask