jobs/ltv_daily.py (188 lines of code) (raw):

import argparse import os from itertools import chain import numpy as np import pandas as pd import pyspark.sql.functions as F from google.cloud import bigquery from pyspark.sql import SparkSession from pyspark.sql.types import DoubleType from lifetimes import BetaGeoFitter PRED_METRICS = [ "days_seen", "days_searched", "days_tagged_searched", "days_clicked_ads", "days_searched_with_ads"] def train_metric(d, metric, plot=True, penalty=0): frequency = metric + "_frequency" recency = metric + "_recency" T = metric + "_T" train = d train = train[(train[frequency] > 0) & (train[recency] >= 0)] train[frequency] = train[frequency] - 1 bgf = BetaGeoFitter(penalizer_coef=penalty) bgf.fit(train[frequency], train[recency], train[T]) n = bgf.data.shape[0] simulated_data = bgf.generate_new_data(size=n) model_counts = pd.DataFrame( bgf.data["frequency"].value_counts().sort_index().iloc[:28] ) simulated_counts = pd.DataFrame( simulated_data["frequency"].value_counts().sort_index().iloc[:28] ) combined_counts = model_counts.merge( simulated_counts, how="outer", left_index=True, right_index=True ).fillna(0) combined_counts.columns = ["Actual", "Model"] if plot: combined_counts.plot.bar() display() return combined_counts, bgf def catch_none(x): if x == None: return 0 return x def ltv_predict(t, frequency, recency, T, model): pred = model.conditional_expected_number_of_purchases_up_to_time( t, catch_none(frequency), catch_none(recency), catch_none(T) ) if pred > t: return float(t) return float(pred) def main( submission_date, project_id, dataset_id, source_qualified_table_id, intermediate_table_id, model_input_table_id, model_output_table_id, temporary_gcs_bucket, training_sample, prediction_days, ): """Model the lifetime-value (LTV) of clients based on search activity. This reads a single partition from a source table into an intermediate table in an analysis dataset. The table is transformed for modeling. The model inputs and predictions are then stored into separate tables in the same dataset. """ print(f"Running ltv_daily job for {submission_date}") bq = bigquery.Client() table_ref = bq.dataset(dataset_id, project=project_id).table(intermediate_table_id) # define the job configuration for the query # set params and output destination for the materialized # dataset job_config = bigquery.QueryJobConfig() job_config.query_parameters = [ bigquery.ScalarQueryParameter("submission_date", "STRING", submission_date) ] job_config.destination = table_ref job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE query = f""" SELECT * FROM `{source_qualified_table_id}` WHERE submission_date = @submission_date """ query_job = bq.query(query, job_config=job_config) query_job.result() spark = SparkSession.builder.getOrCreate() search_rfm_full = ( spark.read.format("bigquery") .option("table", f"{project_id}.{dataset_id}.{intermediate_table_id}") .load() ) columns = [ [ F.col(str(metric + ".frequency")).alias(metric + "_frequency"), F.col(str(metric + ".recency")).alias(metric + "_recency"), F.col(str(metric + ".T")).alias(metric + "_T"), ] for metric in PRED_METRICS ] # flatten list columns = [item for sublist in columns for item in sublist] prediction_prefix = "prediction_" p_alive_prefix = "p_alive_" model_perf_data = pd.DataFrame() model_pred_data = None search_rfm_ds = search_rfm_full.limit(training_sample).select(columns).toPandas() for metric in PRED_METRICS: # train and extract model performace model_perf, model = train_metric(search_rfm_ds, metric, plot=False, penalty=0.8) model_perf["pct"] = model_perf.Model / (model_perf.Actual + 1) - 1 model_perf["metric"] = metric model_perf["date"] = submission_date model_perf_data = pd.concat([model_perf_data, model_perf]) # make predictions using model @F.udf(DoubleType()) def ltv_predict_metric(metric, model=model): import lifetimes return ltv_predict( prediction_days, metric.frequency, metric.recency, metric.T, model ) @F.udf(DoubleType()) def ltv_prob_alive(metric): import lifetimes p_alive = float( model.conditional_probability_alive( catch_none(metric.frequency), catch_none(metric.recency), catch_none(metric.T) ) ) # Lifetimes returns 1.0 if frequency==0 # https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/fitters/beta_geo_fitter.py#L293 if p_alive >= 1.0: return 0.0 return p_alive # go back to full sample here predictions = search_rfm_full.select( "*", ltv_predict_metric(metric).alias(prediction_prefix + metric), ltv_prob_alive(metric).alias(p_alive_prefix + metric) ) if not model_pred_data: model_pred_data = predictions else: model_pred_data = model_pred_data.join( predictions.select("client_id", prediction_prefix + metric, p_alive_prefix + metric), on="client_id" ) predictions = F.create_map( list( chain( *((F.lit(name), F.col(prediction_prefix + name)) for name in PRED_METRICS) ) ) ).alias("predictions") p_alive = F.create_map( list( chain( *((F.lit(name), F.col(p_alive_prefix + name)) for name in PRED_METRICS) ) ) ).alias("p_alive") model_perf_data["active_days"] = model_perf_data.index model_perf_data_sdf = spark.createDataFrame(model_perf_data).withColumn( "date", F.to_date("date") ) ds_nodash = submission_date.replace('-', '') ( model_perf_data_sdf.write.format("bigquery") .option("table", f"{project_id}.{dataset_id}.{model_input_table_id}${ds_nodash}") .option("temporaryGcsBucket", temporary_gcs_bucket) .option("partitionField", "date") .mode("overwrite") .save() ) ( model_pred_data # Add prediction columns as maps .withColumn("predictions", predictions) .withColumn("p_alive", p_alive) # Drop top-level prediction columns .drop(*list(chain(*[[prediction_prefix + n, p_alive_prefix + n] for n in PRED_METRICS]))) # Overwrite BQ partition .write.format("bigquery") .option("table", f"{project_id}.{dataset_id}.{model_output_table_id}${ds_nodash}") .option("temporaryGcsBucket", temporary_gcs_bucket) .option("partitionField", "submission_date") .option("clusteredFields", "sample_id,client_id") .option("allowFieldAddition", "true") .mode("overwrite") .save() ) if __name__ == "__main__": parser = argparse.ArgumentParser(__doc__) parser.add_argument("--submission-date", help="date in YYYY-MM-DD") parser.add_argument("--training-sample", type=int, default=500_000) parser.add_argument("--prediction-days", type=int, default=28) parser.add_argument("--project-id", default="moz-fx-data-bq-data-science") parser.add_argument( "--source-qualified-table-id", default="moz-fx-data-shared-prod.search.search_rfm", ) parser.add_argument("--dataset-id", default="bmiroglio") parser.add_argument("--intermediate-table-id", default="search_rfm_day") parser.add_argument("--model-input-table-id", default="ltv_daily_model_perf_script") parser.add_argument("--model-output-table-id", default="ltv_daily_script") parser.add_argument( "--temporary-gcs-bucket", default="moz-fx-data-bq-data-science-bmiroglio" ) args = parser.parse_args() main(**vars(args))