def cifar10()

in datasets.py [0:0]


def cifar10(data_dir, one_hot=True, test_size=None):
    test_size = test_size or 5000
    tr_data = [unpickle_cifar10(os.path.join(data_dir, 'data_batch_%d' % i)) for i in range(1, 6)]
    trX = np.vstack(data['data'] for data in tr_data)
    trY = np.asarray(flatten([data['labels'] for data in tr_data]))
    te_data = unpickle_cifar10(os.path.join(data_dir, 'test_batch'))
    teX = np.asarray(te_data['data'])
    teY = np.asarray(te_data['labels'])
    trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).reshape([-1, 3072])
    teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).reshape([-1, 3072])
    trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=test_size, random_state=11172018)
    if one_hot:
        trY = np.eye(10, dtype=np.float32)[trY]
        vaY = np.eye(10, dtype=np.float32)[vaY]
        teY = np.eye(10, dtype=np.float32)[teY]
    else:
        trY = np.reshape(trY, [-1, 1])
        vaY = np.reshape(vaY, [-1, 1])
        teY = np.reshape(teY, [-1, 1])
    return (trX, trY), (vaX, vaY), (teX, teY)