vision/m4/sourcing/data_collection/visualization/wikipedia/global_visualization.py (322 lines of code) (raw):

import os import pathlib import random import pandas as pd import streamlit as st from datasets import load_dataset from jinja2 import Template from m4.sourcing.data_collection.processors import ( DOMTreeSimplificator, PreExtractionSimplificator, TextMediaPairsExtractor, ) from m4.sourcing.data_collection.utils import InterestingAttributesSetCategory, make_selectolax_tree class Visualization: def __init__(self, num_docs, dom_viz_template_path): self.num_docs = num_docs @st.experimental_memo # it is caching but is incredibly slow when N is big. def load_examples(num_docs): dataset = load_dataset("/home/lucile/data/wikipedia/html_enterprise/script/wikipedia.py")[ "train" ] # load_from_disk("/home/lucile/data/wikipedia/test/arrow") return [dataset[i] for i in range(num_docs)] self.examples = load_examples(num_docs) def load_dom_viz_template(dom_viz_template_path): with open(dom_viz_template_path, "r") as file: template_string = file.read() return Template(template_string) self.dom_viz_template = load_dom_viz_template(dom_viz_template_path) self.dom_tree_simplificator = DOMTreeSimplificator( strip_multiple_linebreaks=True, strip_multiple_spaces=True, remove_html_comments=True, replace_line_break_tags=True, unwrap_tags=True, strip_tags=False, strip_special_divs=True, remove_dates=True, remove_empty_leaves=True, unnest_nodes=True, remake_tree=True, preserve_img_children=True, remove_everything_after_node_id=["Notes", "References", "See_also", "Further_reading", "External_links"], css_rules=[ "[class~='locmap']", "[class~='reference']", "[role='presentation']", "[role~='note']", ], interesting_attributes_set_cat=InterestingAttributesSetCategory.WIKIPEDIA, ) self.pre_extraction_simplificator_not_merge_texts = PreExtractionSimplificator( only_text_image_nodes=True, format_texts=True, merge_consecutive_text_nodes=False, interesting_attributes_set_cat=InterestingAttributesSetCategory.WIKIPEDIA, ) self.pre_extraction_simplificator_merge_texts = PreExtractionSimplificator( only_text_image_nodes=True, format_texts=True, merge_consecutive_text_nodes=True, interesting_attributes_set_cat=InterestingAttributesSetCategory.WIKIPEDIA, ) self.extractor = TextMediaPairsExtractor( dom_tree_simplificator=self.dom_tree_simplificator, pre_extraction_simplificator=self.pre_extraction_simplificator_merge_texts, also_extract_images_not_in_simplified_dom_tree=True, extract_clip_scores=True, ) def visualization(self): st.title( "Visualization of DOM tree simplification strategies, " "web document rendering, and text-image pair extractions" ) self.choose_mode() self.choose_example() self.simplification_mode() self.extraction_mode() def choose_mode(self): st.header("Mode") self.mode = st.selectbox( label="Choose a mode", options=["Simplification", "Extraction"], index=1, ) def choose_example(self): st.header("Document") if st.button("Select a random document"): dct_idx = random.randint(a=0, b=self.num_docs - 1) else: dct_idx = 0 idx = st.number_input( f"Select a document among the first {self.num_docs} ones", min_value=0, max_value=self.num_docs - 1, value=dct_idx, step=1, help=f"Index between 0 and {self.num_docs-1}", ) self.current_example = self.examples[idx] def get_dom_viz_html(self, html): def get_body_html_string(html): tree = make_selectolax_tree(html) tree.strip_tags(["script"]) return tree.body.html body_html = get_body_html_string(html) rendered_dom = self.dom_viz_template.render(body_html=body_html) return rendered_dom def simplification_mode(self): if self.mode == "Simplification": current_html = self.current_example["html"] current_url = self.current_example["url"] simplified_current_html = self.dom_tree_simplificator(current_html, type_return="str") def display_rendered_webpages(): st.header("Rendered webpage") st.markdown(f"Webpage url: {current_url}") col1, col2 = st.columns(2) with col1: st.subheader("Raw html rendering") st.components.v1.html(current_html, height=450, scrolling=True) with col2: st.subheader("Simplified html rendering") st.components.v1.html(simplified_current_html, height=450, scrolling=True) def display_dom_trees(): st.header("DOM trees") col1, col2 = st.columns(2) with col1: st.subheader("Raw DOM tree") rendered_dom = self.get_dom_viz_html(current_html) st.components.v1.html(rendered_dom, height=600, scrolling=True) with col2: st.subheader("Simplified DOM tree") simplified_rendered_dom = self.get_dom_viz_html(simplified_current_html) st.components.v1.html(simplified_rendered_dom, height=600, scrolling=True) def display_html_codes(): st.header("HTML codes") col1, col2 = st.columns(2) with col1: st.subheader("Raw HTML code") st.components.v1.html("<xmp>" + current_html + "</xmp>", height=450, scrolling=True) with col2: st.subheader("Simplified HTML code") st.components.v1.html("<xmp>" + simplified_current_html + "</xmp>", height=450, scrolling=True) display_rendered_webpages() display_dom_trees() display_html_codes() def extraction_mode(self): if self.mode == "Extraction": current_html = self.current_example["html"] current_url = self.current_example["url"] simplified_current_html_tree = self.dom_tree_simplificator(current_html, type_return="selectolax_tree") simplified_current_html = simplified_current_html_tree.html current_list_nodes_not_merge_texts = self.pre_extraction_simplificator_not_merge_texts( simplified_current_html_tree, page_url=current_url ) current_list_nodes_merge_texts = self.pre_extraction_simplificator_merge_texts( simplified_current_html_tree, page_url=current_url ) extracted_images = self.extractor(html_str=current_html, page_url=current_url) # For simplicity, only doing this replacement on the extracted images. # Doing that before the extraction (i.e. in the DOM simplification) would be possible be would require # more significant changes replacement_dict = { elem["unformatted_src"]: elem["src"] for elem in extracted_images if elem["src"] != elem["unformatted_src"] } def replace_relative_paths(html_string): if replacement_dict: for k, v in replacement_dict.items(): html_string = html_string.replace(k, v) return html_string def display_rendered_webpages(): st.header("Rendered webpage") st.markdown(f"Webpage url: {current_url}") display_raw_html_rendering = st.checkbox("Raw html rendering", value=True) display_simplified_html_rendering = st.checkbox("Simplified html rendering", value=True) col1, col2 = st.columns(2) with col1: display_pre_extraction_visualization = st.checkbox( "Web document rendering (pre-extraction visualization)", value=True ) with col2: if display_pre_extraction_visualization: merge_text_nodes = st.checkbox("Merge text nodes", value=True) list_display_pages = [ [display_raw_html_rendering, "raw_html_rendering"], [display_simplified_html_rendering, "simplified_html_rendering"], [display_pre_extraction_visualization, "pre_extraction_visualization"], ] list_display_pages = [ page_to_display for should_display_page, page_to_display in list_display_pages if should_display_page ] def display_specific_rendered_webpage(page_to_display, col): with col: if page_to_display == "raw_html_rendering": st.subheader("Raw html rendering") st.components.v1.html(replace_relative_paths(current_html), height=800, scrolling=True) elif page_to_display == "simplified_html_rendering": st.subheader("Simplified html rendering") st.components.v1.html( replace_relative_paths(simplified_current_html), height=800, scrolling=True ) elif page_to_display == "pre_extraction_visualization": st.subheader("Web document rendering (pre-extraction visualization)") def list_nodes_to_visu(): if not merge_text_nodes: list_nodes = current_list_nodes_not_merge_texts reduce_levels = { v: i + 1 for i, v in enumerate(sorted(list(set([node.level for node in list_nodes])))) } last_level = None markdown = "" for node in list_nodes: if node.tag in ["-text", "img"]: current_level = reduce_levels[node.level] if last_level != current_level: markdown += ( "#" * min(current_level, 6) + f" Level: {current_level}\n\n" ) last_level = current_level path_in_tree_str = [tag for tag, _ in node.path_in_tree] markdown += f"**{'/'.join(path_in_tree_str)}**\n\n" if node.tag == "-text": markdown += f"{node.text}\n\n" elif node.tag == "img": markdown += f"![img]({node.media_info['src']})\n\n" st.markdown(markdown) else: list_nodes = current_list_nodes_merge_texts for node in list_nodes: if node.tag == "-text": print(node.text) st.text(f"{node.text}\n\n") elif node.tag == "img": st.markdown(f"![img]({node.media_info['src']})\n\n") list_nodes_to_visu() if list_display_pages: columns = st.columns(len(list_display_pages)) for page_to_display, col in zip(list_display_pages, columns): display_specific_rendered_webpage(page_to_display, col) def display_extraction(): st.header("Extracted content") if not extracted_images: st.write("No extracted content") else: df = pd.DataFrame( extracted_images, columns=[ "src", "unformatted_src", "format", "rendered_width", "rendered_height", "original_width", "original_height", "formatted_filename", "alt_text", "extracted_text", "clip_score_image_formatted_filename", "clip_score_image_alt_text", "clip_score_image_extracted_text", "image_in_simplified_dom_tree", ], ) for i, link in enumerate(df["src"]): col1, col2 = st.columns(2) with col1: st.image(link, width=500, use_column_width=True) with col2: src = f'<a target="_blank" href="{link}">{link.split("/")[-1]}</a>' unformatted_src = df["unformatted_src"][i] format = df["format"][i] rendered_width = df["rendered_width"][i] rendered_height = df["rendered_height"][i] original_width = df["original_width"][i] original_height = df["original_height"][i] formatted_filename = df["formatted_filename"][i] alt_text = df["alt_text"][i] extracted_text = df["extracted_text"][i] clip_score_image_formatted_filename = df["clip_score_image_formatted_filename"][i] clip_score_image_alt_text = df["clip_score_image_alt_text"][i] clip_score_image_extracted_text = df["clip_score_image_extracted_text"][i] image_in_simplified_dom_tree = df["image_in_simplified_dom_tree"][i] st.components.v1.html( ( f"<p><strong>Source</strong>: {src}</p><p><strong>Unformated source</strong>:" f" {unformatted_src}</p><p><strong>Format</strong>:" f" {format}</p><p><strong>Rendered width</strong>:" f" {rendered_width}</p><p><strong>Rendered height</strong>:" f" {rendered_height}</p><p><strong>Original width</strong>:" f" {original_width}</p><p><strong>Original height</strong>:" f" {original_height}</p><p><strong>Formatted filename</strong>:" f" {formatted_filename}</p><p><strong>Alt-text</strong>:" f" {alt_text}</p><p><strong>Extracted text</strong>:" f" {extracted_text}</p><p><strong>Clip score image/formatted filename</strong>:" f" {clip_score_image_formatted_filename:.4f}</p><p><strong>Clip score" f" image/alt-text</strong>: {clip_score_image_alt_text:.4f}</p><p><strong>Clip" " score image/extracted text</strong>:" f" {clip_score_image_extracted_text:.4f}</p><p><strong>Image in simplified DOM" f" tree</strong>: {image_in_simplified_dom_tree}</p>" ), height=500, scrolling=True, ) st.write("-----") def display_dom_tree(): st.header("Simplified DOM tree") simplified_rendered_dom = self.get_dom_viz_html(simplified_current_html) st.components.v1.html(simplified_rendered_dom, height=600, scrolling=True) def display_html_code(): st.header("Simplified HTML code") st.components.v1.html("<xmp>" + simplified_current_html + "</xmp>", height=450, scrolling=True) display_rendered_webpages() display_extraction() display_dom_tree() display_html_code() if __name__ == "__main__": st.set_page_config(layout="wide") num_docs = 1_000 dom_viz_template_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "assets/DOM_tree_viz.html") visualization = Visualization(num_docs=num_docs, dom_viz_template_path=dom_viz_template_path) visualization.visualization()