scripts/dataset/transform.py (16 lines of code) (raw):

# Copyright (c) Facebook, Inc. and its affiliates. from PIL import Image from torchvision.transforms import functional as tfn class SegmentationTransform: def __init__(self, longest_max_size, rgb_mean, rgb_std): self.longest_max_size = longest_max_size self.rgb_mean = rgb_mean self.rgb_std = rgb_std def __call__(self, img): # Scaling scale = self.longest_max_size / float(max(img.size[0], img.size[1])) if scale != 1.0: out_size = tuple(int(dim * scale) for dim in img.size) img = img.resize(out_size, resample=Image.BILINEAR) # Convert to torch and normalize img = tfn.to_tensor(img) img.sub_(img.new(self.rgb_mean).view(-1, 1, 1)) img.div_(img.new(self.rgb_std).view(-1, 1, 1)) return img