libraries/python/imagenet_test_map.py (100 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## # This is a library used to map the image names with the labels # when the images and labels are saved in the imagenet dataset hierarchy. # Optionally the images are shuffled from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import os import random parser = argparse.ArgumentParser(description="Map the images with labels") parser.add_argument( "--image-dir", required=True, help="The directory of the images in imagenet database format.", ) parser.add_argument( "--label-file", required=True, help="The file of the labels in imagenet database format.", ) parser.add_argument( "--limit-per-category", type=int, help="Limit the number of files to get from each folder.", ) parser.add_argument( "--limit-files", type=int, help="Limit the total number of files to the test set." ) parser.add_argument( "--output-image-file", help="The file containing the absolute path of all images." ) parser.add_argument( "--output-label-file", required=True, help="The file containing the labels for the images.", ) parser.add_argument( "--shuffle", action="store_true", help="Shuffle the images and labels." ) class ImageLableMap(object): def __init__(self): self.args = parser.parse_args() assert os.path.isfile( self.args.label_file ), "Label file {} doesn't exist".format(self.args.label_file) assert os.path.isdir( self.args.image_dir ), "Image directory {} doesn't exist".format(self.args.image_dir) if self.args.output_image_file: output_image_dir = os.path.dirname( os.path.abspath(self.args.output_image_file) ) if not os.path.isdir(output_image_dir): os.mkdir(output_image_dir) output_label_dir = os.path.dirname(os.path.abspath(self.args.output_label_file)) if not os.path.isdir(output_label_dir): os.mkdir(output_label_dir) def mapImageLabels(self): with open(self.args.label_file, "r") as f: content = f.read() dir_label_mapping_str = content.strip().split("\n") dir_label_mapping = [line.strip().split(",") for line in dir_label_mapping_str] all_images_map = [] for idx in range(len(dir_label_mapping)): one_map = dir_label_mapping[idx] rel_dir = one_map[0] label = one_map[1] dir = os.path.join(os.path.abspath(self.args.image_dir), rel_dir) assert os.path.isdir(dir), "image dir {} doesn't exist".format(dir) files = os.listdir(dir) images_map = [ { "path": os.path.join(dir, filename.strip()), "index": idx, "label": label, } for filename in files ] if self.args.shuffle: random.shuffle(images_map) if self.args.limit_per_category: images_map = images_map[: self.args.limit_per_category] all_images_map.extend(images_map) if self.args.shuffle: random.shuffle(all_images_map) if self.args.limit_files: all_images_map = all_images_map[: self.args.limit_files] if self.args.output_image_file: with open(self.args.output_image_file, "w") as f: image_files = [item["path"] for item in all_images_map] content = "\n".join(image_files) f.write(content) with open(self.args.output_label_file, "w") as f: labels = [ str(item["index"]) + "," + item["label"] + "," + item["path"] for item in all_images_map ] content = "\n".join(labels) f.write(content) # print(all_images_map) if __name__ == "__main__": app = ImageLableMap() app.mapImageLabels()