def split_dataset()

in src/data_utils.py [0:0]


def split_dataset(test_set_names, bucket, folder):
    """Splits S3 dataset into training and test by region

    Args:
        test_set_names: list of region names for test dataset
        folder: folder name within S3 bucket
        bucket: S3 bucket name
    """
    s3_client = boto3.client("s3")
    items = s3_client.list_objects_v2(Bucket=bucket, Prefix=folder)

    list_train = []
    list_test = []
    for item in items["Contents"]:
        file = item["Key"].split("/")[-1]
        if file.endswith(".csv"):
            list_train.append(file)

    for file_name in list_train:
        for pattern in test_set_names:
            if pattern in file_name:
                list_test.append(file_name)
                list_train.remove(file_name)
                continue

    df_train = pd.concat(
        [pd.read_csv(f"s3://{bucket}/{folder}/{item}") for item in list_train], axis=0
    )
    df_test = pd.concat(
        [pd.read_csv(f"s3://{bucket}/{folder}/{item}") for item in list_test], axis=0
    )

    # save dataframes with coordinates for plotting
    df_train.to_csv(f"s3://{bucket}/{folder}/train_with_coord.csv", index=False)
    df_test.to_csv(f"s3://{bucket}/{folder}/test_with_coord.csv", index=False)
    
    # remove the coordinates for training
    df_train = df_train.drop(["x", "y"], axis=1)
    df_test = df_test.drop(["x", "y"], axis=1)
    df_train.to_csv(f"s3://{bucket}/{folder}/train.csv", index=False)
    df_test.to_csv(f"s3://{bucket}/{folder}/test.csv", index=False)