08_bqml/bqml_nonlinear.ipynb (632 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Non-linear classification models\n", "\n", "Run this notebook in Vertex Workbench. In this notebook, we will start from the same features as \n", "in the [logistic regression notebook](bqml_logistic.ipynb) but use non-linear machine learning methods.\n", "\n", "The models in this notebook will take longer to train than the linear models.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## xgboost\n", "\n", "xgboost is usually a very good model for structured data. It's a good next step after logistic regression.\n", "This will take ~10 minutes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Executing query with job ID: 9984af86-cf61-455d-8408-9a3f3b77b81e\n", "Query executing: 183.61s" ] } ], "source": [ "%%bigquery\n", "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgboost\n", "OPTIONS(input_label_cols=['ontime'], \n", " model_type='boosted_tree_classifier',\n", " data_split_method='custom',\n", " data_split_col='is_eval_day')\n", "AS\n", "\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest,\n", " IF(is_train_day = 'True', False, True) AS is_eval_day\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And evaluate this model as before" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1494.23query/s] \n", "Downloading: 100%|██████████| 1/1 [00:01<00:00, 1.03s/rows]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>correct_cancel</th>\n", " <th>total_noncancel</th>\n", " <th>correct_noncancel</th>\n", " <th>total_cancel</th>\n", " <th>rmse</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.838981</td>\n", " <td>1304078</td>\n", " <td>0.964965</td>\n", " <td>283750</td>\n", " <td>0.207227</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " correct_cancel total_noncancel correct_noncancel total_cancel rmse\n", "0 0.838981 1304078 0.964965 283750 0.207227" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%bigquery\n", "\n", "WITH predictions AS (\n", "SELECT \n", " *\n", "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,\n", " (\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest,\n", " IF(is_train_day = 'True', False, True) AS is_eval_day\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False AND\n", " t.is_train_day = 'False'\n", " ),\n", " STRUCT(0.7 AS threshold))),\n", "\n", "stats AS (\n", "SELECT \n", " COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n", " , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n", " , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n", " , COUNTIF(ontime != 'ontime') AS total_cancel\n", " , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n", "FROM predictions, UNNEST(predicted_ontime_probs) p\n", "WHERE p.label = 'ontime'\n", ")\n", "\n", "SELECT\n", " correct_cancel / total_cancel AS correct_cancel\n", " , total_noncancel\n", " , correct_noncancel / total_noncancel AS correct_noncancel\n", " , total_cancel\n", " , rmse\n", "FROM stats" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 427.95query/s] \n", "Downloading: 100%|██████████| 1/1 [00:01<00:00, 1.18s/rows]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>predicted_ontime</th>\n", " <th>predicted_ontime_probs</th>\n", " <th>dep_delay</th>\n", " <th>taxi_out</th>\n", " <th>distance</th>\n", " <th>origin</th>\n", " <th>dest</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>ontime</td>\n", " <td>[{'label': 'ontime', 'prob': 0.868686914443969...</td>\n", " <td>12.0</td>\n", " <td>14.0</td>\n", " <td>802</td>\n", " <td>DFW</td>\n", " <td>ORD</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " predicted_ontime predicted_ontime_probs \\\n", "0 ontime [{'label': 'ontime', 'prob': 0.868686914443969... \n", "\n", " dep_delay taxi_out distance origin dest \n", "0 12.0 14.0 802 DFW ORD " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%bigquery\n", "SELECT * FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,\n", " (\n", "SELECT 12.0 AS dep_delay, 14.0 AS taxi_out, 802 AS distance, 'DFW' AS origin, 'ORD' AS dest\n", " ))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter tuning [Optional]\n", "\n", "Let's tune two things: the MAX_TREE_DEPTH (default=6) and L2 regularization (default=1.0).\n", "\n", "**This section will take ~60 minutes. You can skip it.**\n", "\n", "Note that is_eval_day is now a string column with 3 possible values." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Executing query with job ID: bea5283d-efdc-444c-8fee-3ef81f435d80\n", "Query executing: 3554.37s" ] } ], "source": [ "%%bigquery\n", "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgh\n", "OPTIONS(input_label_cols=['ontime'], \n", " model_type='boosted_tree_classifier',\n", " num_trials=5, l2_reg=hparam_range(0.5, 3.0), max_tree_depth=hparam_range(2, 10),\n", " data_split_method='custom',\n", " data_split_col='is_eval_day')\n", "AS\n", "\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest,\n", " IF(is_train_day = 'True', \n", " IF(RAND() < 0.8, 'TRAIN', 'EVAL'), \n", " 'TEST') AS is_eval_day\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Query complete after 0.00s: 100%|██████████| 2/2 [00:00<00:00, 911.51query/s] \n", "Downloading: 100%|██████████| 3/3 [00:01<00:00, 2.95rows/s]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>l2_reg</th>\n", " <th>max_tree_depth</th>\n", " <th>eval_loss</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>2.536659</td>\n", " <td>10</td>\n", " <td>0.155262</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>2.113224</td>\n", " <td>10</td>\n", " <td>0.155313</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0.887189</td>\n", " <td>10</td>\n", " <td>0.155314</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " l2_reg max_tree_depth eval_loss\n", "0 2.536659 10 0.155262\n", "1 2.113224 10 0.155313\n", "2 0.887189 10 0.155314" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%bigquery\n", "SELECT hyperparameters.l2_reg, hyperparameters.max_tree_depth, eval_loss\n", "FROM ML.TRIAL_INFO(MODEL dsongcp.arr_delay_airports_xgh)\n", "ORDER BY eval_loss ASC LIMIT 3" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1079.99query/s] \n", "Downloading: 100%|██████████| 1/1 [00:02<00:00, 2.70s/rows]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>correct_cancel</th>\n", " <th>total_noncancel</th>\n", " <th>correct_noncancel</th>\n", " <th>total_cancel</th>\n", " <th>rmse</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.841952</td>\n", " <td>1305703</td>\n", " <td>0.965654</td>\n", " <td>283750</td>\n", " <td>0.204358</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " correct_cancel total_noncancel correct_noncancel total_cancel rmse\n", "0 0.841952 1305703 0.965654 283750 0.204358" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%bigquery\n", "\n", "WITH predictions AS (\n", "SELECT \n", " *\n", "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgh,\n", " (\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest,\n", " IF(is_train_day = 'True', False, True) AS is_eval_day\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False AND\n", " t.is_train_day = 'False'\n", " ),\n", " STRUCT(0.7 AS threshold))),\n", "\n", "stats AS (\n", "SELECT \n", " COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n", " , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n", " , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n", " , COUNTIF(ontime != 'ontime') AS total_cancel\n", " , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n", "FROM predictions, UNNEST(predicted_ontime_probs) p\n", "WHERE p.label = 'ontime'\n", ")\n", "\n", "SELECT\n", " correct_cancel / total_cancel AS correct_cancel\n", " , total_noncancel\n", " , correct_noncancel / total_noncancel AS correct_noncancel\n", " , total_cancel\n", " , rmse\n", "FROM stats" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## AutoML (optional)\n", "\n", "Let's try AutoML Tables, which should give us close to state-of-the-art performance.\n", "Note, however, that since a custom data split is not supported by Auto ML, we can not really\n", "compare performance across methods.\n", "\n", "**This will take ~60 minutes. You can skip this step.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bigquery\n", "CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_automl\n", "OPTIONS(input_label_cols=['ontime'], \n", " model_type='automl_classifier')\n", "AS\n", "\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False AND\n", " is_train_day = 'True'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bigquery\n", "\n", "WITH predictions AS (\n", "SELECT \n", " *\n", "FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_automl,\n", " (\n", "SELECT\n", " IF(arr_delay < 15, 'ontime', 'late') AS ontime,\n", " dep_delay,\n", " taxi_out,\n", " distance,\n", " origin,\n", " dest,\n", " IF(is_train_day = 'True', False, True) AS is_eval_day\n", "FROM dsongcp.flights_tzcorr f\n", "JOIN dsongcp.trainday t\n", "ON f.FL_DATE = t.FL_DATE\n", "WHERE\n", " f.CANCELLED = False AND \n", " f.DIVERTED = False AND\n", " t.is_train_day = 'False'\n", " ),\n", " STRUCT(0.7 AS threshold))),\n", "\n", "stats AS (\n", "SELECT \n", " COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel\n", " , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel\n", " , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel\n", " , COUNTIF(ontime != 'ontime') AS total_cancel\n", " , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse\n", "FROM predictions, UNNEST(predicted_ontime_probs) p\n", "WHERE p.label = 'ontime'\n", ")\n", "\n", "SELECT\n", " correct_cancel / total_cancel AS correct_cancel\n", " , total_noncancel\n", " , correct_noncancel / total_noncancel AS correct_noncancel\n", " , total_cancel\n", " , rmse\n", "FROM stats" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2021 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." ] } ], "metadata": { "environment": { "kernel": "", "name": "managed-notebooks.m82", "type": "gcloud", "uri": "gcr.io/deeplearning-platform-release/managed-notebooks:m82" }, "kernelspec": { "display_name": "", "name": "" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 4 }