video_processing/add_nsfw_score.py (33 lines of code) (raw):

import pandas as pd import pathlib from PIL import Image from argparse import ArgumentParser from tqdm import tqdm from modules import load_nsfw, run_nsfw, separate_key_frames_from_row parser = ArgumentParser() parser.add_argument("--path", type=str, required=True) parser.add_argument("--parquet-path", type=str, required=True) parser.add_argument("--parquet-out-path", type=str, required=True) parser.add_argument("--device", type=str, required=True) args = parser.parse_args() path = pathlib.Path(args.path) parquet_path = pathlib.Path(args.parquet_path) parquet_out_path = pathlib.Path(args.parquet_out_path) device = args.device load_nsfw(device) df = pd.read_parquet(parquet_path) data = [] with tqdm() as pbar: for _, row in df.iterrows(): pbar.set_description(row["file"]) key_frames, first, mid, last = separate_key_frames_from_row(path, row) pbar.set_postfix_str(f"{len(key_frames)} key frames") frames = [frame for frame in [first, mid, last] if frame is not None] labels = [label for label in run_nsfw(frames)] data.append({"nsfw_status": labels}) pbar.update() nsfw_df = pd.DataFrame(data) print(nsfw_df) df = df.join(nsfw_df) print(df) df.to_parquet(parquet_out_path)