def explode_search_counts()

in mozetl/search/aggregates.py [0:0]


def explode_search_counts(main_summary):
    def _get_search_fields(exploded_col_name):
        return [
            exploded_col_name + "." + field for field in ["engine", "source", "count"]
        ]

    def _drop_source_columns(base):
        derived = base
        for source_col in [
            "search_counts",
            "scalar_parent_browser_search_ad_clicks",
            "scalar_parent_browser_search_with_ads",
        ]:
            derived = derived.drop(source_col)
        return derived

    def _select_counts(main_summary, col_name, count_udf=None):
        if col_name == "single_search_count":
            derived = main_summary.withColumn(
                "single_search_count", explode(col("search_counts"))
            ).filter("single_search_count.count < %s" % MAX_CLIENT_SEARCH_COUNT)
        else:
            derived = main_summary
        if count_udf is not None:
            derived = derived.withColumn(col_name, explode(count_udf))
        derived = _drop_source_columns(
            derived.select(["*"] + _get_search_fields(col_name)).drop(col_name)
        )
        return derived

    def _get_ad_counts(scalar_name, col_name, udf_function):
        count_udf = udf(
            udf_function,
            ArrayType(
                StructType(
                    [
                        StructField("engine", StringType(), False),
                        StructField("source", StringType(), False),
                        StructField("count", LongType(), False),
                    ]
                )
            ),
        )
        return _select_counts(main_summary, col_name, count_udf(scalar_name))

    exploded_search_counts = _select_counts(main_summary, "single_search_count")

    try:
        exploded_search_counts = exploded_search_counts.union(
            _get_ad_counts(
                "scalar_parent_browser_search_ad_clicks",
                "ad_click_count",
                get_ad_click_count,
            )
        )
        exploded_search_counts = exploded_search_counts.union(
            _get_ad_counts(
                "scalar_parent_browser_search_with_ads",
                "search_with_ads_count",
                get_search_with_ads_count,
            )
        )
    except AnalysisException:
        # older generated versions of main_summary may not have the ad click
        # columns, and that's ok
        pass

    zero_search_users = _drop_source_columns(
        main_summary.where(col("search_counts").isNull())
        .withColumn("engine", lit(None))
        .withColumn("source", lit(None))
        # Using 0 instead of None for search_count makes many queries easier
        # (e.g. average searche per user)
        .withColumn("count", lit(0))
    )

    return exploded_search_counts.union(zero_search_users)