in code/colored_mnist/main.py [0:0]
def make_environment(images, labels, e):
def torch_bernoulli(p, size):
return (torch.rand(size) < p).float()
def torch_xor(a, b):
return (a-b).abs() # Assumes both inputs are either 0 or 1
# 2x subsample for computational convenience
images = images.reshape((-1, 28, 28))[:, ::2, ::2]
# Assign a binary label based on the digit; flip label with probability 0.25
labels = (labels < 5).float()
labels = torch_xor(labels, torch_bernoulli(0.25, len(labels)))
# Assign a color based on the label; flip the color with probability e
colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
# Apply the color to the image by zeroing out the other color channel
images = torch.stack([images, images], dim=1)
images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0
return {
'images': (images.float() / 255.).cuda(),
'labels': labels[:, None].cuda()
}