example/ssd/dataset/mscoco.py (58 lines of code) (raw):
import os
import numpy as np
from imdb import Imdb
from pycocotools.coco import COCO
class Coco(Imdb):
"""
Implementation of Imdb for MSCOCO dataset: https://http://mscoco.org
Parameters:
----------
anno_file : str
annotation file for coco, a json file
image_dir : str
image directory for coco images
shuffle : bool
whether initially shuffle image list
"""
def __init__(self, anno_file, image_dir, shuffle=True, names='mscoco.names'):
assert os.path.isfile(anno_file), "Invalid annotation file: " + anno_file
basename = os.path.splitext(os.path.basename(anno_file))[0]
super(Coco, self).__init__('coco_' + basename)
self.image_dir = image_dir
self.classes = self._load_class_names(names,
os.path.join(os.path.dirname(__file__), 'names'))
self.num_classes = len(self.classes)
self._load_all(anno_file, shuffle)
self.num_images = len(self.image_set_index)
def image_path_from_index(self, index):
"""
given image index, find out full path
Parameters:
----------
index: int
index of a specific image
Returns:
----------
full path of this image
"""
assert self.image_set_index is not None, "Dataset not initialized"
name = self.image_set_index[index]
image_file = os.path.join(self.image_dir, 'images', name)
assert os.path.isfile(image_file), 'Path does not exist: {}'.format(image_file)
return image_file
def label_from_index(self, index):
"""
given image index, return preprocessed ground-truth
Parameters:
----------
index: int
index of a specific image
Returns:
----------
ground-truths of this image
"""
assert self.labels is not None, "Labels not processed"
return self.labels[index]
def _load_all(self, anno_file, shuffle):
"""
initialize all entries given annotation json file
Parameters:
----------
anno_file: str
annotation json file
shuffle: bool
whether to shuffle image list
"""
image_set_index = []
labels = []
coco = COCO(anno_file)
img_ids = coco.getImgIds()
for img_id in img_ids:
# filename
image_info = coco.loadImgs(img_id)[0]
filename = image_info["file_name"]
subdir = filename.split('_')[1]
height = image_info["height"]
width = image_info["width"]
# label
anno_ids = coco.getAnnIds(imgIds=img_id)
annos = coco.loadAnns(anno_ids)
label = []
for anno in annos:
cat_id = int(anno["category_id"])
bbox = anno["bbox"]
assert len(bbox) == 4
xmin = float(bbox[0]) / width
ymin = float(bbox[1]) / height
xmax = xmin + float(bbox[2]) / width
ymax = ymin + float(bbox[3]) / height
label.append([cat_id, xmin, ymin, xmax, ymax, 0])
if label:
labels.append(np.array(label))
image_set_index.append(os.path.join(subdir, filename))
if shuffle:
import random
indices = range(len(image_set_index))
random.shuffle(indices)
image_set_index = [image_set_index[i] for i in indices]
labels = [labels[i] for i in indices]
# store the results
self.image_set_index = image_set_index
self.labels = labels