video_processing/reference_video_similarity.py (166 lines of code) (raw):
import os
import torch
from transformers import SiglipVisionModel, SiglipImageProcessor
from PIL import Image
from tqdm import tqdm
import numpy as np
import argparse
import pandas as pd
from modules import get_frames
def compute_video_embedding(frames, model, preprocessor, device, dtype):
"""
Compute video embeddings. `frames` can either be frames of a single video or a list of list of
frames from multiple videos.
"""
if not frames:
return None
if isinstance(frames[0], list):
video_embeddings = []
flat_frames = []
video_lengths = []
for video in frames:
video_lengths.append(len(video))
flat_frames.extend(video)
all_input = preprocessor(images=flat_frames, return_tensors="pt").to(device)
with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
embeddings = model(**all_input).pooler_output
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
embeddings = embeddings.cpu()
# Group the embeddings back by video
index = 0
for length in video_lengths:
video_emb = embeddings[index : index + length].mean(dim=0)
video_emb = video_emb / video_emb.norm()
video_embeddings.append(video_emb.numpy())
index += length
return video_embeddings
else:
all_input = preprocessor(images=frames, return_tensors="pt").to(device)
with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
embeddings = model(**all_input).pooler_output
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
embeddings = embeddings.cpu()
video_embedding = embeddings.mean(dim=0)
video_embedding = video_embedding / video_embedding.norm()
return video_embedding.numpy()
def compute_image_embedding(image_path, model, preprocessor, device, dtype):
"""
Computes an embedding for a single image.
"""
image = Image.open(image_path).convert("RGB")
image_input = preprocessor(image, return_tensors="pt").to(device)
with torch.no_grad() and torch.autocast(torch.device(device).type, dtype=dtype):
embedding = model(**image_input).pooler_output
embedding = embedding / embedding.norm(dim=-1, keepdim=True)
return embedding.cpu().numpy().flatten()
def compute_reference_embedding(ref_path, model, preprocessor, device, dtype):
"""
Computes the embedding for a reference file (image or video).
"""
video_extensions = (".mp4", ".avi", ".mov", ".mkv")
if ref_path.lower().endswith(video_extensions):
frames = get_frames(ref_path)
frames = next(iter(frames))
frames = [frame.to_image() for frame in frames]
return compute_video_embedding(frames, model, preprocessor, device, dtype)
else:
return compute_image_embedding(ref_path, model, preprocessor, device, dtype)
@torch.no_grad()
@torch.inference_mode()
def main(args):
# List video files in the folder (supports common video extensions)
video_extensions = (".mp4", ".avi", ".mov", ".mkv")
video_files = [
os.path.join(args.videos_folder, f)
for f in os.listdir(args.videos_folder)
if f.lower().endswith(video_extensions)
]
print(f"Total video files: {len(video_files)}")
assert video_files
# Load model.
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
)
model = SiglipVisionModel.from_pretrained(
"google/siglip-so400m-patch14-384", attn_implementation="flash_attention_2"
).to(device)
preprocessor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
# Process each reference file and average their embeddings.
ref_embeddings = []
if os.path.isdir(args.reference):
allow_extensions = video_extensions + (".png", ".jpg", ".jpeg")
reference = [
os.path.join(args.reference, f) for f in os.listdir(args.reference) if f.endswith(allow_extensions)
]
else:
reference = args.reference.split(",")
assert reference
for ref in reference:
emb = compute_reference_embedding(ref, model, preprocessor, device, dtype)
if emb is not None:
ref_embeddings.append(emb)
else:
print(f"Could not compute embedding for reference: {ref}")
if len(ref_embeddings) == 0:
print("No valid reference embeddings found!")
return
ref_embedding = np.mean(ref_embeddings, axis=0)
ref_embedding = ref_embedding / np.linalg.norm(ref_embedding)
results = []
batch_frames = [] # To collect frames for a batch of videos
batch_paths = [] # To keep track of corresponding video paths
pbar = tqdm(video_files, desc="Computing video embeddings.")
for video_path in pbar:
pbar.set_postfix_str(f"{video_path}")
frames_generator = get_frames(video_path)
try:
frames_batch = next(iter(frames_generator))
except StopIteration:
print(f"Could not extract frames from {video_path}")
continue
frames = [frame.to_image() for frame in frames_batch]
if not frames:
print(f"Could not extract frames from {video_path}")
continue
frames = frames[: args.max_num_frames]
batch_frames.append(frames)
batch_paths.append(video_path)
if len(batch_frames) == args.batch_size:
video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
for path, video_embedding in zip(batch_paths, video_embeddings):
if video_embedding is not None:
similarity = np.dot(ref_embedding, video_embedding)
results.append((path.split("/")[-1], similarity))
batch_frames = []
batch_paths = []
# Remaining.
if batch_frames:
video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
for path, video_embedding in zip(batch_paths, video_embeddings):
if video_embedding is not None:
similarity = np.dot(ref_embedding, video_embedding)
results.append((path.split("/")[-1], similarity))
# Sort videos by similarity score (higher means more similar).
results.sort(key=lambda x: x[1], reverse=True)
# Write results to a parquet file.
df = pd.DataFrame(results, columns=["video_path", "similarity"])
df.to_parquet(args.parquet_out_path, index=False)
print(f"\nResults saved to {args.parquet_out_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--videos_folder",
type=str,
required=True,
help="Path to folder containing videos.",
)
parser.add_argument(
"--reference",
type=str,
required=True,
help="Reference image/video file(s).",
)
parser.add_argument(
"--max_num_frames",
type=int,
default=24,
help="Max number of frames per videos.",
)
parser.add_argument(
"--batch_size",
type=int,
default=16,
help="How many videos to process.",
)
parser.add_argument(
"--parquet_out_path",
type=str,
default="results.parquet",
help="Path to the output parquet file.",
)
args = parser.parse_args()
main(args)