in domainbed/datasets.py [0:0]
def __init__(self, root, test_envs, augment, hparams):
super().__init__()
environments = [f.name for f in os.scandir(root) if f.is_dir()]
environments = sorted(environments)
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
augment_transform = transforms.Compose([
# transforms.Resize((224,224)),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.RandomGrayscale(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.datasets = []
for i, environment in enumerate(environments):
if augment and (i not in test_envs):
env_transform = augment_transform
else:
env_transform = transform
path = os.path.join(root, environment)
env_dataset = ImageFolder(path,
transform=env_transform)
self.datasets.append(env_dataset)
self.input_shape = (3, 224, 224,)
self.num_classes = len(self.datasets[-1].classes)