scripts/dataset/dataset.py (33 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. import glob from itertools import chain from os import path import torch from PIL import Image from torch.utils.data import Dataset class SegmentationDataset(Dataset): _EXTENSIONS = ["*.jpg", "*.jpeg", "*.png"] def __init__(self, in_dir, transform): super(SegmentationDataset, self).__init__() self.in_dir = in_dir self.transform = transform # Find all images self.images = [] for img_path in chain( *( glob.iglob(path.join(self.in_dir, ext)) for ext in SegmentationDataset._EXTENSIONS ) ): _, name_with_ext = path.split(img_path) idx, _ = path.splitext(name_with_ext) self.images.append({"idx": idx, "path": img_path}) def __len__(self): return len(self.images) def __getitem__(self, item): # Load image with Image.open(self.images[item]["path"]) as img_raw: size = img_raw.size img = self.transform(img_raw.convert(mode="RGB")) return {"img": img, "meta": {"idx": self.images[item]["idx"], "size": size}} def segmentation_collate(items): imgs = torch.stack([item["img"] for item in items]) metas = [item["meta"] for item in items] return {"img": imgs, "meta": metas}