import argparse
import os
from itertools import chain

import numpy as np
import pandas as pd
import pyspark.sql.functions as F
from google.cloud import bigquery
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType

from lifetimes import BetaGeoFitter

PRED_METRICS = [
    "days_seen",
	"days_searched", 
	"days_tagged_searched",
	"days_clicked_ads",
	"days_searched_with_ads"]

def train_metric(d, metric, plot=True, penalty=0):
    frequency = metric + "_frequency"
    recency = metric + "_recency"
    T = metric + "_T"
    train = d
    train = train[(train[frequency] > 0) & (train[recency] >= 0)]
    train[frequency] = train[frequency] - 1

    bgf = BetaGeoFitter(penalizer_coef=penalty)
    bgf.fit(train[frequency], train[recency], train[T])
    n = bgf.data.shape[0]
    simulated_data = bgf.generate_new_data(size=n)

    model_counts = pd.DataFrame(
        bgf.data["frequency"].value_counts().sort_index().iloc[:28]
    )
    simulated_counts = pd.DataFrame(
        simulated_data["frequency"].value_counts().sort_index().iloc[:28]
    )
    combined_counts = model_counts.merge(
        simulated_counts, how="outer", left_index=True, right_index=True
    ).fillna(0)
    combined_counts.columns = ["Actual", "Model"]
    if plot:
        combined_counts.plot.bar()
        display()
    return combined_counts, bgf


def catch_none(x):
    if x == None:
        return 0
    return x


def ltv_predict(t, frequency, recency, T, model):
    pred = model.conditional_expected_number_of_purchases_up_to_time(
        t, catch_none(frequency), catch_none(recency), catch_none(T)
    )

    if pred > t:
        return float(t)
    return float(pred)


