def combined_roidb_for_training()

in lib/datasets/roidb_rel.py [0:0]


def combined_roidb_for_training(dataset_names, proposal_files):
    def get_roidb(dataset_name, proposal_file):

        logger.info('loading roidb for {}'.format(dataset_name))

        roidb_file = os.path.join(cfg.DATA_DIR, 'roidb_cache', dataset_name +
                                  '_configured_gt_roidb.pkl')
        if os.path.exists(roidb_file):
            with open(roidb_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            logger.info('len(roidb): {}'.format(len(roidb)))
            logger.info('{} configured gt roidb loaded from {}'.format(
                dataset_name, roidb_file))

            if cfg.TRAIN.USE_FLIPPED:
                logger.info('Appending horizontally-flipped training examples...')
                extend_with_flipped_entries(roidb)
            logger.info('Loaded dataset: {:s}'.format(dataset_name))

            return roidb

        ds = get_imdb(dataset_name)
        roidb = ds.gt_roidb()
        logger.info('loading widths and appending them')
        widths, heights = ds.get_widths_and_heights()

        for i in range(len(roidb)):
            logger.info('creating roidb for image {}'.format(i + 1))
            roidb[i]['width'] = widths[i]
            roidb[i]['height'] = heights[i]
            roidb[i]['image'] = ds.image_path_at(i)
            gt_sbj_overlaps = roidb[i]['gt_sbj_overlaps'].toarray()
            # max sbj_overlap with gt over classes (columns)
            sbj_max_overlaps = gt_sbj_overlaps.max(axis=1)
            # gt sbj_class that had the max sbj_overlap
            sbj_max_classes = gt_sbj_overlaps.argmax(axis=1)
            roidb[i]['sbj_max_classes'] = sbj_max_classes
            roidb[i]['sbj_max_overlaps'] = sbj_max_overlaps
            # sanity checks
            # max overlap of 0 => class should be zero (background)
            zero_inds = np.where(sbj_max_overlaps == 0)[0]
            assert all(sbj_max_classes[zero_inds] == 0)
            # max overlap > 0 => class should not be zero (must be a fg class)
            nonzero_inds = np.where(sbj_max_overlaps > 0)[0]
            assert all(sbj_max_classes[nonzero_inds] != 0)

            # need gt_obj_overlaps as a dense array for argmax
            gt_obj_overlaps = roidb[i]['gt_obj_overlaps'].toarray()
            # max obj_overlap with gt over classes (columns)
            obj_max_overlaps = gt_obj_overlaps.max(axis=1)
            # gt obj_class that had the max obj_overlap
            obj_max_classes = gt_obj_overlaps.argmax(axis=1)
            roidb[i]['obj_max_classes'] = obj_max_classes
            roidb[i]['obj_max_overlaps'] = obj_max_overlaps

            # sanity checks
            # max overlap of 0 => class should be zero (background)
            zero_inds = np.where(obj_max_overlaps == 0)[0]
            assert all(obj_max_classes[zero_inds] == 0)
            # max overlap > 0 => class should not be zero (must be a fg class)
            nonzero_inds = np.where(obj_max_overlaps > 0)[0]
            assert all(obj_max_classes[nonzero_inds] != 0)

            # need gt_rel_overlaps as a dense array for argmax
            gt_rel_overlaps = roidb[i]['gt_rel_overlaps'].toarray()
            # max rel_overlap with gt over classes (columns)
            rel_max_overlaps = gt_rel_overlaps.max(axis=1)
            # gt rel_class that had the max rel_overlap
            rel_max_classes = gt_rel_overlaps.argmax(axis=1)
            roidb[i]['rel_max_classes'] = rel_max_classes
            roidb[i]['rel_max_overlaps'] = rel_max_overlaps
            # sanity checks
            # max overlap of 0 => class should be zero (background)
            zero_inds = np.where(rel_max_overlaps == 0)[0]
            assert all(rel_max_classes[zero_inds] == 0)
            # max overlap > 0 => class should not be zero (must be a fg class)
            nonzero_inds = np.where(rel_max_overlaps > 0)[0]
            assert all(rel_max_classes[nonzero_inds] != 0)

        logger.info('len(roidb): {}'.format(len(roidb)))
        with open(roidb_file, 'wb') as fid:
            cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
        logger.info('wrote configured gt roidb to {}'.format(roidb_file))

        if cfg.TRAIN.USE_FLIPPED:
            logger.info('Appending horizontally-flipped training examples...')
            extend_with_flipped_entries(roidb)
        logger.info('Loaded dataset: {:s}'.format(dataset_name))

        return roidb

    dataset_names = dataset_names.split(':')
    if proposal_files is not None:
        proposal_files = proposal_files.split(':')
    else:
        proposal_files = [None] * len(dataset_names)
    assert len(dataset_names) == len(proposal_files)
    roidbs = [get_roidb(*args) for args in zip(dataset_names, proposal_files)]
    roidb = roidbs[0]
    for r in roidbs[1:]:
        roidb.extend(r)

    roidb = filter_for_training(roidb)

    if cfg.TRAIN.PROPOSAL_FILE != '':
        with open(cfg.TRAIN.PROPOSAL_FILE, 'rb') as fid:
            proposals = cPickle.load(fid)
        return roidb, proposals
    else:
        return roidb