in domainbed/datasets.py [0:0]
def __init__(self, dataset, metadata_name, test_envs, augment, hparams):
super().__init__()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
augment_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
transforms.RandomGrayscale(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.datasets = []
for i, metadata_value in enumerate(
self.metadata_values(dataset, metadata_name)):
if augment and (i not in test_envs):
env_transform = augment_transform
else:
env_transform = transform
env_dataset = WILDSEnvironment(
dataset, metadata_name, metadata_value, env_transform)
self.datasets.append(env_dataset)
self.input_shape = (3, 224, 224,)
self.num_classes = dataset.n_classes