def main(
    submission_date,
    project_id,
    dataset_id,
    source_qualified_table_id,
    intermediate_table_id,
    model_input_table_id,
    model_output_table_id,
    temporary_gcs_bucket,
    training_sample,
    prediction_days,
):
    """Model the lifetime-value (LTV) of clients based on search activity.

    This reads a single partition from a source table into an intermediate table
    in an analysis dataset. The table is transformed for modeling. The model
    inputs and predictions are then stored into separate tables in the same
    dataset.
    """
    print(f"Running ltv_daily job for {submission_date}")

    bq = bigquery.Client()
    table_ref = bq.dataset(dataset_id, project=project_id).table(intermediate_table_id)

    # define the job configuration for the query
    # set params and output destination for the materialized
    # dataset
    job_config = bigquery.QueryJobConfig()
    job_config.query_parameters = [
        bigquery.ScalarQueryParameter("submission_date", "STRING", submission_date)
    ]
    job_config.destination = table_ref
    job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE

    query = f"""
	SELECT
	    *
	FROM
	    `{source_qualified_table_id}`
	WHERE
	    submission_date = @submission_date
	"""
    query_job = bq.query(query, job_config=job_config)
    query_job.result()

    spark = SparkSession.builder.getOrCreate()
    search_rfm_full = (
        spark.read.format("bigquery")
        .option("table", f"{project_id}.{dataset_id}.{intermediate_table_id}")
        .load()
    )

    columns = [
        [
            F.col(str(metric + ".frequency")).alias(metric + "_frequency"),
            F.col(str(metric + ".recency")).alias(metric + "_recency"),
            F.col(str(metric + ".T")).alias(metric + "_T"),
        ]
        for metric in PRED_METRICS
    ]

    # flatten list
    columns = [item for sublist in columns for item in sublist]

    prediction_prefix = "prediction_"
    p_alive_prefix = "p_alive_"

    model_perf_data = pd.DataFrame()
    model_pred_data = None
    search_rfm_ds = search_rfm_full.limit(training_sample).select(columns).toPandas()
    for metric in PRED_METRICS:
        # train and extract model performace
        model_perf, model = train_metric(search_rfm_ds, metric, plot=False, penalty=0.8)
        model_perf["pct"] = model_perf.Model / (model_perf.Actual + 1) - 1
        model_perf["metric"] = metric
        model_perf["date"] = submission_date
        model_perf_data = pd.concat([model_perf_data, model_perf])

        # make predictions using model
        @F.udf(DoubleType())
        def ltv_predict_metric(metric, model=model):
            import lifetimes

            return ltv_predict(
                prediction_days, metric.frequency, metric.recency, metric.T, model
            )

        @F.udf(DoubleType())
        def ltv_prob_alive(metric):
            import lifetimes
            p_alive = float(
                model.conditional_probability_alive(
                    catch_none(metric.frequency), catch_none(metric.recency), catch_none(metric.T)
                )
            )

            # Lifetimes returns 1.0 if frequency==0
            # https://github.com/CamDavidsonPilon/lifetimes/blob/master/lifetimes/fitters/beta_geo_fitter.py#L293
            if p_alive >= 1.0:
                return 0.0
            return p_alive

        # go back to full sample here
        predictions = search_rfm_full.select(
            "*", ltv_predict_metric(metric).alias(prediction_prefix + metric), ltv_prob_alive(metric).alias(p_alive_prefix + metric)
        )

        if not model_pred_data:
            model_pred_data = predictions
        else:
            model_pred_data = model_pred_data.join(
                predictions.select("client_id", prediction_prefix + metric, p_alive_prefix + metric), on="client_id"
            )

    predictions = F.create_map(
        list(
            chain(
                *((F.lit(name), F.col(prediction_prefix + name)) for name in PRED_METRICS)
            )
        )
    ).alias("predictions")

    p_alive = F.create_map(
        list(
            chain(
                *((F.lit(name), F.col(p_alive_prefix + name)) for name in PRED_METRICS)
            )
        )
    ).alias("p_alive")

    model_perf_data["active_days"] = model_perf_data.index
    model_perf_data_sdf = spark.createDataFrame(model_perf_data).withColumn(
        "date", F.to_date("date")
    )

    ds_nodash = submission_date.replace('-', '')
    (
        model_perf_data_sdf.write.format("bigquery")
        .option("table", f"{project_id}.{dataset_id}.{model_input_table_id}${ds_nodash}")
        .option("temporaryGcsBucket", temporary_gcs_bucket)
        .option("partitionField", "date")
        .mode("overwrite")
        .save()
    )

    (
        model_pred_data

        # Add prediction columns as maps
        .withColumn("predictions", predictions)
        .withColumn("p_alive", p_alive)

        # Drop top-level prediction columns
        .drop(*list(chain(*[[prediction_prefix + n, p_alive_prefix + n] for n in PRED_METRICS])))

        # Overwrite BQ partition
        .write.format("bigquery")
        .option("table", f"{project_id}.{dataset_id}.{model_output_table_id}${ds_nodash}")
        .option("temporaryGcsBucket", temporary_gcs_bucket)
        .option("partitionField", "submission_date")
        .option("clusteredFields", "sample_id,client_id")
        .option("allowFieldAddition", "true")
        .mode("overwrite")
        .save()
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument("--submission-date", help="date in YYYY-MM-DD")
    parser.add_argument("--training-sample", type=int, default=500_000)
    parser.add_argument("--prediction-days", type=int, default=28)
    parser.add_argument("--project-id", default="moz-fx-data-bq-data-science")
    parser.add_argument(
        "--source-qualified-table-id",
        default="moz-fx-data-shared-prod.search.search_rfm",
    )
    parser.add_argument("--dataset-id", default="bmiroglio")
    parser.add_argument("--intermediate-table-id", default="search_rfm_day")
    parser.add_argument("--model-input-table-id", default="ltv_daily_model_perf_script")
    parser.add_argument("--model-output-table-id", default="ltv_daily_script")
    parser.add_argument(
        "--temporary-gcs-bucket", default="moz-fx-data-bq-data-science-bmiroglio"
    )
    args = parser.parse_args()

    main(**vars(args))
