video_processing/add_captions.py (69 lines of code) (raw):

import pandas as pd import pathlib from PIL import Image from argparse import ArgumentParser from tqdm import tqdm from modules import run, load_florence, separate_key_frames_from_row parser = ArgumentParser() parser.add_argument("--path", type=str, required=True) parser.add_argument("--parquet-path", type=str, required=True) parser.add_argument("--parquet-out-path", type=str, required=True) parser.add_argument("--device", type=str, required=True) parser.add_argument("--dtype", type=str, required=True) args = parser.parse_args() path = pathlib.Path(args.path) parquet_path = pathlib.Path(args.parquet_path) parquet_out_path = pathlib.Path(args.parquet_out_path) device = args.device dtype = args.dtype load_florence( hf_hub_or_path="microsoft/Florence-2-large", device=device, dtype=dtype, ) df = pd.read_parquet(parquet_path) task_prompt = [ "<DENSE_REGION_CAPTION>", "<OCR_WITH_REGION>", "<CAPTION>", "<DETAILED_CAPTION>", ] data = [] with tqdm() as pbar: for _, row in df.iterrows(): pbar.set_description(row["file"]) key_frames, first, mid, last = separate_key_frames_from_row(path, row) pbar.set_postfix_str(f"{len(key_frames)} key frames") frames = [first] first = run(first, task_prompt=task_prompt) caption = [first["<CAPTION>"]] detailed_caption = [first["<DETAILED_CAPTION>"]] region_caption = [first["<DENSE_REGION_CAPTION>"]] ocr_region = [first["<OCR_WITH_REGION>"]] if mid: frames.append(mid) mid = run(mid, task_prompt=task_prompt) caption.append(mid["<CAPTION>"]) detailed_caption.append(mid["<DETAILED_CAPTION>"]) region_caption.append(mid["<DENSE_REGION_CAPTION>"]) ocr_region.append(mid["<OCR_WITH_REGION>"]) if last: frames.append(last) last = run(last, task_prompt=task_prompt) caption.append(last["<CAPTION>"]) detailed_caption.append(last["<DETAILED_CAPTION>"]) region_caption.append(last["<DENSE_REGION_CAPTION>"]) ocr_region.append(last["<OCR_WITH_REGION>"]) row = { "caption": caption, "detailed_caption": detailed_caption, "region_caption": region_caption, "ocr": ocr_region, } data.append(row) pbar.update() caption_df = pd.DataFrame(data) print(caption_df) df = df.join(caption_df) print(df) df.to_parquet(parquet_out_path)