in datasets.py [0:0]
def __init__(self, datasets, pmf, seed=42):
assert len(pmf) == len(datasets)
if seed is None:
raise ValueError("seed can't be None")
self.datasets = datasets
self.pmf = pmf
# Some basic sanity-checks.
attrs = (
"orig_shape",
"shape",
"ctx",
"num_embeddings",
"embedding_sizes",
"n_vocab",
"x_emb",
)
for attr in attrs:
assert hasattr(self.ref, attr), f"{attr} is missing in the main dataset."
ref_attr = getattr(self.ref, attr)
setattr(self, attr, ref_attr)
for oth in self.oth:
assert hasattr(oth, attr), f"{attr} is missing in the auxiliary dataset"
oth_attr = getattr(oth, attr)
assert type(ref_attr) == type(oth_attr)
if isinstance(ref_attr, np.ndarray):
assert (ref_attr == oth_attr).all(), f"expected {attr} to be the same."
else:
assert ref_attr == oth_attr, f"expected {attr} to be the same."
# Perform model selection and evaluation using the main dataset.
attrs = (
"H",
"logprint",
"vaX",
"vaY",
"teX",
"teY",
"n_classes",
"full_dataset_valid",
"full_dataset_train",
"iters_per_epoch",
)
for attr in attrs:
setattr(self, attr, getattr(self.ref, attr, None))
trX = [ds.trX for ds in datasets]
auxX = [np.zeros_like(tr[:, 0:1]) + idx for idx, tr in enumerate(trX)]
self.trX = JankySubsampler(trX, pmf, seed=seed)
self.auxX = JankySubsampler(auxX, pmf, seed=seed)