in detection/run_tf_detector_batch.py [0:0]
def main():
parser = argparse.ArgumentParser(
description='Module to run a TF animal detection model on lots of images')
parser.add_argument(
'detector_file',
help='Path to .pb TensorFlow detector model file')
parser.add_argument(
'image_file',
help='Path to a single image file, a JSON file containing a list of paths to images, or a directory')
parser.add_argument(
'output_file',
help='Path to output JSON results file, should end with a .json extension')
parser.add_argument(
'--recursive',
action='store_true',
help='Recurse into directories, only meaningful if image_file points to a directory')
parser.add_argument(
'--output_relative_filenames',
action='store_true',
help='Output relative file names, only meaningful if image_file points to a directory')
parser.add_argument(
'--use_image_queue',
action='store_true',
help='Pre-load images, may help keep your GPU busy; does not currently support checkpointing. Useful if you have a very fast GPU and a very slow disk.')
parser.add_argument(
'--threshold',
type=float,
default=TFDetector.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD,
help="Confidence threshold between 0 and 1.0, don't include boxes below this confidence in the output file. Default is 0.1")
parser.add_argument(
'--checkpoint_frequency',
type=int,
default=-1,
help='Write results to a temporary file every N images; default is -1, which disables this feature')
parser.add_argument(
'--checkpoint_path',
type=str,
default=None,
help='File name to which checkpoints will be written if checkpoint_frequency is > 0')
parser.add_argument(
'--resume_from_checkpoint',
help='Path to a JSON checkpoint file to resume from, must be in same directory as output_file')
parser.add_argument(
'--ncores',
type=int,
default=0,
help='Number of cores to use; only applies to CPU-based inference, does not support checkpointing when ncores > 1')
if len(sys.argv[1:]) == 0:
parser.print_help()
parser.exit()
args = parser.parse_args()
assert os.path.exists(args.detector_file), 'Specified detector_file does not exist'
assert 0.0 < args.threshold <= 1.0, 'Confidence threshold needs to be between 0 and 1' # Python chained comparison
assert args.output_file.endswith('.json'), 'output_file specified needs to end with .json'
if args.checkpoint_frequency != -1:
assert args.checkpoint_frequency > 0, 'Checkpoint_frequency needs to be > 0 or == -1'
if args.output_relative_filenames:
assert os.path.isdir(args.image_file), 'image_file must be a directory when --output_relative_filenames is set'
if os.path.exists(args.output_file):
print('Warning: output_file {} already exists and will be overwritten'.format(args.output_file))
# Load the checkpoint if available
#
# Relative file names are only output at the end; all file paths in the checkpoint are
# still full paths.
if args.resume_from_checkpoint:
assert os.path.exists(args.resume_from_checkpoint), 'File at resume_from_checkpoint specified does not exist'
with open(args.resume_from_checkpoint) as f:
saved = json.load(f)
assert 'images' in saved, \
'The file saved as checkpoint does not have the correct fields; cannot be restored'
results = saved['images']
print('Restored {} entries from the checkpoint'.format(len(results)))
else:
results = []
# Find the images to score; images can be a directory, may need to recurse
if os.path.isdir(args.image_file):
image_file_names = ImagePathUtils.find_images(args.image_file, args.recursive)
print('{} image files found in the input directory'.format(len(image_file_names)))
# A json list of image paths
elif os.path.isfile(args.image_file) and args.image_file.endswith('.json'):
with open(args.image_file) as f:
image_file_names = json.load(f)
print('{} image files found in the json list'.format(len(image_file_names)))
# A single image file
elif os.path.isfile(args.image_file) and ImagePathUtils.is_image_file(args.image_file):
image_file_names = [args.image_file]
print('A single image at {} is the input file'.format(args.image_file))
else:
raise ValueError('image_file specified is not a directory, a json list, or an image file, '
'(or does not have recognizable extensions).')
assert len(image_file_names) > 0, 'Specified image_file does not point to valid image files'
assert os.path.exists(image_file_names[0]), 'The first image to be scored does not exist at {}'.format(image_file_names[0])
output_dir = os.path.dirname(args.output_file)
if len(output_dir) > 0:
os.makedirs(output_dir,exist_ok=True)
assert not os.path.isdir(args.output_file), 'Specified output file is a directory'
# Test that we can write to the output_file's dir if checkpointing requested
if args.checkpoint_frequency != -1:
if args.checkpoint_path is not None:
checkpoint_path = args.checkpoint_path
else:
checkpoint_path = os.path.join(output_dir, 'checkpoint_{}.json'.format(datetime.utcnow().strftime("%Y%m%d%H%M%S")))
# Confirm that we can write to the checkpoint path, rather than failing after 10000 images
with open(checkpoint_path, 'w') as f:
json.dump({'images': []}, f)
print('The checkpoint file will be written to {}'.format(checkpoint_path))
else:
checkpoint_path = None
start_time = time.time()
results = load_and_run_detector_batch(model_file=args.detector_file,
image_file_names=image_file_names,
checkpoint_path=checkpoint_path,
confidence_threshold=args.threshold,
checkpoint_frequency=args.checkpoint_frequency,
results=results,
n_cores=args.ncores,
use_image_queue=args.use_image_queue)
elapsed = time.time() - start_time
print('Finished inference in {}'.format(humanfriendly.format_timespan(elapsed)))
relative_path_base = None
if args.output_relative_filenames:
relative_path_base = args.image_file
write_results_to_file(results, args.output_file, relative_path_base=relative_path_base)
if checkpoint_path:
os.remove(checkpoint_path)
print('Deleted checkpoint file {}'.format(checkpoint_path))
print('Done!')