in scripts/datasets.py [0:0]
def sample(self, n=1000, env="E0", split="train"):
p = self.envs[env]["p"]
s = self.envs[env]["s"]
w = torch.Tensor([p, 1 - p] * 2) * torch.Tensor([s] * 2 + [1 - s] * 2)
i = torch.multinomial(w, n, True)
x = torch.cat((
(torch.randn(n, self.dim_inv) /
math.sqrt(10) + self.avg_fg[i]) * self.snr_fg,
(torch.randn(n, self.dim_spu) /
math.sqrt(10) + self.avg_bg[i]) * self.snr_bg), -1)
if split == "test":
x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:]
inputs = x @ self.scramble
outputs = x[:, :self.dim_inv].sum(1, keepdim=True).gt(0).float()
return inputs, outputs