def main()

in video_processing/reference_video_similarity.py [0:0]


def main(args):
    # List video files in the folder (supports common video extensions)
    video_extensions = (".mp4", ".avi", ".mov", ".mkv")
    video_files = [
        os.path.join(args.videos_folder, f)
        for f in os.listdir(args.videos_folder)
        if f.lower().endswith(video_extensions)
    ]
    print(f"Total video files: {len(video_files)}")
    assert video_files

    # Load model.
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = (
        torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
    )
    model = SiglipVisionModel.from_pretrained(
        "google/siglip-so400m-patch14-384", attn_implementation="flash_attention_2"
    ).to(device)
    preprocessor = SiglipImageProcessor.from_pretrained("google/siglip-so400m-patch14-384")

    # Process each reference file and average their embeddings.
    ref_embeddings = []
    if os.path.isdir(args.reference):
        allow_extensions = video_extensions + (".png", ".jpg", ".jpeg")
        reference = [
            os.path.join(args.reference, f) for f in os.listdir(args.reference) if f.endswith(allow_extensions)
        ]
    else:
        reference = args.reference.split(",")

    assert reference

    for ref in reference:
        emb = compute_reference_embedding(ref, model, preprocessor, device, dtype)
        if emb is not None:
            ref_embeddings.append(emb)
        else:
            print(f"Could not compute embedding for reference: {ref}")

    if len(ref_embeddings) == 0:
        print("No valid reference embeddings found!")
        return

    ref_embedding = np.mean(ref_embeddings, axis=0)
    ref_embedding = ref_embedding / np.linalg.norm(ref_embedding)

    results = []
    batch_frames = []  # To collect frames for a batch of videos
    batch_paths = []  # To keep track of corresponding video paths
    pbar = tqdm(video_files, desc="Computing video embeddings.")

    for video_path in pbar:
        pbar.set_postfix_str(f"{video_path}")

        frames_generator = get_frames(video_path)
        try:
            frames_batch = next(iter(frames_generator))
        except StopIteration:
            print(f"Could not extract frames from {video_path}")
            continue

        frames = [frame.to_image() for frame in frames_batch]
        if not frames:
            print(f"Could not extract frames from {video_path}")
            continue

        frames = frames[: args.max_num_frames]
        batch_frames.append(frames)
        batch_paths.append(video_path)

        if len(batch_frames) == args.batch_size:
            video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
            for path, video_embedding in zip(batch_paths, video_embeddings):
                if video_embedding is not None:
                    similarity = np.dot(ref_embedding, video_embedding)
                    results.append((path.split("/")[-1], similarity))
            batch_frames = []
            batch_paths = []

    # Remaining.
    if batch_frames:
        video_embeddings = compute_video_embedding(batch_frames, model, preprocessor, device, dtype)
        for path, video_embedding in zip(batch_paths, video_embeddings):
            if video_embedding is not None:
                similarity = np.dot(ref_embedding, video_embedding)
                results.append((path.split("/")[-1], similarity))

    # Sort videos by similarity score (higher means more similar).
    results.sort(key=lambda x: x[1], reverse=True)

    # Write results to a parquet file.
    df = pd.DataFrame(results, columns=["video_path", "similarity"])
    df.to_parquet(args.parquet_out_path, index=False)

    print(f"\nResults saved to {args.parquet_out_path}")