def sample()

in scripts/datasets.py [0:0]


    def sample(self, n=1000, env="E0", split="train"):
        m = n // 2
        sep = .1

        invariant_0 = torch.randn(m, self.dim_inv) * .1 + \
            torch.Tensor([[sep] * self.dim_inv])
        invariant_1 = torch.randn(m, self.dim_inv) * .1 - \
            torch.Tensor([[sep] * self.dim_inv])

        shortcuts_0 = torch.randn(m, self.dim_spu) * .1 + self.envs[env]
        shortcuts_1 = torch.randn(m, self.dim_spu) * .1 - self.envs[env]

        x = torch.cat((torch.cat((invariant_0, shortcuts_0), -1),
                       torch.cat((invariant_1, shortcuts_1), -1)))

        if split == "test":
            x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:]

        inputs = x @ self.scramble
        outputs = torch.cat((torch.zeros(m, 1), torch.ones(m, 1)))

        return inputs, outputs