courses/machine_learning/asl/05_review/2_sample_dataset.ipynb (541 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Creating a Sampled Dataset\n",
"\n",
"**Learning Objectives**\n",
"- Sample the natality dataset to create train/eval/test sets\n",
"- Preprocess the data in Pandas dataframe"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction\n",
"\n",
"In this notebook we'll read data from BigQuery into our notebook to preprocess the data within a Pandas dataframe. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PROJECT = \"cloud-training-demos\" # Replace with your PROJECT\n",
"BUCKET = \"cloud-training-bucket\" # Replace with your BUCKET\n",
"REGION = \"us-central1\" # Choose an available region for Cloud MLE\n",
"TFVERSION = \"1.14\" # TF version for CMLE to use"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"BUCKET\"] = BUCKET\n",
"os.environ[\"PROJECT\"] = PROJECT\n",
"os.environ[\"REGION\"] = REGION\n",
"os.environ[\"TFVERSION\"] = TFVERSION"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"if ! gsutil ls | grep -q gs://${BUCKET}/; then\n",
" gsutil mb -l ${REGION} gs://${BUCKET}\n",
"fi"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create ML datasets by sampling using BigQuery\n",
"\n",
"We'll begin by sampling the BigQuery data to create smaller datasets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create SQL query using natality data after the year 2000\n",
"query_string = \"\"\"\n",
"WITH\n",
" CTE_hash_cols_fixed AS (\n",
" SELECT\n",
" weight_pounds,\n",
" is_male,\n",
" mother_age,\n",
" plurality,\n",
" gestation_weeks,\n",
" year,\n",
" month,\n",
" CASE\n",
" WHEN day IS NULL AND wday IS NULL THEN 0\n",
" ELSE\n",
" CASE\n",
" WHEN day IS NULL THEN wday\n",
" ELSE\n",
" wday\n",
" END\n",
" END\n",
" AS date,\n",
" IFNULL(state,\n",
" \"Unknown\") AS state,\n",
" IFNULL(mother_birth_state,\n",
" \"Unknown\") AS mother_birth_state\n",
" FROM\n",
" publicdata.samples.natality\n",
" WHERE\n",
" year > 2000)\n",
"\n",
"SELECT\n",
" weight_pounds,\n",
" is_male,\n",
" mother_age,\n",
" plurality,\n",
" gestation_weeks,\n",
" FARM_FINGERPRINT(CONCAT(CAST(year AS STRING), CAST(month AS STRING), CAST(date AS STRING), CAST(state AS STRING), CAST(mother_birth_state AS STRING))) AS hashvalues\n",
"FROM\n",
" CTE_hash_cols_fixed\n",
"\"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are only a limited number of years, months, days, and states in the dataset. Let's see what the hash values are.\n",
"\n",
"We'll call BigQuery but group by the hashcolumn and see the number of records for each group. This will enable us to get the correct train/eval/test percentages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from google.cloud import bigquery\n",
"bq = bigquery.Client(project = PROJECT)\n",
"\n",
"df = bq.query(\"SELECT hashvalues, COUNT(weight_pounds) AS num_babies FROM (\" \n",
" + query_string + \n",
" \") GROUP BY hashvalues\").to_dataframe()\n",
"\n",
"print(\"There are {} unique hashvalues.\".format(len(df)))\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can make a query to check if our bucketing values result in the correct sizes of each of our dataset splits and then adjust accordingly"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampling_percentages_query = \"\"\"\n",
"WITH\n",
" -- Get label, features, and column that we are going to use to split into buckets on\n",
" CTE_hash_cols_fixed AS (\n",
" SELECT\n",
" weight_pounds,\n",
" is_male,\n",
" mother_age,\n",
" plurality,\n",
" gestation_weeks,\n",
" year,\n",
" month,\n",
" CASE\n",
" WHEN day IS NULL AND wday IS NULL THEN 0\n",
" ELSE\n",
" CASE\n",
" WHEN day IS NULL THEN wday\n",
" ELSE\n",
" wday\n",
" END\n",
" END\n",
" AS date,\n",
" IFNULL(state,\n",
" \"Unknown\") AS state,\n",
" IFNULL(mother_birth_state,\n",
" \"Unknown\") AS mother_birth_state\n",
" FROM\n",
" publicdata.samples.natality\n",
" WHERE\n",
" year > 2000),\n",
" CTE_data AS (\n",
" SELECT\n",
" weight_pounds,\n",
" is_male,\n",
" mother_age,\n",
" plurality,\n",
" gestation_weeks,\n",
" FARM_FINGERPRINT(CONCAT(CAST(year AS STRING), CAST(month AS STRING), CAST(date AS STRING), CAST(state AS STRING), CAST(mother_birth_state AS STRING))) AS hashvalues\n",
" FROM\n",
" CTE_hash_cols_fixed),\n",
" -- Get the counts of each of the unique hashs of our splitting column\n",
" CTE_first_bucketing AS (\n",
" SELECT\n",
" hashvalues,\n",
" COUNT(*) AS num_records\n",
" FROM\n",
" CTE_data\n",
" GROUP BY\n",
" hashvalues ),\n",
" -- Get the number of records in each of the hash buckets\n",
" CTE_second_bucketing AS (\n",
" SELECT\n",
" ABS(MOD(hashvalues, {0})) AS bucket_index,\n",
" SUM(num_records) AS num_records\n",
" FROM\n",
" CTE_first_bucketing\n",
" GROUP BY\n",
" ABS(MOD(hashvalues, {0}))),\n",
" -- Calculate the overall percentages\n",
" CTE_percentages AS (\n",
" SELECT\n",
" bucket_index,\n",
" num_records,\n",
" CAST(num_records AS FLOAT64) / (\n",
" SELECT\n",
" SUM(num_records)\n",
" FROM\n",
" CTE_second_bucketing) AS percent_records\n",
" FROM\n",
" CTE_second_bucketing ),\n",
" -- Choose which of the hash buckets will be used for training and pull in their statistics\n",
" CTE_train AS (\n",
" SELECT\n",
" *,\n",
" \"train\" AS dataset_name\n",
" FROM\n",
" CTE_percentages\n",
" WHERE\n",
" bucket_index >= 0\n",
" AND bucket_index < {1}),\n",
" -- Choose which of the hash buckets will be used for validation and pull in their statistics\n",
" CTE_eval AS (\n",
" SELECT\n",
" *,\n",
" \"eval\" AS dataset_name\n",
" FROM\n",
" CTE_percentages\n",
" WHERE\n",
" bucket_index >= {1}\n",
" AND bucket_index < {2}),\n",
" -- Choose which of the hash buckets will be used for testing and pull in their statistics\n",
" CTE_test AS (\n",
" SELECT\n",
" *,\n",
" \"test\" AS dataset_name\n",
" FROM\n",
" CTE_percentages\n",
" WHERE\n",
" bucket_index >= {2}\n",
" AND bucket_index < {0}),\n",
" -- Union the training, validation, and testing dataset statistics\n",
" CTE_union AS (\n",
" SELECT\n",
" 0 AS dataset_id,\n",
" *\n",
" FROM\n",
" CTE_train\n",
" UNION ALL\n",
" SELECT\n",
" 1 AS dataset_id,\n",
" *\n",
" FROM\n",
" CTE_eval\n",
" UNION ALL\n",
" SELECT\n",
" 2 AS dataset_id,\n",
" *\n",
" FROM\n",
" CTE_test ),\n",
" -- Show final splitting and associated statistics\n",
" CTE_split AS (\n",
" SELECT\n",
" dataset_id,\n",
" dataset_name,\n",
" SUM(num_records) AS num_records,\n",
" SUM(percent_records) AS percent_records\n",
" FROM\n",
" CTE_union\n",
" GROUP BY\n",
" dataset_id,\n",
" dataset_name )\n",
"SELECT\n",
" *\n",
"FROM\n",
" CTE_split\n",
"ORDER BY\n",
" dataset_id\n",
"\"\"\"\n",
"\n",
"modulo_divisor = 100\n",
"train_percent = 80.0\n",
"eval_percent = 10.0\n",
"\n",
"train_buckets = int(modulo_divisor * train_percent / 100.0)\n",
"eval_buckets = int(modulo_divisor * eval_percent / 100.0)\n",
"\n",
"df = bq.query(sampling_percentages_query.format(modulo_divisor, train_buckets, train_buckets + eval_buckets)).to_dataframe()\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's a way to get a well-distributed portion of the data in such a way that the train/eval/test sets do not overlap. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Added every_n so that we can now subsample from each of the hash values to get approximately the record counts we want\n",
"every_n = 500\n",
"\n",
"train_query = \"SELECT * FROM ({0}) WHERE ABS(MOD(hashvalues, {1} * 100)) < 80\".format(query_string, every_n)\n",
"eval_query = \"SELECT * FROM ({0}) WHERE ABS(MOD(hashvalues, {1} * 100)) >= 80 AND ABS(MOD(hashvalues, {1} * 100)) < 90\".format(query_string, every_n)\n",
"test_query = \"SELECT * FROM ({0}) WHERE ABS(MOD(hashvalues, {1} * 100)) >= 90 AND ABS(MOD(hashvalues, {1} * 100)) < 100\".format(query_string, every_n)\n",
"\n",
"train_df = bq.query(train_query).to_dataframe()\n",
"eval_df = bq.query(eval_query).to_dataframe()\n",
"test_df = bq.query(test_query).to_dataframe()\n",
"\n",
"print(\"There are {} examples in the train dataset.\".format(len(train_df)))\n",
"print(\"There are {} examples in the validation dataset.\".format(len(eval_df)))\n",
"print(\"There are {} examples in the test dataset.\".format(len(test_df)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess data using Pandas\n",
"\n",
"We'll perform a few preprocessing steps to the data in our dataset. Let's add extra rows to simulate the lack of ultrasound. That is we'll duplicate some rows and make the `is_male` field be `Unknown`. Also, if there is more than child we'll change the `plurality` to `Multiple(2+)`. While we're at it, We'll also change the plurality column to be a string. We'll perform these operations below. \n",
"\n",
"Let's start by examining the training dataset as is."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Also, notice that there are some very important numeric fields that are missing in some rows (the count in Pandas doesn't count missing data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is always crucial to clean raw data before using in machine learning, so we have a preprocessing step. We'll define a `preprocess` function below. Note that the mother's age is an input to our model so users will have to provide the mother's age; otherwise, our service won't work. The features we use for our model were chosen because they are such good predictors and because they are easy enough to collect."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"def preprocess(df):\n",
" # Clean up data\n",
" # Remove what we don\"t want to use for training\n",
" df = df[df.weight_pounds > 0]\n",
" df = df[df.mother_age > 0]\n",
" df = df[df.gestation_weeks > 0]\n",
" df = df[df.plurality > 0]\n",
"\n",
" # Modify plurality field to be a string\n",
" twins_etc = dict(zip([1,2,3,4,5],\n",
" [\"Single(1)\", \"Twins(2)\", \"Triplets(3)\", \"Quadruplets(4)\", \"Quintuplets(5)\"]))\n",
" df[\"plurality\"].replace(twins_etc, inplace = True)\n",
"\n",
" # Now create extra rows to simulate lack of ultrasound\n",
" no_ultrasound = df.copy(deep = True)\n",
" no_ultrasound.loc[no_ultrasound[\"plurality\"] != \"Single(1)\", \"plurality\"] = \"Multiple(2+)\"\n",
" no_ultrasound[\"is_male\"] = \"Unknown\"\n",
"\n",
" # Concatenate both datasets together and shuffle\n",
" return pd.concat([df, no_ultrasound]).sample(frac=1).reset_index(drop=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's process the train/eval/test set and see a small sample of the training data after our preprocessing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df = preprocess(train_df)\n",
"eval_df = preprocess(eval_df)\n",
"test_df = preprocess(test_df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.tail()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look again at a summary of the dataset. Note that we only see numeric columns, so `plurality` does not show up."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_df.describe()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write to .csv files \n",
"\n",
"In the final versions, we want to read from files, not Pandas dataframes. So, we write the Pandas dataframes out as csv files. Using csv files gives us the advantage of shuffling during read. This is important for distributed training because some workers might be slower than others, and shuffling the data helps prevent the same data from being assigned to the slow workers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"columns = \"weight_pounds,is_male,mother_age,plurality,gestation_weeks\".split(',')\n",
"train_df.to_csv(path_or_buf = \"train.csv\", columns = columns, header = False, index = False)\n",
"eval_df.to_csv(path_or_buf = \"eval.csv\", columns = columns, header = False, index = False)\n",
"test_df.to_csv(path_or_buf = \"test.csv\", columns = columns, header = False, index = False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"wc -l *.csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"head *.csv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"tail *.csv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright 2017-2018 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}