## Non-linear classification models

Run this notebook in Vertex Workbench. In this notebook, we will start from the same features as 
in the [logistic regression notebook](bqml_logistic.ipynb) but use non-linear machine learning methods.

The models in this notebook will take longer to train than the linear models.


## xgboost

xgboost is usually a very good model for structured data. It's a good next step after logistic regression.
This will take ~10 minutes.

In [None]:
%%bigquery
CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgboost
OPTIONS(input_label_cols=['ontime'], 
        model_type='boosted_tree_classifier',
        data_split_method='custom',
        data_split_col='is_eval_day')
AS

SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest,
  IF(is_train_day = 'True', False, True) AS is_eval_day
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False

Executing query with job ID: 9984af86-cf61-455d-8408-9a3f3b77b81e
Query executing: 183.61s

And evaluate this model as before

In [3]:
%%bigquery

WITH predictions AS (
SELECT 
  *
FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,
                 (
SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest,
  IF(is_train_day = 'True', False, True) AS is_eval_day
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False AND
  t.is_train_day = 'False'
                 ),
                 STRUCT(0.7 AS threshold))),

stats AS (
SELECT 
  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel
  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel
  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel
  , COUNTIF(ontime != 'ontime') AS total_cancel
  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse
FROM predictions, UNNEST(predicted_ontime_probs) p
WHERE p.label = 'ontime'
)

SELECT
   correct_cancel / total_cancel AS correct_cancel
   , total_noncancel
   , correct_noncancel / total_noncancel AS correct_noncancel
   , total_cancel
   , rmse
FROM stats

Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1494.23query/s]                        
Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.03s/rows]


Unnamed: 0,correct_cancel,total_noncancel,correct_noncancel,total_cancel,rmse
0,0.838981,1304078,0.964965,283750,0.207227


In [5]:
%%bigquery
SELECT * FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgboost,
                        (
SELECT 12.0 AS dep_delay, 14.0 AS taxi_out, 802 AS distance, 'DFW' AS origin, 'ORD' AS dest
                        ))

Query complete after 0.00s: 100%|██████████| 1/1 [00:00<00:00, 427.95query/s]                          
Downloading: 100%|██████████| 1/1 [00:01<00:00,  1.18s/rows]


Unnamed: 0,predicted_ontime,predicted_ontime_probs,dep_delay,taxi_out,distance,origin,dest
0,ontime,"[{'label': 'ontime', 'prob': 0.868686914443969...",12.0,14.0,802,DFW,ORD


## Hyperparameter tuning [Optional]

Let's tune two things: the MAX_TREE_DEPTH (default=6) and L2 regularization (default=1.0).

**This section will take ~60 minutes. You can skip it.**

Note that is_eval_day is now a string column with 3 possible values.

In [None]:
%%bigquery
CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_xgh
OPTIONS(input_label_cols=['ontime'], 
        model_type='boosted_tree_classifier',
        num_trials=5, l2_reg=hparam_range(0.5, 3.0), max_tree_depth=hparam_range(2, 10),
        data_split_method='custom',
        data_split_col='is_eval_day')
AS

SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest,
  IF(is_train_day = 'True', 
     IF(RAND() < 0.8, 'TRAIN', 'EVAL'), 
     'TEST') AS is_eval_day
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False

Executing query with job ID: bea5283d-efdc-444c-8fee-3ef81f435d80
Query executing: 3554.37s

In [3]:
%%bigquery
SELECT hyperparameters.l2_reg, hyperparameters.max_tree_depth, eval_loss
FROM ML.TRIAL_INFO(MODEL dsongcp.arr_delay_airports_xgh)
ORDER BY eval_loss ASC LIMIT 3

Query complete after 0.00s: 100%|██████████| 2/2 [00:00<00:00, 911.51query/s]                         
Downloading: 100%|██████████| 3/3 [00:01<00:00,  2.95rows/s]


Unnamed: 0,l2_reg,max_tree_depth,eval_loss
0,2.536659,10,0.155262
1,2.113224,10,0.155313
2,0.887189,10,0.155314


In [4]:
%%bigquery

WITH predictions AS (
SELECT 
  *
FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_xgh,
                 (
SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest,
  IF(is_train_day = 'True', False, True) AS is_eval_day
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False AND
  t.is_train_day = 'False'
                 ),
                 STRUCT(0.7 AS threshold))),

stats AS (
SELECT 
  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel
  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel
  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel
  , COUNTIF(ontime != 'ontime') AS total_cancel
  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse
FROM predictions, UNNEST(predicted_ontime_probs) p
WHERE p.label = 'ontime'
)

SELECT
   correct_cancel / total_cancel AS correct_cancel
   , total_noncancel
   , correct_noncancel / total_noncancel AS correct_noncancel
   , total_cancel
   , rmse
FROM stats

Query complete after 0.00s: 100%|██████████| 3/3 [00:00<00:00, 1079.99query/s]                        
Downloading: 100%|██████████| 1/1 [00:02<00:00,  2.70s/rows]


Unnamed: 0,correct_cancel,total_noncancel,correct_noncancel,total_cancel,rmse
0,0.841952,1305703,0.965654,283750,0.204358


## AutoML (optional)

Let's try AutoML Tables, which should give us close to state-of-the-art performance.
Note, however, that since a custom data split is not supported by Auto ML, we can not really
compare performance across methods.

**This will take ~60 minutes. You can skip this step.**

In [None]:
%%bigquery
CREATE OR REPLACE MODEL dsongcp.arr_delay_airports_automl
OPTIONS(input_label_cols=['ontime'], 
        model_type='automl_classifier')
AS

SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False AND
  is_train_day = 'True'

In [None]:
%%bigquery

WITH predictions AS (
SELECT 
  *
FROM ML.PREDICT(MODEL dsongcp.arr_delay_airports_automl,
                 (
SELECT
  IF(arr_delay < 15, 'ontime', 'late') AS ontime,
  dep_delay,
  taxi_out,
  distance,
  origin,
  dest,
  IF(is_train_day = 'True', False, True) AS is_eval_day
FROM dsongcp.flights_tzcorr f
JOIN dsongcp.trainday t
ON f.FL_DATE = t.FL_DATE
WHERE
  f.CANCELLED = False AND 
  f.DIVERTED = False AND
  t.is_train_day = 'False'
                 ),
                 STRUCT(0.7 AS threshold))),

stats AS (
SELECT 
  COUNTIF(ontime != 'ontime' AND ontime = predicted_ontime) AS correct_cancel
  , COUNTIF(predicted_ontime = 'ontime') AS total_noncancel
  , COUNTIF(ontime = 'ontime' AND ontime = predicted_ontime) AS correct_noncancel
  , COUNTIF(ontime != 'ontime') AS total_cancel
  , SQRT(SUM((IF(ontime = 'ontime', 1, 0) - p.prob) * (IF(ontime = 'ontime', 1, 0) - p.prob))/COUNT(*)) AS rmse
FROM predictions, UNNEST(predicted_ontime_probs) p
WHERE p.label = 'ontime'
)

SELECT
   correct_cancel / total_cancel AS correct_cancel
   , total_noncancel
   , correct_noncancel / total_noncancel AS correct_noncancel
   , total_cancel
   , rmse
FROM stats

Copyright 2021 Google Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.