in scripts/datasets.py [0:0]
def __init__(self, dim_inv, dim_spu, n_envs):
self.scramble = torch.eye(dim_inv + dim_spu)
self.dim_inv = dim_inv
self.dim_spu = dim_spu
self.dim = dim_inv + dim_spu
self.task = "regression"
self.envs = {}
if n_envs >= 2:
self.envs = {'E0': 0.1, 'E1': 1.5}
if n_envs >= 3:
self.envs["E2"] = 2
if n_envs > 3:
for env in range(3, n_envs):
var = 10 ** torch.zeros(1).uniform_(-2, 1).item()
self.envs["E" + str(env)] = var
self.wxy = torch.randn(self.dim_inv, self.dim_inv) / self.dim_inv
self.wyz = torch.randn(self.dim_inv, self.dim_spu) / self.dim_spu