def get_train_data()

in src/train.py [0:0]


def get_train_data(train, train_buff, batch_size, anomalyNumber, validNumber):
    
    train_x = np.load(train + '/train_x.npy')
    train_y = np.load(train + '/train_y.npy')
    
    train_x = train_x.astype('float32')
    train_x = train_x / 255
    train_y_one_hot = tf.keras.utils.to_categorical(train_y)

    train_validIdxs = np.where(np.isin(train_y, validNumber))[0]
    train_anomalyIdxs = np.where(train_y==anomalyNumber)[0]

    train_x_normal = train_x[train_validIdxs]
    train_y_normal = train_y[train_validIdxs]

    train_x_anomaly = train_x[train_anomalyIdxs]
    train_y_anomaly = train_y[train_anomalyIdxs]

    print('train normal x: ', np.shape(train_x_normal))
    print('train normal y: ', np.shape(train_y_normal))
    print('train anomaly x: ', np.shape(train_x_anomaly))
    print('train anomaly y: ', np.shape(train_y_anomaly))
    
    train_x_normal_dataset = (tf.data.Dataset.from_tensor_slices(train_x_normal)).shuffle(train_buff).batch(batch_size)
    train_x_anomaly_dataset = (tf.data.Dataset.from_tensor_slices(train_x_anomaly)).shuffle(train_buff).batch(batch_size)
    
    return train_x_normal, train_x_normal_dataset, train_y_normal, train_x_anomaly, train_x_anomaly_dataset, train_y_anomaly