in data.py [0:0]
def set_up_data(H):
shift_loss = -127.5
scale_loss = 1. / 127.5
if H.dataset == 'imagenet32':
trX, vaX, teX = imagenet32(H.data_root)
H.image_size = 32
H.image_channels = 3
shift = -116.2373
scale = 1. / 69.37404
elif H.dataset == 'imagenet64':
trX, vaX, teX = imagenet64(H.data_root)
H.image_size = 64
H.image_channels = 3
shift = -115.92961967
scale = 1. / 69.37404
elif H.dataset == 'ffhq_256':
trX, vaX, teX = ffhq256(H.data_root)
H.image_size = 256
H.image_channels = 3
shift = -112.8666757481
scale = 1. / 69.84780273
elif H.dataset == 'ffhq_1024':
trX, vaX, teX = ffhq1024(H.data_root)
H.image_size = 1024
H.image_channels = 3
shift = -0.4387
scale = 1.0 / 0.2743
shift_loss = -0.5
scale_loss = 2.0
elif H.dataset == 'cifar10':
(trX, _), (vaX, _), (teX, _) = cifar10(H.data_root, one_hot=False)
H.image_size = 32
H.image_channels = 3
shift = -120.63838
scale = 1. / 64.16736
else:
raise ValueError('unknown dataset: ', H.dataset)
do_low_bit = H.dataset in ['ffhq_256']
if H.test_eval:
print('DOING TEST')
eval_dataset = teX
else:
eval_dataset = vaX
shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1)
scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1)
shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1)
scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1)
if H.dataset == 'ffhq_1024':
train_data = ImageFolder(trX, transforms.ToTensor())
valid_data = ImageFolder(eval_dataset, transforms.ToTensor())
untranspose = True
else:
train_data = TensorDataset(torch.as_tensor(trX))
valid_data = TensorDataset(torch.as_tensor(eval_dataset))
untranspose = False
def preprocess_func(x):
nonlocal shift
nonlocal scale
nonlocal shift_loss
nonlocal scale_loss
nonlocal do_low_bit
nonlocal untranspose
'takes in a data example and returns the preprocessed input'
'as well as the input processed for the loss'
if untranspose:
x[0] = x[0].permute(0, 2, 3, 1)
inp = x[0].cuda(non_blocking=True).float()
out = inp.clone()
inp.add_(shift).mul_(scale)
if do_low_bit:
# 5 bits of precision
out.mul_(1. / 8.).floor_().mul_(8.)
out.add_(shift_loss).mul_(scale_loss)
return inp, out
return H, train_data, valid_data, preprocess_func