def compute_video_embedding()

in video_processing/reference_video_similarity.py [0:0]


def compute_video_embedding(frames, model, preprocessor, device, dtype):
    """
    Compute video embeddings. `frames` can either be frames of a single video or a list of list of
    frames from multiple videos.
    """
    if not frames:
        return None

    if isinstance(frames[0], list):
        video_embeddings = []
        flat_frames = []
        video_lengths = []

        for video in frames:
            video_lengths.append(len(video))
            flat_frames.extend(video)

        all_input = preprocessor(images=flat_frames, return_tensors="pt").to(device)
        with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
            embeddings = model(**all_input).pooler_output
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
        embeddings = embeddings.cpu()

        # Group the embeddings back by video
        index = 0
        for length in video_lengths:
            video_emb = embeddings[index : index + length].mean(dim=0)
            video_emb = video_emb / video_emb.norm()
            video_embeddings.append(video_emb.numpy())
            index += length

        return video_embeddings
    else:
        all_input = preprocessor(images=frames, return_tensors="pt").to(device)
        with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
            embeddings = model(**all_input).pooler_output
            embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
        embeddings = embeddings.cpu()

        video_embedding = embeddings.mean(dim=0)
        video_embedding = video_embedding / video_embedding.norm()
        return video_embedding.numpy()