in scripts/datasets.py [0:0]
def sample(self, n=1000, env="E0", split="train"):
sdv = self.envs[env]
x = torch.randn(n, self.dim_inv) * sdv
y = x @ self.wxy + torch.randn(n, self.dim_inv) * sdv
z = y @ self.wyz + torch.randn(n, self.dim_spu)
if split == "test":
z = z[torch.randperm(len(z))]
inputs = torch.cat((x, z), -1) @ self.scramble
outputs = y.sum(1, keepdim=True)
return inputs, outputs