example/ssd/dataset/mscoco.py (58 lines of code) (raw):

import os import numpy as np from imdb import Imdb from pycocotools.coco import COCO class Coco(Imdb): """ Implementation of Imdb for MSCOCO dataset: https://http://mscoco.org Parameters: ---------- anno_file : str annotation file for coco, a json file image_dir : str image directory for coco images shuffle : bool whether initially shuffle image list """ def __init__(self, anno_file, image_dir, shuffle=True, names='mscoco.names'): assert os.path.isfile(anno_file), "Invalid annotation file: " + anno_file basename = os.path.splitext(os.path.basename(anno_file))[0] super(Coco, self).__init__('coco_' + basename) self.image_dir = image_dir self.classes = self._load_class_names(names, os.path.join(os.path.dirname(__file__), 'names')) self.num_classes = len(self.classes) self._load_all(anno_file, shuffle) self.num_images = len(self.image_set_index) def image_path_from_index(self, index): """ given image index, find out full path Parameters: ---------- index: int index of a specific image Returns: ---------- full path of this image """ assert self.image_set_index is not None, "Dataset not initialized" name = self.image_set_index[index] image_file = os.path.join(self.image_dir, 'images', name) assert os.path.isfile(image_file), 'Path does not exist: {}'.format(image_file) return image_file def label_from_index(self, index): """ given image index, return preprocessed ground-truth Parameters: ---------- index: int index of a specific image Returns: ---------- ground-truths of this image """ assert self.labels is not None, "Labels not processed" return self.labels[index] def _load_all(self, anno_file, shuffle): """ initialize all entries given annotation json file Parameters: ---------- anno_file: str annotation json file shuffle: bool whether to shuffle image list """ image_set_index = [] labels = [] coco = COCO(anno_file) img_ids = coco.getImgIds() for img_id in img_ids: # filename image_info = coco.loadImgs(img_id)[0] filename = image_info["file_name"] subdir = filename.split('_')[1] height = image_info["height"] width = image_info["width"] # label anno_ids = coco.getAnnIds(imgIds=img_id) annos = coco.loadAnns(anno_ids) label = [] for anno in annos: cat_id = int(anno["category_id"]) bbox = anno["bbox"] assert len(bbox) == 4 xmin = float(bbox[0]) / width ymin = float(bbox[1]) / height xmax = xmin + float(bbox[2]) / width ymax = ymin + float(bbox[3]) / height label.append([cat_id, xmin, ymin, xmax, ymax, 0]) if label: labels.append(np.array(label)) image_set_index.append(os.path.join(subdir, filename)) if shuffle: import random indices = range(len(image_set_index)) random.shuffle(indices) image_set_index = [image_set_index[i] for i in indices] labels = [labels[i] for i in indices] # store the results self.image_set_index = image_set_index self.labels = labels