Dassl.pytorch/dassl/data/datasets/dg/wilds/wilds_base.py (85 lines of code) (raw):
import logging # isort:skip
logging.disable(logging.WARNING) # isort:skip
import pickle
import logging
import os.path as osp
from wilds import get_dataset as wilds_get_dataset
from dassl.data.datasets import Datum, DatasetBase
class WILDSBase(DatasetBase):
dataset_dir = ""
relabel_domain = True
def __init__(self, cfg):
root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT))
name = self.dataset_dir.split("_")[0]
self.dataset_dir = osp.join(root, self.dataset_dir)
self.preloaded = osp.join(self.dataset_dir, "zhou_preloaded.pkl")
self.label_to_name = self.load_classnames()
assert isinstance(self.label_to_name, dict)
if osp.exists(self.preloaded):
with open(self.preloaded, "rb") as file:
dataset = pickle.load(file)
train = dataset["train"]
val = dataset["val"]
test = dataset["test"]
else:
dataset = wilds_get_dataset(
dataset=name, root_dir=root, download=True
)
subset_train = dataset.get_subset("train")
subset_val = dataset.get_subset("val")
subset_test = dataset.get_subset("test")
train = self.read_data(subset_train)
val = self.read_data(subset_val)
test = self.read_data(subset_test)
# Save time for data loading next time
preloaded = {"train": train, "val": val, "test": test}
with open(self.preloaded, "wb") as file:
pickle.dump(preloaded, file, protocol=pickle.HIGHEST_PROTOCOL)
# Few-shot learning
k = cfg.DATASET.NUM_SHOTS
if k > 0:
groups = self.split_dataset_by_domain(train)
groups = list(groups.values())
groups = self.generate_fewshot_dataset(*groups, num_shots=k)
train = []
for group in groups:
train.extend(group)
super().__init__(train_x=train, val=val, test=test)
def load_classnames(self):
raise NotImplementedError
def get_image_path(self, dataset, idx):
image_name = dataset._input_array[idx]
image_path = osp.join(self.dataset_dir, image_name)
return image_path
def get_label(self, dataset, idx):
return int(dataset.y_array[idx])
def get_domain(self, dataset, idx):
return int(dataset.metadata_array[idx][0])
def read_data(self, subset):
items = []
indices = subset.indices
dataset = subset.dataset
for idx in indices:
image_path = self.get_image_path(dataset, idx)
label = self.get_label(dataset, idx)
domain = self.get_domain(dataset, idx)
classname = self.label_to_name[label]
item = Datum(
impath=image_path,
label=label,
domain=domain,
classname=classname
)
items.append(item)
if self.relabel_domain:
domains = set([item.domain for item in items])
mapping = {domain: i for i, domain in enumerate(domains)}
items_new = []
for item in items:
item_new = Datum(
impath=item.impath,
label=item.label,
domain=mapping[item.domain],
classname=item.classname
)
items_new.append(item_new)
return items_new
return items