def __init__()

in datasets.py [0:0]


    def __init__(self, datasets, pmf, seed=42):
        assert len(pmf) == len(datasets)
        if seed is None:
            raise ValueError("seed can't be None")
        self.datasets = datasets
        self.pmf = pmf
        # Some basic sanity-checks.
        attrs = (
            "orig_shape",
            "shape",
            "ctx",
            "num_embeddings",
            "embedding_sizes",
            "n_vocab",
            "x_emb",
        )
        for attr in attrs:
            assert hasattr(self.ref, attr), f"{attr} is missing in the main dataset."
            ref_attr = getattr(self.ref, attr)
            setattr(self, attr, ref_attr)
            for oth in self.oth:
                assert hasattr(oth, attr), f"{attr} is missing in the auxiliary dataset"
                oth_attr = getattr(oth, attr)
                assert type(ref_attr) == type(oth_attr)
                if isinstance(ref_attr, np.ndarray):
                    assert (ref_attr == oth_attr).all(), f"expected {attr} to be the same."
                else:
                    assert ref_attr == oth_attr, f"expected {attr} to be the same."
        # Perform model selection and evaluation using the main dataset.
        attrs = (
            "H",
            "logprint",
            "vaX",
            "vaY",
            "teX",
            "teY",
            "n_classes",
            "full_dataset_valid",
            "full_dataset_train",
            "iters_per_epoch",
        )
        for attr in attrs:
            setattr(self, attr, getattr(self.ref, attr, None))
        trX = [ds.trX for ds in datasets]
        auxX = [np.zeros_like(tr[:, 0:1]) + idx for idx, tr in enumerate(trX)]
        self.trX = JankySubsampler(trX, pmf, seed=seed)
        self.auxX = JankySubsampler(auxX, pmf, seed=seed)