def _sample_rois_pos_neg_for_one_branch()

in lib/roi_data/fast_rcnn_rel.py [0:0]


def _sample_rois_pos_neg_for_one_branch(
        all_rois, gt_boxes, gt_labels, gt_vecs, low_shot_helper, label):

    rois_per_image = int(cfg.TRAIN.BATCH_SIZE_PER_IM)
    fg_rois_per_image = int(
        np.round(cfg.TRAIN.FG_FRACTION * rois_per_image))

    overlaps = box_utils.bbox_overlaps(
        all_rois[:, 1:5].astype(dtype=np.float32, copy=False),
        gt_boxes[:, :4].astype(dtype=np.float32, copy=False))
    max_overlaps = overlaps.max(axis=1)
    gt_assignment = overlaps.argmax(axis=1)

    gt_inds = np.where((max_overlaps >= 0.99999))[0]
    pos_inds = np.where((max_overlaps >= cfg.TRAIN.FG_THRESH) &
                        (max_overlaps < 0.99999))[0]
    fg_rois_per_this_image = min(int(fg_rois_per_image),
                                 gt_inds.size + pos_inds.size)
    if pos_inds.size > 0 and \
       pos_inds.size > fg_rois_per_image - gt_inds.size:
        pos_inds = npr.choice(pos_inds,
                              size=(fg_rois_per_this_image - gt_inds.size),
                              replace=False)
    fg_inds = np.append(pos_inds, gt_inds)
    # duplicate low-shot predicates to increase their chances to be chosen
    if cfg.TRAIN.OVERSAMPLE_SO2:
        pos_labels = gt_labels[gt_assignment[fg_inds]] - 1
        if label == 'sbj':
            low_shot_inds = \
                np.array([fg_inds[i] for i, s in enumerate(pos_labels) if
                         low_shot_helper.check_low_shot_s([s, -1, -1])], dtype=np.int32)
        elif label == 'obj':
            low_shot_inds = \
                np.array([fg_inds[i] for i, o in enumerate(pos_labels) if
                         low_shot_helper.check_low_shot_o([-1, -1, o])], dtype=np.int32)
        else:
            raise NotImplementedError
        fg_inds = np.append(low_shot_inds, fg_inds)
    if fg_inds.size > fg_rois_per_image:
        fg_inds = npr.choice(fg_inds, size=fg_rois_per_image, replace=False)

    bg_inds = np.where((max_overlaps < cfg.TRAIN.BG_THRESH_HI) &
                       (max_overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
    bg_rois_per_this_image = min(rois_per_image - fg_inds.size,
                                 rois_per_image - fg_rois_per_image,
                                 bg_inds.size)
    if bg_inds.size > 0:
        bg_inds = npr.choice(bg_inds, size=bg_rois_per_this_image, replace=False)

    if cfg.TRAIN.OVERSAMPLE_SO:
        pos_labels = gt_labels[gt_assignment[fg_inds]] - 1
        # low_shot_inds contains one dummy ROI at the very beginning
        # This is to make sure that low_shot ROIs are never empty
        if label == 'sbj':
            low_shot_inds = \
                np.array([fg_inds[i] for i, s in enumerate(pos_labels) if
                         low_shot_helper.check_low_shot_s([s, -1, -1])], dtype=np.int32)
        elif label == 'obj':
            low_shot_inds = \
                np.array([fg_inds[i] for i, o in enumerate(pos_labels) if
                         low_shot_helper.check_low_shot_o([-1, -1, o])], dtype=np.int32)
        else:
            raise NotImplementedError
        fg_inds = np.append(low_shot_inds, fg_inds)
        low_shot_ends = np.array([low_shot_inds.size, -1], dtype=np.int32)
        regular_starts = np.array([low_shot_inds.size, 0], dtype=np.int32)

    keep_inds = np.append(fg_inds, bg_inds)
    rois = all_rois[keep_inds]

    pos_vecs = gt_vecs[gt_assignment[fg_inds]]

    all_labels = np.zeros(len(keep_inds), dtype=np.float32)
    all_labels[:fg_inds.size] = gt_labels[gt_assignment[fg_inds]]

    all_labels_horizontal_tile = np.tile(
        all_labels, (fg_inds.size, 1))
    all_labels_vertical_tile = np.tile(
        all_labels[:fg_inds.size], (keep_inds.size, 1)).transpose()
    neg_affinity_mask = \
        np.array(all_labels_horizontal_tile !=
                 all_labels_vertical_tile).astype(np.float32)

    pos_labels_horizontal_tile = np.tile(
        all_labels[:fg_inds.size], (fg_inds.size, 1))
    pos_labels_vertical_tile = np.tile(
        all_labels[:fg_inds.size], (fg_inds.size, 1)).transpose()
    pos_affinity_mask = \
        np.array(pos_labels_horizontal_tile ==
                 pos_labels_vertical_tile).astype(np.float32)

    if cfg.TRAIN.OVERSAMPLE_SO:
        return rois, pos_vecs, all_labels, neg_affinity_mask, pos_affinity_mask, \
            low_shot_ends, regular_starts
    else:
        return rois, pos_vecs, all_labels, neg_affinity_mask, pos_affinity_mask