def join_categories()

in experiments/google/cloud/ml/applied/categories/category.py [0:0]


def join_categories(ids: list[str]) -> dict[str : list[str]]:
    """Given list of product IDs, join category names.

    Args:
        ids: list of product IDs used to join against master product table

    Returns:
        dict mapping product IDs to category name. The category name will be
        a list of strings e.g. ['level 1 category', 'level 2 category']
    """
    query = f"""
    SELECT
        {column_id},
        {','.join(column_categories[:category_depth])}
    FROM
        `{table_product}`
    WHERE
        {column_id} IN {str(ids).replace('[', '(').replace(']', ')')}
    """
    query_job = bq_client.query(query)
    rows = query_job.result()
    categories = defaultdict(list)
    for row in rows:
        for col in column_categories:
            if row[col]:
                categories[row[column_id]].append(row[col])
            else:
                if allow_trailing_nulls:
                    if col == column_categories[0]:
                        raise ValueError(
                            f"Top level category {col} for product {row[column_id]} is null"
                        )
                    else:
                        break  # return existing categories
                else:
                    raise ValueError(
                        f"Column {col} for product {row[column_id]} is null. To allow nulls update app.toml"
                    )
    return categories