vision/m4/sourcing/data_collection/visualization/pair_visualization.py (566 lines of code) (raw):
import time
import plotly.figure_factory as ff
import streamlit as st
import yaml
from datasets import load_dataset
from humanfriendly import format_timespan
from m4.sourcing.data_collection.processors import (
DOMTreeSimplificator,
PairFiltering,
PreExtractionSimplificator,
TextMediaPairsExtractor,
)
class Visualization:
def __init__(self, path_config_filter_text_image_pairs):
with open(path_config_filter_text_image_pairs) as f:
self.filtering_params = yaml.load(f, Loader=yaml.FullLoader)
def visualization(self):
self.set_title()
self.choose_extraction_options()
self.extraction()
self.statistics_without_filtering()
self.choose_filtering_options()
self.filtering()
self.get_statistics_on_extracted_pairs()
self.analysis_discarded_pairs()
def set_title(self):
st.title("Visualization to help balance precision and recall of the extraction strategy")
def choose_extraction_options(self):
st.header("Extraction options")
self.num_docs = st.number_input(
"Number of documents to perform the extraction on", min_value=1, max_value=1000, value=100
)
self.should_compute_clip_scores = st.checkbox(
"Compute CLIP scores during the extraction (Warning: way slower when checked)", value=False
)
def extraction(self):
self.extracted_pairs = Visualization.get_extracted_pairs(
num_docs=self.num_docs,
should_compute_clip_scores=self.should_compute_clip_scores,
)
@staticmethod
@st.cache(suppress_st_warning=True)
def get_extracted_pairs(num_docs, should_compute_clip_scores):
st.header("Extraction")
def load_examples(num_docs):
dataset = load_dataset(
"bs-modeling-metadata/c4-en-html-with-metadata",
streaming=True,
split="train",
use_auth_token=True,
)
return list(dataset.take(num_docs))
docs = load_examples(num_docs)
dom_tree_simplificator = DOMTreeSimplificator(
strip_multiple_linebreaks=True,
strip_multiple_spaces=True,
remove_html_comments=True,
replace_line_break_tags=True,
unwrap_tags=True,
strip_tags=True,
strip_special_divs=True,
remove_dates=True,
remove_empty_leaves=True,
unnest_nodes=True,
remake_tree=True,
)
pre_extraction_simplificator = PreExtractionSimplificator(
only_text_image_nodes=True,
format_texts=True,
merge_consecutive_text_nodes=True,
)
extractor = TextMediaPairsExtractor(
dom_tree_simplificator=dom_tree_simplificator,
pre_extraction_simplificator=pre_extraction_simplificator,
also_extract_images_not_in_simplified_dom_tree=True,
extract_clip_scores=should_compute_clip_scores,
)
st.markdown("Extraction progress bar")
extraction_progress_bar = st.progress(0.0)
def extract_pairs(doc, ind):
extracted_pairs = extractor(html_str=doc["html"], page_url=doc["url"])
extraction_progress_bar.progress((ind + 1) / num_docs)
return extracted_pairs
start_extraction_time = time.time()
extracted_pairs = [extract_pairs(doc, ind) for ind, doc in enumerate(docs)]
extracted_pairs = [sub_el for el in extracted_pairs for sub_el in el]
end_extraction_time = time.time()
extraction_time = format_timespan(round(end_extraction_time - start_extraction_time))
st.markdown(f"Extraction done in {extraction_time}")
st.balloons()
return extracted_pairs
@staticmethod
def plot_distributions(hist_data, group_labels, bin_size, title):
"""hist_data is a list of statistics lists (works up to 3, if more is needed, add some colors
and annotation colors)"""
def check_same_number_list(list_):
# Useful to check if a matrix is singular, otherwise Plotly raises an error
# Lists without any element or with only one are automatically discarded
if not list_:
return True
num = list_[0]
for el in list_:
if el != num:
return False
return True
colors = ["#200CCF", "#DB2A2A", "#0CCF6D"][: len(hist_data)]
annotation_color = ["#11066B", "#B30707", "#0B7540"][: len(hist_data)]
count_del = 0
for i in range(len(hist_data)):
if check_same_number_list(hist_data[i - count_del]):
del hist_data[i - count_del]
del group_labels[i - count_del]
del colors[i - count_del]
del annotation_color[i - count_del]
count_del += 1
fig = ff.create_distplot(hist_data, group_labels, colors=colors, bin_size=bin_size, show_rug=False)
fig["layout"].update(title=title)
for data, color in zip(hist_data, annotation_color):
fig.add_vline(
x=sum(data) / len(data), line_width=2, line_dash="dash", line_color=color, annotation_text="x̄"
)
st.plotly_chart(fig, use_container_width=True)
@staticmethod
def truncate(list_, max_val):
return [el if el < max_val else max_val for el in list_]
def statistics_without_filtering(self):
st.header("Statistics (without filtering)")
def image_sizes(
max_width=1_000,
bin_size_width=10,
max_height=1_000,
bin_size_height=10,
max_num_pixels=1_000_000,
bin_size_num_pixels=1_000,
):
st.subheader("Image sizes")
col1, col2, col3 = st.columns(3)
with col1:
original_widths = Visualization.truncate(
[pair["original_width"] for pair in self.extracted_pairs], max_val=max_width
)
Visualization.plot_distributions(
[original_widths],
["All images"],
bin_size=bin_size_width,
title="Distribution of original widths of images",
)
with col2:
original_heights = Visualization.truncate(
[pair["original_height"] for pair in self.extracted_pairs], max_val=max_height
)
Visualization.plot_distributions(
[original_heights],
["All pairs"],
bin_size=bin_size_height,
title="Distribution of original heights of images",
)
with col3:
original_num_pixels = Visualization.truncate(
[pair["original_width"] * pair["original_height"] for pair in self.extracted_pairs],
max_val=max_num_pixels,
)
Visualization.plot_distributions(
[original_num_pixels],
["All images"],
bin_size=bin_size_num_pixels,
title="Distribution of numbers of pixels of images",
)
def text_lengths():
st.subheader("Text lengths")
Visualization.plot_distributions(
[
[len(pair[text_key].split(" ")) for pair in self.extracted_pairs if text_key in pair]
for text_key in ["formatted_filename", "alt_text", "extracted_text"]
],
["Formatted filename", "Alt text", "Extracted text"],
bin_size=1,
title="Distribution of numbers of words",
)
def clip_scores():
if self.should_compute_clip_scores:
st.subheader("CLIP scores")
Visualization.plot_distributions(
[
[
pair[f"clip_score_image_{text_key}"]
for pair in self.extracted_pairs
if f"clip_score_image_{text_key}" in pair
]
for text_key in ["formatted_filename", "alt_text", "extracted_text"]
],
["Formatted filename", "Alt text", "Extracted text"],
bin_size=0.02,
title="Distribution of CLIP scores",
)
image_sizes()
text_lengths()
clip_scores()
def choose_filtering_options(self):
st.header("Filtering options")
text_keys = ["Formatted filename", "Alt text", "Extracted text"]
text_key = st.selectbox("Choose the type of text to pair with images", text_keys, index=2)
self.text_key = text_key.lower().replace(" ", "_")
st.write("-----")
self.should_remove_images_not_in_simplified_dom_trees = st.checkbox(
"Remove images not in simplified DOM trees", value=False
)
st.write("-----")
self.should_remove_images_not_in_valid_formats = st.checkbox("Remove images not in valid formats", value=False)
if self.should_remove_images_not_in_valid_formats:
self.valid_formats = st.multiselect(
"Valid formats",
options=list(self.filtering_params["valid_formats"]),
default=self.filtering_params["valid_formats"],
)
st.write("-----")
self.should_remove_images_not_in_valid_sizes = st.checkbox("Remove images not in valid sizes", value=False)
if self.should_remove_images_not_in_valid_sizes:
col1, col2, col3, col4 = st.columns(4)
with col1:
self.original_width_min_cutoff = st.number_input(
"Minimum original width",
min_value=1,
max_value=None,
value=self.filtering_params["original_width_min_cutoff"],
step=1,
)
self.rendered_width_min_cutoff = st.number_input(
"Minimum rendered width",
min_value=1,
max_value=None,
value=self.filtering_params["rendered_width_min_cutoff"],
step=1,
)
with col2:
self.original_width_max_cutoff = st.number_input(
"Maximum original width",
min_value=1,
max_value=None,
value=self.filtering_params["original_width_max_cutoff"],
step=1,
)
self.rendered_width_max_cutoff = st.number_input(
"Maximum rendered width",
min_value=1,
max_value=None,
value=self.filtering_params["rendered_width_max_cutoff"],
step=1,
)
with col3:
self.original_height_min_cutoff = st.number_input(
"Minimum original height",
min_value=1,
max_value=None,
value=self.filtering_params["original_height_min_cutoff"],
step=1,
)
self.rendered_height_min_cutoff = st.number_input(
"Minimum rendered height",
min_value=1,
max_value=None,
value=self.filtering_params["rendered_height_min_cutoff"],
step=1,
)
with col4:
self.original_height_max_cutoff = st.number_input(
"Maximum original height",
min_value=1,
max_value=None,
value=self.filtering_params["original_height_max_cutoff"],
step=1,
)
self.rendered_height_max_cutoff = st.number_input(
"Maximum rendered height",
min_value=1,
max_value=None,
value=self.filtering_params["rendered_height_max_cutoff"],
step=1,
)
self.aspect_ratio_max_cutoff = st.number_input(
"Maximum aspect ratio",
min_value=1.0,
max_value=None,
value=float(self.filtering_params["aspect_ratio_max_cutoff"]),
step=0.5,
)
st.write("-----")
self.should_remove_texts_not_in_valid_number_words = st.checkbox(
"Remove texts not having a valid number of words", value=False
)
if self.should_remove_texts_not_in_valid_number_words:
col1, col2 = st.columns(2)
with col1:
self.number_words_min_cutoff = st.number_input(
"Minimum number of words",
min_value=0,
max_value=None,
value=self.filtering_params["number_words_min_cutoff"],
step=1,
)
with col2:
self.number_words_max_cutoff = st.number_input(
"Maximum number of words",
min_value=0,
max_value=None,
value=self.filtering_params["number_words_max_cutoff"],
step=1,
)
st.write("-----")
self.should_remove_texts_with_too_high_special_character_ratio = st.checkbox(
"Remove texts with a too high special character ratio", value=False
)
if self.should_remove_texts_with_too_high_special_character_ratio:
self.special_character_ratio_max_cutoff = st.number_input(
"Maximum special character ratio",
min_value=0.0,
max_value=1.0,
value=self.filtering_params["special_character_ratio_max_cutoff"],
step=0.01,
)
st.write("-----")
self.should_remove_texts_with_too_high_repetition_ratio = st.checkbox(
"Remove texts with a too high repetition ratio", value=False
)
if self.should_remove_texts_with_too_high_repetition_ratio:
self.repetition_ratio_max_cutoff = st.number_input(
"Maximum repetition ratio",
min_value=0.0,
max_value=1.0,
value=self.filtering_params["repetition_ratio_max_cutoff"],
step=0.01,
)
st.write("-----")
self.should_remove_pairs_with_too_low_clip_score = st.checkbox(
"Remove pairs with a too low CLIP score", value=False
)
if self.should_remove_pairs_with_too_low_clip_score:
self.clip_score_min_cutoff = st.number_input(
"Minimum CLIP score",
min_value=0.0,
max_value=1.0,
value=self.filtering_params["clip_score_min_cutoff"],
step=0.01,
)
def filtering(self):
def should_keep_pair(pair):
# pair = media_info
if self.text_key not in pair:
return False
if self.should_remove_images_not_in_simplified_dom_trees:
if not PairFiltering.check_image_in_simplified_dom_tree(pair):
return False
if self.should_remove_images_not_in_valid_formats:
if not PairFiltering.check_format(pair, self.valid_formats):
return False
if self.should_remove_images_not_in_valid_sizes:
if not PairFiltering.check_size_image(
pair,
self.original_width_min_cutoff,
self.original_width_max_cutoff,
self.original_height_min_cutoff,
self.original_height_max_cutoff,
self.rendered_width_min_cutoff,
self.rendered_width_max_cutoff,
self.rendered_height_min_cutoff,
self.rendered_height_max_cutoff,
self.aspect_ratio_max_cutoff,
):
return False
if self.should_remove_texts_not_in_valid_number_words:
if not PairFiltering.check_number_words(
pair, self.text_key, self.number_words_min_cutoff, self.number_words_max_cutoff
):
return False
if self.should_remove_texts_with_too_high_special_character_ratio:
if not PairFiltering.check_special_character_ratio(
pair, self.text_key, self.special_character_ratio_max_cutoff
):
return False
if self.should_remove_texts_with_too_high_repetition_ratio:
if not PairFiltering.check_repetition_ratio(pair, self.text_key, self.repetition_ratio_max_cutoff):
return False
if self.should_remove_pairs_with_too_low_clip_score:
if not PairFiltering.check_clip_score(pair, self.text_key, self.clip_score_min_cutoff):
return False
return True
all_pairs = [[pair, should_keep_pair(pair)] for pair in self.extracted_pairs]
self.retained_pairs = [pair for pair, keep_pair in all_pairs if keep_pair]
self.discarded_pairs = [pair for pair, keep_pair in all_pairs if not keep_pair]
def get_statistics_on_extracted_pairs(self):
st.header("Statistics for retained and discarded pairs")
def number_pairs():
st.subheader("Number of pairs")
st.markdown(
"*Retained pairs*: "
f"**{len(self.retained_pairs)}/{len(self.extracted_pairs)} "
f"({round(len(self.retained_pairs)/len(self.extracted_pairs)*100, 1)}%)**"
)
st.markdown(
"*Discarded pairs*: "
f"**{len(self.discarded_pairs)}/{len(self.extracted_pairs)} "
f"({round(len(self.discarded_pairs)/len(self.extracted_pairs)*100, 1)}%)**"
)
def image_sizes(
max_width=1_000,
bin_size_width=10,
max_height=1_000,
bin_size_height=10,
max_num_pixels=1_000_000,
bin_size_num_pixels=1_000,
):
st.subheader("Image sizes")
col1, col2, col3 = st.columns(3)
with col1:
original_widths_retained_pairs = Visualization.truncate(
[pair["original_width"] for pair in self.retained_pairs], max_val=max_width
)
original_widths_discarded_pairs = Visualization.truncate(
[pair["original_width"] for pair in self.discarded_pairs], max_val=max_width
)
Visualization.plot_distributions(
[original_widths_retained_pairs, original_widths_discarded_pairs],
["Retained pairs", "Discarded pairs"],
bin_size=bin_size_width,
title="Distribution of original widths of images",
)
with col2:
original_heights_retained_pairs = Visualization.truncate(
[pair["original_height"] for pair in self.retained_pairs], max_val=max_height
)
original_heights_discarded_pairs = Visualization.truncate(
[pair["original_height"] for pair in self.discarded_pairs], max_val=max_height
)
Visualization.plot_distributions(
[original_heights_retained_pairs, original_heights_discarded_pairs],
["Retained pairs", "Discarded pairs"],
bin_size=bin_size_height,
title="Distribution of original heights of images",
)
with col3:
original_num_pixels_retained_pairs = Visualization.truncate(
[pair["original_width"] * pair["original_height"] for pair in self.retained_pairs],
max_val=max_num_pixels,
)
original_num_pixels_discarded_pairs = Visualization.truncate(
[pair["original_width"] * pair["original_height"] for pair in self.discarded_pairs],
max_val=max_num_pixels,
)
Visualization.plot_distributions(
[original_num_pixels_retained_pairs, original_num_pixels_discarded_pairs],
["Retained pairs", "Discarded pairs"],
bin_size=bin_size_num_pixels,
title="Distribution of numbers of pixels of images",
)
def text_lengths():
st.subheader("Text lengths")
num_words_retained_pairs = [
len(pair[self.text_key].split(" ")) for pair in self.retained_pairs if self.text_key in pair
]
num_words_discarded_pairs = [
len(pair[self.text_key].split(" ")) for pair in self.discarded_pairs if self.text_key in pair
]
Visualization.plot_distributions(
[num_words_retained_pairs, num_words_discarded_pairs],
["Retained pairs", "Discarded pairs"],
bin_size=1,
title=f"Distribution of numbers of words in the {self.text_key.replace('_', ' ')}",
)
def clip_scores():
if self.should_compute_clip_scores:
st.subheader("CLIP scores")
clip_scores_retained_pairs = [
pair[f"clip_score_image_{self.text_key}"]
for pair in self.retained_pairs
if f"clip_score_image_{self.text_key}" in pair
]
clip_scores_discarded_pairs = [
pair[f"clip_score_image_{self.text_key}"]
for pair in self.discarded_pairs
if f"clip_score_image_{self.text_key}" in pair
]
Visualization.plot_distributions(
[clip_scores_retained_pairs, clip_scores_discarded_pairs],
["Retained pairs", "Discarded pairs"],
bin_size=0.02,
title=f"Distribution of CLIP scores for the {self.text_key.replace('_', ' ')}",
)
number_pairs()
image_sizes()
text_lengths()
clip_scores()
def analysis_discarded_pairs(self):
num_discarded_tot = len(self.discarded_pairs)
perc_discarded_tot = round(num_discarded_tot / len(self.extracted_pairs) * 100, 1)
st.header(
f"Analysis of discarded pairs: {num_discarded_tot}/{len(self.extracted_pairs)} ({perc_discarded_tot}%)"
)
if not self.discarded_pairs:
st.markdown("No pair discarded")
else:
num_discarded_filter = len([1 for pair in self.discarded_pairs if self.text_key not in pair])
perc_discarded_filter = round(num_discarded_filter / num_discarded_tot * 100, 1)
st.markdown(
"Discarded because of the *chosen type of text not being in pairs*:"
f" **{num_discarded_filter}/{num_discarded_tot} ({perc_discarded_filter}%)**"
)
def display_discarded_by_filter(should_use_filter, func_filter, msg_filter):
if should_use_filter:
num_discarded_filter = len([1 for pair in self.discarded_pairs if not func_filter(pair)])
perc_discarded_filter = round(num_discarded_filter / num_discarded_tot * 100, 1)
st.markdown(
f"Discarded by the filter on *{msg_filter}*:"
f" **{num_discarded_filter}/{num_discarded_tot} ({perc_discarded_filter}%)**"
)
display_discarded_by_filter(
self.should_remove_images_not_in_simplified_dom_trees,
lambda pair: PairFiltering.check_image_in_simplified_dom_tree(pair),
"not being in simplified DOM trees",
)
display_discarded_by_filter(
self.should_remove_images_not_in_valid_formats,
lambda pair: PairFiltering.check_format(pair, self.valid_formats),
"not being in valid formats",
)
display_discarded_by_filter(
self.should_remove_images_not_in_valid_sizes,
lambda pair: PairFiltering.check_size_image(
pair,
self.original_width_min_cutoff,
self.original_width_max_cutoff,
self.original_height_min_cutoff,
self.original_height_max_cutoff,
self.rendered_width_min_cutoff,
self.rendered_width_max_cutoff,
self.rendered_height_min_cutoff,
self.rendered_height_max_cutoff,
self.aspect_ratio_max_cutoff,
),
"not being in valid image sizes",
)
display_discarded_by_filter(
self.should_remove_texts_not_in_valid_number_words,
lambda pair: PairFiltering.check_number_words(
pair, self.text_key, self.number_words_min_cutoff, self.number_words_max_cutoff
),
"not having a valid number of words",
)
display_discarded_by_filter(
self.should_remove_texts_with_too_high_special_character_ratio,
lambda pair: PairFiltering.check_special_character_ratio(
pair, self.text_key, self.special_character_ratio_max_cutoff
),
"having a too high special character ratio",
)
display_discarded_by_filter(
self.should_remove_texts_with_too_high_repetition_ratio,
lambda pair: PairFiltering.check_repetition_ratio(
pair, self.text_key, self.repetition_ratio_max_cutoff
),
"having a too high repetition ratio",
)
display_discarded_by_filter(
self.should_remove_pairs_with_too_low_clip_score,
lambda pair: PairFiltering.check_clip_score(pair, self.text_key, self.clip_score_min_cutoff),
"having a too low CLIP score",
)
if __name__ == "__main__":
st.set_page_config(layout="wide")
path_config_filter_text_image_pairs = "./m4/sourcing/data_collection/configs/config_filter_text_image_pairs.yaml"
visualization = Visualization(path_config_filter_text_image_pairs=path_config_filter_text_image_pairs)
visualization.visualization()