libraries/python/coco/generate_im_info.py (100 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
# This is the script to generate the im_info blob used in MaskRCNN2Go model
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import json
import logging
import sys
import numpy as np
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
logging.basicConfig(
level=logging.DEBUG, format=FORMAT, datefmt="%H:%M:%S", stream=sys.stdout
)
logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser(description="Load and extract coco dataset")
parser.add_argument(
"--batch-size",
type=int,
default=-1,
help="The batch size of the input data. If less than zero, all inputs "
"are in one batch. Otherwise, the number of inputs must be multiples "
"of the batch size.",
)
parser.add_argument(
"--dataset-file",
type=str,
required=True,
help="The file of the dataset containing image annotations",
)
parser.add_argument(
"--min-size",
type=int,
required=True,
help="The minimum size to scale the input image.",
)
parser.add_argument(
"--max-size",
type=int,
required=True,
help="The maximum size to scale the input image.",
)
parser.add_argument(
"--output-file",
type=str,
required=True,
help="The output file containing the info for im_info blob.",
)
class ImInfo(object):
def __init__(self, args):
self.args = args
def run(self):
with open(self.args.dataset_file, "r") as f:
imgs = [json.loads(s) for s in f.readlines()]
batch_size = self.args.batch_size if self.args.batch_size > 0 else len(imgs)
num_batches = len(imgs) // batch_size
assert len(imgs) == num_batches * batch_size
im_infos = []
for i in range(num_batches):
one_batch_info = []
for j in range(i * batch_size, (i + 1) * batch_size):
img = imgs[j]
im_scale = self.getScale(img["height"], img["width"])
height = int(np.round(img["height"] * im_scale))
width = int(np.round(img["width"] * im_scale))
assert (
height <= self.args.max_size
), "height {} is more than the max_size {}".format(
height, self.args.max_size
)
assert (
width <= self.args.max_size
), "width {} is more than the max_size {}".format(
width, self.args.max_size
)
if height < self.args.min_size or width < self.args.min_size:
assert height == self.args.max_size or width == self.args.max_size
else:
assert height == self.args.min_size or width == self.args.min_size
im_info = [height, width, im_scale]
one_batch_info.append(im_info)
im_infos.append(one_batch_info)
with open(self.args.output_file, "w") as f:
f.write("{}, {}\n".format(num_batches * batch_size, 3))
for batch in im_infos:
for im_info in batch:
s = ", ".join([str(s) for s in im_info])
f.write("{}\n".format(s))
def getScale(self, height, width):
min_size = self.args.min_size
max_size = self.args.max_size
im_min_size = height if height < width else width
im_max_size = height if height > width else width
im_scale = float(min_size) / float(im_min_size)
if np.round(im_scale * im_max_size) > max_size:
im_scale = float(max_size) / float(im_max_size)
return im_scale
if __name__ == "__main__":
args = parser.parse_args()
app = ImInfo(args)
app.run()