in lib/core/get_gt_perturbed_proposals.py [0:0]
def get_gt_perturbed_proposals(gt_roidb):
data_dir = os.path.join(cfg.DATA_DIR, 'proposals')
proposal_file_path = os.path.join(data_dir, 'vg')
proposal_name = 'gt_perturbed_proposals_flipped.pkl'
proposal_file = os.path.join(proposal_file_path, proposal_name)
logger.info('proposal file: {}'.format(proposal_file))
if os.path.exists(proposal_file):
logger.info('Loading existing proposals...')
with open(proposal_file, 'rb') as fid:
proposals = pickle.load(fid)
return proposals
else:
logger.info('Generating gt perturbed proposals...')
num_images = len(gt_roidb)
blob_names = ['unique_all_rois_sbj', 'unique_all_rois_obj',
'unique_sbj_gt_inds', 'unique_obj_gt_inds']
all_blobs = [{}] * num_images
for im_i, entry in enumerate(gt_roidb):
logger.info('Preparing roidb {}/{}'.format(im_i, num_images))
all_blobs[im_i] = {k: [] for k in blob_names}
sbj_gt_inds = np.where((entry['gt_sbj_classes'] > 0))[0]
obj_gt_inds = np.where((entry['gt_obj_classes'] > 0))[0]
scale = 1.0
sbj_gt_rois = entry['sbj_boxes'][sbj_gt_inds, :] * scale
obj_gt_rois = entry['obj_boxes'][obj_gt_inds, :] * scale
sbj_gt_rois = sbj_gt_rois.astype(np.float32)
obj_gt_rois = obj_gt_rois.astype(np.float32)
sbj_gt_boxes = np.zeros((len(sbj_gt_inds), 6), dtype=np.float32)
sbj_gt_boxes[:, 0] = 0 # batch inds
sbj_gt_boxes[:, 1:5] = sbj_gt_rois
sbj_gt_boxes[:, 5] = entry['gt_sbj_classes'][sbj_gt_inds]
obj_gt_boxes = np.zeros((len(obj_gt_inds), 6), dtype=np.float32)
obj_gt_boxes[:, 0] = 0 # batch inds
obj_gt_boxes[:, 1:5] = obj_gt_rois
obj_gt_boxes[:, 5] = entry['gt_obj_classes'][obj_gt_inds]
# Get unique boxes
rows = set()
unique_sbj_gt_inds = []
for idx, row in enumerate(sbj_gt_boxes):
if tuple(row) not in rows:
rows.add(tuple(row))
unique_sbj_gt_inds.append(idx)
unique_sbj_gt_boxes = sbj_gt_boxes[unique_sbj_gt_inds, :]
rows = set()
unique_obj_gt_inds = []
for idx, row in enumerate(obj_gt_boxes):
if tuple(row) not in rows:
rows.add(tuple(row))
unique_obj_gt_inds.append(idx)
unique_obj_gt_boxes = obj_gt_boxes[unique_obj_gt_inds, :]
# use better sampling by default
im_width = entry['width'] * scale
im_height = entry['height'] * scale
_rois_sbj = _augment_gt_boxes_by_perturbation(
unique_sbj_gt_boxes[:, 1:5], im_width, im_height)
rois_sbj = np.zeros((_rois_sbj.shape[0], 5), dtype=np.float32)
rois_sbj[:, 0] = 0
rois_sbj[:, 1:5] = _rois_sbj
_rois_obj = _augment_gt_boxes_by_perturbation(
unique_obj_gt_boxes[:, 1:5], im_width, im_height)
rois_obj = np.zeros((_rois_obj.shape[0], 5), dtype=np.float32)
rois_obj[:, 0] = 0
rois_obj[:, 1:5] = _rois_obj
rows = set()
unique_sbj_rois_inds = []
for idx, row in enumerate(rois_sbj):
if tuple(row) not in rows:
rows.add(tuple(row))
unique_sbj_rois_inds.append(idx)
unique_rois_sbj = rois_sbj[unique_sbj_rois_inds, :]
rows = set()
unique_obj_rois_inds = []
for idx, row in enumerate(rois_obj):
if tuple(row) not in rows:
rows.add(tuple(row))
unique_obj_rois_inds.append(idx)
unique_rois_obj = rois_obj[unique_obj_rois_inds, :]
unique_all_rois_sbj = \
np.vstack((unique_rois_sbj, unique_sbj_gt_boxes[:, :-1]))
unique_all_rois_obj = \
np.vstack((unique_rois_obj, unique_obj_gt_boxes[:, :-1]))
for k in all_blobs[im_i]:
all_blobs[im_i][k] = locals()[k]
proposals = all_blobs
import pdb
pdb.set_trace()
with open(proposal_file, 'wb') as fid:
pickle.dump(proposals, fid, pickle.HIGHEST_PROTOCOL)
print('Wrote shdet gt perturbed proposals to {}'.format(
os.path.abspath(proposal_file)))
return proposals