def __init__()

in Dassl.pytorch/dassl/data/datasets/dg/wilds/wilds_base.py [0:0]


    def __init__(self, cfg):
        root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
        name = self.dataset_dir.split("_")[0]
        self.dataset_dir = osp.join(root, self.dataset_dir)
        self.preloaded = osp.join(self.dataset_dir, "zhou_preloaded.pkl")

        self.label_to_name = self.load_classnames()
        assert isinstance(self.label_to_name, dict)

        if osp.exists(self.preloaded):
            with open(self.preloaded, "rb") as file:
                dataset = pickle.load(file)
                train = dataset["train"]
                val = dataset["val"]
                test = dataset["test"]
        else:
            dataset = wilds_get_dataset(
                dataset=name, root_dir=root, download=True
            )
            subset_train = dataset.get_subset("train")
            subset_val = dataset.get_subset("val")
            subset_test = dataset.get_subset("test")

            train = self.read_data(subset_train)
            val = self.read_data(subset_val)
            test = self.read_data(subset_test)

            # Save time for data loading next time
            preloaded = {"train": train, "val": val, "test": test}
            with open(self.preloaded, "wb") as file:
                pickle.dump(preloaded, file, protocol=pickle.HIGHEST_PROTOCOL)

        # Few-shot learning
        k = cfg.DATASET.NUM_SHOTS
        if k > 0:
            groups = self.split_dataset_by_domain(train)
            groups = list(groups.values())
            groups = self.generate_fewshot_dataset(*groups, num_shots=k)
            train = []
            for group in groups:
                train.extend(group)

        super().__init__(train_x=train, val=val, test=test)