in scripts/datasets.py [0:0]
def __init__(self, dim_inv, dim_spu, n_envs):
self.scramble = torch.eye(dim_inv + dim_spu)
self.dim_inv = dim_inv
self.dim_spu = dim_spu
self.dim = dim_inv + dim_spu
self.task = "classification"
self.envs = {}
if n_envs >= 2:
self.envs = {
'E0': {"p": 0.95, "s": 0.3},
'E1': {"p": 0.97, "s": 0.5}
}
if n_envs >= 3:
self.envs["E2"] = {"p": 0.99, "s": 0.7}
if n_envs > 3:
for env in range(3, n_envs):
self.envs["E" + str(env)] = {
"p": torch.zeros(1).uniform_(0.9, 1).item(),
"s": torch.zeros(1).uniform_(0.3, 0.7).item()
}
# foreground is 100x noisier than background
self.snr_fg = 1e-2
self.snr_bg = 1
# foreground (fg) denotes animal (cow / camel)
cow = torch.ones(1, self.dim_inv)
self.avg_fg = torch.cat((cow, cow, -cow, -cow))
# background (bg) denotes context (grass / sand)
grass = torch.ones(1, self.dim_spu)
self.avg_bg = torch.cat((grass, -grass, -grass, grass))