def sample()

in scripts/datasets.py [0:0]


    def sample(self, n=1000, env="E0", split="train"):
        p = self.envs[env]["p"]
        s = self.envs[env]["s"]
        w = torch.Tensor([p, 1 - p] * 2) * torch.Tensor([s] * 2 + [1 - s] * 2)
        i = torch.multinomial(w, n, True)
        x = torch.cat((
            (torch.randn(n, self.dim_inv) /
                math.sqrt(10) + self.avg_fg[i]) * self.snr_fg,
            (torch.randn(n, self.dim_spu) /
                math.sqrt(10) + self.avg_bg[i]) * self.snr_bg), -1)

        if split == "test":
            x[:, self.dim_inv:] = x[torch.randperm(len(x)), self.dim_inv:]

        inputs = x @ self.scramble
        outputs = x[:, :self.dim_inv].sum(1, keepdim=True).gt(0).float()

        return inputs, outputs