build_obelics/07_02_nsfw_image_visualization.py (68 lines of code) (raw):

import base64 from io import BytesIO import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st from datasets import load_from_disk from PIL import Image if __name__ == "__main__": st.set_page_config(layout="wide") st.title("Visualization to help choosing the NSFW filtering parameters") path_image_dataset_with_nsfw_scores = ( # Find at s3://m4-datasets/trash/image_dataset_25k_with_nsfw_scores/ "./large_files/image_dataset_25k_with_nsfw_scores" ) image_dataset_with_nsfw_scores = load_from_disk(path_image_dataset_with_nsfw_scores) st.header("Number of images to consider") num_considered_examples = st.number_input( "Choose the number of images to consider", min_value=0, max_value=image_dataset_with_nsfw_scores.num_rows, value=1_000, help=f"Enter a number between 0 and {image_dataset_with_nsfw_scores.num_rows}", ) image_dataset_with_nsfw_scores = image_dataset_with_nsfw_scores.select(range(num_considered_examples)) stats = {} stats["images"] = image_dataset_with_nsfw_scores["image"] stats["hentai_score"] = [ round(nsfw_scores_["hentai"], 2) for nsfw_scores_ in image_dataset_with_nsfw_scores["nsfw_scores"] ] stats["porn_score"] = [ round(nsfw_scores_["porn"], 2) for nsfw_scores_ in image_dataset_with_nsfw_scores["nsfw_scores"] ] stats["sexy_score"] = [ round(nsfw_scores_["sexy"], 2) for nsfw_scores_ in image_dataset_with_nsfw_scores["nsfw_scores"] ] st.header("Distribution of the statistics") bins = st.number_input("Number of bins", min_value=0, max_value=100, value=25) fig, ax = plt.subplots(len(stats) - 1, 1) # -1 for the key "images" which is not a statistic i = 0 for stat_name, stat_list in stats.items(): if stat_name == "images": continue truncated_stat_list = np.sort(stat_list) ax[i].hist(truncated_stat_list, bins=bins) ax[i].set_title(f"{stat_name}") i += 1 fig.suptitle("Histograms of the statistics") fig.set_figheight(3 * (len(stats) - 1)) st.pyplot(fig) st.header("A closer look at the data") stat_sort = st.selectbox("Descending sort by", options=["hentai score", "porn score", "sexy score"]) stat_sort = stat_sort.replace(" ", "_") idx_sort = np.argsort(stats[stat_sort])[::-1].tolist() stats["images"] = [stats["images"][idx] for idx in idx_sort] stats["hentai_score"] = [stats["hentai_score"][idx] for idx in idx_sort] stats["porn_score"] = [stats["porn_score"][idx] for idx in idx_sort] stats["sexy_score"] = [stats["sexy_score"][idx] for idx in idx_sort] def transform_img(img): img = Image.open(BytesIO(img)) img.thumbnail((50, 50)) with BytesIO() as buffer: img.save(buffer, "png") base_64_encoding = base64.b64encode(buffer.getvalue()).decode() return f'<img src="data:image/png;base64,{base_64_encoding}">' stats["images"] = [transform_img(img) for img in stats.pop("images")] data_frame = pd.DataFrame(stats) html_data_frame = data_frame.to_html(escape=False) st.markdown(html_data_frame, unsafe_allow_html=True)