def sample()

in scripts/datasets.py [0:0]


    def sample(self, n=1000, env="E0", split="train"):
        sdv = self.envs[env]
        x = torch.randn(n, self.dim_inv) * sdv
        y = x @ self.wxy + torch.randn(n, self.dim_inv) * sdv
        z = y @ self.wyz + torch.randn(n, self.dim_spu)

        if split == "test":
            z = z[torch.randperm(len(z))]

        inputs = torch.cat((x, z), -1) @ self.scramble
        outputs = y.sum(1, keepdim=True)

        return inputs, outputs