data_loaders/get_mnist_cifar.py (99 lines of code) (raw):
import numpy as np
def downsample(x, resolution):
assert x.dtype == np.float32
assert x.shape[1] % resolution == 0
assert x.shape[2] % resolution == 0
if x.shape[1] == x.shape[2] == resolution:
return x
s = x.shape
x = np.reshape(x, [s[0], resolution, s[1] // resolution,
resolution, s[2] // resolution, s[3]])
x = np.mean(x, (2, 4))
return x
def x_to_uint8(x):
x = np.clip(np.floor(x), 0, 255)
return x.astype(np.uint8)
def shard(data, shards, rank):
# Determinisitc shards
x, y = data
assert x.shape[0] == y.shape[0]
assert x.shape[0] % shards == 0
assert 0 <= rank < shards
size = x.shape[0] // shards
ind = rank*size
return x[ind:ind+size], y[ind:ind+size]
def get_data(problem, shards, rank, data_augmentation_level, n_batch_train, n_batch_test, n_batch_init, resolution):
if problem == 'mnist':
from keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = np.reshape(y_train, [-1])
y_test = np.reshape(y_test, [-1])
# Pad with zeros to make 32x32
x_train = np.lib.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'minimum')
# Pad with zeros to make 32x23
x_test = np.lib.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'minimum')
x_train = np.tile(np.reshape(x_train, (-1, 32, 32, 1)), (1, 1, 1, 3))
x_test = np.tile(np.reshape(x_test, (-1, 32, 32, 1)), (1, 1, 1, 3))
elif problem == 'cifar10':
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np.reshape(y_train, [-1])
y_test = np.reshape(y_test, [-1])
else:
raise Exception()
print('n_train:', x_train.shape[0], 'n_test:', x_test.shape[0])
# Shard before any shuffling
x_train, y_train = shard((x_train, y_train), shards, rank)
x_test, y_test = shard((x_test, y_test), shards, rank)
print('n_shard_train:', x_train.shape[0], 'n_shard_test:', x_test.shape[0])
from keras.preprocessing.image import ImageDataGenerator
datagen_test = ImageDataGenerator()
if data_augmentation_level == 0:
datagen_train = ImageDataGenerator()
else:
if problem == 'mnist':
datagen_train = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1
)
elif problem == 'cifar10':
if data_augmentation_level == 1:
datagen_train = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1
)
elif data_augmentation_level == 2:
datagen_train = ImageDataGenerator(
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
rotation_range=15, # degrees rotation
zoom_range=0.1,
shear_range=0.02,
)
else:
raise Exception()
else:
raise Exception()
datagen_train.fit(x_train)
datagen_test.fit(x_test)
train_flow = datagen_train.flow(x_train, y_train, n_batch_train)
test_flow = datagen_test.flow(x_test, y_test, n_batch_test, shuffle=False)
def make_iterator(flow, resolution):
def iterator():
x_full, y = flow.next()
x_full = x_full.astype(np.float32)
x = downsample(x_full, resolution)
x = x_to_uint8(x)
return x, y
return iterator
#init_iterator = make_iterator(train_flow, resolution)
train_iterator = make_iterator(train_flow, resolution)
test_iterator = make_iterator(test_flow, resolution)
# Get data for initialization
data_init = make_batch(train_iterator, n_batch_train, n_batch_init)
return train_iterator, test_iterator, data_init
def make_batch(iterator, iterator_batch_size, required_batch_size):
ib, rb = iterator_batch_size, required_batch_size
#assert rb % ib == 0
k = int(np.ceil(rb / ib))
xs, ys = [], []
for i in range(k):
x, y = iterator()
xs.append(x)
ys.append(y)
x, y = np.concatenate(xs)[:rb], np.concatenate(ys)[:rb]
return {'x': x, 'y': y}