in Dassl.pytorch/dassl/data/datasets/dg/wilds/wilds_base.py [0:0]
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
name = self.dataset_dir.split("_")[0]
self.dataset_dir = osp.join(root, self.dataset_dir)
self.preloaded = osp.join(self.dataset_dir, "zhou_preloaded.pkl")
self.label_to_name = self.load_classnames()
assert isinstance(self.label_to_name, dict)
if osp.exists(self.preloaded):
with open(self.preloaded, "rb") as file:
dataset = pickle.load(file)
train = dataset["train"]
val = dataset["val"]
test = dataset["test"]
else:
dataset = wilds_get_dataset(
dataset=name, root_dir=root, download=True
)
subset_train = dataset.get_subset("train")
subset_val = dataset.get_subset("val")
subset_test = dataset.get_subset("test")
train = self.read_data(subset_train)
val = self.read_data(subset_val)
test = self.read_data(subset_test)
# Save time for data loading next time
preloaded = {"train": train, "val": val, "test": test}
with open(self.preloaded, "wb") as file:
pickle.dump(preloaded, file, protocol=pickle.HIGHEST_PROTOCOL)
# Few-shot learning
k = cfg.DATASET.NUM_SHOTS
if k > 0:
groups = self.split_dataset_by_domain(train)
groups = list(groups.values())
groups = self.generate_fewshot_dataset(*groups, num_shots=k)
train = []
for group in groups:
train.extend(group)
super().__init__(train_x=train, val=val, test=test)