scripts/dataset/transform.py (16 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates.
from PIL import Image
from torchvision.transforms import functional as tfn
class SegmentationTransform:
def __init__(self, longest_max_size, rgb_mean, rgb_std):
self.longest_max_size = longest_max_size
self.rgb_mean = rgb_mean
self.rgb_std = rgb_std
def __call__(self, img):
# Scaling
scale = self.longest_max_size / float(max(img.size[0], img.size[1]))
if scale != 1.0:
out_size = tuple(int(dim * scale) for dim in img.size)
img = img.resize(out_size, resample=Image.BILINEAR)
# Convert to torch and normalize
img = tfn.to_tensor(img)
img.sub_(img.new(self.rgb_mean).view(-1, 1, 1))
img.div_(img.new(self.rgb_std).view(-1, 1, 1))
return img