08_bqml/bqml_logistic.ipynb (1,515 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h1> Training flight delay model in BigQuery ML </h1>\n",
"\n",
"Run this notebook in Vertex Workbench. In this notebook, we will use BigQuery ML to train the same model that we did in Spark ML.\n",
"\n",
"Note how much easier this is ... (and also much more scaleable)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Verify dataset\n",
"\n",
"Let's make sure that we have the traindays and flights data in BigQuery. If you don't, please follow steps in the README.md in this directory."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.05s: 100%|██████████| 2/2 [00:00<00:00, 383.69query/s] \n",
"Downloading: 100%|██████████| 5/5 [00:01<00:00, 4.14rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>FL_DATE</th>\n",
" <th>UNIQUE_CARRIER</th>\n",
" <th>ORIGIN_AIRPORT_SEQ_ID</th>\n",
" <th>ORIGIN</th>\n",
" <th>DEST_AIRPORT_SEQ_ID</th>\n",
" <th>DEST</th>\n",
" <th>CRS_DEP_TIME</th>\n",
" <th>DEP_TIME</th>\n",
" <th>DEP_DELAY</th>\n",
" <th>TAXI_OUT</th>\n",
" <th>...</th>\n",
" <th>ARR_DELAY</th>\n",
" <th>CANCELLED</th>\n",
" <th>DIVERTED</th>\n",
" <th>DISTANCE</th>\n",
" <th>DEP_AIRPORT_LAT</th>\n",
" <th>DEP_AIRPORT_LON</th>\n",
" <th>DEP_AIRPORT_TZOFFSET</th>\n",
" <th>ARR_AIRPORT_LAT</th>\n",
" <th>ARR_AIRPORT_LON</th>\n",
" <th>ARR_AIRPORT_TZOFFSET</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2015-12-09</td>\n",
" <td>AS</td>\n",
" <td>1029904</td>\n",
" <td>ANC</td>\n",
" <td>1055102</td>\n",
" <td>BET</td>\n",
" <td>2015-12-10 04:00:00+00:00</td>\n",
" <td>2015-12-10 03:55:00+00:00</td>\n",
" <td>-5.0</td>\n",
" <td>10.0</td>\n",
" <td>...</td>\n",
" <td>-3.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>399.0</td>\n",
" <td>61.174167</td>\n",
" <td>-149.998056</td>\n",
" <td>-32400.0</td>\n",
" <td>60.778611</td>\n",
" <td>-161.837222</td>\n",
" <td>-32400.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2015-02-07</td>\n",
" <td>B6</td>\n",
" <td>1169703</td>\n",
" <td>FLL</td>\n",
" <td>1484304</td>\n",
" <td>SJU</td>\n",
" <td>2015-02-07 13:50:00+00:00</td>\n",
" <td>2015-02-07 14:23:00+00:00</td>\n",
" <td>33.0</td>\n",
" <td>13.0</td>\n",
" <td>...</td>\n",
" <td>22.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>1046.0</td>\n",
" <td>26.072500</td>\n",
" <td>-80.152778</td>\n",
" <td>-18000.0</td>\n",
" <td>18.439444</td>\n",
" <td>-66.002222</td>\n",
" <td>-14400.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2015-10-29</td>\n",
" <td>AS</td>\n",
" <td>1482803</td>\n",
" <td>SIT</td>\n",
" <td>1252303</td>\n",
" <td>JNU</td>\n",
" <td>2015-10-29 14:00:00+00:00</td>\n",
" <td>2015-10-29 13:57:00+00:00</td>\n",
" <td>-3.0</td>\n",
" <td>8.0</td>\n",
" <td>...</td>\n",
" <td>10.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>95.0</td>\n",
" <td>57.046944</td>\n",
" <td>-135.361111</td>\n",
" <td>-28800.0</td>\n",
" <td>58.354722</td>\n",
" <td>-134.574722</td>\n",
" <td>-28800.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2015-10-02</td>\n",
" <td>HA</td>\n",
" <td>1298202</td>\n",
" <td>LIH</td>\n",
" <td>1383002</td>\n",
" <td>OGG</td>\n",
" <td>2015-10-02 22:53:00+00:00</td>\n",
" <td>2015-10-02 22:58:00+00:00</td>\n",
" <td>5.0</td>\n",
" <td>9.0</td>\n",
" <td>...</td>\n",
" <td>8.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>201.0</td>\n",
" <td>21.976111</td>\n",
" <td>-159.338889</td>\n",
" <td>-36000.0</td>\n",
" <td>20.898611</td>\n",
" <td>-156.430556</td>\n",
" <td>-36000.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2015-06-30</td>\n",
" <td>AS</td>\n",
" <td>1075403</td>\n",
" <td>BRW</td>\n",
" <td>1470903</td>\n",
" <td>SCC</td>\n",
" <td>2015-07-01 04:20:00+00:00</td>\n",
" <td>2015-07-01 04:16:00+00:00</td>\n",
" <td>-4.0</td>\n",
" <td>5.0</td>\n",
" <td>...</td>\n",
" <td>-8.0</td>\n",
" <td>False</td>\n",
" <td>False</td>\n",
" <td>204.0</td>\n",
" <td>71.284722</td>\n",
" <td>-156.768611</td>\n",
" <td>-28800.0</td>\n",
" <td>70.195556</td>\n",
" <td>-148.465833</td>\n",
" <td>-28800.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 25 columns</p>\n",
"</div>"
],
"text/plain": [
" FL_DATE UNIQUE_CARRIER ORIGIN_AIRPORT_SEQ_ID ORIGIN DEST_AIRPORT_SEQ_ID \\\n",
"0 2015-12-09 AS 1029904 ANC 1055102 \n",
"1 2015-02-07 B6 1169703 FLL 1484304 \n",
"2 2015-10-29 AS 1482803 SIT 1252303 \n",
"3 2015-10-02 HA 1298202 LIH 1383002 \n",
"4 2015-06-30 AS 1075403 BRW 1470903 \n",
"\n",
" DEST CRS_DEP_TIME DEP_TIME DEP_DELAY \\\n",
"0 BET 2015-12-10 04:00:00+00:00 2015-12-10 03:55:00+00:00 -5.0 \n",
"1 SJU 2015-02-07 13:50:00+00:00 2015-02-07 14:23:00+00:00 33.0 \n",
"2 JNU 2015-10-29 14:00:00+00:00 2015-10-29 13:57:00+00:00 -3.0 \n",
"3 OGG 2015-10-02 22:53:00+00:00 2015-10-02 22:58:00+00:00 5.0 \n",
"4 SCC 2015-07-01 04:20:00+00:00 2015-07-01 04:16:00+00:00 -4.0 \n",
"\n",
" TAXI_OUT ... ARR_DELAY CANCELLED DIVERTED DISTANCE DEP_AIRPORT_LAT \\\n",
"0 10.0 ... -3.0 False False 399.0 61.174167 \n",
"1 13.0 ... 22.0 False False 1046.0 26.072500 \n",
"2 8.0 ... 10.0 False False 95.0 57.046944 \n",
"3 9.0 ... 8.0 False False 201.0 21.976111 \n",
"4 5.0 ... -8.0 False False 204.0 71.284722 \n",
"\n",
" DEP_AIRPORT_LON DEP_AIRPORT_TZOFFSET ARR_AIRPORT_LAT ARR_AIRPORT_LON \\\n",
"0 -149.998056 -32400.0 60.778611 -161.837222 \n",
"1 -80.152778 -18000.0 18.439444 -66.002222 \n",
"2 -135.361111 -28800.0 58.354722 -134.574722 \n",
"3 -159.338889 -36000.0 20.898611 -156.430556 \n",
"4 -156.768611 -28800.0 70.195556 -148.465833 \n",
"\n",
" ARR_AIRPORT_TZOFFSET \n",
"0 -32400.0 \n",
"1 -14400.0 \n",
"2 -28800.0 \n",
"3 -36000.0 \n",
"4 -28800.0 \n",
"\n",
"[5 rows x 25 columns]"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * FROM dsongcp.flights_tzcorr\n",
"LIMIT 5"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 383.64query/s] \n",
"Downloading: 100%|██████████| 5/5 [00:01<00:00, 4.52rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>FL_DATE</th>\n",
" <th>is_train_day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2015-01-01</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2015-01-02</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2015-01-03</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>2015-01-04</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>2015-01-05</td>\n",
" <td>True</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" FL_DATE is_train_day\n",
"0 2015-01-01 True\n",
"1 2015-01-02 False\n",
"2 2015-01-03 False\n",
"3 2015-01-04 True\n",
"4 2015-01-05 True"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * FROM dsongcp.trainday\n",
"LIMIT 5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Logistic regression\n",
"\n",
"Let's use SQL to craft the dataset just the way we want it."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1109.31query/s] \n",
"Downloading: 100%|██████████| 5/5 [00:01<00:00, 4.66rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ontime</th>\n",
" <th>dep_delay</th>\n",
" <th>taxi_out</th>\n",
" <th>distance</th>\n",
" <th>is_eval_day</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ontime</td>\n",
" <td>-5.0</td>\n",
" <td>10.0</td>\n",
" <td>399.0</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>late</td>\n",
" <td>33.0</td>\n",
" <td>13.0</td>\n",
" <td>1046.0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ontime</td>\n",
" <td>-3.0</td>\n",
" <td>8.0</td>\n",
" <td>95.0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ontime</td>\n",
" <td>5.0</td>\n",
" <td>9.0</td>\n",
" <td>201.0</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ontime</td>\n",
" <td>-4.0</td>\n",
" <td>5.0</td>\n",
" <td>204.0</td>\n",
" <td>False</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ontime dep_delay taxi_out distance is_eval_day\n",
"0 ontime -5.0 10.0 399.0 True\n",
"1 late 33.0 13.0 1046.0 False\n",
"2 ontime -3.0 8.0 95.0 False\n",
"3 ontime 5.0 9.0 201.0 True\n",
"4 ontime -4.0 5.0 204.0 False"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance,\n",
" IF(is_train_day = 'True', False, True) AS is_eval_day\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND\n",
" f.DIVERTED = False\n",
"LIMIT 5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model:\n",
"* ontime column in the label\n",
"* the model is a logistic regression\n",
"* use the is_eval_day column to split the data\n",
"* remaining columns are features\n",
"\n",
"This will take about 10 minutes."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1061.58query/s] \n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: []\n",
"Index: []"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"CREATE OR REPLACE MODEL dsongcp.arr_delay_lm\n",
"OPTIONS(input_label_cols=['ontime'], \n",
" model_type='logistic_reg', \n",
" data_split_method='custom',\n",
" data_split_col='is_eval_day')\n",
"AS\n",
"\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance,\n",
" IF(is_train_day = 'True', False, True) AS is_eval_day\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND\n",
" f.DIVERTED = False"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 644.29query/s] \n",
"Downloading: 100%|██████████| 20/20 [00:01<00:00, 18.82rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>training_run</th>\n",
" <th>iteration</th>\n",
" <th>loss</th>\n",
" <th>eval_loss</th>\n",
" <th>learning_rate</th>\n",
" <th>duration_ms</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>19</td>\n",
" <td>0.000003</td>\n",
" <td>0.000004</td>\n",
" <td>104857.6</td>\n",
" <td>3306</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>18</td>\n",
" <td>0.000007</td>\n",
" <td>0.000007</td>\n",
" <td>52428.8</td>\n",
" <td>3350</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0</td>\n",
" <td>17</td>\n",
" <td>0.000013</td>\n",
" <td>0.000012</td>\n",
" <td>26214.4</td>\n",
" <td>2941</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0</td>\n",
" <td>16</td>\n",
" <td>0.000027</td>\n",
" <td>0.000020</td>\n",
" <td>13107.2</td>\n",
" <td>2921</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>15</td>\n",
" <td>0.000054</td>\n",
" <td>0.000033</td>\n",
" <td>6553.6</td>\n",
" <td>3370</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0</td>\n",
" <td>14</td>\n",
" <td>0.000108</td>\n",
" <td>0.000058</td>\n",
" <td>3276.8</td>\n",
" <td>2693</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>0</td>\n",
" <td>13</td>\n",
" <td>0.000216</td>\n",
" <td>0.000102</td>\n",
" <td>1638.4</td>\n",
" <td>3356</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>0</td>\n",
" <td>12</td>\n",
" <td>0.000432</td>\n",
" <td>0.000182</td>\n",
" <td>819.2</td>\n",
" <td>3068</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>0</td>\n",
" <td>11</td>\n",
" <td>0.000865</td>\n",
" <td>0.000333</td>\n",
" <td>409.6</td>\n",
" <td>3291</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>0</td>\n",
" <td>10</td>\n",
" <td>0.001734</td>\n",
" <td>0.000625</td>\n",
" <td>204.8</td>\n",
" <td>3051</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>0</td>\n",
" <td>9</td>\n",
" <td>0.003484</td>\n",
" <td>0.001207</td>\n",
" <td>102.4</td>\n",
" <td>3412</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>0</td>\n",
" <td>8</td>\n",
" <td>0.007037</td>\n",
" <td>0.002421</td>\n",
" <td>51.2</td>\n",
" <td>3210</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>0</td>\n",
" <td>7</td>\n",
" <td>0.014348</td>\n",
" <td>0.005100</td>\n",
" <td>25.6</td>\n",
" <td>3731</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>0</td>\n",
" <td>6</td>\n",
" <td>0.029768</td>\n",
" <td>0.011594</td>\n",
" <td>12.8</td>\n",
" <td>3482</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>0</td>\n",
" <td>5</td>\n",
" <td>0.063422</td>\n",
" <td>0.029168</td>\n",
" <td>6.4</td>\n",
" <td>3982</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>0</td>\n",
" <td>4</td>\n",
" <td>0.136566</td>\n",
" <td>0.085203</td>\n",
" <td>3.2</td>\n",
" <td>4235</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0.267786</td>\n",
" <td>0.225114</td>\n",
" <td>1.6</td>\n",
" <td>4171</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>0.426662</td>\n",
" <td>0.409772</td>\n",
" <td>0.8</td>\n",
" <td>3742</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>0.558553</td>\n",
" <td>0.555572</td>\n",
" <td>0.4</td>\n",
" <td>4570</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>0.644285</td>\n",
" <td>0.644522</td>\n",
" <td>0.2</td>\n",
" <td>3581</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" training_run iteration loss eval_loss learning_rate duration_ms\n",
"0 0 19 0.000003 0.000004 104857.6 3306\n",
"1 0 18 0.000007 0.000007 52428.8 3350\n",
"2 0 17 0.000013 0.000012 26214.4 2941\n",
"3 0 16 0.000027 0.000020 13107.2 2921\n",
"4 0 15 0.000054 0.000033 6553.6 3370\n",
"5 0 14 0.000108 0.000058 3276.8 2693\n",
"6 0 13 0.000216 0.000102 1638.4 3356\n",
"7 0 12 0.000432 0.000182 819.2 3068\n",
"8 0 11 0.000865 0.000333 409.6 3291\n",
"9 0 10 0.001734 0.000625 204.8 3051\n",
"10 0 9 0.003484 0.001207 102.4 3412\n",
"11 0 8 0.007037 0.002421 51.2 3210\n",
"12 0 7 0.014348 0.005100 25.6 3731\n",
"13 0 6 0.029768 0.011594 12.8 3482\n",
"14 0 5 0.063422 0.029168 6.4 3982\n",
"15 0 4 0.136566 0.085203 3.2 4235\n",
"16 0 3 0.267786 0.225114 1.6 4171\n",
"17 0 2 0.426662 0.409772 0.8 3742\n",
"18 0 1 0.558553 0.555572 0.4 4570\n",
"19 0 0 0.644285 0.644522 0.2 3581"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * FROM ML.TRAINING_INFO(MODEL dsongcp.arr_delay_lm)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 512.13query/s] \n",
"Downloading: 100%|██████████| 4/4 [00:01<00:00, 3.72rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>processed_input</th>\n",
" <th>weight</th>\n",
" <th>category_weights</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>dep_delay</td>\n",
" <td>-0.132984</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>taxi_out</td>\n",
" <td>-0.121715</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>distance</td>\n",
" <td>0.000223</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>__INTERCEPT__</td>\n",
" <td>4.762572</td>\n",
" <td>[]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" processed_input weight category_weights\n",
"0 dep_delay -0.132984 []\n",
"1 taxi_out -0.121715 []\n",
"2 distance 0.000223 []\n",
"3 __INTERCEPT__ 4.762572 []"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * FROM ML.WEIGHTS(MODEL dsongcp.arr_delay_lm)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 2/2 [00:00<00:00, 815.93query/s] \n",
"Downloading: 100%|██████████| 1/1 [00:01<00:00, 1.12s/rows]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>predicted_ontime</th>\n",
" <th>predicted_ontime_probs</th>\n",
" <th>dep_delay</th>\n",
" <th>taxi_out</th>\n",
" <th>distance</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ontime</td>\n",
" <td>[{'label': 'ontime', 'prob': 0.850350772376706...</td>\n",
" <td>12.0</td>\n",
" <td>14.0</td>\n",
" <td>1231</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" predicted_ontime predicted_ontime_probs \\\n",
"0 ontime [{'label': 'ontime', 'prob': 0.850350772376706... \n",
"\n",
" dep_delay taxi_out distance \n",
"0 12.0 14.0 1231 "
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * FROM ML.PREDICT(MODEL dsongcp.arr_delay_lm,\n",
" (\n",
"SELECT 12.0 AS dep_delay, 14.0 AS taxi_out, 1231 AS distance\n",
" ))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 7/7 [00:00<00:00, 2886.65query/s] \n",
"Downloading: 100%|██████████| 1/1 [00:02<00:00, 2.39s/rows]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>accuracy</th>\n",
" <th>f1_score</th>\n",
" <th>log_loss</th>\n",
" <th>roc_auc</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.964337</td>\n",
" <td>0.956535</td>\n",
" <td>0.935174</td>\n",
" <td>0.96042</td>\n",
" <td>0.167233</td>\n",
" <td>0.956248</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" precision recall accuracy f1_score log_loss roc_auc\n",
"0 0.964337 0.956535 0.935174 0.96042 0.167233 0.956248"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * \n",
"FROM ML.EVALUATE(MODEL dsongcp.arr_delay_lm,\n",
" (\n",
" \n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND\n",
" f.DIVERTED = False AND\n",
" is_train_day = 'False'\n",
" \n",
" ),\n",
" STRUCT(0.7 AS threshold))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Do same metrics as in Spark code\n",
"\n",
"We are using ML.PREDICT and computing the necessary stats"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 4/4 [00:00<00:00, 1431.14query/s] \n",
"Downloading: 100%|██████████| 1/1 [00:01<00:00, 1.14s/rows]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>correct_cancel</th>\n",
" <th>total_noncancel</th>\n",
" <th>correct_noncancel</th>\n",
" <th>total_cancel</th>\n",
" <th>rmse</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.836363</td>\n",
" <td>1301948</td>\n",
" <td>0.964337</td>\n",
" <td>283750</td>\n",
" <td>0.213091</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" correct_cancel total_noncancel correct_noncancel total_cancel rmse\n",
"0 0.836363 1301948 0.964337 283750 0.213091"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"\n",
"WITH predictions AS (\n",
"SELECT \n",
" *\n",
"FROM ML.PREDICT(MODEL dsongcp.arr_delay_lm,\n",
" (\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND \n",
" f.DIVERTED = False AND\n",
" t.is_train_day = 'False'\n",
" ),\n",
" STRUCT(0.7 AS threshold))),\n",
"\n",
"stats AS (\n",
"SELECT \n",
" COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
" , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
" , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
" , COUNTIF(ontime != 'ontime') AS total_cancel\n",
" , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
"FROM predictions, UNNEST(predicted_ontime_probs) p\n",
"WHERE p.label = 'ontime'\n",
")\n",
"\n",
"SELECT\n",
" correct_cancel / total_cancel AS correct_cancel\n",
" , total_noncancel\n",
" , correct_noncancel / total_noncancel AS correct_noncancel\n",
" , total_cancel\n",
" , rmse\n",
"FROM stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Add airport info\n",
"\n",
"Add airport information to model (note two additional columns: origin and dest). This seemingly simple change adds two categorical variables that, when one-hot-encoded, adds 600+ new columns to the model. BigQuery ML doesn't skip a beat ...\n",
"\n",
"This query will take ~10 minutes"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1256.16query/s] \n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"Empty DataFrame\n",
"Columns: []\n",
"Index: []"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_lm\n",
"OPTIONS(input_label_cols=['ontime'], \n",
" model_type='logistic_reg', \n",
" data_split_method='custom',\n",
" data_split_col='is_eval_day')\n",
"AS\n",
"\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance,\n",
" origin,\n",
" dest,\n",
" IF(is_train_day = 'True', False, True) AS is_eval_day\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND \n",
" f.DIVERTED = False"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 7/7 [00:00<00:00, 2760.97query/s] \n",
"Downloading: 100%|██████████| 1/1 [00:01<00:00, 1.11s/rows]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>precision</th>\n",
" <th>recall</th>\n",
" <th>accuracy</th>\n",
" <th>f1_score</th>\n",
" <th>log_loss</th>\n",
" <th>roc_auc</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.967151</td>\n",
" <td>0.957706</td>\n",
" <td>0.938477</td>\n",
" <td>0.962405</td>\n",
" <td>0.165557</td>\n",
" <td>0.960821</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" precision recall accuracy f1_score log_loss roc_auc\n",
"0 0.967151 0.957706 0.938477 0.962405 0.165557 0.960821"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"SELECT * \n",
"FROM ML.EVALUATE(MODEL dsongcp.arr_delay_airports_lm,\n",
" (\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance,\n",
" origin,\n",
" dest,\n",
" IF(is_train_day = 'True', False, True) AS is_eval_day\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND \n",
" f.DIVERTED = False AND\n",
" t.is_train_day = 'False'\n",
" ),\n",
" STRUCT(0.7 AS threshold))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Query complete after 0.00s: 100%|██████████| 4/4 [00:00<00:00, 934.66query/s] \n",
"Downloading: 100%|██████████| 1/1 [00:00<00:00, 1.01rows/s]\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>correct_cancel</th>\n",
" <th>total_noncancel</th>\n",
" <th>correct_noncancel</th>\n",
" <th>total_cancel</th>\n",
" <th>rmse</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.84953</td>\n",
" <td>1299749</td>\n",
" <td>0.967151</td>\n",
" <td>283750</td>\n",
" <td>0.209839</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" correct_cancel total_noncancel correct_noncancel total_cancel rmse\n",
"0 0.84953 1299749 0.967151 283750 0.209839"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%bigquery\n",
"\n",
"WITH predictions AS (\n",
"SELECT \n",
" *\n",
"FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_lm,\n",
" (\n",
"SELECT\n",
" IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n",
" dep_delay,\n",
" taxi_out,\n",
" distance,\n",
" origin,\n",
" dest,\n",
" IF(is_train_day = 'True', False, True) AS is_eval_day\n",
"FROM dsongcp.flights_tzcorr f\n",
"JOIN dsongcp.trainday t\n",
"ON f.FL_DATE = t.FL_DATE\n",
"WHERE\n",
" f.CANCELLED = False AND \n",
" f.DIVERTED = False AND\n",
" t.is_train_day = 'False'\n",
" ),\n",
" STRUCT(0.7 AS threshold))),\n",
"\n",
"stats AS (\n",
"SELECT \n",
" COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n",
" , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n",
" , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n",
" , COUNTIF(ontime != 'ontime') AS total_cancel\n",
" , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n",
"FROM predictions, UNNEST(predicted_ontime_probs) p\n",
"WHERE p.label = 'ontime'\n",
")\n",
"\n",
"SELECT\n",
" correct_cancel / total_cancel AS correct_cancel\n",
" , total_noncancel\n",
" , correct_noncancel / total_noncancel AS correct_noncancel\n",
" , total_cancel\n",
" , rmse\n",
"FROM stats"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the addition of the airports information has improved both the AUC and the RMSE"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright 2019-2021 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
]
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "managed-notebooks.m82",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:latest"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}