def main()

in src/mlmax/preprocessing.py [0:0]


def main(args):
    """
    To run locally:

    DATA=s3://sagemaker-sample-data-us-east-1/processing/census/census-income.csv
    aws s3 cp $DATA /tmp/input/
    mkdir /tmp/{train,test,model}
    python preprocessing.py --mode "train" --data-dir /tmp
    """
    input_data_path = os.path.join(args.data_dir, args.data_input)
    df = read_data(input_data_path)

    if args.mode == "infer":
        test_features = transform(df, args)
        write_data(test_features, args, "test/test_features.csv")
        if target_col in df.columns:
            write_data(df[target_col], args, "test/test_labels.csv")
        return test_features
    elif args.mode == "train":
        X_train, X_test, y_train, y_test = split_data(df, args)
        preprocess = fit(X_train, args)
        train_features = transform(X_train, args, preprocess)
        test_features = transform(X_test, args, preprocess)
        write_data(train_features, args, "train/train_features.csv")
        write_data(y_train, args, "train/train_labels.csv")
        write_data(test_features, args, "test/test_features.csv")
        write_data(y_test, args, "test/test_labels.csv")
        return train_features, test_features