def get_train_data()

in cifar_train.py [0:0]


def get_train_data(train_dir):
    X_train = np.load(os.path.join(train_dir, 'X_train.npy'))
    y_train = np.load(os.path.join(train_dir, 'y_train.npy'))
    logger.info(f'X_train: {X_train.shape} | y_train: {y_train.shape}')
    return X_train, y_train