in video_processing/reference_video_similarity.py [0:0]
def main(args):
# List video files in the folder (supports common video extensions)
video_extensions = (".mp4", ".avi", ".mov", ".mkv")
video_files = [
os.path.join(args.videos_folder, f)
for f in os.listdir(args.videos_folder)
if f.lower().endswith(video_extensions)
]
print(f"Total video files: {len(video_files)}")
assert video_files
# Load model.
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = (
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
)
model = SiglipVisionModel.from_pretrained(
"google/siglip-so400m-patch14-384", attn_implementation="flash_attention_2"
).to(device)
preprocessor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")
# Process each reference file and average their embeddings.
ref_embeddings = []
if os.path.isdir(args.reference):
allow_extensions = video_extensions + (".png", ".jpg", ".jpeg")
reference = [
os.path.join(args.reference, f) for f in os.listdir(args.reference) if f.endswith(allow_extensions)
]
else:
reference = args.reference.split(",")
assert reference
for ref in reference:
emb = compute_reference_embedding(ref, model, preprocessor, device, dtype)
if emb is not None:
ref_embeddings.append(emb)
else:
print(f"Could not compute embedding for reference: {ref}")
if len(ref_embeddings) == 0:
print("No valid reference embeddings found!")
return
ref_embedding = np.mean(ref_embeddings, axis=0)
ref_embedding = ref_embedding / np.linalg.norm(ref_embedding)
results = []
batch_frames = [] # To collect frames for a batch of videos
batch_paths = [] # To keep track of corresponding video paths
pbar = tqdm(video_files, desc="Computing video embeddings.")
for video_path in pbar:
pbar.set_postfix_str(f"{video_path}")
frames_generator = get_frames(video_path)
try:
frames_batch = next(iter(frames_generator))
except StopIteration:
print(f"Could not extract frames from {video_path}")
continue
frames = [frame.to_image() for frame in frames_batch]
if not frames:
print(f"Could not extract frames from {video_path}")
continue
frames = frames[: args.max_num_frames]
batch_frames.append(frames)
batch_paths.append(video_path)
if len(batch_frames) == args.batch_size:
video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
for path, video_embedding in zip(batch_paths, video_embeddings):
if video_embedding is not None:
similarity = np.dot(ref_embedding, video_embedding)
results.append((path.split("/")[-1], similarity))
batch_frames = []
batch_paths = []
# Remaining.
if batch_frames:
video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
for path, video_embedding in zip(batch_paths, video_embeddings):
if video_embedding is not None:
similarity = np.dot(ref_embedding, video_embedding)
results.append((path.split("/")[-1], similarity))
# Sort videos by similarity score (higher means more similar).
results.sort(key=lambda x: x[1], reverse=True)
# Write results to a parquet file.
df = pd.DataFrame(results, columns=["video_path", "similarity"])
df.to_parquet(args.parquet_out_path, index=False)
print(f"\nResults saved to {args.parquet_out_path}")