in scripts/datasets.py [0:0]
def sample(self, n=1000, env="E0", split="train"):
m = n // 2
sep = .1
invariant_0 = torch.randn(m, self.dim_inv) * .1 + \
torch.Tensor([[sep] * self.dim_inv])
invariant_1 = torch.randn(m, self.dim_inv) * .1 - \
torch.Tensor([[sep] * self.dim_inv])
shortcuts_0 = torch.randn(m, self.dim_spu) * .1 + self.envs[env]
shortcuts_1 = torch.randn(m, self.dim_spu) * .1 - self.envs[env]
x = torch.cat((torch.cat((invariant_0, shortcuts_0), -1),
torch.cat((invariant_1, shortcuts_1), -1)))
if split == "test":
x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:]
inputs = x @ self.scramble
outputs = torch.cat((torch.zeros(m, 1), torch.ones(m, 1)))
return inputs, outputs