def __init__()

in scripts/datasets.py [0:0]


    def __init__(self, dim_inv, dim_spu, n_envs):
        self.scramble = torch.eye(dim_inv + dim_spu)
        self.dim_inv = dim_inv
        self.dim_spu = dim_spu
        self.dim = dim_inv + dim_spu

        self.task = "regression"
        self.envs = {}

        if n_envs >= 2:
            self.envs = {'E0': 0.1, 'E1': 1.5}
        if n_envs >= 3:
            self.envs["E2"] = 2
        if n_envs > 3:
            for env in range(3, n_envs):
                var = 10 ** torch.zeros(1).uniform_(-2, 1).item()
                self.envs["E" + str(env)] = var

        self.wxy = torch.randn(self.dim_inv, self.dim_inv) / self.dim_inv
        self.wyz = torch.randn(self.dim_inv, self.dim_spu) / self.dim_spu