def __init__()

in datasets.py [0:0]


    def __init__(self, data_path, split, subsample_what=None, duplicates=None):
        root = os.path.join(data_path, "waterbirds/waterbird_complete95_forest2water2/")
        metadata = os.path.join(data_path,"metadata_waterbirds.csv")

        transform = transforms.Compose(
            [
                transforms.Resize(
                    (
                        int(224 * (256 / 224)),
                        int(224 * (256 / 224)),
                    )
                ),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        super().__init__(split, root, metadata, transform, subsample_what, duplicates)
        self.data_type = "images"