def join_categories()

in experiments/legacy/backend/category.py [0:0]


def join_categories(
    ids: list[str], 
    category_depth:int = config.CATEGORY_DEPTH,
    allow_trailing_nulls:bool = config.ALLOW_TRAILING_NULLS) -> dict[str:list[str]]:
    """Given list of product IDs, join category names.
    
    Args:
        ids: list of product IDs used to join against master product table
        category_depth: number of levels in category hierarchy to return

    Returns:
        dict mapping product IDs to category name. The category name will be
        a list of strings e.g. ['level 1 category', 'level 2 category']
    """
    query = f"""
    SELECT
        {config.COLUMN_ID},
        {','.join(config.COLUMN_CATEGORIES[:category_depth])}
    FROM
        `{config.PRODUCT_REFERENCE_TABLE}`
    WHERE
        {config.COLUMN_ID} IN {str(ids).replace('[','(').replace(']',')')}
    """
    query_job = bq_client.query(query)
    rows = query_job.result()
    categories = defaultdict(list) 
    for row in rows:
      for col in config.COLUMN_CATEGORIES:
        if row[col]:
          categories[row[config.COLUMN_ID]].append(row[col])
        else:
          if allow_trailing_nulls:
            if col == config.COLUMN_CATEGORIES[0]:
              raise ValueError(f'Top level category {col} for product {row[config.COLUMN_ID]} is null')
            else:
              break # return existing categories
          else:
              raise ValueError(f'Column {col} for product {row[config.COLUMN_ID]} is null. To allow nulls update config.py')
    return categories