tools/smolvlm_local_inference/SmolVLM_video_inference.py (109 lines of code) (raw):
import torch
from transformers import AutoProcessor, Idefics3ForConditionalGeneration
from PIL import Image
import cv2
import numpy as np
from typing import List
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VideoFrameExtractor:
def __init__(self, max_frames: int = 50):
self.max_frames = max_frames
def resize_and_center_crop(self, image: Image.Image, target_size: int) -> Image.Image:
# Get current dimensions
width, height = image.size
# Calculate new dimensions keeping aspect ratio
if width < height:
new_width = target_size
new_height = int(height * (target_size / width))
else:
new_height = target_size
new_width = int(width * (target_size / height))
# Resize
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Center crop
left = (new_width - target_size) // 2
top = (new_height - target_size) // 2
right = left + target_size
bottom = top + target_size
return image.crop((left, top, right, bottom))
def extract_frames(self, video_path: str) -> List[Image.Image]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Could not open video: {video_path}")
# Get video properties
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Calculate frame indices to extract (1fps)
frame_indices = list(range(0, total_frames, fps))
# If we have more frames than max_frames, sample evenly
if len(frame_indices) > self.max_frames:
indices = np.linspace(0, len(frame_indices) - 1, self.max_frames, dtype=int)
frame_indices = [frame_indices[i] for i in indices]
frames = []
for frame_idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame)
pil_image = self.resize_and_center_crop(pil_image, 384)
frames.append(pil_image)
cap.release()
return frames
def load_model(checkpoint_path: str, base_model_id: str = "HuggingFaceTB/SmolVLM-Instruct", device: str = "cuda"):
# Load processor from original model
processor = AutoProcessor.from_pretrained(base_model_id)
if checkpoint_path:
# Load fine-tuned model from checkpoint
model = Idefics3ForConditionalGeneration.from_pretrained(
checkpoint_path,
torch_dtype=torch.bfloat16,
device_map=device
)
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
device_map=device
)
# Configure processor for video frames
processor.image_processor.size = (384, 384)
processor.image_processor.do_resize = False
processor.image_processor.do_image_splitting = False
return model, processor
def generate_response(model, processor, video_path: str, question: str, max_frames: int = 50):
# Extract frames
frame_extractor = VideoFrameExtractor(max_frames)
frames = frame_extractor.extract_frames(video_path)
logger.info(f"Extracted {len(frames)} frames from video")
# Create prompt with frames
image_tokens = [{"type": "image"} for _ in range(len(frames))]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
*image_tokens,
{"type": "text", "text": question}
]
}
]
# Process inputs
inputs = processor(
text=processor.apply_chat_template(messages, add_generation_prompt=True),
images=[img for img in frames],
return_tensors="pt"
).to(model.device)
# Generate response
outputs = model.generate(
**inputs,
max_new_tokens=100,
num_beams=5,
temperature=0.7,
do_sample=True,
use_cache=True
)
# Decode response
response = processor.decode(outputs[0], skip_special_tokens=True)
return response
def main():
# Configuration
#checkpoint_path = "/path/to/your/checkpoint"
checkpoint_path = None
base_model_id = "HuggingFaceTB/SmolVLM-Instruct"
video_path = "/path/to/video.mp4"
question = "Describe the video"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
logger.info("Loading model...")
model, processor = load_model(checkpoint_path, base_model_id, device)
# Generate response
logger.info("Generating response...")
response = generate_response(model, processor, video_path, question)
# Print results
print("Question:", question)
print("Response:", response)
if __name__ == "__main__":
main()