def map_to_parent_categories()

in contentselection/oracle.py [0:0]


def map_to_parent_categories(df, taxonomy):
    """
    Maps each inferred category in the DataFrame to its top-level parent category
    in the hierarchical taxonomy.

    :param df: DataFrame containing video data with an 'inferred_category' column.
    :param taxonomy: A nested dictionary representing the hierarchical taxonomy.
    :return: DataFrame with an added 'parent_category' column representing the top-level parent category.
    """
    
    # Helper function to find the top-level parent category
    def find_top_parent_category(leaf_name, taxonomy):
        """
        Finds the top-level parent category of a given leaf in the hierarchical taxonomy.

        :param leaf_name: A string representing the leaf node to search for.
        :param taxonomy: A dictionary representing the full hierarchical taxonomy.
        :return: The top-level parent category of the given leaf if found, else None.
        """
        def recursive_search(taxonomy, leaf_name, current_top_category):
            for category, subcategories in taxonomy.items():
                if category == leaf_name:
                    # Found the leaf node; return the top-level category
                    return current_top_category
                if isinstance(subcategories, dict):
                    # Continue searching deeper
                    found_category = recursive_search(subcategories, leaf_name, current_top_category)
                    if found_category:
                        return found_category
            return None

        # Start the search with top-level categories
        for top_category, subcategories in taxonomy.items():
            result = recursive_search(subcategories, leaf_name, top_category)
            if result:
                return result

        return None

    # Map each inferred category to its top-level parent category
    df['parent_category'] = df['inferred_category'].apply(lambda x: find_top_parent_category(x, taxonomy))
    
    return df