vision/m4/sourcing/data_collection/visualization/pair_stat_dashboard.py (359 lines of code) (raw):
import sys
import time
from pathlib import Path
import altair as alt
import numpy as np
import pandas as pd
import plotly.figure_factory as ff
import streamlit as st
from datasets import load_dataset
from tqdm import tqdm
# Useful to add the path to the m4 module to Streamlit
sys.path.append(str(Path(sys.path[0]).parent.absolute().parent.absolute().parent.absolute().parent.absolute()))
from m4.sourcing.data_collection.processors import (
DOMTreeSimplificator,
PreExtractionSimplificator,
TextMediaPairsExtractor,
)
def check_image_quality(media_info):
"""
Args_ : Media Node
Returns :
img_has_good_quality: Boolean indictating there is an image with good quality (defined by its height, width, and aspect ratio)
w: image width
h: image height
"""
w, h = media_info["original_width"], media_info["original_height"]
img_width_plus_height = w + h
img_has_good_quality = False
aspect_ratio = h / (w + 1e-8)
if w > 64 and h > 64 and 1 / 5 < aspect_ratio < 5:
img_has_good_quality = True
return img_has_good_quality, img_width_plus_height
def check_text(media_info):
"""
Args_ : Media Node
Returns :
has_text: Boolean indictating if there is a text that corresponds to the media
txt_dict: Dictionary mapping each text_length to its text type (filename, alt-text, extracted_text)
Note:
All variables are set to 0 if they don't exist in the media node
"""
has_text = False
txt_dict = {"formatted_filename": 0, "alt_text": 0, "extracted_text": 0}
for text_type in ["formatted_filename", "alt_text", "extracted_text"]:
try:
curr_txt_len = len(media_info[text_type])
txt_dict[text_type] = curr_txt_len
if curr_txt_len > 0:
has_text = True
except Exception:
pass
return has_text, txt_dict
def check_CLIP(media_info):
"""
Args_ : Media Node
Returns :
clip_score_max_per_img: Max CLIP score per Image
clip_nbr_per_img: Number of CLIP scores for a given image
clip_dict: Dictionary mapping each CLIP score to its text type (filename, alt-text, extracted_text).
Note:
All variables are set to 0 if they don't exist in the media node
"""
clip_score_max_per_img = 0
clip_nbr_per_img = 0
clip_dict = {"formatted_filename": 0, "alt_text": 0, "extracted_text": 0}
clip_var_list = ["formatted_filename", "alt_text", "extracted_text"]
for i, clip_type in enumerate(
["clip_score_image_formatted_filename", "clip_score_image_alt_text", "clip_score_image_extracted_text"]
):
try:
curr_clip = media_info[clip_type]
clip_dict[clip_var_list[i]] = curr_clip
if curr_clip > clip_score_max_per_img:
clip_score_max_per_img = curr_clip
if curr_clip > 0:
clip_nbr_per_img += 1
except Exception:
pass
return clip_score_max_per_img, clip_nbr_per_img, clip_dict
def update_df_metrics_and_lists_for_extraction_method(
media_info, aggregate_metrics_df, image_centric_df, text_centric_df, extraction_method_name
):
"""_summary_
Given a Media_Node and the Extraction_Method_Name used to get this Media_Node,
this function uses the Media_Node's values to update the 2D Dataframes' numbers
and append values to the 3D Dataframes' lists.
"""
img_has_good_quality, img_width_plus_height = check_image_quality(media_info)
has_text, txt_dict = check_text(media_info)
aggregate_metrics_df[extraction_method_name]["images_nbr"] += 1
aggregate_metrics_df[extraction_method_name]["images_of_quality_nbr"] += 1 if img_has_good_quality else 0
aggregate_metrics_df[extraction_method_name]["images_with_txt_pair_nbr"] += 1 if has_text else 0
image_centric_df[extraction_method_name]["images_width_plus_height"].append(img_width_plus_height)
if use_clip_scores:
clip_score_max_per_img, clip_nbr_per_img, clip_dict = check_CLIP(media_info)
image_centric_df[extraction_method_name]["clip_max_per_img"].append(clip_score_max_per_img)
image_centric_df[extraction_method_name]["clip_nbr_per_img"].append(clip_nbr_per_img)
for key in ["formatted_filename", "alt_text", "extracted_text"]:
text_centric_df[extraction_method_name]["len_" + key].append(txt_dict[key])
text_centric_df[extraction_method_name]["clip_" + key].append(clip_dict[key])
else:
for key in ["formatted_filename", "alt_text", "extracted_text"]:
text_centric_df[extraction_method_name]["len_" + key].append(txt_dict[key])
text_centric_df[extraction_method_name]["clip_" + key].append(0)
return aggregate_metrics_df, image_centric_df, text_centric_df
def get_extraction_evaluation_metrics(
num_docs_to_consider=100,
use_clip_scores=True,
):
"""_summary_
Args:
num_docs_to_consider (int, optional): _description_. Defaults to 100.
use_clip_scores (bool, optional): _description_. Defaults to True.
Returns:
_type_: _description_
"""
dataset = load_dataset(
"bs-modeling-metadata/c4-en-html-with-metadata",
streaming=True,
split="train",
use_auth_token=True,
)
dataset = list(dataset.take(num_docs_to_consider))
dom_tree_simplificator = DOMTreeSimplificator(
strip_multiple_linebreaks=True,
strip_multiple_spaces=True,
remove_html_comments=True,
replace_line_break_tags=True,
unwrap_tags=True,
strip_tags=True,
strip_special_divs=True,
remove_dates=True,
remove_empty_leaves=True,
unnest_nodes=True,
remake_tree=True,
)
pre_extraction_simplificator = PreExtractionSimplificator(
only_text_image_nodes=True,
format_texts=True,
merge_consecutive_text_nodes=True,
)
extractor = TextMediaPairsExtractor(
dom_tree_simplificator=dom_tree_simplificator,
pre_extraction_simplificator=pre_extraction_simplificator,
also_extract_images_not_in_simplified_dom_tree=True,
extract_clip_scores=use_clip_scores,
)
# List of all extraction methods considered
extraction_methods_names = ["DOM", "Residual"]
# Time variable
avg_extraction_time = 0
# Dataframe containing metrics aggregated on the fly
aggregate_metrics_df = pd.DataFrame(
0,
columns=extraction_methods_names,
index=["images_nbr", "images_of_quality_nbr", "images_with_txt_pair_nbr"],
)
# Dataframe containing lists of datapoints collected for each image
image_centric_df = pd.DataFrame(
columns=extraction_methods_names, index=["clip_nbr_per_img", "clip_max_per_img", "images_width_plus_height"]
)
# Dataframe containing lists of datapoints collected for each text
text_centric_df = pd.DataFrame(
columns=extraction_methods_names,
index=[
"len_formatted_filename",
"len_alt_text",
"len_extracted_text",
"clip_formatted_filename",
"clip_alt_text",
"clip_extracted_text",
],
)
# Best way I found so far to set-up 3D dataframes
for col in image_centric_df:
for row in list(image_centric_df.index.values):
image_centric_df[col][row] = []
for col in text_centric_df:
for row in list(text_centric_df.index.values):
text_centric_df[col][row] = []
# For each page in the dataset, extract media content and update metrics
for i, example in enumerate(tqdm(dataset)):
html_str = example["html"]
url = example["url"]
start_time_extraction = time.time()
DOM_and_residual_content = extractor(html_str, url)
end_time_extraction = time.time()
avg_extraction_time += end_time_extraction - start_time_extraction
# Set-up all the different extractions methods and list them
# TODO: Add Filtered extraction_method
simple_DOM_tree_imgs = [
media_info for media_info in DOM_and_residual_content if media_info["image_in_simplified_dom_tree"]
]
residual_images = [
media_info for media_info in DOM_and_residual_content if not media_info["image_in_simplified_dom_tree"]
]
extraction_methods_dict = {
extraction_methods_names[0]: simple_DOM_tree_imgs,
extraction_methods_names[1]: residual_images,
}
# Each of those lists will lead to an update of the column
# corresponding to their extraction method in each dataframe
for extraction_method_name in extraction_methods_dict:
media_list = extraction_methods_dict[extraction_method_name]
for media_info in media_list:
(
aggregate_metrics_df,
image_centric_df,
text_centric_df,
) = update_df_metrics_and_lists_for_extraction_method(
media_info, aggregate_metrics_df, image_centric_df, text_centric_df, extraction_method_name
)
avg_extraction_time = avg_extraction_time / len(dataset)
return extraction_methods_names, aggregate_metrics_df, image_centric_df, text_centric_df, avg_extraction_time
class Visualization:
def __init__(self, num_docs, use_clip_scores=True):
self.num_docs = num_docs
self.use_clip_scores = use_clip_scores
(
self.extraction_methods_names,
self.aggregate_metrics_df,
self.image_centric_df,
self.text_centric_df,
self.avg_extraction_time,
) = get_extraction_evaluation_metrics(
num_docs_to_consider=self.num_docs,
use_clip_scores=self.use_clip_scores,
)
self.df_aggregate_metric_names = list(self.aggregate_metrics_df.index.values)
self.df_image_centric_metric_names = list(self.image_centric_df.index.values)
self.df_text_centric_metric_names = list(self.text_centric_df.index.values)
def visualize(self):
self.image_text_pair_recall()
self.images_quality_recall()
self.display_bar_charts(
header="Image quality and quantity comparisons",
list_metric_to_compare=["images_nbr", "images_of_quality_nbr", "images_width_plus_height"],
)
self.display_distribution_plot(
list_extraction_methods=self.extraction_methods_names,
list_metric_to_compare=["images_width_plus_height"],
title="Distribution of Images' Dimensions (width + height)",
bin_size=[50, 50],
max_value=2500,
)
if self.use_clip_scores:
self.display_distribution_plot(
list_extraction_methods=self.extraction_methods_names,
list_metric_to_compare=["clip_max_per_img"],
title="Distribution of Max CLIP Scores",
bin_size=[0.02, 0.02],
max_value=None,
)
self.display_distribution_plot(
list_extraction_methods=["DOM"],
list_metric_to_compare=["clip_formatted_filename", "clip_alt_text", "clip_extracted_text"],
title="Distribution of CLIP Score per Text Type in DOM extraction",
bin_size=[0.02, 0.02, 0.02],
max_value=None,
)
self.display_distribution_plot(
list_extraction_methods=["DOM"],
list_metric_to_compare=["len_formatted_filename", "len_alt_text", "len_extracted_text"],
title="Distribution of Text Types Lengths in DOM extraction",
bin_size=[5, 5, 5],
max_value=200,
)
def image_text_pair_recall(self):
st.header("Most important stats:")
avg_extraction_time = self.avg_extraction_time
DOM_Recall = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["DOM"]["images_with_txt_pair_nbr"],
denominator=self.aggregate_metrics_df["DOM"]["images_nbr"],
)
res_Recall = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["Residual"]["images_with_txt_pair_nbr"],
denominator=self.aggregate_metrics_df["Residual"]["images_nbr"],
)
DOM_General_Recall = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["DOM"]["images_with_txt_pair_nbr"],
denominator=self.aggregate_metrics_df["DOM"]["images_nbr"]
+ self.aggregate_metrics_df["Residual"]["images_nbr"],
)
st.write(f"Avg Extraction Time per page: {avg_extraction_time:.2f}sec")
col1, col2, col3 = st.columns(3)
col1.metric(label="DOM Recall over DOM Images", value=f"{DOM_Recall:.2f}%")
col2.metric(label="Residual Recall over Residual Images", value=f"{res_Recall:.2f}%")
col3.metric(label="DOM Recall over All Images", value=f"{DOM_General_Recall:.2f}%")
def images_quality_recall(self):
col1, col2, col3 = st.columns(3)
DOM_quality_images__DOM_images = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["DOM"]["images_of_quality_nbr"],
denominator=self.aggregate_metrics_df["DOM"]["images_nbr"],
)
res_quality_images__res_images = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["Residual"]["images_of_quality_nbr"],
denominator=self.aggregate_metrics_df["Residual"]["images_nbr"],
)
DOM_quality_images__All_images = self.get_percentage_ratio(
ratio_numerator=self.aggregate_metrics_df["DOM"]["images_of_quality_nbr"],
denominator=self.aggregate_metrics_df["DOM"]["images_of_quality_nbr"]
+ self.aggregate_metrics_df["Residual"]["images_of_quality_nbr"],
)
col1.metric(label="DOM Quality Images / DOM images", value=f"{DOM_quality_images__DOM_images:.2f}%")
col2.metric(label="Residual Quality Images / Residual images", value=f"{res_quality_images__res_images:.2f}%")
col3.metric(label="DOM Quality Images / All Quality images", value=f"{DOM_quality_images__All_images:.2f}%")
# Helper Methods
def get_percentage_ratio(self, ratio_numerator=0, denominator=1):
ratio_in_percentage = ratio_numerator / denominator * 100
return ratio_in_percentage
def get_bar_chart_from_aggregate_metrics(
self, metrics_values, metrics_names, x_label_categories, x_label, y_label, bar_size=30
):
df = pd.DataFrame(
{
y_label: metrics_values,
x_label: x_label_categories,
"Caption": metrics_names,
}
)
chart = alt.Chart(df).mark_bar(size=bar_size).encode(x=f"{x_label}:N", y=f"{y_label}:Q", color="Caption:N")
return chart
def get_chart_infos_from_3D_df(self, df, col_list, rows_list, reduction="mean"):
lists = []
names_list = []
for col in col_list:
for row in rows_list:
if reduction == "mean":
lists.append(np.mean(df[col][row]))
list_name = col + "_avg_" + row
names_list.append(list_name)
return np.array(lists), np.array(names_list)
def get_dist_infos_from_3D_df(self, df, col_list, rows_list, max_value=None):
lists = []
names_list = []
for col in col_list:
for row in rows_list:
if max_value:
list_to_append = [el if el < max_value else max_value for el in df[col][row]]
lists.append(list_to_append)
else:
lists.append(df[col][row])
list_name = col + "_" + row
names_list.append(list_name)
return lists, names_list
def display_distribution_plot(
self, list_extraction_methods, list_metric_to_compare, title=None, bin_size=[1, 1], max_value=None
):
if list_metric_to_compare[0] in self.df_image_centric_metric_names:
lists_of_metrics_to_plot, list_of_metric_names_to_plot = self.get_dist_infos_from_3D_df(
self.image_centric_df,
col_list=list_extraction_methods,
rows_list=list_metric_to_compare,
max_value=max_value,
)
elif list_metric_to_compare[0] in self.df_text_centric_metric_names:
lists_of_metrics_to_plot, list_of_metric_names_to_plot = self.get_dist_infos_from_3D_df(
self.text_centric_df,
col_list=list_extraction_methods,
rows_list=list_metric_to_compare,
max_value=max_value,
)
fig = ff.create_distplot(lists_of_metrics_to_plot, list_of_metric_names_to_plot, bin_size=bin_size)
if title:
fig.update_layout(title_text=title)
st.plotly_chart(fig, use_container_width=True)
def display_bar_charts(self, header, list_metric_to_compare):
"""
Given a list of metrics to compare, makes one bar chart per metric and compares over
all extraction methods.
Each bar chart has its own column, so it is better to put no more than 3 metrics.
"""
charts = []
for metric_to_compare in list_metric_to_compare:
if metric_to_compare in self.df_aggregate_metric_names:
chart = self.get_bar_chart_from_aggregate_metrics(
[self.aggregate_metrics_df[col][metric_to_compare] for col in self.aggregate_metrics_df],
[col + "_" + metric_to_compare for col in self.aggregate_metrics_df],
x_label_categories=self.extraction_methods_names,
x_label="Extraction methods",
y_label=metric_to_compare,
bar_size=30,
)
else:
if metric_to_compare in self.df_image_centric_metric_names:
metrics_list, metrics_name_list = self.get_chart_infos_from_3D_df(
self.image_centric_df, col_list=self.extraction_methods_names, rows_list=[metric_to_compare]
)
elif metric_to_compare in self.df_text_centric_metric_names:
metrics_list, metrics_name_list = self.get_chart_infos_from_3D_df(
self.text_centric_df, col_list=self.extraction_methods_names, rows_list=[metric_to_compare]
)
chart = self.get_bar_chart_from_aggregate_metrics(
metrics_list,
metrics_name_list,
x_label_categories=self.extraction_methods_names,
x_label="Extraction methods",
y_label=metric_to_compare,
bar_size=30,
)
charts.append(chart)
if header is not None:
st.header(header)
columns = st.columns(len(list_metric_to_compare))
for chart_idx, column in enumerate(columns):
with column:
st.altair_chart(charts[chart_idx], use_container_width=True)
if __name__ == "__main__":
num_docs = 10
use_clip_scores = True
visualization = Visualization(num_docs=num_docs, use_clip_scores=use_clip_scores)
st.set_page_config(layout="wide")
visualization.visualize()