in video_processing/reference_video_similarity.py [0:0]
def compute_video_embedding(frames, model, preprocessor, device, dtype):
"""
Compute video embeddings. `frames` can either be frames of a single video or a list of list of
frames from multiple videos.
"""
if not frames:
return None
if isinstance(frames[0], list):
video_embeddings = []
flat_frames = []
video_lengths = []
for video in frames:
video_lengths.append(len(video))
flat_frames.extend(video)
all_input = preprocessor(images=flat_frames, return_tensors="pt").to(device)
with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
embeddings = model(**all_input).pooler_output
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
embeddings = embeddings.cpu()
# Group the embeddings back by video
index = 0
for length in video_lengths:
video_emb = embeddings[index : index + length].mean(dim=0)
video_emb = video_emb / video_emb.norm()
video_embeddings.append(video_emb.numpy())
index += length
return video_embeddings
else:
all_input = preprocessor(images=frames, return_tensors="pt").to(device)
with torch.no_grad(), torch.autocast(torch.device(device).type, dtype=dtype):
embeddings = model(**all_input).pooler_output
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
embeddings = embeddings.cpu()
video_embedding = embeddings.mean(dim=0)
video_embedding = video_embedding / video_embedding.norm()
return video_embedding.numpy()