def get_data()

in data_loaders/get_mnist_cifar.py [0:0]


def get_data(problem, shards, rank, data_augmentation_level, n_batch_train, n_batch_test, n_batch_init, resolution):
    if problem == 'mnist':
        from keras.datasets import mnist
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        y_train = np.reshape(y_train, [-1])
        y_test = np.reshape(y_test, [-1])
        # Pad with zeros to make 32x32
        x_train = np.lib.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'minimum')
        # Pad with zeros to make 32x23
        x_test = np.lib.pad(x_test, ((0, 0), (2, 2), (2, 2)), 'minimum')
        x_train = np.tile(np.reshape(x_train, (-1, 32, 32, 1)), (1, 1, 1, 3))
        x_test = np.tile(np.reshape(x_test, (-1, 32, 32, 1)), (1, 1, 1, 3))
    elif problem == 'cifar10':
        from keras.datasets import cifar10
        (x_train, y_train), (x_test, y_test) = cifar10.load_data()
        y_train = np.reshape(y_train, [-1])
        y_test = np.reshape(y_test, [-1])
    else:
        raise Exception()

    print('n_train:', x_train.shape[0], 'n_test:', x_test.shape[0])

    # Shard before any shuffling
    x_train, y_train = shard((x_train, y_train), shards, rank)
    x_test, y_test = shard((x_test, y_test), shards, rank)

    print('n_shard_train:', x_train.shape[0], 'n_shard_test:', x_test.shape[0])

    from keras.preprocessing.image import ImageDataGenerator
    datagen_test = ImageDataGenerator()
    if data_augmentation_level == 0:
        datagen_train = ImageDataGenerator()
    else:
        if problem == 'mnist':
            datagen_train = ImageDataGenerator(
                width_shift_range=0.1,
                height_shift_range=0.1
            )
        elif problem == 'cifar10':
            if data_augmentation_level == 1:
                datagen_train = ImageDataGenerator(
                    width_shift_range=0.1,
                    height_shift_range=0.1
                )
            elif data_augmentation_level == 2:
                datagen_train = ImageDataGenerator(
                    width_shift_range=0.1,
                    height_shift_range=0.1,
                    horizontal_flip=True,
                    rotation_range=15,  # degrees rotation
                    zoom_range=0.1,
                    shear_range=0.02,
                )
            else:
                raise Exception()
        else:
            raise Exception()

    datagen_train.fit(x_train)
    datagen_test.fit(x_test)
    train_flow = datagen_train.flow(x_train, y_train, n_batch_train)
    test_flow = datagen_test.flow(x_test, y_test, n_batch_test, shuffle=False)

    def make_iterator(flow, resolution):
        def iterator():
            x_full, y = flow.next()
            x_full = x_full.astype(np.float32)
            x = downsample(x_full, resolution)
            x = x_to_uint8(x)
            return x, y

        return iterator

    #init_iterator = make_iterator(train_flow, resolution)
    train_iterator = make_iterator(train_flow, resolution)
    test_iterator = make_iterator(test_flow, resolution)

    # Get data for initialization
    data_init = make_batch(train_iterator, n_batch_train, n_batch_init)

    return train_iterator, test_iterator, data_init