def extract_search_counts()

in mozetl/clientsdaily/rollup.py [0:0]


def extract_search_counts(frame):
    """
    :frame DataFrame conforming to main_summary's schema.

    :return one row for each row in frame, replacing the nullable
    array-of-structs column "search_counts" with seven columns
        "search_count_{access_point}_sum":
    one for each valid SEARCH_ACCESS_POINT, plus one named "all"
    which is always a sum of the other six.

    All seven columns default to 0 and will be 0 if search_counts was NULL.
    Note that the Mozilla term of art "search access point", referring to
    GUI elements, is named "source" in main_summary.

    This routine is hairy because it generates a lot of SQL and Spark
    pseudo-SQL; see inline comments.

    TODO:
      Replace (JOIN with WHERE NULL) with fillna() to an array literal.
      Maybe use a PIVOT.
    """
    two_columns = frame.select(F.col("document_id").alias("did"), "search_counts")
    # First, each row becomes N rows, N == len(search_counts)
    exploded = two_columns.select(
        "did", F.explode("search_counts").alias("search_struct")
    )
    # Remove any rows where the values are corrupt
    exploded = exploded.where("search_struct.count > -1").where(
        "search_struct.source in %s" % str(tuple(SEARCH_ACCESS_POINTS))
    )  # This in clause looks like:
    # "search_struct.source in (
    # 'abouthome', 'contextmenu', 'newtab', 'searchbar', 'system', 'urlbar'
    # )"

    # Now we have clean search_count structs. Next block:
    # For each of the form Row(engine=u'engine', source=SAP, count=n):
    #    SELECT
    #     n as search_count_all,
    #     n as search_count_SAP, (one of the 6 above, such as 'newtab')
    #     0 as search_count_OTHER1
    #     ...
    #     0 as search_count_OTHER5
    if_template = "IF(search_struct.source = '{}', search_struct.count, 0)"
    if_expressions = [
        F.expr(if_template.format(sap)).alias(SEARCH_ACCESS_COLUMN_TEMPLATE.format(sap))
        for sap in SEARCH_ACCESS_POINTS
    ]
    unpacked = exploded.select(
        "did", F.expr("search_struct.count").alias("search_count_atom"), *if_expressions
    )

    # Collapse the exploded search_counts rows into a single output row.
    grouping_dict = dict([(c, "sum") for c in SEARCH_ACCESS_COLUMNS])
    grouping_dict["search_count_atom"] = "sum"
    grouped = unpacked.groupBy("did").agg(grouping_dict)
    extracted = grouped.select(
        "did",
        F.col("sum(search_count_atom)").alias("search_count_all"),
        *[F.col("sum({})".format(c)).alias(c) for c in SEARCH_ACCESS_COLUMNS],
    )
    # Create a homologous output row for each input row
    # where search_counts is NULL.
    nulls = (
        two_columns.select("did")
        .where("search_counts is NULL")
        .select(
            "did",
            F.lit(0).alias("search_count_all"),
            *[F.lit(0).alias(c) for c in SEARCH_ACCESS_COLUMNS],
        )
    )
    intermediate = extracted.unionAll(nulls)
    result = frame.join(intermediate, frame.document_id == intermediate.did)
    return result