in train.py [0:0]
def get_data(hps, sess):
if hps.image_size == -1:
hps.image_size = {'mnist': 32, 'cifar10': 32, 'imagenet-oord': 64,
'imagenet': 256, 'celeba': 256, 'lsun_realnvp': 64, 'lsun': 256}[hps.problem]
if hps.n_test == -1:
hps.n_test = {'mnist': 10000, 'cifar10': 10000, 'imagenet-oord': 50000, 'imagenet': 50000,
'celeba': 3000, 'lsun_realnvp': 300*hvd.size(), 'lsun': 300*hvd.size()}[hps.problem]
hps.n_y = {'mnist': 10, 'cifar10': 10, 'imagenet-oord': 1000,
'imagenet': 1000, 'celeba': 1, 'lsun_realnvp': 1, 'lsun': 1}[hps.problem]
if hps.data_dir == "":
hps.data_dir = {'mnist': None, 'cifar10': None, 'imagenet-oord': '/mnt/host/imagenet-oord-tfr', 'imagenet': '/mnt/host/imagenet-tfr',
'celeba': '/mnt/host/celeba-reshard-tfr', 'lsun_realnvp': '/mnt/host/lsun_realnvp', 'lsun': '/mnt/host/lsun'}[hps.problem]
if hps.problem == 'lsun_realnvp':
hps.rnd_crop = True
else:
hps.rnd_crop = False
if hps.category:
hps.data_dir += ('/%s' % hps.category)
# Use anchor_size to rescale batch size based on image_size
s = hps.anchor_size
hps.local_batch_train = hps.n_batch_train * \
s * s // (hps.image_size * hps.image_size)
hps.local_batch_test = {64: 50, 32: 25, 16: 10, 8: 5, 4: 2, 2: 2, 1: 1}[
hps.local_batch_train] # round down to closest divisor of 50
hps.local_batch_init = hps.n_batch_init * \
s * s // (hps.image_size * hps.image_size)
print("Rank {} Batch sizes Train {} Test {} Init {}".format(
hvd.rank(), hps.local_batch_train, hps.local_batch_test, hps.local_batch_init))
if hps.problem in ['imagenet-oord', 'imagenet', 'celeba', 'lsun_realnvp', 'lsun']:
hps.direct_iterator = True
import data_loaders.get_data as v
train_iterator, test_iterator, data_init = \
v.get_data(sess, hps.data_dir, hvd.size(), hvd.rank(), hps.pmap, hps.fmap, hps.local_batch_train,
hps.local_batch_test, hps.local_batch_init, hps.image_size, hps.rnd_crop)
elif hps.problem in ['mnist', 'cifar10']:
hps.direct_iterator = False
import data_loaders.get_mnist_cifar as v
train_iterator, test_iterator, data_init = \
v.get_data(hps.problem, hvd.size(), hvd.rank(), hps.dal, hps.local_batch_train,
hps.local_batch_test, hps.local_batch_init, hps.image_size)
else:
raise Exception()
return train_iterator, test_iterator, data_init