def save_files()

in 5. MLOps SageMaker Project/sagemaker-workshop-preprocess-seedcode-v1/pipelines/preprocess/preprocess.py [0:0]


def save_files(base_dir: str, data_df: pd.DataFrame, data_fg: pd.DataFrame,
               val_size=0.2, test_size=0.05, current_host=None):
        
    logger.info(f"Splitting {len(data_df)} rows of data into train, val, test.")

    train_df, val_df = train_test_split(data_df, test_size=val_size, random_state=42)
    val_df, test_df = train_test_split(val_df, test_size=test_size, random_state=42)

    logger.info(f"Writing out datasets to {base_dir}")
    tmp_id = uuid.uuid4().hex[:8]
    train_df.to_csv(f"{base_dir}/train/train_{current_host}_{tmp_id}.csv", header=False, index=False)
    val_df.to_csv(f"{base_dir}/validation/validation_{current_host}_{tmp_id}.csv", header=False, index=False)

    # Save test data without header
    test_df.to_csv(f"{base_dir}/test/test_{current_host}_{tmp_id}.csv", header=False, index=False)

    return