in datasets.py [0:0]
def __init__(self, data_path, split, subsample_what=None, duplicates=None):
root = os.path.join(data_path, "waterbirds/waterbird_complete95_forest2water2/")
metadata = os.path.join(data_path,"metadata_waterbirds.csv")
transform = transforms.Compose(
[
transforms.Resize(
(
int(224 * (256 / 224)),
int(224 * (256 / 224)),
)
),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
super().__init__(split, root, metadata, transform, subsample_what, duplicates)
self.data_type = "images"