06_dataproc/bayes_on_spark.py (53 lines of code) (raw):

#!/usr/bin/env python3 # Copyright 2021 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import pandas as pd import numpy as np from pyspark.sql import SparkSession import pyspark.sql.functions as F def run_bayes(BUCKET): spark = SparkSession \ .builder \ .appName("Bayes classification using Spark") \ .getOrCreate() # read flights data inputs = 'gs://{}/flights/tzcorr/all_flights-*'.format(BUCKET) # FULL flights = spark.read.json(inputs) flights.createOrReplaceTempView('flights') # which days are training days? traindays = spark.read \ .option("header", "true") \ .option("inferSchema", "true") \ .csv('gs://{}/flights/trainday.csv'.format(BUCKET)) traindays.createOrReplaceTempView('traindays') # create training dataset statement = """ SELECT f.FL_DATE AS date, CAST(distance AS FLOAT) AS distance, dep_delay, IF(arr_delay < 15, 1, 0) AS ontime FROM flights f JOIN traindays t ON f.FL_DATE == t.FL_DATE WHERE t.is_train_day AND f.dep_delay IS NOT NULL ORDER BY f.dep_delay DESC """ flights = spark.sql(statement) # quantiles distthresh = flights.approxQuantile('distance', list(np.arange(0, 1.0, 0.2)), 0.02) distthresh[-1] = float('inf') delaythresh = range(10, 20) logging.info("Computed distance thresholds: {}".format(distthresh)) # bayes in each bin df = pd.DataFrame(columns=['dist_thresh', 'delay_thresh', 'frac_ontime']) for m in range(0, len(distthresh) - 1): for n in range(0, len(delaythresh) - 1): bdf = flights[(flights['distance'] >= distthresh[m]) & (flights['distance'] < distthresh[m + 1]) & (flights['dep_delay'] >= delaythresh[n]) & (flights['dep_delay'] < delaythresh[n + 1])] ontime_frac = bdf.agg(F.sum('ontime')).collect()[0][0] / bdf.agg(F.count('ontime')).collect()[0][0] print(m, n, ontime_frac) df = df.append({ 'dist_thresh': distthresh[m], 'delay_thresh': delaythresh[n], 'frac_ontime': ontime_frac }, ignore_index=True) # lookup table df['score'] = abs(df['frac_ontime'] - 0.7) bayes = df.sort_values(['score']).groupby('dist_thresh').head(1).sort_values('dist_thresh') bayes.to_csv('gs://{}/flights/bayes.csv'.format(BUCKET), index=False) logging.info("Wrote lookup table: {}".format(bayes.head())) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Create Bayes lookup table') parser.add_argument('--bucket', help='GCS bucket to read/write data', required=True) parser.add_argument('--debug', dest='debug', action='store_true', help='Specify if you want debug messages') args = parser.parse_args() if args.debug: logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.DEBUG) else: logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO) run_bayes(args.bucket)