tools/prepare_data/create_voc_low_shot_challenge_samples.py (99 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
################################################################################
"""
This script is used to create the low-shot data for VOC svm trainings.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import argparse
import json
import logging
import os
import random
import sys
import numpy as np
# create the logger
FORMAT = '[%(levelname)s: %(filename)s: %(lineno)4d]: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT, stream=sys.stdout)
logger = logging.getLogger(__name__)
def load_json(file_path, ground_truth=True):
import json
assert os.path.exists(file_path), '{} does not exist'.format(file_path)
with open(file_path, 'r') as fp:
data = json.load(fp)
img_ids = sorted(list(data.keys()))
cls_names = sorted(list(data[img_ids[0]].keys()))
if ground_truth:
output = np.empty((len(img_ids), len(cls_names)), dtype=np.int32)
else:
output = np.empty((len(img_ids), len(cls_names)), dtype=np.float64)
for idx in range(len(img_ids)):
for cls_idx in range(len(cls_names)):
output[idx][cls_idx] = data[img_ids[idx]][cls_names[cls_idx]]
return output, img_ids, cls_names
def save_json(input_data, img_ids, cls_names, output_file):
output_dict = {}
for img_idx in range(len(img_ids)):
img_id = img_ids[img_idx]
out_lbl = {}
for cls_idx in range(len(cls_names)):
name = cls_names[cls_idx]
out_lbl[name] = int(input_data[img_idx][cls_idx])
output_dict[img_id] = out_lbl
logger.info('Saving file: {}'.format(output_file))
with open(output_file, 'w') as fp:
json.dump(output_dict, fp)
def sample_symbol(input_targets, output_target, symbol, num):
logger.info('Sampling symbol: {} for num: {}'.format(symbol, num))
num_classes = input_targets.shape[1]
for idx in range(num_classes):
symbol_data = np.where(input_targets[:, idx] == symbol)[0]
sampled = random.sample(list(symbol_data), num)
for index in sampled:
output_target[index, idx] = symbol
return output_target
def generate_independent_sample(opts, targets, img_ids, cls_names):
k_values = [int(val) for val in opts.k_values.split(',')]
# the way sample works is: for each independent sample, and a given k value
# we create a matrix of the same shape as given targets file. We initialize
# this matrix with -1 (ignore label). We then sample k positive and
# (num_classes-1) * k negatives.
# N x 20 shape
num_classes = targets.shape[1]
for idx in range(opts.num_samples):
for k in k_values:
logger.info('Sampling: {} time for k-value: {}'.format(idx + 1, k))
output = np.ones(targets.shape, dtype=np.int32) * -1
output = sample_symbol(targets, output, 1, k)
output = sample_symbol(targets, output, 0, (num_classes - 1) * k)
prefix = opts.targets_data_file.split('/')[-1].split('.')[0]
output_file = os.path.join(
opts.output_path,
'{}_sample{}_k{}.json'.format(prefix, idx + 1, k))
save_json(output, img_ids, cls_names, output_file)
npy_output_file = os.path.join(
opts.output_path,
'{}_sample{}_k{}.npy'.format(prefix, idx + 1, k))
logger.info('Saving npy file: {}'.format(npy_output_file))
np.save(npy_output_file, output)
logger.info('Done!!')
def main():
parser = argparse.ArgumentParser(
description='Sample Low shot data for VOC')
parser.add_argument(
'--targets_data_file',
type=str,
default=None,
help='Json file containing image labels')
parser.add_argument(
'--output_path',
type=str,
default=None,
help='path where low-shot samples should be saved')
parser.add_argument(
'--k_values',
type=str,
default='1,2,4,8,16,32,64,96',
help='Low-shot k-values for svm testing.')
parser.add_argument(
'--num_samples',
type=int,
default=5,
help='Number of independent samples.')
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
opts = parser.parse_args()
targets, img_ids, cls_names = load_json(opts.targets_data_file)
generate_independent_sample(opts, targets, img_ids, cls_names)
if __name__ == '__main__':
main()